fruit_project.models.model_factory¶
Attributes¶
Functions¶
|
Retrieves and initializes a model based on the provided configuration. |
|
Loads the RT-DETRv2 model along with its configuration, processor, and transformations. |
Module Contents¶
- fruit_project.models.model_factory.get_model(cfg: omegaconf.DictConfig, device: torch.device, n_classes: int, id2lbl: Dict, lbl2id: Dict) Tuple[torch.nn.Module, albumentations.Compose, List, List, transformers.AutoImageProcessor] [source]¶
Retrieves and initializes a model based on the provided configuration. :param cfg: Configuration object containing model specifications. :type cfg: DictConfig :param device: The device on which the model will be loaded (e.g., ‘cpu’ or ‘cuda’). :type device: torch.device :param n_classes: Number of classes for the model’s output. :type n_classes: int :param id2lbl: Mapping from class IDs to labels. :type id2lbl: dict :param lbl2id: Mapping from labels to class IDs. :type lbl2id: dict
- Returns:
The initialized model.
- Return type:
torch.nn.Module
- Raises:
ValueError – If the specified model name in the configuration is not supported.
- fruit_project.models.model_factory.get_RTDETRv2(device: torch.device, n_classes: int, id2label: dict, label2id: dict, cfg: omegaconf.DictConfig) Tuple[torch.nn.Module, albumentations.Compose, List, List, transformers.AutoImageProcessor] [source]¶
Loads the RT-DETRv2 model along with its configuration, processor, and transformations.
- Parameters:
device (str) – The device to load the model onto (e.g., ‘cpu’, ‘cuda’).
n_classes (int) – The number of classes for the object detection task.
id2label (dict) – A dictionary mapping class IDs to class labels.
label2id (dict) – A dictionary mapping class labels to class IDs.
cfg (object) – Configuration object containing model settings, including the model name.
- Returns:
- A tuple containing:
model (torch.nn.Module): The loaded RT-DETRv2 model moved to the specified device.
transforms (callable): The transformation function for preprocessing input images.
image_mean (list): The mean values used for image normalization.
image_std (list): The standard deviation values used for image normalization.
processor (AutoImageProcessor): The processor for handling image inputs.
- Return type:
tuple