inFairness.distances.distance module#
- class inFairness.distances.distance.Distance[source]#
Bases:
Module
Abstract base class for model distances
- fit(**kwargs)[source]#
Fits the metric parameters for learnable metrics Default functionality is to do nothing. Subclass should overwrite this method to implement custom fit logic
- abstract forward(x, y)[source]#
Subclasses must override this method to compute particular distances
- Returns:
distance between two inputs
- Return type:
Tensor
- load_state_dict(state_dict, strict=True)[source]#
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- training: bool#