Source code for inFairness.distances.wasserstein_distance

import torch
from ot import emd2

from inFairness.distances import MahalanobisDistances


[docs] class WassersteinDistance(MahalanobisDistances): """computes a batched Wasserstein Distance for pairs of sets of items on each batch in the tensors with dimensions B, N, D and B, M, D where B and D are the batch and feature sizes and N and M are the number of items on each batch. Currently only supporting distances inheriting from :class: `MahalanobisDistances`. transforms an Mahalanobis Distance object so that the forward method becomes a differentiable batched Wasserstein distance between sets of items. This Wasserstein distance will use the underlying Mahalanobis distance as pairwise cost function to solve the optimal transport problem. for more information see equation 2.5 of the reference bellow References ---------- `Amanda Bower, Hamid Eftekhari, Mikhail Yurochkin, Yuekai Sun: Individually Fair Rankings. ICLR 2021` """ def __init__(self): super().__init__()
[docs] def forward(self, X1: torch.Tensor, X2: torch.Tensor): """computes a batch wasserstein distance implied by the cost function represented by an underlying mahalanobis distance. Parameters -------------- X1: torch.Tensor Data sample of shape (B, N, D) X2: torch.Tensor Data sample of shape (B, M, D) Returns -------- dist: torch.Tensor Wasserstein distance of shape (B) between batch samples in X1 and X2 """ costs = super().forward(X1, X2, itemwise_dist=False) uniform_x1 = torch.ones(X1.shape[1]) / X1.shape[1] uniform_x2 = torch.ones(X2.shape[1]) / X2.shape[1] num_batches = X1.shape[0] dist = torch.stack( [emd2(uniform_x1, uniform_x2, costs[j]) for j in range(num_batches)] ) return dist