fruit_project.utils.trainer

Classes

Module Contents

class fruit_project.utils.trainer.Trainer(model: torch.nn.Module, processor: transformers.AutoImageProcessor, device: torch.device, cfg: omegaconf.DictConfig, name: str, run: wandb.sdk.wandb_run.Run, train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader, val_dl: torch.utils.data.DataLoader)[source]
model: torch.nn.Module[source]
device: torch.device[source]
scaler[source]
cfg: omegaconf.DictConfig[source]
optimizer: torch.optim.Optimizer[source]
processor: transformers.AutoImageProcessor[source]
name: str[source]
early_stopping: fruit_project.utils.early_stop.EarlyStopping[source]
run: wandb.sdk.wandb_run.Run[source]
train_dl: torch.utils.data.DataLoader[source]
test_dl: torch.utils.data.DataLoader[source]
val_dl: torch.utils.data.DataLoader[source]
start_epoch: int = 0[source]
map_evaluator[source]
accum_steps: int[source]
scheduler: torch.optim.lr_scheduler.SequentialLR[source]
ema: torch_ema.ExponentialMovingAverage = None[source]
get_scheduler() torch.optim.lr_scheduler.SequentialLR[source]

Creates a learning rate scheduler with a warmup phase.

Returns:

The learning rate scheduler.

Return type:

SequentialLR

get_optimizer() torch.optim.AdamW[source]

Creates an AdamW optimizer with a differential learning rate for the backbone and the rest of the model (head), following standard fine-tuning practices.

Returns:

The configured optimizer.

Return type:

AdamW

move_labels_to_device(batch: transformers.BatchEncoding) transformers.BatchEncoding[source]

Moves label tensors within a batch to the specified device.

Parameters:

batch (BatchEncoding) – The batch containing labels.

Returns:

The batch with labels moved to the device.

Return type:

BatchEncoding

nested_to_cpu(obj: Any) Any[source]

Recursively moves tensors in a nested structure (dict, list, tuple) to CPU.

Parameters:

obj – The object containing tensors to move.

Returns:

The object with all tensors moved to CPU.

format_targets_for_cm(targets: List[Dict]) List[Dict][source]

Formats raw targets for torchmetrics and confusion matrix. This is a helper for the confusion matrix, as MAPEvaluator handles its own formatting.

train(current_epoch: int) Dict[str, float][source]

Performs one epoch of training.

Parameters:

current_epoch (int) – The current epoch number.

Returns:

The average training loss for the epoch.

Return type:

float

eval(val_dl: torch.utils.data.DataLoader, current_epoch: int, calc_cm: bool = False) Tuple[dict[str, float], dict[str, Any], fruit_project.utils.metrics.ConfusionMatrix | None][source]
_run_eval(val_dl: torch.utils.data.DataLoader, current_epoch: int, calc_cm: bool = False) Tuple[dict[str, float], dict[str, Any], fruit_project.utils.metrics.ConfusionMatrix | None][source]

Evaluates the model on a given dataloader.

Parameters:
  • test_dl (DataLoader) – The dataloader for evaluation.

  • current_epoch (int) – The current epoch number.

  • calc_cm (bool, optional) – Whether to calculate and return a confusion matrix. Defaults to False.

Returns:

A tuple containing:
  • loss (Dict): The evaluation loss.

  • test_map (float): The mAP@.5-.95.

  • test_map50 (float): The mAP@.50.

  • test_map_50_per_class (torch.Tensor): The mAP@.50 for each class.

  • cm (ConfusionMatrix | None): The confusion matrix if calc_cm is True, else None.

Return type:

tuple

fit() None[source]

Runs the main training loop for the specified number of epochs.

_build_save_dict(epoch)[source]
_save_checkpoint(epoch: int) str[source]

Saves a checkpoint of the model, optimizer, scheduler, and scaler states.

Parameters:

epoch (int) – The current epoch number.

Returns:

The path to the saved checkpoint file.

Return type:

str

_load_checkpoint(path: str, model_only: bool = True) None[source]

Loads a checkpoint and restores the state of the model, optimizer, scheduler, and scaler.

Parameters:

path (str) – The path to the checkpoint file.