Source code for inFairness.fairalgo.sensei

import torch
from torch import nn

from inFairness.auditor import SenSeIAuditor
from inFairness.fairalgo.datainterfaces import FairModelResponse
from inFairness.utils import datautils


[docs] class SenSeI(nn.Module): """Implementes the Sensitive Set Invariane (SenSeI) algorithm. Proposed in `SenSeI: Sensitive Set Invariance for Enforcing Individual Fairness <https://arxiv.org/abs/2006.14168>`_ Parameters ------------ network: torch.nn.Module Network architecture distance_x: inFairness.distances.Distance Distance metric in the input space distance_y: inFairness.distances.Distance Distance metric in the output space loss_fn: torch.nn.Module Loss function rho: float :math:`\\rho` parameter in the SenSR algorithm eps: float :math:`\epsilon` parameter in the SenSR algorithm auditor_nsteps: int Number of update steps for the auditor to find worst-case examples auditor_lr: float Learning rate for the auditor """ def __init__( self, network, distance_x, distance_y, loss_fn, rho, eps, auditor_nsteps, auditor_lr, ): super().__init__() self.distance_x = distance_x self.distance_y = distance_y self.network = network self.loss_fn = loss_fn self.lamb = None self.rho = rho self.eps = eps self.auditor_nsteps = auditor_nsteps self.auditor_lr = auditor_lr self.auditor = self.__init_auditor__() def __init_auditor__(self): auditor = SenSeIAuditor( distance_x=self.distance_x, distance_y=self.distance_y, num_steps=self.auditor_nsteps, lr=self.auditor_lr, ) return auditor
[docs] def forward_train(self, X, Y): """Forward method during the training phase""" device = datautils.get_device(X) minlambda = torch.tensor(1e-5, device=device) if self.lamb is None: self.lamb = torch.tensor(1.0, device=device) if type(self.eps) is float: self.eps = torch.tensor(self.eps, device=device) Y_pred = self.network(X) X_worst = self.auditor.generate_worst_case_examples( self.network, X, lambda_param=self.lamb ) dist_x = self.distance_x(X, X_worst) mean_dist_x = dist_x.mean() lr_factor = torch.maximum(mean_dist_x, self.eps) / torch.minimum(mean_dist_x, self.eps) self.lamb = torch.max( torch.stack( [minlambda, self.lamb + lr_factor * (mean_dist_x - self.eps)] ) ) Y_pred_worst = self.network(X_worst) fair_loss = torch.mean( self.loss_fn(Y_pred, Y) + self.rho * self.distance_y(Y_pred, Y_pred_worst) ) response = FairModelResponse(loss=fair_loss, y_pred=Y_pred) return response
[docs] def forward_test(self, X): """Forward method during the test phase""" Y_pred = self.network(X) response = FairModelResponse(y_pred=Y_pred) return response
[docs] def forward(self, X, Y=None, *args, **kwargs): """Defines the computation performed at every call. Parameters ------------ X: torch.Tensor Input data Y: torch.Tensor Expected output data Returns ---------- output: torch.Tensor Model output """ if self.training: return self.forward_train(X, Y) else: return self.forward_test(X)