inFairness.distances.mahalanobis_distance module#

class inFairness.distances.mahalanobis_distance.MahalanobisDistances[source]#

Bases: Distance

Base class implementing the Generalized Mahalanobis Distances

Mahalanobis distance between two points X1 and X2 is computed as:

\[\text{dist}(X_1, X_2) = (X_1 - X_2) \Sigma (X_1 - X_2)^{T}\]
fit(sigma)[source]#

Fit Mahalanobis Distance metric

Parameters:

sigma (torch.Tensor) – Covariance matrix

forward(X1, X2, itemwise_dist=True)[source]#

Computes the distance between data samples X1 and X2

Parameters:
  • X1 (torch.Tensor) – Data samples from batch 1 of shape (n_samples_1, n_features)

  • X2 (torch.Tensor) – Data samples from batch 2 of shape (n_samples_2, n_features)

  • itemwise_dist (bool, default: True) –

    Compute the distance in an itemwise manner or pairwise manner.

    In the itemwise fashion (itemwise_dist=False), distance is computed between the ith data sample in X1 to the ith data sample in X2. Thus, the two data samples X1 and X2 should be of the same shape

    In the pairwise fashion (itemwise_dist=False), distance is computed between all the samples in X1 and all the samples in X2. In this case, the two data samples X1 and X2 can be of different shapes.

Returns:

dist – Distance between samples of batch 1 and batch 2.

If itemwise_dist=True, item-wise distance is returned of shape (n_samples, 1)

If itemwise_dist=False, pair-wise distance is returned of shape (n_samples_1, n_samples_2)

Return type:

torch.Tensor

to(device)[source]#

Moves distance metric to a particular device

Parameters:

device (torch.device) –

training: bool#
class inFairness.distances.mahalanobis_distance.SquaredEuclideanDistance[source]#

Bases: MahalanobisDistances

computes the squared euclidean distance as a special case of the mahalanobis distance where:

\[\Sigma= I_{num_dims}\]
fit(num_dims: int)[source]#

Fit Square Euclidean Distance metric

Parameters:

num_dims (int) – the number of dimensions of the space in which the Squared Euclidean distance will be used.

training: bool#