Source code for inFairness.postprocessing.data_ds

import torch

from inFairness.postprocessing.distance_ds import DistanceStructure


[docs] class PostProcessingDataStore(object): """Data strucuture to hold the data used for post-processing Parameters ------------- distance_x: inFairness.distances.Distance Distance metric in the input space """ def __init__(self, distance_x): self.data_X = None self.data_Y = None self.n_samples = 0 self.distance_ds = DistanceStructure(distance_x) @property def distance_matrix(self): """Distances between N data points. Shape: (N, N)""" return self.distance_ds.distance_matrix
[docs] def add_datapoints_X(self, X: torch.Tensor): """Add datapoints to the input datapoints X Parameters ------------ X: torch.Tensor New data points to add to the input data `X` should have the same dimensions as previous data along all dimensions except the first (batch) dimension """ if self.data_X is None: self.data_X = X else: self.data_X = torch.cat([self.data_X, X], dim=0)
[docs] def add_datapoints_Y(self, Y: torch.Tensor): """Add datapoints to the output datapoints Y Parameters ------------ Y: torch.Tensor New data points to add to the output data `Y` should have the same dimensions as previous data along all dimensions except the first (batch) dimension """ if self.data_Y is None: self.data_Y = Y else: self.data_Y = torch.cat([self.data_Y, Y], dim=0)
[docs] def add_datapoints(self, X: torch.Tensor, Y: torch.Tensor): """Add new datapoints to the existing datapoints Parameters ------------ X: torch.Tensor New data points to add to the input data `X` should have the same dimensions as previous data along all dimensions except the first (batch) dimension Y: torch.Tensor New data points to add to the output data `Y` should have the same dimensions as previous data along all dimensions except the first (batch) dimension """ self.add_datapoints_X(X) self.add_datapoints_Y(Y) self.n_samples = self.n_samples + X.shape[0] self.distance_ds.build_distance_matrix(self.data_X)
[docs] def reset(self): """Reset the data structure holding the data points for post-processing. Invoking this operation removes all datapoints and resets the state back to the initial state. """ self.data_X = None self.data_Y = None self.n_samples = 0 self.distance_ds.reset()