Simulai models transformer

Transformer#

Bases: NetworkTemplate

Source code in simulai/models/_pytorch_models/_transformer.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
class Transformer(NetworkTemplate):
    def __init__(
        self,
        num_heads_encoder: int = 1,
        num_heads_decoder: int = 1,
        embed_dim_encoder: Union[int, Tuple] = None,
        embed_dim_decoder: Union[int, Tuple] = None,
        output_dim: Union[int, Tuple] = None,
        encoder_activation: Union[str, torch.nn.Module] = "relu",
        decoder_activation: Union[str, torch.nn.Module] = "relu",
        encoder_mlp_layer_config: dict = None,
        decoder_mlp_layer_config: dict = None,
        number_of_encoders: int = 1,
        number_of_decoders: int = 1,
        devices: Union[str, list] = "cpu",
    ) -> None:
        r"""A classical encoder-decoder transformer:

        Graphical example:

        Example::

             U -> ( Encoder_1 -> Encoder_2 -> ... -> Encoder_N ) -> u_e

            (u_e, U) -> ( Decoder_1 -> Decoder_2 -> ... Decoder_N ) -> V

        Args:
            num_heads_encoder (int, optional): The number of heads for the self-attention layer of the encoder. (Default value = 1)
            num_heads_decoder (int, optional): The number of heads for the self-attention layer of the decoder. (Default value = 1)
            embed_dim_encoder (int, optional): The dimension of the embedding for the encoder. (Default value = Union[int, Tuple])
            embed_dim_decoder (int, optional): The dimension of the embedding for the decoder. (Default value = Union[int, Tuple])
            output_dim (int, optional): The dimension of the final output. (Default value = Union[int, Tuple])
            encoder_activation (Union[str, torch.nn.Module], optional): The activation to be used in all the encoder layers. (Default value = 'relu')
            decoder_activation (Union[str, torch.nn.Module], optional): The activation to be used in all the decoder layers. (Default value = 'relu')
            encoder_mlp_layer_config (dict, optional): A configuration dictionary to instantiate the encoder MLP layer.weights (Default value = None)
            decoder_mlp_layer_config (dict, optional): A configuration dictionary to instantiate the encoder MLP layer.weights (Default value = None)
            number_of_encoders (int, optional): The number of encoders to be used. (Default value = 1)
            number_of_decoders (int, optional): The number of decoders to be used. (Default value = 1)

        """

        super(Transformer, self).__init__()

        self.num_heads_encoder = num_heads_encoder
        self.num_heads_decoder = num_heads_decoder

        self.embed_dim_encoder = embed_dim_encoder
        self.embed_dim_decoder = embed_dim_decoder

        if not output_dim:
            self.output_dim = embed_dim_decoder
        else:
            self.output_dim = output_dim

        self.encoder_mlp_layer_dict = encoder_mlp_layer_config
        self.decoder_mlp_layer_dict = decoder_mlp_layer_config

        self.number_of_encoders = number_of_encoders
        self.number_of_decoders = number_of_encoders

        self.encoder_activation = encoder_activation
        self.decoder_activation = decoder_activation

        self.encoder_mlp_layers_list = list()
        self.decoder_mlp_layers_list = list()

        #Determining the kind of device in which the modelwill be executed
        self.device = self._set_device(devices=devices)

        # Creating independent copies for the MLP layers which will be used
        # by the multiple encoders/decoders.
        for e in range(self.number_of_encoders):
            self.encoder_mlp_layers_list.append(
                DenseNetwork(**self.encoder_mlp_layer_dict)
            )

        for d in range(self.number_of_decoders):
            self.decoder_mlp_layers_list.append(
                DenseNetwork(**self.decoder_mlp_layer_dict)
            )

        # Defining the encoder architecture
        self.EncoderStage = torch.nn.Sequential(
            *[
                BasicEncoder(
                    num_heads=self.num_heads_encoder,
                    activation=self.encoder_activation,
                    mlp_layer=self.encoder_mlp_layers_list[e],
                    embed_dim=self.embed_dim_encoder,
                    device=self.device,
                )
                for e in range(self.number_of_encoders)
            ]
        )

        # Defining the decoder architecture
        self.DecoderStage = torch.nn.ModuleList(
            [
                BasicDecoder(
                    num_heads=self.num_heads_decoder,
                    activation=self.decoder_activation,
                    mlp_layer=self.decoder_mlp_layers_list[d],
                    embed_dim=self.embed_dim_decoder,
                    device=self.device,
                )
                for d in range(self.number_of_decoders)
            ]
        )

        self.weights = list()

        for e, encoder_e in enumerate(self.EncoderStage):
            self.weights += encoder_e.weights
            self.add_module(f"encoder_{e}", encoder_e)

        for d, decoder_d in enumerate(self.DecoderStage):
            self.weights += decoder_d.weights
            self.add_module(f"decoder_{d}", decoder_d)


        self.final_layer = Linear(input_size=self.embed_dim_decoder, output_size=self.output_dim)
        self.add_module("final_linear_layer", self.final_layer)

        #  Sending everything to the proper device
        self.EncoderStage = self.EncoderStage.to(self.device)
        self.DecoderStage = self.DecoderStage.to(self.device)
        self.final_layer = self.final_layer.to(self.device)

    @as_tensor
    def forward(
        self, input_data: Union[torch.Tensor, np.ndarray] = None
    ) -> torch.Tensor:
        """

        Args:
            input_data (Union[torch.Tensor, np.ndarray], optional): The input dataset. (Default value = None)

        Returns:
            torch.Tensor: The transformer output.

        """

        encoder_output = self.EncoderStage(input_data)

        current_input = input_data
        for decoder in self.DecoderStage:
            output = decoder(input_data=current_input, encoder_output=encoder_output)
            current_input = output

        # Final linear operation
        final_output = self.final_layer(output)

        return final_output

    def summary(self):
        """It prints a general view of the architecture."""

        print(self)

