Source code for inFairness.postprocessing.glif

import torch
import numpy as np

from inFairness.utils.postprocessing import (
    build_graph_from_dists,
    get_laplacian,
    laplacian_solve,
)
from inFairness.postprocessing.base_postprocessing import BasePostProcessing
from inFairness.postprocessing.datainterfaces import PostProcessingObjectiveResponse


[docs] class GraphLaplacianIF(BasePostProcessing): """Implements the Graph Laplacian Individual Fairness Post-Processing method. Proposed in `Post-processing for Individual Fairness <https://arxiv.org/abs/2110.13796>`_ Parameters ------------ distance_x: inFairness.distances.Distance Distance metric in the input space is_output_probas: bool True if the `data_Y` (model output) are probabilities implying that this is a classification setting, and False if the `data_Y` are in euclidean space implying that this is a regression setting. """ def __init__(self, distance_x, is_output_probas): super().__init__(distance_x, is_output_probas=is_output_probas) self._METHOD_COORDINATE_KEY = "coordinate-descent" self._METHOD_EXACT_KEY = "exact" def __exact_pp__(self, lambda_param, scale, threshold, normalize): """Implements Exact version of post processing""" y_hat = self.__get_yhat__() W, idxs = build_graph_from_dists( self.distance_matrix, scale, threshold, normalize ) data_y_idxs = y_hat[idxs] L = get_laplacian(W, normalize) if normalize: L = (L + L.T) / 2 y = laplacian_solve(L, data_y_idxs, lambda_param) data_y_new = torch.clone(y_hat) data_y_new[idxs] = y objective = self.get_objective( data_y_new, lambda_param, scale, threshold, normalize, W, idxs, L ) return data_y_new, objective def __coordinate_update__( self, yhat_batch, W_batch, y, batchidx, lambda_param, D_inv_batch=None, diag_W_batch=None, D_batch=None, ): W_xy = W_batch.unsqueeze(-1) * y.unsqueeze(0) """ Shapes: W_batch: (bsz, nsamples) y: (nsamples, ncls - 1) W_xy: (bsz, nsamples, ncls-1) W_xy_corr: (bsz, ncls-1) numerator: (bsz, ncls-1) denominator: (bsz, 1) """ if D_inv_batch is None: W_xy_corr = torch.diagonal(W_xy[:, batchidx], offset=0, dim1=0, dim2=1).T numerator = yhat_batch + lambda_param * (W_xy.sum(dim=1) - W_xy_corr) denominator = 1 + lambda_param * ( W_batch.sum(dim=1, keepdim=True) - diag_W_batch.view(-1, 1) ) y_new = numerator / denominator else: W_xy = W_xy * D_inv_batch.unsqueeze(-1) W_xy_corr = torch.diagonal(W_xy[:, batchidx], offset=0, dim1=0, dim2=1).T numerator = yhat_batch + (lambda_param * (W_xy.sum(dim=1) - W_xy_corr)) / 2 denominator = ( 1 + lambda_param - lambda_param * diag_W_batch.view(-1, 1) / D_batch.view(-1, 1) ) y_new = numerator / denominator return y_new def __coordinate_pp__( self, lambda_param, scale, threshold, normalize, batchsize, epochs ): """Implements coordinate descent for large-scale data""" y_hat = self.__get_yhat__() y_copy = y_hat.clone() n_samples = self.datastore.n_samples W, idxs = build_graph_from_dists( self.distance_matrix, scale, threshold, normalize ) data_y_idxs = y_hat[idxs] W_diag = torch.diag(W) if normalize: D = W.sum(dim=1) D_inv = 1 / D.reshape(1, -1) + 1 / D.reshape(-1, 1) for epoch_idx in range(epochs): epoch_idxs_random = np.random.permutation(n_samples) curridx = 0 while curridx < n_samples: batchidxs = epoch_idxs_random[curridx : curridx + batchsize] if normalize: y_copy[batchidxs] = self.__coordinate_update__( data_y_idxs[batchidxs], W[batchidxs], y_copy, batchidxs, lambda_param=lambda_param, D_inv_batch=D_inv[batchidxs], diag_W_batch=W_diag[batchidxs], D_batch=D[batchidxs], ) else: y_copy[batchidxs] = self.__coordinate_update__( data_y_idxs[batchidxs], W[batchidxs], y_copy, batchidxs, lambda_param=lambda_param, diag_W_batch=W_diag[batchidxs], ) curridx += batchsize pp_sol = y_hat.clone() pp_sol[idxs] = y_copy objective = self.get_objective( pp_sol, lambda_param, scale, threshold, normalize, W, idxs ) return pp_sol, objective
[docs] def get_objective( self, y_solution, lambda_param: float, scale: float, threshold: float, normalize: bool = False, W_graph=None, idxs=None, L=None, ): """Compute the objective values for the individual fairness as follows: .. math:: \\widehat{\\mathbf{f}} = \\arg \\min_{\\mathbf{f}} \\ \\|\\mathbf{f} - \\hat{\\mathbf{y}}\\|_2^2 + \\lambda \\ \\mathbf{f}^{\\top}\\mathbb{L_n} \\mathbf{f} Refer equation 3.1 in the paper Parameters ------------ y_solution: torch.Tensor Post-processed solution values of shape (N, C) lambda_param: float Weight for the Laplacian Regularizer scale: float Parameter used to scale the computed distances. Refer equation 2.2 in the proposing paper. threshold: float Parameter used to construct the Graph from distances Distances below provided threshold are considered to be connected edges, while beyond the threshold are considered to be disconnected. Refer equation 2.2 in the proposing paper. normalize: bool Whether to normalize the computed Laplacian or not W_graph: torch.Tensor Adjacency matrix of shape (N, N) idxs: torch.Tensor Indices of data points which are included in the adjacency matrix L: torch.Tensor Laplacian of the adjacency matrix Returns --------- objective: PostProcessingObjectiveResponse post-processed solution containing two parts: (a) Post-processed output probabilities of shape (N, C) where N is the number of data samples, and C is the number of output classes (b) Objective values. Refer equation 3.1 in the paper for an explanation of the various parts """ if W_graph is None or idxs is None: W_graph, idxs = build_graph_from_dists( self.distance_matrix, scale, threshold, normalize ) if L is None: L = get_laplacian(W_graph, normalize) y_hat = self.__get_yhat__() y_dist = ((y_hat - y_solution) ** 2).sum() L_obj = lambda_param * (y_solution[idxs] * (L @ y_solution[idxs])).sum() overall_objective = y_dist + L_obj result = { "y_dist": y_dist.item(), "L_objective": L_obj.item(), "overall_objective": overall_objective.item(), } return result
[docs] def postprocess( self, method: str, lambda_param: float, scale: float, # 0.001 threshold: float, # median of all distances if None normalize: bool = False, batchsize: int = None, epochs: int = None, ): """Implements the Graph Laplacian Individual Fairness Post-processing algorithm Parameters ------------- method: str GLIF method type. Possible values are: (a) `coordinate-descent` method which is more suitable for large-scale data and post-processes by batching data into minibatches (see section 3.2.2 of the paper), or (b) `exact` method which gets the exact solution but is not appropriate for large-scale data (refer equation 3.3 in the paper). lambda_param: float Weight for the Laplacian Regularizer scale: float Parameter used to scale the computed distances. Refer equation 2.2 in the proposing paper. threshold: float Parameter used to construct the Graph from distances Distances below provided threshold are considered to be connected edges, while beyond the threshold are considered to be disconnected. Refer equation 2.2 in the proposing paper. normalize: bool Whether to normalize the computed Laplacian or not batchsize: int Batch size. *Required when method=`coordinate-descent`* epochs: int Number of coordinate descent epochs. *Required when method=`coordinate-descent`* Returns ----------- solution: PostProcessingObjectiveResponse post-processed solution containing two parts: (a) Post-processed output probabilities of shape (N, C) where N is the number of data samples, and C is the number of output classes (b) Objective values. Refer equation 3.1 in the paper for an explanation of the various parts """ assert method in [ self._METHOD_COORDINATE_KEY, self._METHOD_EXACT_KEY, ], f"`method` should be either `coordinate-descent` or `exact`. Value provided: {method}" if method == self._METHOD_COORDINATE_KEY: assert ( batchsize is not None and epochs is not None ), f"batchsize and epochs parameter is required but None provided" if method == self._METHOD_EXACT_KEY: data_y_new, objective = self.__exact_pp__( lambda_param, scale, threshold, normalize ) elif method == self._METHOD_COORDINATE_KEY: data_y_new, objective = self.__coordinate_pp__( lambda_param, scale, threshold, normalize, batchsize, epochs ) if self.is_output_probas: pp_sol = torch.exp(data_y_new) / ( 1 + torch.exp(data_y_new).sum(axis=1, keepdim=True) ) y_solution = torch.hstack((pp_sol, 1 - pp_sol.sum(axis=1, keepdim=True))) else: y_solution = data_y_new result = PostProcessingObjectiveResponse( y_solution=y_solution, objective=objective ) return result