Source code for pygod.utils.early_stopping

# -*- coding: utf-8 -*-
"""
Early Stopping Counter
Adapted from DGL
"""
# Author: Kay Liu <zliu234@uic.edu>
# License: BSD 2 clause

import torch


[docs]class EarlyStopping: r"""Early Stopping Counter Parameters ---------- patience : int The epoch number waiting after the highest score Default: 10 verbose : bool Whether to print information Default: False """ def __init__(self, patience: int = 10, verbose: bool = True): self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.stop = False
[docs] def step(self, score: float, model: torch.nn.Module) -> bool: if self.best_score is None: self.best_score = score torch.save(model.state_dict(), 'es_checkpoint.pt') elif score < self.best_score: self.counter += 1 if self.verbose: print(f'Early Stopping Counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.stop = True else: self.best_score = score torch.save(model.state_dict(), 'es_checkpoint.pt') self.counter = 0 return self.stop