fruit_project.models.model_factory

Attributes

Functions

get_model(→ Tuple[torch.nn.Module, ...)

Retrieves and initializes a model based on the provided configuration.

get_RTDETRv2(→ Tuple[torch.nn.Module, ...)

Loads the RT-DETRv2 model along with its configuration, processor, and transformations.

Module Contents

fruit_project.models.model_factory.supported_models[source]
fruit_project.models.model_factory.get_model(cfg: omegaconf.DictConfig, device: torch.device, n_classes: int, id2lbl: Dict, lbl2id: Dict) Tuple[torch.nn.Module, albumentations.Compose, List, List, transformers.AutoImageProcessor][source]

Retrieves and initializes a model based on the provided configuration. :param cfg: Configuration object containing model specifications. :type cfg: DictConfig :param device: The device on which the model will be loaded (e.g., ‘cpu’ or ‘cuda’). :type device: torch.device :param n_classes: Number of classes for the model’s output. :type n_classes: int :param id2lbl: Mapping from class IDs to labels. :type id2lbl: dict :param lbl2id: Mapping from labels to class IDs. :type lbl2id: dict

Returns:

The initialized model.

Return type:

torch.nn.Module

Raises:

ValueError – If the specified model name in the configuration is not supported.

fruit_project.models.model_factory.get_RTDETRv2(device: torch.device, n_classes: int, id2label: dict, label2id: dict, cfg: omegaconf.DictConfig) Tuple[torch.nn.Module, albumentations.Compose, List, List, transformers.AutoImageProcessor][source]

Loads the RT-DETRv2 model along with its configuration, processor, and transformations.

Parameters:
  • device (str) – The device to load the model onto (e.g., ‘cpu’, ‘cuda’).

  • n_classes (int) – The number of classes for the object detection task.

  • id2label (dict) – A dictionary mapping class IDs to class labels.

  • label2id (dict) – A dictionary mapping class labels to class IDs.

  • cfg (object) – Configuration object containing model settings, including the model name.

Returns:

A tuple containing:
  • model (torch.nn.Module): The loaded RT-DETRv2 model moved to the specified device.

  • transforms (callable): The transformation function for preprocessing input images.

  • image_mean (list): The mean values used for image normalization.

  • image_std (list): The standard deviation values used for image normalization.

  • processor (AutoImageProcessor): The processor for handling image inputs.

Return type:

tuple