Source code for pysad.evaluation.windowed_metric

from pysad.core.base_metric import BaseMetric
from pysad.statistics.average_meter import AverageMeter


[docs]class WindowedMetric(BaseMetric): """A helper class to evaluate windowed metrics. The distributions of the streaming model scores often change due to model collapse (i.e. becoming closer to the always loss=0) or appearing nonstationarities. Thus, the metrics such as ROC or AUC scores may change drastically. To prevent their effect, this class creates windows of size `window_size`. After each `window_size`th object, a new instance of the `metric_cls` is being created. Lastly, the metrics from all windows are averaged :cite:`xstream,gokcesu2017online`. Args: metric_cls (class): The metric class to be windowed. window_size (int): The window size. ignore_nonempty_last (bool): Whether to ignore the score of the nonempty last window. Note that the empty last window is always ignored. """ def __init__( self, metric_cls, window_size, ignore_nonempty_last=True, **kwargs): super().__init__() self.ignore_nonempty_last = ignore_nonempty_last self.window_size = window_size self.metric_cls = metric_cls self.metric = self._init_metric(**kwargs) self.score_meter = AverageMeter() self.step = 0 self.num_windows = 1 def _init_metric(self, **kwargs): return self.metric_cls(**kwargs)
[docs] def update(self, y_true, y_pred): """Updates the score with new true label and predicted score/label. Args: y_true : float The ground truth score for the incoming instance. y_pred : float The predicted score for the incoming instance. Returns: object: self. """ self.step += 1 self.metric.update(y_true, y_pred) if self.step % self.window_size == 0: self.num_windows += 1 score = self.metric.get() self.score_meter.update(score) self.metric = self._init_metric() return self
[docs] def get(self): """Obtains the averaged score. Returns: float: The average score of the windows. """ if self.num_windows == 1: return self.metric.get() elif not self.ignore_nonempty_last and self.step % self.window_size != 0: return (self.metric.get() + self.score_meter.get() * (self.num_windows - 1)) / self.num_windows else: return self.score_meter.get()