PolyActivationConverter#
- class PolyActivationConverter(args, model, train_state=None, chp_was_completed=None, loss=None)#
Bases:
object
This class helps in the FHE conversion of the model; namely the activation replacement. It holds the current model conversion state. The initialization can be either from scratch, or from some given point - this can be usefull when loading a mnodel from a checkpoint.
Our assumption is that a model’s activation functions can only be either relu or relu_range_aware, but not both at the same time. Furthermore, we assume that the non-trainable PolyRelu activation function is only utilized as a substitute for relu_range_aware activations.
- Supported replacement flows:
relu -> relu_range_aware -> non_trainable_poly
relu -> weighted_relu (for smooth_transition, usually for the trainable_poly activation)
- __init__(args, model, train_state=None, chp_was_completed=None, loss=None)#
Methods
__init__
(args, model[, train_state, ...])create_train_state
(epoch)Returns a namespace that groups start_epoch, current epoch and wait_before_change together
get_best_found_loss
()get_start_epoch
()get_wait_before_change
()get_was_completed
()is_fhe_friendly
(model)- param model:
model
replace_activations
(trainer, epoch, scheduler)Handles the entire replacement logic - depending on the arguments values and current epoch
set_best_found_loss_and_epoch
(loss, epoch)- create_train_state(epoch: int)#
Returns a namespace that groups start_epoch, current epoch and wait_before_change together
- Parameters:
epoch (int) – the current epoch
- Returns:
a namespace that groupps the passed in arguments together
- Return type:
SimpleNamespace
- is_fhe_friendly(model)#
- Parameters:
model (nn.Model) – model
- Returns:
True if the model has no ReLU activations
- Return type:
bool