Source code for fruit_project.models.model_factory

# SPDX-FileCopyrightText: 2025 Mohamed Khayat
# SPDX-License-Identifier: AGPL-3.0-or-later

import torch
import torch.nn as nn
from fruit_project.models.transforms_factory import get_transforms
from fruit_project.config import SUPPORTED_MODELS
from omegaconf import DictConfig
from typing import Any, Dict, List, Tuple
from albumentations import Compose
from transformers import (
    AutoImageProcessor,
    AutoModelForObjectDetection,
    AutoConfig,
)

[docs] supported_models = SUPPORTED_MODELS
[docs] def get_model( cfg: DictConfig, device: torch.device, n_classes: int, id2lbl: Dict, lbl2id: Dict ) -> Tuple[nn.Module, Compose, List, List, AutoImageProcessor, Any]: """ Retrieves and initializes a model based on the provided configuration. Args: cfg (DictConfig): Configuration object containing model specifications. device (torch.device): The device on which the model will be loaded (e.g., 'cpu' or 'cuda'). n_classes (int): Number of classes for the model's output. id2lbl (dict): Mapping from class IDs to labels. lbl2id (dict): Mapping from labels to class IDs. Returns: torch.nn.Module: The initialized model. Raises: ValueError: If the specified model name in the configuration is not supported. """ if cfg.model.name in supported_models.keys(): return get_hf_model(device, n_classes, id2lbl, lbl2id, cfg) else: raise ValueError( f"model : {cfg.model.name} not supported, use one of : {supported_models.keys()}" )
[docs] def get_hf_model( device: torch.device, n_classes: int, id2label: dict, label2id: dict, cfg: DictConfig, ) -> Tuple[nn.Module, Compose, List, List, AutoImageProcessor, Any]: """ Loads the HF model along with its configuration, processor, and transformations. Args: device (str): The device to load the model onto (e.g., 'cpu', 'cuda'). n_classes (int): The number of classes for the object detection task. id2label (dict): A dictionary mapping class IDs to class labels. label2id (dict): A dictionary mapping class labels to class IDs. cfg (object): Configuration object containing model settings, including the model name. Returns: tuple: A tuple containing: - model (torch.nn.Module): The loaded RT-DETRv2 model moved to the specified device. - transforms (callable): The transformation function for preprocessing input images. - image_mean (list): The mean values used for image normalization. - image_std (list): The standard deviation values used for image normalization. - processor (AutoImageProcessor): The processor for handling image inputs. """ checkpoint = supported_models[cfg.model.name] print(f"getting : {checkpoint}") config_kwargs: Dict[str, Any] = dict( trust_remote_code=True, num_labels=n_classes, id2label=id2label, label2id=label2id, ) if hasattr(cfg.model, "decoder_method"): config_kwargs["decoder_method"] = cfg.model.decoder_method config = AutoConfig.from_pretrained( checkpoint, **config_kwargs, ) processor_kwargs: Dict[str, Any] = { "trust_remote_code": True, "use_fast": True, "do_resize": True, "do_pad": True, "do_normalize": cfg.model.do_normalize, "size": { "max_height": cfg.model.input_height, "max_width": cfg.model.input_width, }, "pad_size": {"height": cfg.model.input_height, "width": cfg.model.input_width}, } model_kwargs: Dict[str, Any] = { "config": config, "ignore_mismatched_sizes": True, "output_loading_info": cfg.smart_optim, } if "yolos" in cfg.model.name: model_kwargs.update( {"attn_implementation": "sdpa", "torch_dtype": torch.float32} ) model = AutoModelForObjectDetection.from_pretrained( checkpoint, **model_kwargs, ) loading_info = None if cfg.smart_optim: model, loading_info = model # model = freeze_weights(model, cfg.freeze_backbone, cfg.partially_freeze_backbone) processor = AutoImageProcessor.from_pretrained(checkpoint, **processor_kwargs) transforms = get_transforms(cfg, id2label) print("model loaded") return ( model.to(device), transforms, processor.image_mean, processor.image_std, processor, loading_info, )
[docs] def freeze_weights( model: nn.Module, freeze_backbone=True, partially_freeze_backbone=False, ) -> nn.Module: for name, param in model.named_parameters(): param.requires_grad = True if freeze_backbone and name.startswith("model.backbone"): param.requires_grad = False if partially_freeze_backbone and name.startswith( "model.backbone.model.encoder.stages.3" ): param.requires_grad = True return model