# SPDX-FileCopyrightText: 2025 Mohamed Khayat
# SPDX-License-Identifier: AGPL-3.0-or-later
import torch
import torch.nn as nn
from fruit_project.models.transforms_factory import get_transforms
from fruit_project.config import SUPPORTED_MODELS
from omegaconf import DictConfig
from typing import Any, Dict, List, Tuple
from albumentations import Compose
from transformers import (
AutoImageProcessor,
AutoModelForObjectDetection,
AutoConfig,
)
[docs]
supported_models = SUPPORTED_MODELS
[docs]
def get_model(
cfg: DictConfig, device: torch.device, n_classes: int, id2lbl: Dict, lbl2id: Dict
) -> Tuple[nn.Module, Compose, List, List, AutoImageProcessor, Any]:
"""
Retrieves and initializes a model based on the provided configuration.
Args:
cfg (DictConfig): Configuration object containing model specifications.
device (torch.device): The device on which the model will be loaded (e.g., 'cpu' or 'cuda').
n_classes (int): Number of classes for the model's output.
id2lbl (dict): Mapping from class IDs to labels.
lbl2id (dict): Mapping from labels to class IDs.
Returns:
torch.nn.Module: The initialized model.
Raises:
ValueError: If the specified model name in the configuration is not supported.
"""
if cfg.model.name in supported_models.keys():
return get_hf_model(device, n_classes, id2lbl, lbl2id, cfg)
else:
raise ValueError(
f"model : {cfg.model.name} not supported, use one of : {supported_models.keys()}"
)
[docs]
def get_hf_model(
device: torch.device,
n_classes: int,
id2label: dict,
label2id: dict,
cfg: DictConfig,
) -> Tuple[nn.Module, Compose, List, List, AutoImageProcessor, Any]:
"""
Loads the HF model along with its configuration, processor, and transformations.
Args:
device (str): The device to load the model onto (e.g., 'cpu', 'cuda').
n_classes (int): The number of classes for the object detection task.
id2label (dict): A dictionary mapping class IDs to class labels.
label2id (dict): A dictionary mapping class labels to class IDs.
cfg (object): Configuration object containing model settings, including the model name.
Returns:
tuple: A tuple containing:
- model (torch.nn.Module): The loaded RT-DETRv2 model moved to the specified device.
- transforms (callable): The transformation function for preprocessing input images.
- image_mean (list): The mean values used for image normalization.
- image_std (list): The standard deviation values used for image normalization.
- processor (AutoImageProcessor): The processor for handling image inputs.
"""
checkpoint = supported_models[cfg.model.name]
print(f"getting : {checkpoint}")
config_kwargs: Dict[str, Any] = dict(
trust_remote_code=True,
num_labels=n_classes,
id2label=id2label,
label2id=label2id,
)
if hasattr(cfg.model, "decoder_method"):
config_kwargs["decoder_method"] = cfg.model.decoder_method
config = AutoConfig.from_pretrained(
checkpoint,
**config_kwargs,
)
processor_kwargs: Dict[str, Any] = {
"trust_remote_code": True,
"use_fast": True,
"do_resize": True,
"do_pad": True,
"do_normalize": cfg.model.do_normalize,
"size": {
"max_height": cfg.model.input_height,
"max_width": cfg.model.input_width,
},
"pad_size": {"height": cfg.model.input_height, "width": cfg.model.input_width},
}
model_kwargs: Dict[str, Any] = {
"config": config,
"ignore_mismatched_sizes": True,
"output_loading_info": cfg.smart_optim,
}
if "yolos" in cfg.model.name:
model_kwargs.update(
{"attn_implementation": "sdpa", "torch_dtype": torch.float32}
)
model = AutoModelForObjectDetection.from_pretrained(
checkpoint,
**model_kwargs,
)
loading_info = None
if cfg.smart_optim:
model, loading_info = model
# model = freeze_weights(model, cfg.freeze_backbone, cfg.partially_freeze_backbone)
processor = AutoImageProcessor.from_pretrained(checkpoint, **processor_kwargs)
transforms = get_transforms(cfg, id2label)
print("model loaded")
return (
model.to(device),
transforms,
processor.image_mean,
processor.image_std,
processor,
loading_info,
)
[docs]
def freeze_weights(
model: nn.Module,
freeze_backbone=True,
partially_freeze_backbone=False,
) -> nn.Module:
for name, param in model.named_parameters():
param.requires_grad = True
if freeze_backbone and name.startswith("model.backbone"):
param.requires_grad = False
if partially_freeze_backbone and name.startswith(
"model.backbone.model.encoder.stages.3"
):
param.requires_grad = True
return model