Source code for inFairness.distances.euclidean_dists

import torch

from inFairness.distances.distance import Distance


[docs] class EuclideanDistance(Distance): def __init__(self): super().__init__()
[docs] def forward(self, x, y, itemwise_dist=True): if itemwise_dist: return torch.cdist(x.unsqueeze(1), y.unsqueeze(1)).reshape(-1, 1) else: return torch.cdist(x, y)
[docs] class ProtectedEuclideanDistance(Distance): def __init__(self): super().__init__() self._protected_attributes = None self._num_attributes = None self.register_buffer("protected_vector", torch.Tensor())
[docs] def to(self, device): """Moves distance metric to a particular device Parameters ------------ device: torch.device """ assert ( self.protected_vector is not None and len(self.protected_vector.size()) != 0 ), "Please fit the metric before moving parameters to device" self.device = device self.protected_vector = self.protected_vector.to(self.device)
[docs] def fit(self, protected_attributes, num_attributes): """Fit Protected Euclidean Distance metric Parameters ------------ protected_attributes: Iterable[int] List of attribute indices considered to be protected. The metric would ignore these protected attributes while computing distance between data points. num_attributes: int Total number of attributes in the data points. """ self._protected_attributes = protected_attributes self._num_attributes = num_attributes self.protected_vector = torch.ones(num_attributes) self.protected_vector[protected_attributes] = 0.0
[docs] def forward(self, x, y, itemwise_dist=True): """ :param x, y: a B x D matrices :return: B x 1 matrix with the protected distance camputed between x and y """ protected_x = (x * self.protected_vector).unsqueeze(1) protected_y = (y * self.protected_vector).unsqueeze(1) if itemwise_dist: return torch.cdist(protected_x, protected_y).reshape(-1, 1) else: return torch.cdist(protected_x, protected_y)