PolyActivationConverter#
- class PolyActivationConverter(args, model, train_state=None, chp_was_completed=None, loss=None)#
Bases:
object
This class helps in the FHE conversion of the model; namely the activation replacement. It holds the current model conversion state. The initialization can be either from scratch, or from some given point - this can be usefull when loading a mnodel from a checkpoint.
Our assumption is that a model’s activation functions can only be either relu or relu_range_aware, but not both at the same time. Furthermore, we assume that the non-trainable PolyRelu activation function is only utilized as a substitute for relu_range_aware activations.
- Supported replacement flows:
relu -> relu_range_aware -> non_trainable_poly
relu -> weighted_relu (for smooth_transition, usually for the trainable_poly activation)
- __init__(args, model, train_state=None, chp_was_completed=None, loss=None)#
Initializes the PolyActivationConverter, managing the activation function replacement process for Fully Homomorphic Encryption (FHE) conversion.
- Parameters:
args (Arguments) – A set of user-defined arguments for the conversion process.
model (torch.nn.Module) – The model to be converted.
train_state (SimpleNamespace, optional) – A dictionary holding the training state, if resuming from a checkpoint. Defaults to None.
chp_was_completed (bool, optional) – Indicates whether the checkpoint conversion was fully completed. True if there are no Relu activations in the model.
loss (float, optional) – Loss value at train_state.epoch
Methods
__init__
(args, model[, train_state, ...])Initializes the PolyActivationConverter, managing the activation function replacement process for Fully Homomorphic Encryption (FHE) conversion.
create_train_state
(epoch)Returns a namespace that groups start_epoch, current epoch and wait_before_change together
A getter for the best_loss class argument
A getter for the start_epoch class argument
A getter for the wait_before_change class argument
A getter for the was_completed class argument
is_fhe_friendly
(model)- param model:
model
replace_activations
(trainer, epoch, scheduler)Handles the entire replacement logic - depending on the arguments values and current epoch
set_best_found_loss_and_epoch
(loss, epoch)A setter for the best_loss and best_epoch class arguments
- create_train_state(epoch: int)#
Returns a namespace that groups start_epoch, current epoch and wait_before_change together
- Parameters:
epoch (int) – the current epoch
- Returns:
a namespace that groupps the passed in arguments together
- Return type:
SimpleNamespace
- get_best_found_loss()#
A getter for the best_loss class argument
- get_start_epoch()#
A getter for the start_epoch class argument
- get_wait_before_change()#
A getter for the wait_before_change class argument
- get_was_completed()#
A getter for the was_completed class argument
- is_fhe_friendly(model)#
- Parameters:
model (nn.Model) – model
- Returns:
True if the model has no ReLU activations
- Return type:
bool
- replace_activations(trainer, epoch, scheduler)#
Handles the entire replacement logic - depending on the arguments values and current epoch
- Parameters:
trainer (Trainer) – trainer instance
epoch (int) – current epoch number
scheduler (ReduceLROnPlateau) – Learning rate reduction schedualer
- set_best_found_loss_and_epoch(loss, epoch)#
A setter for the best_loss and best_epoch class arguments