fruit_project.utils.general¶
Functions¶
|
Sets the seed for reproducibility across various libraries. |
|
Seeds a worker for multiprocessing to ensure reproducibility. |
|
Plots an image using matplotlib. |
|
Unnormalizes an image tensor by reversing normalization. |
|
Checks if the given model is a Hugging Face PreTrainedModel. |
Module Contents¶
- fruit_project.utils.general.set_seed(SEED: int) torch.Generator [source]¶
Sets the seed for reproducibility across various libraries.
- Parameters:
SEED (int) – The seed value to use.
- Returns:
A PyTorch generator seeded with the given value.
- Return type:
torch.Generator
- fruit_project.utils.general.seed_worker(worker_id: int, base_seed: int) None [source]¶
Seeds a worker for multiprocessing to ensure reproducibility.
- Parameters:
worker_id (int) – The ID of the worker.
base_seed (int) – The base seed value.
- Returns:
None
- fruit_project.utils.general.plot_img(img, label: str | None = None) None [source]¶
Plots an image using matplotlib.
- Parameters:
img (torch.Tensor) – The image tensor to plot (shape: C x H x W).
label (str, optional) – The label to display as the title. Defaults to None.
- Returns:
None
- fruit_project.utils.general.unnormalize(img_tensor: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) torch.Tensor [source]¶
Unnormalizes an image tensor by reversing normalization.
- Parameters:
img_tensor (torch.Tensor) – The normalized image tensor (shape: N x C x H x W or C x H x W).
mean (torch.Tensor) – The mean used for normalization.
std (torch.Tensor) – The standard deviation used for normalization.
- Returns:
The unnormalized image tensor.
- Return type:
torch.Tensor