Source code for inFairness.distances.distance
from abc import ABCMeta, abstractmethod
from torch import nn
[docs]
class Distance(nn.Module, metaclass=ABCMeta):
"""
Abstract base class for model distances
"""
def __init__(self):
super().__init__()
[docs]
def fit(self, **kwargs):
"""
Fits the metric parameters for learnable metrics
Default functionality is to do nothing. Subclass
should overwrite this method to implement custom fit
logic
"""
pass
[docs]
def load_state_dict(self, state_dict, strict=True):
buffer_keys = [bufferitem[0] for bufferitem in self.named_buffers()]
for key, val in state_dict.items():
if key not in buffer_keys and strict:
raise AssertionError(
f"{key} not found in metric state and strict parameter is set to True. Either set strict parameter to False or remove extra entries from the state dictionary."
)
setattr(self, key, val)
[docs]
@abstractmethod
def forward(self, x, y):
"""
Subclasses must override this method to compute particular distances
Returns:
Tensor: distance between two inputs
"""