Source code for fruit_project.main

import glob
from omegaconf import DictConfig
import hydra
from fruit_project.utils.data import make_datasets, make_dataloaders
from fruit_project.models.model_factory import get_model
import torch
import cv2
from fruit_project.utils.trainer import Trainer
from fruit_project.utils.logging import (
    initwandb,
    get_run_name,
    log_transforms,
    log_images,
)
from fruit_project.utils.general import set_seed

cv2.setNumThreads(0)


@hydra.main(version_base=None, config_path="../../conf", config_name="config")
[docs] def main(cfg: DictConfig): if cfg.log: run = initwandb(cfg) name = run.name else: run = None name = get_run_name(cfg) generator = set_seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"using : {device}") train_ds, test_ds, val_ds = make_datasets(cfg) model, transforms, mean, std, processor = get_model( cfg, device, len(train_ds.labels), train_ds.id2lbl, train_ds.lbl2id, ) train_dl, test_dl, val_dl = make_dataloaders( cfg, train_ds, test_ds, val_ds, generator, processor, transforms ) log_images(run, next(iter(test_dl)), test_ds.id2lbl, (3, 3), mean, std) log_transforms( run, next(iter(train_dl)), (3, 3), train_ds.id2lbl, transforms, mean, std ) trainer = Trainer( model, processor, device, cfg, name, run, train_dl, test_dl, val_dl ) print("Setup complete.") if cfg.load_ckpt: existing = sorted(glob.glob("../checkpoints/*.pth")) if existing: trainer._load_checkpoint(existing[-1]) trainer.fit()
if __name__ == "__main__": main()