fruit_project.utils.data

Attributes

Functions

download_dataset()

Downloads the dataset from Kaggle using the Kaggle API.

make_datasets(...)

Creates training, testing, and validation datasets.

get_sampler(→ torch.utils.data.WeightedRandomSampler)

Creates a WeightedRandomSampler for the training dataset.

make_dataloaders(→ Tuple[torch.utils.data.DataLoader, ...)

Creates dataloaders for training, testing, and validation datasets.

get_labels_and_mappings(→ Tuple[List, Dict, Dict])

Generates labels and mappings for class IDs and names.

collate_fn(→ Tuple[transformers.BatchEncoding, List])

Collates a batch of data for the dataloader.

set_transforms(→ Tuple[torch.utils.data.DataLoader, ...)

Sets transformations for the datasets in the dataloaders.

Module Contents

fruit_project.utils.data.rlimit[source]
fruit_project.utils.data.download_dataset()[source]

Downloads the dataset from Kaggle using the Kaggle API.

Raises:

RuntimeError – If the required environment variables for Kaggle API are not set.

Returns:

None

fruit_project.utils.data.make_datasets(cfg: omegaconf.DictConfig) Tuple[fruit_project.utils.datasets.det_dataset.DET_DS, fruit_project.utils.datasets.det_dataset.DET_DS, fruit_project.utils.datasets.det_dataset.DET_DS][source]

Creates training, testing, and validation datasets.

Parameters:

cfg (DictConfig) – Configuration object containing dataset parameters.

Returns:

The training, testing, and validation datasets.

Return type:

Tuple[DET_DS, DET_DS, DET_DS]

fruit_project.utils.data.get_sampler(train_ds: fruit_project.utils.datasets.det_dataset.DET_DS, strat: str) torch.utils.data.WeightedRandomSampler[source]

Creates a WeightedRandomSampler for the training dataset.

Parameters:
  • train_ds (DET_DS) – The training dataset.

  • strat (str) – The strategy for weighting (‘max’ or ‘mean’).

Returns:

A sampler for the training dataset.

Return type:

WeightedRandomSampler

fruit_project.utils.data.make_dataloaders(cfg: omegaconf.DictConfig, train_ds: fruit_project.utils.datasets.det_dataset.DET_DS, test_ds: fruit_project.utils.datasets.det_dataset.DET_DS, val_ds: fruit_project.utils.datasets.det_dataset.DET_DS, generator: torch.Generator, processor: transformers.AutoImageProcessor, transforms: albumentations.Compose) Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader][source]

Creates dataloaders for training, testing, and validation datasets.

Parameters:
  • cfg (DictConfig) – Configuration object containing dataloader parameters.

  • train_ds (DET_DS) – The training dataset.

  • test_ds (DET_DS) – The testing dataset.

  • val_ds (DET_DS) – The validation dataset.

  • generator (torch.Generator) – A PyTorch generator for reproducibility.

  • processor (AutoImageProcessor) – Processor for image preprocessing.

  • transforms (Compose) – Transformations to apply to the datasets.

Returns:

The training, testing, and validation dataloaders.

Return type:

Tuple[DataLoader, DataLoader, DataLoader]

fruit_project.utils.data.get_labels_and_mappings(train_labels: List, test_labels: List) Tuple[List, Dict, Dict][source]

Generates labels and mappings for class IDs and names.

Parameters:
  • train_labels (List) – List of labels from the training dataset.

  • test_labels (List) – List of labels from the testing dataset.

Returns:

A tuple containing:
  • labels (List): Sorted list of unique labels.

  • id2lbl (Dict): Mapping from class IDs to labels.

  • lbl2id (Dict): Mapping from labels to class IDs.

Return type:

Tuple[List, Dict, Dict]

fruit_project.utils.data.collate_fn(batch: transformers.BatchEncoding, processor: transformers.AutoImageProcessor) Tuple[transformers.BatchEncoding, List][source]

Collates a batch of data for the dataloader.

Parameters:
  • batch (BatchEncoding) – A batch of data containing images and targets.

  • processor (AutoImageProcessor) – Processor for image preprocessing.

Returns:

Processed batch and list of targets.

Return type:

Tuple[BatchEncoding, List]

fruit_project.utils.data.set_transforms(train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader, val_dl: torch.utils.data.DataLoader, transforms: albumentations.Compose) Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader][source]

Sets transformations for the datasets in the dataloaders.

Parameters:
  • train_dl (DataLoader) – Training dataloader.

  • test_dl (DataLoader) – Testing dataloader.

  • val_dl (DataLoader) – Validation dataloader.

  • transforms (Compose) – Transformations to apply.

Returns:

Updated dataloaders with transformations applied.

Return type:

Tuple[DataLoader, DataLoader, DataLoader]