Source code for fruit_project.models.transforms_factory
import albumentations as A
import os
import cv2
from omegaconf import DictConfig
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
[docs]
def get_transforms(cfg: DictConfig):
"""
Generates a dictionary of Albumentations transformations for training and testing.
Args:
cfg (DictConfig): Configuration object containing the following attributes:
Returns:
dict: A dictionary with keys "train" and "test"
"""
hard_train_transforms = A.Compose(
[
A.RandomSizedBBoxSafeCrop(
height=cfg.model.input_size,
width=cfg.model.input_size,
erosion_rate=0.2,
p=1.0,
),
A.HorizontalFlip(p=0.5),
A.OneOf(
[
A.RGBShift(15, 15, 15, p=1.0),
A.HueSaturationValue(
hue_shift_limit=10,
sat_shift_limit=20,
val_shift_limit=10,
p=1.0,
),
# A.ToGray(p=1.0),
],
p=0.5,
),
A.OneOf(
[
A.RandomBrightnessContrast(p=1.0),
A.RandomToneCurve(p=1.0),
],
p=0.5,
),
A.CLAHE(clip_limit=2.0, p=0.3),
],
bbox_params=A.BboxParams(
format="coco",
label_fields=["labels"],
clip=True,
min_visibility=cfg.min_viz,
min_area=cfg.min_area,
min_width=cfg.min_width,
min_height=cfg.min_height,
),
)
safe_train_transforms = A.Compose(
[
A.RandomSizedBBoxSafeCrop(
height=cfg.model.input_size,
width=cfg.model.input_size,
p=0.8,
erosion_rate=0.2,
),
A.HorizontalFlip(p=0.5),
A.RGBShift(
p=0.5,
b_shift_limit=(-15.0, 15.0),
g_shift_limit=(-15.0, 15.0),
r_shift_limit=(-15.0, 15.0),
),
A.RandomBrightnessContrast(
p=0.5,
brightness_limit=(-0.2, 0.2),
contrast_limit=(-0.2, 0.2),
),
A.HueSaturationValue(
p=0.3,
hue_shift_limit=(-10.0, 10.0),
sat_shift_limit=(-25.0, 25.0),
val_shift_limit=(-10.0, 10.0),
),
A.CLAHE(p=0.2, clip_limit=(1.0, 2.0)),
],
bbox_params=A.BboxParams(
format="coco",
label_fields=["labels"],
clip=True,
min_visibility=cfg.min_viz,
min_area=cfg.min_area,
min_width=cfg.min_width,
min_height=cfg.min_height,
),
)
transforms = {
"train": hard_train_transforms if cfg.aug == "hard" else safe_train_transforms,
"test": A.Compose(
[
A.LongestMaxSize(max_size=cfg.model.input_size, p=1.0),
A.PadIfNeeded(
min_height=cfg.model.input_size,
min_width=cfg.model.input_size,
border_mode=cv2.BORDER_CONSTANT,
fill=0,
p=1.0,
),
],
bbox_params=A.BboxParams(
format="coco",
label_fields=["labels"],
clip=True,
),
),
}
return transforms