fruit_project.utils.early_stop

Classes

Module Contents

class fruit_project.utils.early_stop.EarlyStopping(patience: int, delta: float, path: str, name: str, cfg: omegaconf.DictConfig, run: wandb.sdk.wandb_run.Run)[source]
patience[source]
delta[source]
path[source]
name[source]
cfg[source]
run[source]
best_metric: float | None = None[source]
counter = 0[source]
earlystop = False[source]
saved_checkpoints: List[Tuple[float, pathlib.Path]] = [][source]
__call__(val_metric: float, model: torch.nn.Module) bool[source]

Checks if early stopping criteria are met and saves the model if the metric improves.

Parameters:
  • val_metric (float) – Validation metric to monitor.

  • model (nn.Module) – PyTorch model to save.

Returns:

True if early stopping criteria are met, False otherwise.

Return type:

bool

save_model(model: torch.nn.Module, val_metric: float)[source]

Saves the model checkpoint.

Parameters:
  • model (nn.Module) – PyTorch model to save.

  • val_metric (float) – Validation metric value used for naming the checkpoint file.

Returns:

None

cleanup_checkpoints()[source]

Deletes all saved checkpoints except the best one.

Returns:

None

get_best_model(model: torch.nn.Module) torch.nn.Module[source]

Loads the best model checkpoint and sets the model to evaluation mode.

Parameters:

model (nn.Module) – PyTorch model to load the best checkpoint into.

Returns:

The model with the best checkpoint loaded.

Return type:

nn.Module