Source code for inFairness.postprocessing.distance_ds
import torch
[docs]
class DistanceStructure(object):
"""Data structure to store and track the distance matrix between data points
Parameters
-------------
distance_x: inFairness.distances.Distance
Distance metric in the input space
"""
def __init__(self, distance_x):
self.distance_x = distance_x
self.distance_matrix = None
[docs]
def reset(self):
"""Reset the state of the data structure back to its initial state"""
self.distance_matrix = None
[docs]
def build_distance_matrix(self, data_X):
"""Build the distance matrix between input data samples `data_X`
Parameters
-------------
data_X: torch.Tensor
Data points between which the distance matrix is to be computed
"""
nsamples_old = (
0 if self.distance_matrix is None else self.distance_matrix.shape[0]
)
nsamples_total = data_X.shape[0]
device = data_X.device
distance_matrix_new = torch.zeros(
size=(nsamples_total, nsamples_total), device=device
)
if self.distance_matrix is not None:
distance_matrix_new[:nsamples_old, :nsamples_old] = self.distance_matrix
dist = (
self.distance_x(
data_X[nsamples_old:nsamples_total], data_X, itemwise_dist=False
)
.detach()
.squeeze()
)
distance_matrix_new[nsamples_old:, :] = dist
distance_matrix_new[:, nsamples_old:] = dist.T
self.distance_matrix = distance_matrix_new.clone()