# SPDX-FileCopyrightText: 2025 Mohamed Khayat
# SPDX-License-Identifier: AGPL-3.0-or-later
from typing import Optional, Tuple, Dict, Any, List
import numpy as np
from torch.utils.data import Dataset
import albumentations as A
from albumentations import Compose
from tqdm import tqdm
from fruit_project.models.transforms_factory import get_bbox_params
from .det_dataset import DET_DS, format_for_hf_processor
[docs]
class AlbumentationsMosaicDataset(Dataset):
"""
Dataset wrapper that applies Albumentations' native Mosaic augmentation,
following the correct API based on official documentation.
"""
def __init__(
self,
dataset: DET_DS,
current_epoch: int = 0,
hard_transforms: Compose = None,
easy_transforms: Compose = None,
cfg=None,
):
[docs]
self.target_size = cfg.model.input_size
[docs]
self.mosaic_prob = cfg.mosaic.prob
[docs]
self.disable_mosaic_epochs = cfg.mosaic.disable_epoch
[docs]
self.current_epoch = current_epoch
[docs]
self.total_epochs = cfg.epochs
# Copy dataset attributes
[docs]
self.processor = dataset.processor
[docs]
self.id2lbl = dataset.id2lbl
[docs]
self.lbl2id = dataset.lbl2id
[docs]
self.labels = dataset.labels
[docs]
self.image_paths = dataset.image_paths
[docs]
self.label_paths = dataset.label_paths
[docs]
self.config_dir = dataset.config_dir
[docs]
self.normalize = cfg.model.do_normalize
[docs]
self.bbox_params = get_bbox_params(cfg)
mosaic_pipeline = [
self.mosaic_transform,
A.Resize(self.target_size, self.target_size),
]
if self.hard_transforms:
mosaic_pipeline.extend(self.hard_transforms.transforms)
[docs]
self.mosaic_compose = A.Compose(
mosaic_pipeline,
bbox_params=self.bbox_params,
)
easy_pipeline = [A.Resize(self.target_size, self.target_size)]
if self.easy_transforms:
easy_pipeline.extend(self.easy_transforms.transforms)
[docs]
self.easy_compose = A.Compose(
easy_pipeline,
bbox_params=self.bbox_params,
)
[docs]
def update_epoch(self, epoch: int):
"""Update current epoch for mosaic scheduling."""
self.current_epoch = epoch
[docs]
def should_apply_mosaic(self) -> bool:
"""Determine if mosaic should be applied based on epoch and probability."""
if self.current_epoch >= (self.total_epochs - self.disable_mosaic_epochs):
return False
return np.random.rand() < self.mosaic_prob
[docs]
def _validate_and_clip_bbox(
self, bbox: List[float], img_width: int, img_height: int
) -> Optional[List[float]]:
"""Validate and clip bounding box coordinates."""
x, y, w, h = bbox
x = max(0, min(x, img_width - 1))
y = max(0, min(y, img_height - 1))
w = max(1, min(w, img_width - x))
h = max(1, min(h, img_height - y))
return [float(x), float(y), float(w), float(h)]
[docs]
def _apply_mosaic_augmentation(self, idx: int) -> Tuple[np.ndarray, List, List]:
"""Apply Albumentations Mosaic transform."""
primary_img, primary_boxes, primary_labels = self.dataset.get_raw_item(idx)
img_height, img_width = primary_img.shape[:2]
primary_coco_boxes = []
valid_primary_labels = []
if len(primary_boxes) > 0:
for box, label in zip(primary_boxes, primary_labels):
# clipped_box = self._validate_and_clip_bbox(box, img_width, img_height)
clipped_box = box
if clipped_box is not None:
primary_coco_boxes.append(clipped_box)
valid_primary_labels.append(int(label))
metadata_list = self._prepare_mosaic_metadata(idx)
try:
# Use the pre-composed mosaic+hard transform pipeline
augmented = self.mosaic_compose(
image=primary_img,
bboxes=primary_coco_boxes,
labels=valid_primary_labels,
mosaic_metadata=metadata_list,
)
return augmented["image"], augmented["bboxes"], augmented["labels"]
except Exception as e:
tqdm.write(f"Mosaic augmentation failed for idx {idx}: {e}. Falling back.")
# Fallback should now use the easy transform
return self._apply_fallback_transform(idx, use_easy_transforms=True)
[docs]
def __len__(self):
return len(self.dataset)
[docs]
def __getitem__(self, idx):
"""Get item with optional mosaic augmentation."""
if self.should_apply_mosaic():
img, boxes, labels = self._apply_mosaic_augmentation(idx)
else:
img, boxes, labels = self._apply_fallback_transform(idx)
target = format_for_hf_processor(boxes, labels, idx)
if hasattr(self, "processor") and self.processor:
try:
result = self.processor(
images=img,
annotations=target,
return_tensors="pt",
do_normalize=self.normalize,
size={"height": self.input_size, "width": self.input_size},
do_pad=False,
)
result = {k: v[0] for k, v in result.items()}
return result
except Exception as e:
tqdm.write(f"Processor failed for idx {idx}: {e}")
raise AttributeError("HuggingFace Processor failed")
else:
raise AttributeError("No processor found")
[docs]
def create_albumentations_mosaic_dataset(
dataset: DET_DS,
hard_transforms: Compose = None,
easy_transforms: Compose = None,
cfg=None,
) -> AlbumentationsMosaicDataset:
return AlbumentationsMosaicDataset(
dataset=dataset,
current_epoch=0,
hard_transforms=hard_transforms,
easy_transforms=easy_transforms,
cfg=cfg,
)