Source code for fruit_project.models.transforms_factory
# SPDX-FileCopyrightText: 2025 Mohamed Khayat
# SPDX-License-Identifier: AGPL-3.0-or-later
from typing import Dict
import albumentations as A
import os
from omegaconf import DictConfig
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
[docs]
def get_transforms(cfg: DictConfig, id2label: Dict[int, str]) -> Dict[str, A.Compose]:
"""
Generates a dictionary of Albumentations transformations for training and testing.
Args:
cfg (DictConfig): Configuration object containing the following attributes:
Returns:
dict: A dictionary with keys "train" and "test"
"""
train_bbox_params, test_bbox_params = get_bbox_params(cfg)
box_labels = [k for k in id2label.keys()]
multi_view_crop = A.OneOf(
[
A.Compose(
[
A.SmallestMaxSize(
max_size_hw=(cfg.model.input_height, cfg.model.input_width),
p=1.0,
),
A.RandomSizedBBoxSafeCrop(
height=cfg.model.input_height,
width=cfg.model.input_width,
erosion_rate=0.0,
p=1.0,
),
]
),
A.Compose(
[
A.SmallestMaxSize(
max_size_hw=(
int(cfg.model.input_height * 1.5),
int(cfg.model.input_width * 1.5),
),
p=1.0,
),
A.RandomSizedBBoxSafeCrop(
height=cfg.model.input_height,
width=cfg.model.input_width,
erosion_rate=0.1,
p=1.0,
),
]
),
A.Compose(
[
A.SmallestMaxSize(
max_size_hw=(
cfg.model.input_height * 3,
cfg.model.input_width * 3,
),
p=1.0,
),
A.RandomSizedBBoxSafeCrop(
height=cfg.model.input_height,
width=cfg.model.input_width,
erosion_rate=0.2,
p=1.0,
),
]
),
],
p=1.0,
)
hard_train_transforms = A.Compose(
[
# A.Compose(
# [
# A.SmallestMaxSize(
# max_size_hw=(cfg.model.input_height, cfg.model.input_width),
# p=1.0,
# ),
# A.RandomSizedBBoxSafeCrop(
# height=cfg.model.input_height,
# width=cfg.model.input_width,
# erosion_rate=0.1,
# p=1.0,
# ),
# ],
# p=0.2,
# ),
multi_view_crop,
A.HorizontalFlip(p=0.5),
A.Perspective(
scale=(0.02, 0.05),
fit_output=True,
fill=(114, 114, 114),
p=0.15,
),
A.OneOf(
[
A.Sharpen(p=0.5),
A.Emboss(p=0.5),
A.RandomToneCurve(p=0.5),
],
p=0.2,
),
A.ConstrainedCoarseDropout(
num_holes_range=(1, 2),
hole_height_range=(0.02, 0.08),
hole_width_range=(0.02, 0.08),
fill=(114, 114, 114),
bbox_labels=box_labels,
p=0.01,
),
A.OneOf(
[
A.RandomBrightnessContrast(
brightness_limit=0.1,
contrast_limit=0.1,
ensure_safe_range=True,
p=0.5,
),
A.RandomGamma(gamma_limit=(60, 100), p=0.5),
A.RandomToneCurve(p=0.5),
],
p=0.2,
),
A.OneOf(
[
A.HueSaturationValue(
hue_shift_limit=15,
sat_shift_limit=25,
val_shift_limit=15,
p=0.5,
),
A.RGBShift(
r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.5
),
],
p=0.2,
),
A.OneOf(
[
A.Blur(blur_limit=3, p=0.5),
A.MotionBlur(blur_limit=3, p=0.5),
A.Defocus(radius=(1, 3), alias_blur=(0.1, 0.25), p=0.1),
A.MedianBlur(blur_limit=3, p=0.2),
],
p=0.1,
),
A.CLAHE(clip_limit=1.5, p=0.1),
],
bbox_params=train_bbox_params,
)
safe_train_transforms = A.Compose(
[
A.Compose(
[
A.SmallestMaxSize(
max_size_hw=(cfg.model.input_height, cfg.model.input_width),
p=1.0,
),
A.RandomSizedBBoxSafeCrop(
height=cfg.model.input_height,
width=cfg.model.input_width,
erosion_rate=0.1,
p=1.0,
),
],
p=0.9,
),
A.HorizontalFlip(p=0.5),
A.Perspective(
scale=(0.02, 0.05), fit_output=True, fill=(114, 114, 114), p=0.1
),
A.OneOf(
[
A.Blur(blur_limit=3, p=0.5),
A.MotionBlur(blur_limit=3, p=0.5),
A.Defocus(radius=(1, 5), alias_blur=(0.1, 0.25), p=0.1),
],
p=0.1,
),
A.RandomBrightnessContrast(p=0.2),
A.HueSaturationValue(p=0.1),
],
bbox_params=train_bbox_params,
)
transforms = {
"train": hard_train_transforms if cfg.aug == "hard" else safe_train_transforms,
"train_easy": safe_train_transforms,
"test": A.Compose(
[
A.SmallestMaxSize(
max_size_hw=(cfg.model.input_height, cfg.model.input_width),
p=1.0,
),
A.RandomSizedBBoxSafeCrop(
height=cfg.model.input_height,
width=cfg.model.input_width,
erosion_rate=0.1,
p=1.0,
),
],
p=1.0,
bbox_params=test_bbox_params,
),
}
return transforms
[docs]
def get_bbox_params(cfg):
params = {
"format": "coco",
"label_fields": ["labels"],
"clip": True,
}
return (
A.BboxParams(**{**params, **{"min_area": cfg.min_area}}),
A.BboxParams(**params),
)