nn.TransformerDecoder¶
- class lucid.nn.TransformerDecoder(decoder_layer: TransformerDecoderLayer | Module, num_layers: int, norm: Module | None = None)¶
Overview¶
The TransformerDecoder module stacks multiple TransformerDecoderLayer instances to form a complete Transformer decoder. It sequentially processes the target input through multiple decoder layers while attending to the encoder memory output. An optional layer normalization can be applied to the final output.
Class Signature¶
class lucid.nn.TransformerDecoder(
decoder_layer: TransformerDecoderLayer | nn.Module,
num_layers: int,
norm: nn.Module | None = None,
)
Parameters¶
decoder_layer (TransformerDecoderLayer | nn.Module): A single instance of TransformerDecoderLayer that will be replicated for num_layers times.
num_layers (int): The number of decoder layers in the stack.
norm (nn.Module | None, optional): An optional layer normalization module applied to the final output. Default is None.
Forward Method¶
def forward(
tgt: Tensor,
memory: Tensor,
tgt_mask: Tensor | None = None,
mem_mask: Tensor | None = None,
tgt_key_padding_mask: Tensor | None = None,
mem_key_padding_mask: Tensor | None = None,
tgt_is_causal: bool = False,
mem_is_causal: bool = False
) -> Tensor
Computes the forward pass of the Transformer decoder.
Inputs:
tgt (Tensor): The target input tensor of shape \((N, L_t, d_{model})\).
memory (Tensor): The encoder output tensor of shape \((N, L_m, d_{model})\).
tgt_mask (Tensor | None, optional): A mask of shape \((L_t, L_t)\) applied to self-attention weights. Default is None.
mem_mask (Tensor | None, optional): A mask of shape \((L_t, L_m)\) applied to cross-attention weights. Default is None.
tgt_key_padding_mask (Tensor | None, optional): A mask of shape \((N, L_t)\), where non-zero values indicate positions that should be ignored. Default is None.
mem_key_padding_mask (Tensor | None, optional): A mask of shape \((N, L_m)\), where non-zero values indicate positions that should be ignored. Default is None.
tgt_is_causal (bool, optional, default=False): If True, enforces a lower-triangular mask in self-attention.
mem_is_causal (bool, optional, default=False): If True, enforces a lower-triangular mask in cross-attention.
Output:
Tensor: The output tensor of shape \((N, L_t, d_{model})\).
Mathematical Details¶
The Transformer decoder processes input through a sequence of decoder layers as follows:
Iterative Decoding
Each target tensor \(T\) is passed through num_layers decoder layers while attending to the encoder memory:
\[T_0 = T T_{i+1} = \operatorname{DecoderLayer}(T_i, M), \quad \forall i \in [0, \text{num\_layers}-1]\]where \(M\) represents the memory from the encoder.
Optional Normalization
If norm is provided, it is applied to the final output:
\[Y = \operatorname{LayerNorm}(T_{\text{num\_layers}})\]Otherwise, the final decoder layer output is returned.
Usage Example¶
import lucid
import lucid.nn as nn
# Create a decoder layer
decoder_layer = nn.TransformerDecoderLayer(d_model=512, num_heads=8)
# Stack multiple decoder layers into a Transformer decoder
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
# Create random input tensors
tgt = lucid.random.randn(16, 10, 512) # (batch, seq_len, embed_dim)
memory = lucid.random.randn(16, 20, 512) # Encoder output
# Compute decoder output
output = transformer_decoder(tgt, memory)
print(output.shape) # Expected output: (16, 10, 512)