# SPDX-FileCopyrightText: 2025 Mohamed Khayat
# SPDX-License-Identifier: AGPL-3.0-or-later
import os
from typing import List, Optional, Tuple, Dict, Any
import torch
import torch.nn as nn
from torch.amp import GradScaler
from tqdm import tqdm
from transformers.image_transforms import center_to_corners_format
from omegaconf import DictConfig
from fruit_project.utils.datasets.alb_mosaic_dataset import AlbumentationsMosaicDataset
from fruit_project.utils.early_stop import EarlyStopping
from torch.optim.lr_scheduler import CosineAnnealingLR, SequentialLR, LinearLR
from fruit_project.utils.metrics import ConfusionMatrix, MAPEvaluator
from transformers import AutoImageProcessor, BatchEncoding
from wandb.sdk.wandb_run import Run
from torch.utils.data import DataLoader
from torch.optim import AdamW, Optimizer
from torch_ema import ExponentialMovingAverage
from fruit_project.utils.logging import (
log_checkpoint_artifact,
log_epoch_data,
log_test_data,
)
[docs]
class Trainer:
def __init__(
self,
model: nn.Module,
processor: AutoImageProcessor,
device: torch.device,
cfg: DictConfig,
name: str,
run: Run,
train_dl: DataLoader,
test_dl: DataLoader,
val_dl: DataLoader,
):
"""
Initializes the Trainer object.
Args:
model (nn.Module): The model to train.
processor (AutoImageProcessor): The processor for preprocessing data.
device (torch.device): The device to run training on.
cfg (DictConfig): The configuration object.
name (str): The name of the training run.
run (Run): The wandb run object.
train_dl (DataLoader): The training dataloader.
test_dl (DataLoader): The testing dataloader.
val_dl (DataLoader): The validation dataloader.
"""
[docs]
self.model: nn.Module = model
[docs]
self.device: torch.device = device
[docs]
self.scaler = GradScaler("cuda")
[docs]
self.cfg: DictConfig = cfg
[docs]
self.optimizer: Optimizer = self.get_optimizer()
[docs]
self.processor: AutoImageProcessor = processor
[docs]
self.early_stopping: EarlyStopping = EarlyStopping(
cfg.patience, cfg.delta, "checkpoints", name, run, cfg.log, cfg.upload
)
[docs]
self.train_dl: DataLoader = train_dl
[docs]
self.test_dl: DataLoader = test_dl
[docs]
self.val_dl: DataLoader = val_dl
[docs]
self.start_epoch: int = 0
[docs]
self.map_evaluator = MAPEvaluator(
image_processor=processor,
device=self.device,
threshold=self.cfg.threshold,
id2label=train_dl.dataset.id2lbl,
)
[docs]
self.accum_steps: int = (
self.cfg.effective_batch_size // self.cfg.step_batch_size
)
assert self.cfg.effective_batch_size % self.cfg.step_batch_size == 0, (
f"effective_batch_size ({self.cfg.effective_batch_size}) must be divisible by batch_size "
f"({self.accum_steps})."
)
[docs]
self.scheduler: SequentialLR = self.get_scheduler()
[docs]
self.ema: ExponentialMovingAverage = (
None
if not self.cfg.ema.use
else ExponentialMovingAverage(
self.model.parameters(), decay=self.cfg.ema.decay
)
)
[docs]
def get_scheduler(self) -> SequentialLR:
"""
Creates a learning rate scheduler with a warmup phase.
Returns:
SequentialLR: The learning rate scheduler.
"""
train_steps = len(self.train_dl) // self.accum_steps
if self.cfg.phase == 1:
total_warmup_steps = self.cfg.warmup_epochs * train_steps
warmup_scheduler = LinearLR(
self.optimizer,
start_factor=self.cfg.lin_start_factor,
total_iters=total_warmup_steps,
)
main_steps = (self.cfg.epochs - self.cfg.warmup_epochs) * train_steps
main_scheduler = CosineAnnealingLR(
self.optimizer,
T_max=main_steps,
eta_min=self.cfg.lr / self.cfg.eta_min_factor,
)
scheduler = SequentialLR(
self.optimizer,
schedulers=[warmup_scheduler, main_scheduler],
milestones=[
total_warmup_steps,
],
)
else:
scheduler = CosineAnnealingLR(
self.optimizer, T_max=self.cfg.epochs * train_steps, eta_min=0
)
return scheduler
[docs]
def get_optimizer(self) -> AdamW:
"""
Creates an AdamW optimizer with a differential learning rate for the backbone
and the rest of the model (head), following standard fine-tuning practices.
Returns:
AdamW: The configured optimizer.
"""
head_params = [p for p in self.model.parameters() if p.requires_grad]
backbone_params = [
p
for n, p in self.model.named_parameters()
if n.startswith("model.backbone") or n.startswith("vit") and p.requires_grad
]
backbone_param_ids = {id(p) for p in backbone_params}
head_params_final = [p for p in head_params if id(p) not in backbone_param_ids]
param_dicts = [
{
"params": head_params_final,
"lr": self.cfg.lr,
},
{
"params": backbone_params,
"lr": self.cfg.lr / self.cfg.lr_back_factor,
},
]
print(
f"Backbone params: {sum(p.numel() for p in backbone_params)} parameters at LR {self.cfg.lr / self.cfg.lr_back_factor}"
)
print(
f"Head (Encoder, Decoder, Neck, etc.) params: {sum(p.numel() for p in head_params_final)} parameters at LR {self.cfg.lr}"
)
if self.cfg.optim == "8bit":
tqdm.write("using AdamW8bit")
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(
param_dicts,
weight_decay=self.cfg.weight_decay,
)
elif self.cfg.optim == "torch":
tqdm.write("using torch AdamW")
optimizer = AdamW(
param_dicts,
weight_decay=self.cfg.weight_decay,
fused=True,
)
else:
raise KeyError("invalid optim type, use 8bit or torch")
return optimizer
[docs]
def move_labels_to_device(self, batch: BatchEncoding) -> BatchEncoding:
"""
Moves label tensors within a batch to the specified device.
Args:
batch (BatchEncoding): The batch containing labels.
Returns:
BatchEncoding: The batch with labels moved to the device.
"""
for lab in batch["labels"]:
for k, v in lab.items():
lab[k] = v.to(self.device)
return batch
[docs]
def nested_to_cpu(self, obj: Any) -> Any:
"""
Recursively moves tensors in a nested structure (dict, list, tuple) to CPU.
Args:
obj: The object containing tensors to move.
Returns:
The object with all tensors moved to CPU.
"""
if torch.is_tensor(obj):
return obj.cpu()
if isinstance(obj, dict):
return {k: self.nested_to_cpu(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [self.nested_to_cpu(v) for v in obj]
return obj
[docs]
def train(
self,
current_epoch: int,
) -> Dict[str, float]:
"""
Performs one epoch of training.
Args:
current_epoch (int): The current epoch number.
Returns:
float: The average training loss for the epoch.
"""
self.model.train()
self.optimizer.zero_grad(set_to_none=True)
epoch_loss = {"loss": 0.0}
epoch_loss.update({k: 0.0 for k in ["class_loss", "bbox_loss", "giou_loss"]})
progress_bar = tqdm(
self.train_dl,
desc=f"Epoch {current_epoch} Training",
leave=False,
bar_format="{l_bar}{bar:30}{r_bar}{bar:-30b}",
)
for batch_idx, batch in enumerate(progress_bar):
batch["pixel_values"] = batch["pixel_values"].to(self.device)
batch = self.move_labels_to_device(batch)
with torch.autocast(device_type=self.device.type, dtype=torch.float16):
out = self.model(**batch)
batch_loss = out.loss / self.accum_steps
loss_dict = out.loss_dict
self.scaler.scale(batch_loss).backward()
if (batch_idx + 1) % self.accum_steps == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self.cfg.model.grad_max_norm
)
self.scaler.step(self.optimizer)
self.scheduler.step()
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
if self.ema and current_epoch >= self.cfg.warmup_epochs:
self.ema.update()
epoch_loss["loss"] += out.loss.item()
epoch_loss["class_loss"] += loss_dict.get(
"loss_vfl", loss_dict.get("loss_ce", 0.0)
).item()
epoch_loss["bbox_loss"] += loss_dict["loss_bbox"].item()
epoch_loss["giou_loss"] += loss_dict["loss_giou"].item()
current_avg_loss = epoch_loss["loss"] / (batch_idx + 1)
progress_bar.set_postfix(
{
"Loss": f"{current_avg_loss:.4f}",
"Batch": f"{batch_idx + 1}/{len(self.train_dl)}",
}
)
if batch_idx % 50 == 0:
torch.cuda.empty_cache()
num_batches = len(self.train_dl)
epoch_loss = {k: v / num_batches for k, v in epoch_loss.items()}
tqdm.write(f"Epoch : {current_epoch}")
tqdm.write(
f"\tTrain --- Loss: {epoch_loss['loss']:.4f}, Class Loss : {epoch_loss['class_loss']:.4f}, Bbox Loss : {epoch_loss['bbox_loss']:.4f}, Giou Loss : {epoch_loss['giou_loss']:.4f}"
)
return epoch_loss
@torch.no_grad()
[docs]
def eval(
self, val_dl: DataLoader, current_epoch: int, calc_cm: bool = False
) -> Tuple[dict[str, float], dict[str, Any], Optional[ConfusionMatrix]]:
if self.ema and current_epoch >= self.cfg.warmup_epochs:
tqdm.write("evaluating with EMA weights")
with self.ema.average_parameters():
return self._run_eval(val_dl, current_epoch, calc_cm)
else:
tqdm.write("evaluating with regular weights")
return self._run_eval(val_dl, current_epoch, calc_cm)
[docs]
def _run_eval(
self, val_dl: DataLoader, current_epoch: int, calc_cm: bool = False
) -> Tuple[dict[str, float], dict[str, Any], Optional[ConfusionMatrix]]:
"""
Evaluates the model on a given dataloader.
Args:
test_dl (DataLoader): The dataloader for evaluation.
current_epoch (int): The current epoch number.
calc_cm (bool, optional): Whether to calculate and return a confusion matrix. Defaults to False.
Returns:
tuple: A tuple containing:
- loss (Dict): The evaluation loss.
- test_map (float): The mAP@.5-.95.
- test_map50 (float): The mAP@.50.
- test_map_50_per_class (torch.Tensor): The mAP@.50 for each class.
- cm (ConfusionMatrix | None): The confusion matrix if calc_cm is True, else None.
"""
self.model.eval()
self.map_evaluator.map_metric.reset()
self.map_evaluator.map_50_metric.reset()
epoch_loss = {"loss": 0.0}
epoch_loss.update({k: 0.0 for k in ["class_loss", "bbox_loss", "giou_loss"]})
if calc_cm:
cm = ConfusionMatrix(
nc=len(val_dl.dataset.labels), conf=0.374, iou_thres=0.45
)
else:
cm = None
progress_bar = tqdm(
val_dl,
desc=f"Epoch {current_epoch} Evaluating",
leave=False,
bar_format="{l_bar}{bar:30}{r_bar}{bar:-30b}",
)
for batch_idx, batch in enumerate(progress_bar):
batch["pixel_values"] = batch["pixel_values"].to(self.device)
batch = self.move_labels_to_device(batch)
out = self.model(**batch)
batch_loss = out.loss
loss_dict = out.loss_dict
epoch_loss["loss"] += batch_loss.item()
epoch_loss["class_loss"] += loss_dict.get(
"loss_vfl", loss_dict.get("loss_ce", 0.0)
).item()
epoch_loss["bbox_loss"] += loss_dict["loss_bbox"].item()
epoch_loss["giou_loss"] += loss_dict["loss_giou"].item()
batch_targets = batch["labels"]
image_sizes = self.map_evaluator.collect_image_sizes(batch_targets)
batch_preds_processed = self.map_evaluator.collect_predictions(
out, image_sizes
)
batch_targets_processed = self.map_evaluator.collect_targets(
batch_targets, image_sizes
)
self.map_evaluator.map_metric.update(
batch_preds_processed, batch_targets_processed
)
self.map_evaluator.map_50_metric.update(
batch_preds_processed, batch_targets_processed
)
current_avg_loss = epoch_loss["loss"] / (batch_idx + 1)
progress_bar.set_postfix(
{
"Loss": f"{current_avg_loss:.4f}",
"Batch": f"{batch_idx + 1}/{len(val_dl)}",
}
)
if calc_cm and cm:
preds = self.nested_to_cpu(batch_preds_processed)
targets_for_cm = self.format_targets_for_cm(batch["labels"])
cm.update(preds, targets_for_cm)
tqdm.write("Computing mAP metrics")
map_50_95_metrics = self.map_evaluator.map_metric.compute()
val_map = map_50_95_metrics.get("map", 0.0)
val_map50 = map_50_95_metrics.get("map_50", 0.0)
map_50_metrics = self.map_evaluator.map_50_metric.compute()
optimal_precisions, optimal_recalls = (
self.map_evaluator.get_optimal_f1_ultralytics_style(map_50_metrics)
)
present_classes = map_50_metrics.get(
"classes", torch.tensor([], device=self.device)
)
overall_precision, overall_recall = (
self.map_evaluator.get_averaged_precision_recall_ultralytics_style(
optimal_precisions, optimal_recalls, present_classes
)
)
val_metrics = {
"map@50:95": val_map,
"map@50": val_map50,
"map@50_per_class": self.map_evaluator.get_per_class(
map_50_metrics, metric="map_per_class"
),
"precision_per_class": optimal_precisions,
"recall_per_class": optimal_recalls,
"precision": overall_precision,
"recall": overall_recall,
}
num_batches = len(val_dl)
epoch_loss = {k: v / num_batches for k, v in epoch_loss.items()}
tqdm.write(
f"\tEval --- Loss: {epoch_loss['loss']:.4f}, Class Loss : {epoch_loss['class_loss']:.4f}, Bbox Loss : {epoch_loss['bbox_loss']:.4f}, Giou Loss : {epoch_loss['giou_loss']:.4f}"
)
tqdm.write(f"\tEval --- mAP50-95: {val_map:.4f}, mAP@50 : {val_map50:.4f}")
tqdm.write("\t--- Per-class mAP@50 ---")
class_names = val_dl.dataset.labels
if val_metrics["map@50_per_class"].is_cuda:
val_metrics["map@50_per_class"] = val_metrics["map@50_per_class"].cpu()
for i, class_name in enumerate(class_names):
if i < len(val_metrics["map@50_per_class"]):
tqdm.write(
f"\t\t{class_name:<15}: {val_metrics['map@50_per_class'][i].item():.4f}"
)
return epoch_loss, val_metrics, cm
[docs]
def fit(self) -> None:
"""
Runs the main training loop for the specified number of epochs.
"""
epoch_pbar = tqdm(total=self.cfg.epochs, desc="Epochs", position=0, leave=True)
best_val_map = 0
for epoch in range(self.start_epoch, self.cfg.epochs):
epoch_pbar.set_description(f"Epoch {epoch}/{self.cfg.epochs}")
if self.run and self.cfg.ckpt.save:
ckpt_path = self._save_checkpoint(epoch)
if self.cfg.log:
log_checkpoint_artifact(
self.run, ckpt_path, self.cfg.model.name, epoch, self.cfg.wait
)
if isinstance(self.train_dl.dataset, AlbumentationsMosaicDataset):
self.train_dl.dataset.update_epoch(epoch)
train_loss = self.train(
epoch,
)
val_loss, val_metrics, _ = self.eval(self.val_dl, epoch)
epoch_pbar.update(1)
best_val_map = max(val_metrics["map@50"], best_val_map)
if self.cfg.log:
log_epoch_data(
epoch,
train_loss,
val_loss,
val_metrics,
self,
)
if self.early_stopping(val_metrics["map@50:95"], self.model):
tqdm.write(f"Early stopping triggered at epoch {epoch}.")
break
tqdm.write("Training finished.")
if self.cfg.log:
log_test_data(epoch, best_val_map, self)
self.run.finish()
epoch_pbar.close()
[docs]
def _build_save_dict(self, epoch):
ckpt = {
"epoch": epoch,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"scaler": self.scaler.state_dict(),
"ema": self.ema.state_dict(),
}
return ckpt
[docs]
def _save_checkpoint(self, epoch: int) -> str:
"""
Saves a checkpoint of the model, optimizer, scheduler, and scaler states.
Args:
epoch (int): The current epoch number.
Returns:
str: The path to the saved checkpoint file.
"""
tqdm.write("saving checkpoint")
if self.ema:
with self.ema.average_parameters():
ckpt = self._build_save_dict(epoch)
ckpt.update({"ema": self.ema.state_dict()})
else:
ckpt = self._build_save_dict(epoch)
path = os.path.join("checkpoints", f"{self.name}_epoch{epoch}.pth")
torch.save(ckpt, path)
tqdm.write("done saving checkpoint")
return path
[docs]
def _load_checkpoint(self, path: str, model_only: bool = True) -> None:
"""
Loads a checkpoint and restores the state of the model, optimizer, scheduler, and scaler.
Args:
path (str): The path to the checkpoint file.
"""
tqdm.write("loading checkpoint")
ckpt = torch.load(path, map_location=self.device)
if model_only:
self.model.load_state_dict(ckpt)
else:
self.model.load_state_dict(ckpt["model"])
self.optimizer.load_state_dict(ckpt["optimizer"])
self.scheduler.load_state_dict(ckpt["scheduler"])
self.scaler.load_state_dict(ckpt["scaler"])
if "ema" in ckpt:
self.ema.load_state_dict(ckpt["ema"])
self.start_epoch = ckpt["epoch"] + 1