__init__(num_heads_encoder=1, num_heads_decoder=1, embed_dim_encoder=None, embed_dim_decoder=None, output_dim=None, encoder_activation='relu', decoder_activation='relu', encoder_mlp_layer_config=None, decoder_mlp_layer_config=None, number_of_encoders=1, number_of_decoders=1, devices='cpu') #

A classical encoder-decoder transformer:

Graphical example:

Example::

 U -> ( Encoder_1 -> Encoder_2 -> ... -> Encoder_N ) -> u_e

(u_e, U) -> ( Decoder_1 -> Decoder_2 -> ... Decoder_N ) -> V

Parameters:

Name Type Description Default
num_heads_encoder int

The number of heads for the self-attention layer of the encoder. (Default value = 1)

1
num_heads_decoder int

The number of heads for the self-attention layer of the decoder. (Default value = 1)

1
embed_dim_encoder int

The dimension of the embedding for the encoder. (Default value = Union[int, Tuple])

None
embed_dim_decoder int

The dimension of the embedding for the decoder. (Default value = Union[int, Tuple])

None
output_dim int

The dimension of the final output. (Default value = Union[int, Tuple])

None
encoder_activation Union[str, Module]

The activation to be used in all the encoder layers. (Default value = 'relu')

'relu'
decoder_activation Union[str, Module]

The activation to be used in all the decoder layers. (Default value = 'relu')

'relu'
encoder_mlp_layer_config dict

A configuration dictionary to instantiate the encoder MLP layer.weights (Default value = None)

None
decoder_mlp_layer_config dict

A configuration dictionary to instantiate the encoder MLP layer.weights (Default value = None)

None
number_of_encoders int

The number of encoders to be used. (Default value = 1)

1
number_of_decoders int

The number of decoders to be used. (Default value = 1)

