fruit_project.utils.metrics

Classes

ConfusionMatrix

Object Detection Confusion Matrix inspired by Ultralytics.

MAPEvaluator

Mean Average Precision evaluator for RT-DETRv2 - adapted for fruit_project.

Module Contents

class fruit_project.utils.metrics.ConfusionMatrix(nc: int, conf: float = 0.25, iou_thres: float = 0.45)[source]

Object Detection Confusion Matrix inspired by Ultralytics.

Parameters:
  • nc (int) – Number of classes.

  • conf (float) – Confidence threshold for detections.

  • iou_thres (float) – IoU threshold for matching.

nc[source]
conf = 0.25[source]
iou_thres = 0.45[source]
matrix[source]
eps = 1e-06[source]
process_batch(detections: torch.Tensor, labels: torch.Tensor) None[source]

Update the confusion matrix with a batch of detections and ground truths.

Parameters:
  • detections (torch.Tensor) – Tensor of detections, shape [N, 6] (x1, y1, x2, y2, conf, class).

  • labels (torch.Tensor) – Tensor of ground truths, shape [M, 5] (class, x1, y1, x2, y2).

update(preds, targets_for_cm)[source]
plot(class_names: List[str], normalize: bool = True) matplotlib.pyplot.Figure[source]

Generates and returns a matplotlib figure of the confusion matrix.

get_matrix() torch.Tensor[source]

Returns the raw confusion matrix tensor.

Returns:

The (nc + 1) x (nc + 1) confusion matrix.

Return type:

torch.Tensor

class fruit_project.utils.metrics.MAPEvaluator(image_processor, device, threshold: float = 0.0, id2label: Dict[int, str] | None = None)[source]

Mean Average Precision evaluator for RT-DETRv2 - adapted for fruit_project.

image_processor[source]
threshold = 0.0[source]
id2label = None[source]
map_metric[source]
map_50_metric[source]
device[source]
collect_image_sizes(targets)[source]

Collect image sizes from targets.

collect_targets(targets, image_sizes)[source]

Process ground truth targets - now handles HF-processed format.

collect_predictions(predictions, image_sizes)[source]

Process model predictions using HuggingFace post-processing.

get_per_class(map_50_metrics, metric)[source]
get_optimal_f1_ultralytics_style(metrics_dict)[source]
get_averaged_precision_recall_ultralytics_style(optimal_precisions: torch.Tensor, optimal_recalls: torch.Tensor, present_classes: torch.Tensor)[source]

Calculate overall precision and recall…