PolyActivationConverter#

class PolyActivationConverter(args, model, train_state=None, chp_was_completed=None, loss=None)#

Bases: object

This class helps in the FHE conversion of the model; namely the activation replacement. It holds the current model conversion state. The initialization can be either from scratch, or from some given point - this can be usefull when loading a mnodel from a checkpoint.

Our assumption is that a model’s activation functions can only be either relu or relu_range_aware, but not both at the same time. Furthermore, we assume that the non-trainable PolyRelu activation function is only utilized as a substitute for relu_range_aware activations.

Supported replacement flows:
  1. relu -> relu_range_aware -> non_trainable_poly

  2. relu -> weighted_relu (for smooth_transition, usually for the trainable_poly activation)

__init__(args, model, train_state=None, chp_was_completed=None, loss=None)#

Initializes the PolyActivationConverter, managing the activation function replacement process for Fully Homomorphic Encryption (FHE) conversion.

Parameters:
  • args (Arguments) – A set of user-defined arguments for the conversion process.

  • model (torch.nn.Module) – The model to be converted.

  • train_state (SimpleNamespace, optional) – A dictionary holding the training state, if resuming from a checkpoint. Defaults to None.

  • chp_was_completed (bool, optional) – Indicates whether the checkpoint conversion was fully completed. True if there are no Relu activations in the model.

  • loss (float, optional) – Loss value at train_state.epoch

Methods

__init__(args, model[, train_state, ...])

Initializes the PolyActivationConverter, managing the activation function replacement process for Fully Homomorphic Encryption (FHE) conversion.

create_train_state(epoch)

Returns a namespace that groups start_epoch, current epoch and wait_before_change together

get_best_found_loss()

A getter for the best_loss class argument

get_start_epoch()

A getter for the start_epoch class argument

get_wait_before_change()

A getter for the wait_before_change class argument

get_was_completed()

A getter for the was_completed class argument

is_fhe_friendly(model)

param model:

model

replace_activations(trainer, epoch, scheduler)

Handles the entire replacement logic - depending on the arguments values and current epoch

set_best_found_loss_and_epoch(loss, epoch)

A setter for the best_loss and best_epoch class arguments

create_train_state(epoch: int)#

Returns a namespace that groups start_epoch, current epoch and wait_before_change together

Parameters:

epoch (int) – the current epoch

Returns:

a namespace that groupps the passed in arguments together

Return type:

SimpleNamespace

get_best_found_loss()#

A getter for the best_loss class argument

get_start_epoch()#

A getter for the start_epoch class argument

get_wait_before_change()#

A getter for the wait_before_change class argument

get_was_completed()#

A getter for the was_completed class argument

is_fhe_friendly(model)#
Parameters:

model (nn.Model) – model

Returns:

True if the model has no ReLU activations

Return type:

bool

replace_activations(trainer, epoch, scheduler)#

Handles the entire replacement logic - depending on the arguments values and current epoch

Parameters:
  • trainer (Trainer) – trainer instance

  • epoch (int) – current epoch number

  • scheduler (ReduceLROnPlateau) – Learning rate reduction schedualer

set_best_found_loss_and_epoch(loss, epoch)#

A setter for the best_loss and best_epoch class arguments