1
Source code in simulai/models/_pytorch_models/_transformer.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def __init__(
    self,
    num_heads_encoder: int = 1,
    num_heads_decoder: int = 1,
    embed_dim_encoder: Union[int, Tuple] = None,
    embed_dim_decoder: Union[int, Tuple] = None,
    output_dim: Union[int, Tuple] = None,
    encoder_activation: Union[str, torch.nn.Module] = "relu",
    decoder_activation: Union[str, torch.nn.Module] = "relu",
    encoder_mlp_layer_config: dict = None,
    decoder_mlp_layer_config: dict = None,
    number_of_encoders: int = 1,
    number_of_decoders: int = 1,
    devices: Union[str, list] = "cpu",
) -> None:
    r"""A classical encoder-decoder transformer:

    Graphical example:

    Example::

         U -> ( Encoder_1 -> Encoder_2 -> ... -> Encoder_N ) -> u_e

        (u_e, U) -> ( Decoder_1 -> Decoder_2 -> ... Decoder_N ) -> V

    Args:
        num_heads_encoder (int, optional): The number of heads for the self-attention layer of the encoder. (Default value = 1)
        num_heads_decoder (int, optional): The number of heads for the self-attention layer of the decoder. (Default value = 1)
        embed_dim_encoder (int, optional): The dimension of the embedding for the encoder. (Default value = Union[int, Tuple])
        embed_dim_decoder (int, optional): The dimension of the embedding for the decoder. (Default value = Union[int, Tuple])
        output_dim (int, optional): The dimension of the final output. (Default value = Union[int, Tuple])
        encoder_activation (Union[str, torch.nn.Module], optional): The activation to be used in all the encoder layers. (Default value = 'relu')
        decoder_activation (Union[str, torch.nn.Module], optional): The activation to be used in all the decoder layers. (Default value = 'relu')
        encoder_mlp_layer_config (dict, optional): A configuration dictionary to instantiate the encoder MLP layer.weights (Default value = None)
        decoder_mlp_layer_config (dict, optional): A configuration dictionary to instantiate the encoder MLP layer.weights (Default value = None)
        number_of_encoders (int, optional): The number of encoders to be used. (Default value = 1)
        number_of_decoders (int, optional): The number of decoders to be used. (Default value = 1)

    """

    super(Transformer, self).__init__()

    self.num_heads_encoder = num_heads_encoder
    self.num_heads_decoder = num_heads_decoder

    self.embed_dim_encoder = embed_dim_encoder
    self.embed_dim_decoder = embed_dim_decoder

    if not output_dim:
        self.output_dim = embed_dim_decoder
    else:
        self.output_dim = output_dim

    self.encoder_mlp_layer_dict = encoder_mlp_layer_config
    self.decoder_mlp_layer_dict = decoder_mlp_layer_config

    self.number_of_encoders = number_of_encoders
    self.number_of_decoders = number_of_encoders

    self.encoder_activation = encoder_activation
    self.decoder_activation = decoder_activation

    self.encoder_mlp_layers_list = list()
    self.decoder_mlp_layers_list = list()

    #Determining the kind of device in which the modelwill be executed
    self.device = self._set_device(devices=devices)

    # Creating independent copies for the MLP layers which will be used
    # by the multiple encoders/decoders.
    for e in range(self.number_of_encoders):
        self.encoder_mlp_layers_list.append(
            DenseNetwork(**self.encoder_mlp_layer_dict)
        )

    for d in range(self.number_of_decoders):
        self.decoder_mlp_layers_list.append(
            DenseNetwork(**self.decoder_mlp_layer_dict)
        )

    # Defining the encoder architecture
    self.EncoderStage = torch.nn.Sequential(
        *[
            BasicEncoder(
                num_heads=self.num_heads_encoder,
                activation=self.encoder_activation,
                mlp_layer=self.encoder_mlp_layers_list[e],
                embed_dim=self.embed_dim_encoder,
                device=self.device,
            )
            for e in range(self.number_of_encoders)
        ]
    )

    # Defining the decoder architecture
    self.DecoderStage = torch.nn.ModuleList(
        [
            BasicDecoder(
                num_heads=self.num_heads_decoder,
                activation=self.decoder_activation,
                mlp_layer=self.decoder_mlp_layers_list[d],
                embed_dim=self.embed_dim_decoder,
                device=self.device,
            )
            for d in range(self.number_of_decoders)
        ]
    )

    self.weights = list()

    for e, encoder_e in enumerate(self.EncoderStage):
        self.weights += encoder_e.weights
        self.add_module(f"encoder_{e}", encoder_e)

    for d, decoder_d in enumerate(self.DecoderStage):
        self.weights += decoder_d.weights
        self.add_module(f"decoder_{d}", decoder_d)


    self.final_layer = Linear(input_size=self.embed_dim_decoder, output_size=self.output_dim)
    self.add_module("final_linear_layer", self.final_layer)

    #  Sending everything to the proper device
    self.EncoderStage = self.EncoderStage.to(self.device)
    self.DecoderStage = self.DecoderStage.to(self.device)
    self.final_layer = self.final_layer.to(self.device)

forward(input_data=None) #

Parameters:

Name Type Description Default
input_data Union[Tensor, ndarray]

The input dataset. (Default value = None)

None

Returns:

Type Description
Tensor

torch.Tensor: The transformer output.

Source code in simulai/models/_pytorch_models/_transformer.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@as_tensor
def forward(
    self, input_data: Union[torch.Tensor, np.ndarray] = None
) -> torch.Tensor:
    """

    Args:
        input_data (Union[torch.Tensor, np.ndarray], optional): The input dataset. (Default value = None)

    Returns:
        torch.Tensor: The transformer output.

    """

    encoder_output = self.EncoderStage(input_data)

    current_input = input_data
    for decoder in self.DecoderStage:
        output = decoder(input_data=current_input, encoder_output=encoder_output)
        current_input = output

    # Final linear operation
    final_output = self.final_layer(output)

    return final_output

summary() #

It prints a general view of the architecture.

Source code in simulai/models/_pytorch_models/_transformer.py
326
327
328
329
def summary(self):
    """It prints a general view of the architecture."""

    print(self)