Module datatap.metrics.precision_recall_curve
Expand source code
from __future__ import annotations
from typing import Iterable, Sequence, TYPE_CHECKING, List, NamedTuple, Optional, cast
import numpy as np
from scipy.optimize import linear_sum_assignment
from sortedcontainers import SortedDict
from datatap.droplet import ImageAnnotation
from ._types import GroundTruthBox, PredictionBox
if TYPE_CHECKING:
import matplotlib.pyplot as plt
class MaximizeF1Result(NamedTuple):
"""
Represents the precision, recall, and f1 for a given `PrecisionRecallCurve`
at the threshold that maximizes f1.
"""
threshold: float
precision: float
recall: float
f1: float
class _PrecisionRecallPoint(NamedTuple):
threshold: float
precision: float
recall: float
class _DetectionEvent(NamedTuple):
true_positive_delta: int
false_positive_delta: int
def __add__(self, other: _DetectionEvent) -> _DetectionEvent:
if isinstance(other, _DetectionEvent): # type: ignore - pyright complains about the isinstance check being redundant
return _DetectionEvent(self.true_positive_delta + other.true_positive_delta, self.false_positive_delta + other.false_positive_delta)
return NotImplemented
class PrecisionRecallCurve:
"""
Represents a curve relating a chosen detection threshold to precision and recall. Internally, this is actually
stored as a sorted list of detection events, which are used to compute metrics on the fly when needed.
"""
# TODO(mdsavage): make this accept matching strategies other than bounding box IOU
events: SortedDict[float, _DetectionEvent]
ground_truth_positives: int
def __init__(self, events: Optional[SortedDict[float, _DetectionEvent]] = None, ground_truth_positives: int = 0):
self.events = SortedDict() if events is None else events
self.ground_truth_positives = ground_truth_positives
def clone(self) -> PrecisionRecallCurve:
return PrecisionRecallCurve(self.events.copy(), self.ground_truth_positives)
def maximize_f1(self) -> MaximizeF1Result:
maximum = MaximizeF1Result(threshold = 1, precision = 0, recall = 0, f1 = 0)
for threshold, precision, recall in self._compute_curve():
f1 = 2 / ((1 / precision) + (1 / recall)) if precision > 0 and recall > 0 else 0
if f1 >= maximum.f1:
maximum = MaximizeF1Result(threshold = threshold, precision = precision, recall = recall, f1 = f1)
return maximum
def plot(self) -> plt.Figure:
import matplotlib.pyplot as plt
fig = plt.figure()
curve = self._compute_curve()
plt.plot([pt.recall for pt in curve], [pt.precision for pt in curve], "o-")
plt.xlabel("Recall")
plt.ylabel("Precision")
return fig
def add_annotation(
self: PrecisionRecallCurve,
ground_truth: ImageAnnotation,
prediction: ImageAnnotation,
iou_threshold: float
) -> None:
"""
Returns a precision-recall curve for the given ground truth and prediction annotations evaluated with the given
IOU threshold.
Note: this handles instances only; multi-instances are ignored.
"""
ground_truth_boxes = [
GroundTruthBox(class_name, instance.bounding_box.rectangle)
for class_name in ground_truth.classes.keys()
for instance in ground_truth.classes[class_name].instances
if instance.bounding_box is not None
]
prediction_boxes = sorted([
PredictionBox(instance.bounding_box.confidence or 1, class_name, instance.bounding_box.rectangle)
for class_name in prediction.classes.keys()
for instance in prediction.classes[class_name].instances
if instance.bounding_box is not None
], reverse = True, key = lambda p: p.confidence)
iou_matrix = np.array([
[ground_truth_box.box.iou(prediction_box.box) for ground_truth_box in ground_truth_boxes]
for prediction_box in prediction_boxes
])
self._add_ground_truth_positives(len(ground_truth_boxes))
previous_true_positives = 0
previous_false_positives = 0
for i in range(len(prediction_boxes)):
confidence_threshold = prediction_boxes[i].confidence
if i < len(prediction_boxes) - 1 and prediction_boxes[i+1].confidence == confidence_threshold:
continue
prediction_indices, ground_truth_indices = linear_sum_assignment(iou_matrix[:i+1,], maximize = True)
true_positives = 0
false_positives = max(0, i + 1 - len(ground_truth_boxes))
for prediction_index, ground_truth_index in zip(cast(Iterable[int], prediction_indices), cast(Iterable[int], ground_truth_indices)):
if (
iou_matrix[prediction_index, ground_truth_index] >= iou_threshold
and prediction_boxes[prediction_index].class_name == ground_truth_boxes[ground_truth_index].class_name
):
true_positives += 1
else:
false_positives += 1
self._add_event(confidence_threshold, _DetectionEvent(
true_positive_delta = true_positives - previous_true_positives,
false_positive_delta = false_positives - previous_false_positives
))
previous_true_positives = true_positives
previous_false_positives = false_positives
def batch_add_annotation(
self: PrecisionRecallCurve,
ground_truths: Sequence[ImageAnnotation],
predictions: Sequence[ImageAnnotation],
iou_threshold: float
) -> None:
"""
Updates this precision-recall curve with the values from several annotations simultaneously.
"""
for ground_truth, prediction in zip(ground_truths, predictions):
self.add_annotation(ground_truth, prediction, iou_threshold)
def _compute_curve(self) -> List[_PrecisionRecallPoint]:
assert self.ground_truth_positives > 0
precision_recall_points: List[_PrecisionRecallPoint] = []
true_positives = 0
detections = 0
for threshold in reversed(self.events):
true_positive_delta, false_positive_delta = self.events[threshold]
true_positives += true_positive_delta
detections += true_positive_delta + false_positive_delta
assert detections > 0
precision_recall_points.append(_PrecisionRecallPoint(
threshold = threshold,
precision = true_positives / detections,
recall = true_positives / self.ground_truth_positives
))
return precision_recall_points
def _add_event(self, threshold: float, event: _DetectionEvent) -> None:
if threshold not in self.events:
self.events[threshold] = _DetectionEvent(0, 0)
self.events[threshold] += event
def _add_ground_truth_positives(self, count: int) -> None:
self.ground_truth_positives += count
def __add__(self, other: PrecisionRecallCurve) -> PrecisionRecallCurve:
if isinstance(other, PrecisionRecallCurve): # type: ignore - pyright complains about the isinstance check being redundant
ret = self.clone()
ret._add_ground_truth_positives(other.ground_truth_positives)
for threshold, event in other.events.items():
ret._add_event(threshold, event)
return ret
return NotImplemented
Classes
class MaximizeF1Result (threshold: float, precision: float, recall: float, f1: float)
-
Represents the precision, recall, and f1 for a given
PrecisionRecallCurve
at the threshold that maximizes f1.Expand source code
class MaximizeF1Result(NamedTuple): """ Represents the precision, recall, and f1 for a given `PrecisionRecallCurve` at the threshold that maximizes f1. """ threshold: float precision: float recall: float f1: float
Ancestors
- builtins.tuple
Instance variables
var f1 : float
-
Alias for field number 3
var precision : float
-
Alias for field number 1
var recall : float
-
Alias for field number 2
var threshold : float
-
Alias for field number 0
class PrecisionRecallCurve (events: Optional[SortedDict[float, _DetectionEvent]] = None, ground_truth_positives: int = 0)
-
Represents a curve relating a chosen detection threshold to precision and recall. Internally, this is actually stored as a sorted list of detection events, which are used to compute metrics on the fly when needed.
Expand source code
class PrecisionRecallCurve: """ Represents a curve relating a chosen detection threshold to precision and recall. Internally, this is actually stored as a sorted list of detection events, which are used to compute metrics on the fly when needed. """ # TODO(mdsavage): make this accept matching strategies other than bounding box IOU events: SortedDict[float, _DetectionEvent] ground_truth_positives: int def __init__(self, events: Optional[SortedDict[float, _DetectionEvent]] = None, ground_truth_positives: int = 0): self.events = SortedDict() if events is None else events self.ground_truth_positives = ground_truth_positives def clone(self) -> PrecisionRecallCurve: return PrecisionRecallCurve(self.events.copy(), self.ground_truth_positives) def maximize_f1(self) -> MaximizeF1Result: maximum = MaximizeF1Result(threshold = 1, precision = 0, recall = 0, f1 = 0) for threshold, precision, recall in self._compute_curve(): f1 = 2 / ((1 / precision) + (1 / recall)) if precision > 0 and recall > 0 else 0 if f1 >= maximum.f1: maximum = MaximizeF1Result(threshold = threshold, precision = precision, recall = recall, f1 = f1) return maximum def plot(self) -> plt.Figure: import matplotlib.pyplot as plt fig = plt.figure() curve = self._compute_curve() plt.plot([pt.recall for pt in curve], [pt.precision for pt in curve], "o-") plt.xlabel("Recall") plt.ylabel("Precision") return fig def add_annotation( self: PrecisionRecallCurve, ground_truth: ImageAnnotation, prediction: ImageAnnotation, iou_threshold: float ) -> None: """ Returns a precision-recall curve for the given ground truth and prediction annotations evaluated with the given IOU threshold. Note: this handles instances only; multi-instances are ignored. """ ground_truth_boxes = [ GroundTruthBox(class_name, instance.bounding_box.rectangle) for class_name in ground_truth.classes.keys() for instance in ground_truth.classes[class_name].instances if instance.bounding_box is not None ] prediction_boxes = sorted([ PredictionBox(instance.bounding_box.confidence or 1, class_name, instance.bounding_box.rectangle) for class_name in prediction.classes.keys() for instance in prediction.classes[class_name].instances if instance.bounding_box is not None ], reverse = True, key = lambda p: p.confidence) iou_matrix = np.array([ [ground_truth_box.box.iou(prediction_box.box) for ground_truth_box in ground_truth_boxes] for prediction_box in prediction_boxes ]) self._add_ground_truth_positives(len(ground_truth_boxes)) previous_true_positives = 0 previous_false_positives = 0 for i in range(len(prediction_boxes)): confidence_threshold = prediction_boxes[i].confidence if i < len(prediction_boxes) - 1 and prediction_boxes[i+1].confidence == confidence_threshold: continue prediction_indices, ground_truth_indices = linear_sum_assignment(iou_matrix[:i+1,], maximize = True) true_positives = 0 false_positives = max(0, i + 1 - len(ground_truth_boxes)) for prediction_index, ground_truth_index in zip(cast(Iterable[int], prediction_indices), cast(Iterable[int], ground_truth_indices)): if ( iou_matrix[prediction_index, ground_truth_index] >= iou_threshold and prediction_boxes[prediction_index].class_name == ground_truth_boxes[ground_truth_index].class_name ): true_positives += 1 else: false_positives += 1 self._add_event(confidence_threshold, _DetectionEvent( true_positive_delta = true_positives - previous_true_positives, false_positive_delta = false_positives - previous_false_positives )) previous_true_positives = true_positives previous_false_positives = false_positives def batch_add_annotation( self: PrecisionRecallCurve, ground_truths: Sequence[ImageAnnotation], predictions: Sequence[ImageAnnotation], iou_threshold: float ) -> None: """ Updates this precision-recall curve with the values from several annotations simultaneously. """ for ground_truth, prediction in zip(ground_truths, predictions): self.add_annotation(ground_truth, prediction, iou_threshold) def _compute_curve(self) -> List[_PrecisionRecallPoint]: assert self.ground_truth_positives > 0 precision_recall_points: List[_PrecisionRecallPoint] = [] true_positives = 0 detections = 0 for threshold in reversed(self.events): true_positive_delta, false_positive_delta = self.events[threshold] true_positives += true_positive_delta detections += true_positive_delta + false_positive_delta assert detections > 0 precision_recall_points.append(_PrecisionRecallPoint( threshold = threshold, precision = true_positives / detections, recall = true_positives / self.ground_truth_positives )) return precision_recall_points def _add_event(self, threshold: float, event: _DetectionEvent) -> None: if threshold not in self.events: self.events[threshold] = _DetectionEvent(0, 0) self.events[threshold] += event def _add_ground_truth_positives(self, count: int) -> None: self.ground_truth_positives += count def __add__(self, other: PrecisionRecallCurve) -> PrecisionRecallCurve: if isinstance(other, PrecisionRecallCurve): # type: ignore - pyright complains about the isinstance check being redundant ret = self.clone() ret._add_ground_truth_positives(other.ground_truth_positives) for threshold, event in other.events.items(): ret._add_event(threshold, event) return ret return NotImplemented
Class variables
var events : SortedDict[float, _DetectionEvent]
var ground_truth_positives : int
Methods
def add_annotation(self: PrecisionRecallCurve, ground_truth: ImageAnnotation, prediction: ImageAnnotation, iou_threshold: float) ‑> None
-
Returns a precision-recall curve for the given ground truth and prediction annotations evaluated with the given IOU threshold.
Note: this handles instances only; multi-instances are ignored.
Expand source code
def add_annotation( self: PrecisionRecallCurve, ground_truth: ImageAnnotation, prediction: ImageAnnotation, iou_threshold: float ) -> None: """ Returns a precision-recall curve for the given ground truth and prediction annotations evaluated with the given IOU threshold. Note: this handles instances only; multi-instances are ignored. """ ground_truth_boxes = [ GroundTruthBox(class_name, instance.bounding_box.rectangle) for class_name in ground_truth.classes.keys() for instance in ground_truth.classes[class_name].instances if instance.bounding_box is not None ] prediction_boxes = sorted([ PredictionBox(instance.bounding_box.confidence or 1, class_name, instance.bounding_box.rectangle) for class_name in prediction.classes.keys() for instance in prediction.classes[class_name].instances if instance.bounding_box is not None ], reverse = True, key = lambda p: p.confidence) iou_matrix = np.array([ [ground_truth_box.box.iou(prediction_box.box) for ground_truth_box in ground_truth_boxes] for prediction_box in prediction_boxes ]) self._add_ground_truth_positives(len(ground_truth_boxes)) previous_true_positives = 0 previous_false_positives = 0 for i in range(len(prediction_boxes)): confidence_threshold = prediction_boxes[i].confidence if i < len(prediction_boxes) - 1 and prediction_boxes[i+1].confidence == confidence_threshold: continue prediction_indices, ground_truth_indices = linear_sum_assignment(iou_matrix[:i+1,], maximize = True) true_positives = 0 false_positives = max(0, i + 1 - len(ground_truth_boxes)) for prediction_index, ground_truth_index in zip(cast(Iterable[int], prediction_indices), cast(Iterable[int], ground_truth_indices)): if ( iou_matrix[prediction_index, ground_truth_index] >= iou_threshold and prediction_boxes[prediction_index].class_name == ground_truth_boxes[ground_truth_index].class_name ): true_positives += 1 else: false_positives += 1 self._add_event(confidence_threshold, _DetectionEvent( true_positive_delta = true_positives - previous_true_positives, false_positive_delta = false_positives - previous_false_positives )) previous_true_positives = true_positives previous_false_positives = false_positives
def batch_add_annotation(self: PrecisionRecallCurve, ground_truths: Sequence[ImageAnnotation], predictions: Sequence[ImageAnnotation], iou_threshold: float) ‑> None
-
Updates this precision-recall curve with the values from several annotations simultaneously.
Expand source code
def batch_add_annotation( self: PrecisionRecallCurve, ground_truths: Sequence[ImageAnnotation], predictions: Sequence[ImageAnnotation], iou_threshold: float ) -> None: """ Updates this precision-recall curve with the values from several annotations simultaneously. """ for ground_truth, prediction in zip(ground_truths, predictions): self.add_annotation(ground_truth, prediction, iou_threshold)
def clone(self) ‑> PrecisionRecallCurve
-
Expand source code
def clone(self) -> PrecisionRecallCurve: return PrecisionRecallCurve(self.events.copy(), self.ground_truth_positives)
def maximize_f1(self) ‑> MaximizeF1Result
-
Expand source code
def maximize_f1(self) -> MaximizeF1Result: maximum = MaximizeF1Result(threshold = 1, precision = 0, recall = 0, f1 = 0) for threshold, precision, recall in self._compute_curve(): f1 = 2 / ((1 / precision) + (1 / recall)) if precision > 0 and recall > 0 else 0 if f1 >= maximum.f1: maximum = MaximizeF1Result(threshold = threshold, precision = precision, recall = recall, f1 = f1) return maximum
def plot(self) ‑> plt.Figure
-
Expand source code
def plot(self) -> plt.Figure: import matplotlib.pyplot as plt fig = plt.figure() curve = self._compute_curve() plt.plot([pt.recall for pt in curve], [pt.precision for pt in curve], "o-") plt.xlabel("Recall") plt.ylabel("Precision") return fig