Source code for fruit_project.utils.data

# SPDX-FileCopyrightText: 2025 Mohamed Khayat
# SPDX-License-Identifier: AGPL-3.0-or-later

import functools
from tqdm import tqdm
from fruit_project.utils.datasets.alb_mosaic_dataset import (
    create_albumentations_mosaic_dataset,
)
from fruit_project.utils.datasets.det_dataset import DET_DS
from torch.utils.data import DataLoader, WeightedRandomSampler
from omegaconf import DictConfig
from fruit_project.utils.general import seed_worker
import os
from collections import Counter
import torch
from albumentations import Compose
from typing import Dict, List, Tuple
from transformers import AutoImageProcessor, BatchEncoding


[docs] def make_datasets(cfg: DictConfig) -> Tuple[DET_DS, DET_DS, DET_DS]: """ Creates training, testing, and validation datasets. Args: cfg (DictConfig): Configuration object containing dataset parameters. Returns: Tuple[DET_DS, DET_DS, DET_DS]: The training, testing, and validation datasets. """ print("making datasets") train_ds_base = DET_DS( root_dir=cfg.root_dir, split="train", config_file=cfg.data_conf_file, transforms=None, normalize=cfg.model.do_normalize, ) val_ds = DET_DS( root_dir=cfg.root_dir, split="val", config_file=cfg.data_conf_file, transforms=None, normalize=cfg.model.do_normalize, ) test_ds = DET_DS( root_dir=cfg.root_dir, split="test", config_file=cfg.data_conf_file, transforms=None, normalize=cfg.model.do_normalize, ) return train_ds_base, test_ds, val_ds
[docs] def get_sampler(train_ds: DET_DS, generator) -> WeightedRandomSampler: """ Creates a WeightedRandomSampler for the training dataset. Handles the new dataset format which returns a single dictionary. """ print("Creating weighted sampler...") class_counts: Counter = Counter() image_classes = [] for label_path in tqdm( train_ds.label_paths, desc="1/2: Counting classes for sampler" ): classes_in_image = set() if os.path.exists(label_path): with open(label_path, "r") as f: for line in f: if line.strip(): class_id = int(line.strip().split()[0]) classes_in_image.add(class_id) class_counts.update(classes_in_image) image_classes.append(classes_in_image) if not class_counts: print( "Warning: No classes found in dataset for sampler. Using uniform sampling." ) return None class_weights = {c: 1.0 / cnt for c, cnt in class_counts.items()} weights = [] for classes_in_image in tqdm(image_classes, desc="2/2: Assigning Class Weights"): if not classes_in_image: weights.append(min(class_weights.values()) if class_weights else 1.0) continue weights.append(max(class_weights[c] for c in classes_in_image)) weights = torch.tensor(weights, dtype=torch.double) sampler_generator = torch.Generator().manual_seed(generator.initial_seed() + 1) sampler = WeightedRandomSampler( weights, num_samples=len(weights), replacement=True, generator=sampler_generator ) print("Weighted sampler created.") return sampler
[docs] def make_dataloaders( cfg: DictConfig, train_ds_base: DET_DS, test_ds: DET_DS, val_ds: DET_DS, generator: torch.Generator, processor: AutoImageProcessor, transforms: Compose, ) -> Tuple[DataLoader, DataLoader, DataLoader, Dict]: """ Creates dataloaders for training, testing, and validation datasets. Args: cfg (DictConfig): Configuration object containing dataloader parameters. train_ds (DET_DS): The training dataset. test_ds (DET_DS): The testing dataset. val_ds (DET_DS): The validation dataset. generator (torch.Generator): A PyTorch generator for reproducibility. processor (AutoImageProcessor): Processor for image preprocessing. transforms (Compose): Transformations to apply to the datasets. Returns: Tuple[DataLoader, DataLoader, DataLoader, Tuple[torch.Tensor, torch.Tensor]]: The training, testing, validation dataloaders and a training sample. """ print("making dataloaders") worker_init = functools.partial(seed_worker, base_seed=cfg.seed) collate = functools.partial(collate_fn, model_name=cfg.model.name) for ds in [train_ds_base, test_ds, val_ds]: ds.processor = processor train_ds_base, test_ds, val_ds = set_transforms( train_ds_base, test_ds, val_ds, transforms ) if cfg.mosaic.use: print("Applying Mosaic augmentation to the training dataset.") train_ds = create_albumentations_mosaic_dataset( train_ds_base, transforms["train"], transforms["train_easy"], cfg ) else: train_ds = train_ds_base sampler = None if cfg.do_sample: sampler = get_sampler(train_ds, generator) train_dl = DataLoader( train_ds, batch_size=cfg.step_batch_size, shuffle=not cfg.do_sample, sampler=sampler, num_workers=cfg.num_workers, persistent_workers=True, pin_memory=True, drop_last=True, worker_init_fn=worker_init, generator=generator, collate_fn=collate, ) test_dl = DataLoader( test_ds, batch_size=cfg.step_batch_size, num_workers=cfg.num_workers, persistent_workers=True, pin_memory=True, worker_init_fn=worker_init, generator=torch.Generator().manual_seed(generator.initial_seed() + 2), collate_fn=collate, ) val_dl = DataLoader( val_ds, batch_size=cfg.step_batch_size, num_workers=cfg.num_workers, persistent_workers=True, pin_memory=True, worker_init_fn=worker_init, generator=torch.Generator().manual_seed(generator.initial_seed() + 3), collate_fn=collate, ) test_sample = next(iter(test_dl)) return train_dl, test_dl, val_dl, test_sample
[docs] def get_labels_and_mappings( train_labels: List, test_labels: List ) -> Tuple[List, Dict, Dict]: """ Generates labels and mappings for class IDs and names. Args: train_labels (List): List of labels from the training dataset. test_labels (List): List of labels from the testing dataset. Returns: Tuple[List, Dict, Dict]: A tuple containing: - labels (List): Sorted list of unique labels. - id2lbl (Dict): Mapping from class IDs to labels. - lbl2id (Dict): Mapping from labels to class IDs. """ labels = sorted(list(set(train_labels + test_labels))) id2lbl = {i: lbl for i, lbl in enumerate(labels)} lbl2id = {v: k for k, v in id2lbl.items()} return labels, id2lbl, lbl2id
[docs] def collate_fn(batch: BatchEncoding, model_name) -> Dict: """ Collates a batch of data for the dataloader. Args: batch (BatchEncoding): A batch of data containing images and targets. Returns: Tuple[BatchEncoding, List]: Processed batch and list of targets. """ data = {} data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch]) data["labels"] = [x["labels"] for x in batch] if "pixel_mask" in batch[0] and "yolos" not in model_name: data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch]) return data
[docs] def set_transforms( train_ds, test_ds, val_ds, transforms: Compose, ): """ Sets transformations for the datasets in the dataloaders. Args: train_dl (DataLoader): Training dataloader. test_dl (DataLoader): Testing dataloader. val_dl (DataLoader): Validation dataloader. transforms (Compose): Transformations to apply. Returns: Tuple[DataLoader, DataLoader, DataLoader]: Updated dataloaders with transformations applied. """ train_ds.transforms = transforms["train"] test_ds.transforms = transforms["test"] val_ds.transforms = transforms["test"] return train_ds, test_ds, val_ds