inFairness.distances.wasserstein_distance module#

class inFairness.distances.wasserstein_distance.WassersteinDistance[source]#

Bases: MahalanobisDistances

computes a batched Wasserstein Distance for pairs of sets of items on each batch in the tensors with dimensions B, N, D and B, M, D where B and D are the batch and feature sizes and N and M are the number of items on each batch.

Currently only supporting distances inheriting from :class: MahalanobisDistances.

transforms an Mahalanobis Distance object so that the forward method becomes a differentiable batched Wasserstein distance between sets of items. This Wasserstein distance will use the underlying Mahalanobis distance as pairwise cost function to solve the optimal transport problem.

for more information see equation 2.5 of the reference bellow

References

Amanda Bower, Hamid Eftekhari, Mikhail Yurochkin, Yuekai Sun: Individually Fair Rankings. ICLR 2021

forward(X1: Tensor, X2: Tensor)[source]#

computes a batch wasserstein distance implied by the cost function represented by an underlying mahalanobis distance.

Parameters:
Returns:

dist – Wasserstein distance of shape (B) between batch samples in X1 and X2

Return type:

torch.Tensor

training: bool#