fruit_project.utils.trainer

Classes

Module Contents

class fruit_project.utils.trainer.Trainer(model: torch.nn.Module, processor: transformers.AutoImageProcessor, device: torch.device, cfg: omegaconf.DictConfig, name: str, run: wandb.sdk.wandb_run.Run, train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader, val_dl: torch.utils.data.DataLoader)[source]
model[source]
device[source]
scaler[source]
cfg[source]
optimizer[source]
processor[source]
name[source]
early_stopping[source]
scheduler[source]
run[source]
train_dl[source]
test_dl[source]
val_dl[source]
start_epoch = 0[source]
accum_steps[source]
get_scheduler() torch.optim.lr_scheduler.SequentialLR[source]

Creates a learning rate scheduler with a warmup phase.

Returns:

The learning rate scheduler.

Return type:

SequentialLR

get_optimizer() torch.optim.AdamW[source]

Creates an AdamW optimizer with different learning rates for backbone and other parameters.

Returns:

The optimizer.

Return type:

AdamW

move_labels_to_device(batch: transformers.BatchEncoding) transformers.BatchEncoding[source]

Moves label tensors within a batch to the specified device.

Parameters:

batch (BatchEncoding) – The batch containing labels.

Returns:

The batch with labels moved to the device.

Return type:

BatchEncoding

nested_to_cpu(obj: Any) Any[source]

Recursively moves tensors in a nested structure (dict, list, tuple) to CPU.

Parameters:

obj – The object containing tensors to move.

Returns:

The object with all tensors moved to CPU.

format_targets_for_map(y: List) List[source]

Formats target annotations for MeanAveragePrecision metric calculation.

Parameters:

y (List) – A list of target dictionaries.

Returns:

A list of formatted target dictionaries for the metric.

Return type:

List

train(current_epoch: int) float[source]

Performs one epoch of training.

Parameters:

current_epoch (int) – The current epoch number.

Returns:

The average training loss for the epoch.

Return type:

float

eval(test_dl: torch.utils.data.DataLoader, current_epoch: int, calc_cm: bool = False) Tuple[float, float, float, torch.Tensor, fruit_project.utils.metrics.ConfusionMatrix | None][source]

Evaluates the model on a given dataloader.

Parameters:
  • test_dl (DataLoader) – The dataloader for evaluation.

  • current_epoch (int) – The current epoch number.

  • calc_cm (bool, optional) – Whether to calculate and return a confusion matrix. Defaults to False.

Returns:

A tuple containing:
  • loss (float): The average evaluation loss.

  • test_map (float): The mAP@.5-.95.

  • test_map50 (float): The mAP@.50.

  • test_map_50_per_class (torch.Tensor): The mAP@.50 for each class.

  • cm (ConfusionMatrix | None): The confusion matrix if calc_cm is True, else None.

Return type:

tuple

fit() None[source]

Runs the main training loop for the specified number of epochs.

_save_checkpoint(epoch: int) str[source]

Saves a checkpoint of the model, optimizer, scheduler, and scaler states.

Parameters:

epoch (int) – The current epoch number.

Returns:

The path to the saved checkpoint file.

Return type:

str

_load_checkpoint(path: str) None[source]

Loads a checkpoint and restores the state of the model, optimizer, scheduler, and scaler.

Parameters:

path (str) – The path to the checkpoint file.

get_epoch_log_data(epoch: int, train_loss: float, test_map: float, test_map50: float, test_loss: float, test_map_per_class: torch.Tensor) Dict[str, Any][source]

Constructs a dictionary of metrics for logging at the end of an epoch.

Parameters:
  • epoch (int) – The current epoch number.

  • train_loss (float) – The training loss.

  • test_map (float) – The test mAP@.5-.95.

  • test_map50 (float) – The test mAP@.50.

  • test_loss (float) – The test loss.

  • test_map_per_class (torch.Tensor) – The test mAP@.50 for each class.

Returns:

A dictionary of metrics for logging.

Return type:

dict

get_val_log_data(epoch: int, best_test_map: float) Dict[str, Any][source]

Performs final validation, logs metrics, and returns the log data.

Parameters:
  • epoch (int) – The final epoch number.

  • best_test_map (float) – The best test mAP@.50 achieved during training.

Returns:

A dictionary of validation metrics for logging.

Return type:

dict