fruit_project.utils.trainer =========================== .. py:module:: fruit_project.utils.trainer Classes ------- .. autoapisummary:: fruit_project.utils.trainer.Trainer Module Contents --------------- .. py:class:: Trainer(model: torch.nn.Module, processor: transformers.AutoImageProcessor, device: torch.device, cfg: omegaconf.DictConfig, name: str, run: wandb.sdk.wandb_run.Run, train_dl: torch.utils.data.DataLoader, test_dl: torch.utils.data.DataLoader, val_dl: torch.utils.data.DataLoader, loading_info: Dict) .. py:attribute:: model :type: torch.nn.Module .. py:attribute:: device :type: torch.device .. py:attribute:: scaler .. py:attribute:: cfg :type: omegaconf.DictConfig .. py:attribute:: optimizer :type: torch.optim.Optimizer .. py:attribute:: processor :type: transformers.AutoImageProcessor .. py:attribute:: name :type: str .. py:attribute:: early_stopping :type: fruit_project.utils.early_stop.EarlyStopping .. py:attribute:: run :type: wandb.sdk.wandb_run.Run .. py:attribute:: train_dl :type: torch.utils.data.DataLoader .. py:attribute:: test_dl :type: torch.utils.data.DataLoader .. py:attribute:: val_dl :type: torch.utils.data.DataLoader .. py:attribute:: start_epoch :type: int :value: 0 .. py:attribute:: map_evaluator .. py:attribute:: accum_steps :type: int .. py:attribute:: scheduler :type: torch.optim.lr_scheduler.SequentialLR .. py:attribute:: ema :type: torch_ema.ExponentialMovingAverage :value: None .. py:method:: get_scheduler() -> torch.optim.lr_scheduler.SequentialLR Creates a learning rate scheduler with a warmup phase. :returns: The learning rate scheduler. :rtype: SequentialLR .. py:method:: get_optimizer(loading_info=None) -> torch.optim.AdamW Creates an AdamW optimizer with a differential learning rate for the backbone and the rest of the model (head), following standard fine-tuning practices. :returns: The configured optimizer. :rtype: AdamW .. py:method:: move_batch_to_device(batch: transformers.BatchEncoding) -> transformers.BatchEncoding Moves label tensors within a batch to the specified device. :param batch: The batch containing labels. :type batch: BatchEncoding :returns: The batch with labels moved to the device. :rtype: BatchEncoding .. py:method:: nested_to_cpu(obj: Any) -> Any Recursively moves tensors in a nested structure (dict, list, tuple) to CPU. :param obj: The object containing tensors to move. :returns: The object with all tensors moved to CPU. .. py:method:: format_targets_for_cm(targets: List[Dict]) -> List[Dict] Formats raw targets for torchmetrics and confusion matrix. This is a helper for the confusion matrix, as MAPEvaluator handles its own formatting. .. py:method:: forward_step_map(batch_idx, batch) .. py:method:: forward_step_fp32(batch_idx, batch) .. py:method:: forward_step(batch_idx, batch, use_fp16, current_epoch) .. py:method:: train(current_epoch: int) -> Dict[str, float] Performs one epoch of training. :param current_epoch: The current epoch number. :type current_epoch: int :returns: The average training loss for the epoch. :rtype: float .. py:method:: eval(val_dl: torch.utils.data.DataLoader, current_epoch: int, calc_cm: bool = False) -> Tuple[dict[str, float], dict[str, Any], Optional[fruit_project.utils.metrics.ConfusionMatrix]] .. py:method:: _run_eval(val_dl: torch.utils.data.DataLoader, current_epoch: int, calc_cm: bool = False) -> Tuple[dict[str, float], dict[str, Any], Optional[fruit_project.utils.metrics.ConfusionMatrix]] Evaluates the model on a given dataloader. :param test_dl: The dataloader for evaluation. :type test_dl: DataLoader :param current_epoch: The current epoch number. :type current_epoch: int :param calc_cm: Whether to calculate and return a confusion matrix. Defaults to False. :type calc_cm: bool, optional :returns: A tuple containing: - loss (Dict): The evaluation loss. - test_map (float): The mAP@.5-.95. - test_map50 (float): The mAP@.50. - test_map_50_per_class (torch.Tensor): The mAP@.50 for each class. - cm (ConfusionMatrix | None): The confusion matrix if calc_cm is True, else None. :rtype: tuple .. py:method:: fit() -> None Runs the main training loop for the specified number of epochs. .. py:method:: _build_save_dict(epoch) .. py:method:: _save_checkpoint(epoch: int) -> str Saves a checkpoint of the model, optimizer, scheduler, and scaler states. :param epoch: The current epoch number. :type epoch: int :returns: The path to the saved checkpoint file. :rtype: str .. py:method:: _load_checkpoint(path: str, model_only: bool = True) -> None Loads a checkpoint and restores the state of the model, optimizer, scheduler, and scaler. :param path: The path to the checkpoint file. :type path: str