Source code for fruit_project.models.transforms_factory

# SPDX-FileCopyrightText: 2025 Mohamed Khayat
# SPDX-License-Identifier: AGPL-3.0-or-later

from typing import Dict
import albumentations as A
import os
from omegaconf import DictConfig

os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"


[docs] def get_transforms(cfg: DictConfig, id2label: Dict[int, str]) -> Dict[str, A.Compose]: """ Generates a dictionary of Albumentations transformations for training and testing. Args: cfg (DictConfig): Configuration object containing the following attributes: Returns: dict: A dictionary with keys "train" and "test" """ bbox_params = get_bbox_params(cfg) box_labels = [k for k in id2label.keys()] hard_train_transforms = A.Compose( [ A.RandomSizedBBoxSafeCrop( height=cfg.model.input_size, width=cfg.model.input_size, erosion_rate=0.0, p=0.4, ), A.HorizontalFlip(p=0.5), A.OneOf( [ A.Affine( scale=(0.8, 1.2), translate_percent={"x": (-0.02, 0.02), "y": (-0.02, 0.02)}, rotate=(-5, 5), fill=(114, 114, 114), p=0.5, ), A.Perspective( scale=(0.02, 0.05), fit_output=True, fill=(114, 114, 114), p=0.5, ), ], p=0.4, ), A.ConstrainedCoarseDropout( num_holes_range=(1, 2), hole_height_range=(0.05, 0.15), hole_width_range=(0.05, 0.15), fill=(114, 114, 114), bbox_labels=box_labels, p=0.2, ), A.OneOf( [ A.RandomBrightnessContrast( brightness_limit=0.2, contrast_limit=0.2, ensure_safe_range=True, p=0.5, ), A.RandomGamma(gamma_limit=(80, 120), p=0.5), A.RandomToneCurve(p=0.5), ], p=0.4, ), A.OneOf( [ A.HueSaturationValue( hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5, ), A.RGBShift( r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.5 ), ], p=0.4, ), A.OneOf( [ A.Blur(blur_limit=3, p=0.5), A.MotionBlur(blur_limit=3, p=0.5), A.Defocus(radius=(1, 3), alias_blur=(0.1, 0.25), p=0.1), A.MedianBlur(blur_limit=3, p=0.2), ], p=0.1, ), A.CLAHE(clip_limit=1.5, p=0.3), ], bbox_params=bbox_params, ) safe_train_transforms = A.Compose( [ A.HorizontalFlip(p=0.5), A.Perspective( scale=(0.02, 0.05), fit_output=True, fill=(114, 114, 114), p=0.1 ), A.OneOf( [ A.Blur(blur_limit=3, p=0.5), A.MotionBlur(blur_limit=3, p=0.5), A.Defocus(radius=(1, 5), alias_blur=(0.1, 0.25), p=0.1), ], p=0.1, ), A.RandomBrightnessContrast(p=0.5), A.HueSaturationValue(p=0.1), ], bbox_params=bbox_params, ) transforms = { "train": hard_train_transforms if cfg.aug == "hard" else safe_train_transforms, "train_easy": safe_train_transforms, "test": A.Compose([A.NoOp()], bbox_params=bbox_params), } return transforms
[docs] def get_bbox_params(cfg): return A.BboxParams( format="coco", label_fields=["labels"], clip=True, filter_invalid_bboxes=True, min_visibility=cfg.min_viz, min_area=cfg.min_area, min_width=cfg.min_width, min_height=cfg.min_height, )