Source code for inFairness.fairalgo.sensr

import torch
from torch import nn

from inFairness.auditor import SenSRAuditor
from inFairness.fairalgo.datainterfaces import FairModelResponse
from inFairness.utils import datautils


[docs] class SenSR(nn.Module): """Implementes the Sensitive Subspace Robustness (SenSR) algorithm. Proposed in `Training individually fair ML models with sensitive subspace robustness <https://arxiv.org/abs/1907.00020>`_ Parameters ------------ network: torch.nn.Module Network architecture distance_x: inFairness.distances.Distance Distance metric in the input space loss_fn: torch.nn.Module Loss function eps: float :math:`\epsilon` parameter in the SenSR algorithm lr_lamb: float :math:`\lambda` parameter in the SenSR algorithm lr_param: float :math:`\\alpha` parameter in the SenSR algorithm auditor_nsteps: int Number of update steps for the auditor to find worst-case examples auditor_lr: float Learning rate for the auditor """ def __init__( self, network, distance_x, loss_fn, eps, lr_lamb, lr_param, auditor_nsteps, auditor_lr, ): super().__init__() self.distance_x = distance_x self.network = network self.loss_fn = loss_fn self.lambda_param = None self.eps = eps self.lr_lamb = lr_lamb self.lr_param = lr_param self.auditor_nsteps = auditor_nsteps self.auditor_lr = auditor_lr self.auditor = self.__init_auditor__() def __init_auditor__(self): auditor = SenSRAuditor( loss_fn=self.loss_fn, distance_x=self.distance_x, num_steps=self.auditor_nsteps, lr=self.auditor_lr, ) return auditor
[docs] def forward_train(self, X, Y): """Forward method during the training phase""" device = datautils.get_device(X) if self.lambda_param is None: self.lambda_param = torch.tensor(1.0, device=device) Y_pred = self.network(X) X_worst = self.auditor.generate_worst_case_examples( self.network, X, Y, lambda_param=self.lambda_param ) self.lambda_param = torch.max( torch.stack( [ torch.tensor(0.0, device=device), self.lambda_param - self.lr_lamb * (self.eps - self.distance_x(X, X_worst).mean()), ] ) ) Y_pred_worst = self.network(X_worst) fair_loss = torch.mean(self.lr_param * self.loss_fn(Y_pred_worst, Y)) response = FairModelResponse(loss=fair_loss, y_pred=Y_pred) return response
[docs] def forward_test(self, X): """Forward method during the test phase""" Y_pred = self.network(X) response = FairModelResponse(y_pred=Y_pred) return response
[docs] def forward(self, X, Y=None, *args, **kwargs): """Defines the computation performed at every call. Parameters ------------ X: torch.Tensor Input data Y: torch.Tensor Expected output data Returns ---------- output: torch.Tensor Model output """ if self.training: return self.forward_train(X, Y) else: return self.forward_test(X)