fruit_project.utils.trainer¶
Classes¶
Module Contents¶
- class fruit_project.utils.trainer.Trainer(model: torch.nn.Module, processor: transformers.AutoImageProcessor, device: torch.device, cfg: omegaconf.DictConfig, name: str, run: wandb.sdk.wandb_run.Run, train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader, val_dl: torch.utils.data.DataLoader)[source]¶
-
- get_scheduler() torch.optim.lr_scheduler.SequentialLR [source]¶
Creates a learning rate scheduler with a warmup phase.
- Returns:
The learning rate scheduler.
- Return type:
SequentialLR
- get_optimizer() torch.optim.AdamW [source]¶
Creates an AdamW optimizer with different learning rates for backbone and other parameters.
- Returns:
The optimizer.
- Return type:
AdamW
- move_labels_to_device(batch: transformers.BatchEncoding) transformers.BatchEncoding [source]¶
Moves label tensors within a batch to the specified device.
- Parameters:
batch (BatchEncoding) – The batch containing labels.
- Returns:
The batch with labels moved to the device.
- Return type:
BatchEncoding
- nested_to_cpu(obj: Any) Any [source]¶
Recursively moves tensors in a nested structure (dict, list, tuple) to CPU.
- Parameters:
obj – The object containing tensors to move.
- Returns:
The object with all tensors moved to CPU.
- format_targets_for_map(y: List) List [source]¶
Formats target annotations for MeanAveragePrecision metric calculation.
- Parameters:
y (List) – A list of target dictionaries.
- Returns:
A list of formatted target dictionaries for the metric.
- Return type:
List
- train(current_epoch: int) float [source]¶
Performs one epoch of training.
- Parameters:
current_epoch (int) – The current epoch number.
- Returns:
The average training loss for the epoch.
- Return type:
float
- eval(test_dl: torch.utils.data.DataLoader, current_epoch: int, calc_cm: bool = False) Tuple[float, float, float, torch.Tensor, fruit_project.utils.metrics.ConfusionMatrix | None] [source]¶
Evaluates the model on a given dataloader.
- Parameters:
test_dl (DataLoader) – The dataloader for evaluation.
current_epoch (int) – The current epoch number.
calc_cm (bool, optional) – Whether to calculate and return a confusion matrix. Defaults to False.
- Returns:
- A tuple containing:
loss (float): The average evaluation loss.
test_map (float): The mAP@.5-.95.
test_map50 (float): The mAP@.50.
test_map_50_per_class (torch.Tensor): The mAP@.50 for each class.
cm (ConfusionMatrix | None): The confusion matrix if calc_cm is True, else None.
- Return type:
tuple
- _save_checkpoint(epoch: int) str [source]¶
Saves a checkpoint of the model, optimizer, scheduler, and scaler states.
- Parameters:
epoch (int) – The current epoch number.
- Returns:
The path to the saved checkpoint file.
- Return type:
str
- _load_checkpoint(path: str) None [source]¶
Loads a checkpoint and restores the state of the model, optimizer, scheduler, and scaler.
- Parameters:
path (str) – The path to the checkpoint file.
- get_epoch_log_data(epoch: int, train_loss: float, test_map: float, test_map50: float, test_loss: float, test_map_per_class: torch.Tensor) Dict[str, Any] [source]¶
Constructs a dictionary of metrics for logging at the end of an epoch.
- Parameters:
epoch (int) – The current epoch number.
train_loss (float) – The training loss.
test_map (float) – The test mAP@.5-.95.
test_map50 (float) – The test mAP@.50.
test_loss (float) – The test loss.
test_map_per_class (torch.Tensor) – The test mAP@.50 for each class.
- Returns:
A dictionary of metrics for logging.
- Return type:
dict
- get_val_log_data(epoch: int, best_test_map: float) Dict[str, Any] [source]¶
Performs final validation, logs metrics, and returns the log data.
- Parameters:
epoch (int) – The final epoch number.
best_test_map (float) – The best test mAP@.50 achieved during training.
- Returns:
A dictionary of validation metrics for logging.
- Return type:
dict