PolyActivationConverter#

class PolyActivationConverter(args, model, train_state=None, chp_was_completed=None, loss=None)#

Bases: object

This class helps in the FHE conversion of the model; namely the activation replacement. It holds the current model conversion state. The initialization can be either from scratch, or from some given point - this can be usefull when loading a mnodel from a checkpoint.

Our assumption is that a model’s activation functions can only be either relu or relu_range_aware, but not both at the same time. Furthermore, we assume that the non-trainable PolyRelu activation function is only utilized as a substitute for relu_range_aware activations.

Supported replacement flows:
  1. relu -> relu_range_aware -> non_trainable_poly

  2. relu -> weighted_relu (for smooth_transition, usually for the trainable_poly activation)

__init__(args, model, train_state=None, chp_was_completed=None, loss=None)#

Methods

__init__(args, model[, train_state, ...])

create_train_state(epoch)

Returns a namespace that groups start_epoch, current epoch and wait_before_change together

get_best_found_loss()

get_start_epoch()

get_wait_before_change()

get_was_completed()

is_fhe_friendly(model)

param model:

model

replace_activations(trainer, epoch, scheduler)

Handles the entire replacement logic - depending on the arguments values and current epoch

set_best_found_loss_and_epoch(loss, epoch)

create_train_state(epoch: int)#

Returns a namespace that groups start_epoch, current epoch and wait_before_change together

Parameters:

epoch (int) – the current epoch

Returns:

a namespace that groupps the passed in arguments together

Return type:

SimpleNamespace

is_fhe_friendly(model)#
Parameters:

model (nn.Model) – model

Returns:

True if the model has no ReLU activations

Return type:

bool

replace_activations(trainer, epoch, scheduler)#

Handles the entire replacement logic - depending on the arguments values and current epoch

Parameters:
  • trainer (Trainer) – trainer instance

  • epoch (int) – current epoch number

  • scheduler (ReduceLROnPlateau) – Learning rate reduction schedualer