fruit_project.utils.data¶
Attributes¶
Functions¶
Downloads the dataset from Kaggle using the Kaggle API. |
|
|
Creates training, testing, and validation datasets. |
|
Creates a WeightedRandomSampler for the training dataset. |
|
Creates dataloaders for training, testing, and validation datasets. |
|
Generates labels and mappings for class IDs and names. |
|
Collates a batch of data for the dataloader. |
|
Sets transformations for the datasets in the dataloaders. |
Module Contents¶
- fruit_project.utils.data.download_dataset()[source]¶
Downloads the dataset from Kaggle using the Kaggle API.
- Raises:
RuntimeError – If the required environment variables for Kaggle API are not set.
- Returns:
None
- fruit_project.utils.data.make_datasets(cfg: omegaconf.DictConfig) Tuple[fruit_project.utils.datasets.det_dataset.DET_DS, fruit_project.utils.datasets.det_dataset.DET_DS, fruit_project.utils.datasets.det_dataset.DET_DS] [source]¶
Creates training, testing, and validation datasets.
- fruit_project.utils.data.get_sampler(train_ds: fruit_project.utils.datasets.det_dataset.DET_DS, strat: str) torch.utils.data.WeightedRandomSampler [source]¶
Creates a WeightedRandomSampler for the training dataset.
- Parameters:
train_ds (DET_DS) – The training dataset.
strat (str) – The strategy for weighting (‘max’ or ‘mean’).
- Returns:
A sampler for the training dataset.
- Return type:
WeightedRandomSampler
- fruit_project.utils.data.make_dataloaders(cfg: omegaconf.DictConfig, train_ds: fruit_project.utils.datasets.det_dataset.DET_DS, test_ds: fruit_project.utils.datasets.det_dataset.DET_DS, val_ds: fruit_project.utils.datasets.det_dataset.DET_DS, generator: torch.Generator, processor: transformers.AutoImageProcessor, transforms: albumentations.Compose) Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader] [source]¶
Creates dataloaders for training, testing, and validation datasets.
- Parameters:
cfg (DictConfig) – Configuration object containing dataloader parameters.
train_ds (DET_DS) – The training dataset.
test_ds (DET_DS) – The testing dataset.
val_ds (DET_DS) – The validation dataset.
generator (torch.Generator) – A PyTorch generator for reproducibility.
processor (AutoImageProcessor) – Processor for image preprocessing.
transforms (Compose) – Transformations to apply to the datasets.
- Returns:
The training, testing, and validation dataloaders.
- Return type:
Tuple[DataLoader, DataLoader, DataLoader]
- fruit_project.utils.data.get_labels_and_mappings(train_labels: List, test_labels: List) Tuple[List, Dict, Dict] [source]¶
Generates labels and mappings for class IDs and names.
- Parameters:
train_labels (List) – List of labels from the training dataset.
test_labels (List) – List of labels from the testing dataset.
- Returns:
- A tuple containing:
labels (List): Sorted list of unique labels.
id2lbl (Dict): Mapping from class IDs to labels.
lbl2id (Dict): Mapping from labels to class IDs.
- Return type:
Tuple[List, Dict, Dict]
- fruit_project.utils.data.collate_fn(batch: transformers.BatchEncoding, processor: transformers.AutoImageProcessor) Tuple[transformers.BatchEncoding, List] [source]¶
Collates a batch of data for the dataloader.
- Parameters:
batch (BatchEncoding) – A batch of data containing images and targets.
processor (AutoImageProcessor) – Processor for image preprocessing.
- Returns:
Processed batch and list of targets.
- Return type:
Tuple[BatchEncoding, List]
- fruit_project.utils.data.set_transforms(train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader, val_dl: torch.utils.data.DataLoader, transforms: albumentations.Compose) Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader] [source]¶
Sets transformations for the datasets in the dataloaders.
- Parameters:
train_dl (DataLoader) – Training dataloader.
test_dl (DataLoader) – Testing dataloader.
val_dl (DataLoader) – Validation dataloader.
transforms (Compose) – Transformations to apply.
- Returns:
Updated dataloaders with transformations applied.
- Return type:
Tuple[DataLoader, DataLoader, DataLoader]