fruit_project.utils.trainer¶
Classes¶
Module Contents¶
- class fruit_project.utils.trainer.Trainer(model: torch.nn.Module, processor: transformers.AutoImageProcessor, device: torch.device, cfg: omegaconf.DictConfig, name: str, run: wandb.sdk.wandb_run.Run, train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader, val_dl: torch.utils.data.DataLoader)[source]¶
-
- early_stopping: fruit_project.utils.early_stop.EarlyStopping[source]¶
- get_scheduler() torch.optim.lr_scheduler.SequentialLR [source]¶
Creates a learning rate scheduler with a warmup phase.
- Returns:
The learning rate scheduler.
- Return type:
SequentialLR
- get_optimizer() torch.optim.AdamW [source]¶
Creates an AdamW optimizer with a differential learning rate for the backbone and the rest of the model (head), following standard fine-tuning practices.
- Returns:
The configured optimizer.
- Return type:
AdamW
- move_labels_to_device(batch: transformers.BatchEncoding) transformers.BatchEncoding [source]¶
Moves label tensors within a batch to the specified device.
- Parameters:
batch (BatchEncoding) – The batch containing labels.
- Returns:
The batch with labels moved to the device.
- Return type:
BatchEncoding
- nested_to_cpu(obj: Any) Any [source]¶
Recursively moves tensors in a nested structure (dict, list, tuple) to CPU.
- Parameters:
obj – The object containing tensors to move.
- Returns:
The object with all tensors moved to CPU.
- format_targets_for_cm(targets: List[Dict]) List[Dict] [source]¶
Formats raw targets for torchmetrics and confusion matrix. This is a helper for the confusion matrix, as MAPEvaluator handles its own formatting.
- train(current_epoch: int) Dict[str, float] [source]¶
Performs one epoch of training.
- Parameters:
current_epoch (int) – The current epoch number.
- Returns:
The average training loss for the epoch.
- Return type:
float
- eval(val_dl: torch.utils.data.DataLoader, current_epoch: int, calc_cm: bool = False) Tuple[dict[str, float], dict[str, Any], fruit_project.utils.metrics.ConfusionMatrix | None] [source]¶
- _run_eval(val_dl: torch.utils.data.DataLoader, current_epoch: int, calc_cm: bool = False) Tuple[dict[str, float], dict[str, Any], fruit_project.utils.metrics.ConfusionMatrix | None] [source]¶
Evaluates the model on a given dataloader.
- Parameters:
test_dl (DataLoader) – The dataloader for evaluation.
current_epoch (int) – The current epoch number.
calc_cm (bool, optional) – Whether to calculate and return a confusion matrix. Defaults to False.
- Returns:
- A tuple containing:
loss (Dict): The evaluation loss.
test_map (float): The mAP@.5-.95.
test_map50 (float): The mAP@.50.
test_map_50_per_class (torch.Tensor): The mAP@.50 for each class.
cm (ConfusionMatrix | None): The confusion matrix if calc_cm is True, else None.
- Return type:
tuple