[docs]classEarlyStopping:def__init__(self,patience:int,delta:float,path:str,name:str,cfg:DictConfig,run:Run,):""" Initializes the EarlyStopping object. Args: patience (int): Number of epochs to wait before stopping if no improvement. delta (float): Minimum change in the monitored metric to qualify as an improvement. path (str): Directory path to save model checkpoints. name (str): Name prefix for saved model files. cfg (DictConfig): Configuration object. run (Run): WandB run object for logging artifacts. """
[docs]def__call__(self,val_metric:float,model:nn.Module)->bool:""" Checks if early stopping criteria are met and saves the model if the metric improves. Args: val_metric (float): Validation metric to monitor. model (nn.Module): PyTorch model to save. Returns: bool: True if early stopping criteria are met, False otherwise. """ifself.best_metricisNone:self.best_metric=val_metrictqdm.write("saved model weights")self.save_model(model,val_metric)elifval_metric<=self.best_metric+self.delta:self.counter+=1else:self.best_metric=val_metricself.save_model(model,val_metric)self.counter=0tqdm.write("saved model weights")ifself.counter>=self.patience:self.earlystop=Truereturnself.earlystop
[docs]defsave_model(self,model:nn.Module,val_metric:float):""" Saves the model checkpoint. Args: model (nn.Module): PyTorch model to save. val_metric (float): Validation metric value used for naming the checkpoint file. Returns: None """filename=f"{self.name}_{val_metric:.4f}.pth"full_path=self.path/filenametorch.save(model.state_dict(),full_path)self.saved_checkpoints.append((val_metric,full_path))
[docs]defcleanup_checkpoints(self):""" Deletes all saved checkpoints except the best one. Returns: None """ifnotself.saved_checkpoints:tqdm.write("No checkpoints to clean up.")returntqdm.write("cleaning up old checkpoints...")_,best_path=max(self.saved_checkpoints,key=lambdax:x[0])for_,pathinself.saved_checkpoints:ifpath!=best_pathandpath.exists():try:path.unlink()tqdm.write(f"deleted {path.name}")exceptExceptionase:tqdm.write(f"could not delete {path.name}: {e}")tqdm.write(f"kept best model: {best_path.name}")
[docs]defget_best_model(self,model:nn.Module)->nn.Module:""" Loads the best model checkpoint and sets the model to evaluation mode. Args: model (nn.Module): PyTorch model to load the best checkpoint into. Returns: nn.Module: The model with the best checkpoint loaded. """self.cleanup_checkpoints()tqdm.write("loading best model")model.eval()iflen(self.saved_checkpoints)>0:_,best_path=max(self.saved_checkpoints,key=lambdax:x[0])model.load_state_dict(torch.load(best_path,weights_only=True))""" artifact = wandb.Artifact( name=f"{self.cfg.model.name}", type="model-earlystopping-bestmodel", description=f"best model at epoch", ) artifact.add_file(best_path) self.run.log_artifact(artifact) artifact.wait() """returnmodel