nn.ScaledDotProductAttention¶
- class lucid.nn.ScaledDotProductAttention(attn_mask: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: int | float | complex | None = None)¶
The ScaledDotProductAttention module encapsulates the scaled dot-product attention operation commonly used in transformer-based architectures. It allows configurable masking, dropout, and causal attention.
Class Signature¶
class lucid.nn.ScaledDotProductAttention(
attn_mask: Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: _Scalar | None = None,
)
Parameters¶
attn_mask (Tensor | None, optional): A mask tensor of shape (N, H, L, S), used to mask out certain positions. If None, no masking is applied. Default: None.
dropout_p (float, optional): Dropout probability applied to attention weights. Default: 0.0.
is_causal (bool, optional): If True, applies a causal mask to prevent attending to future positions. This is useful for autoregressive models. Default: False.
scale (_Scalar | None, optional): Scaling factor applied to the dot-product before softmax. If None, the scale is set to 1 / sqrt(D), where D is the embedding dimension. Default: None.
Forward Calculation¶
Given query, key, and value tensors, the module computes attention as follows:
Compute the scaled dot-product scores:
\[\text{Scores} = \frac{\mathbf{Q} \mathbf{K}^\top}{\text{scale}}\]Apply the attention mask if provided:
\[\text{Scores} = \text{Scores} + \text{attn_mask}\]Compute the attention weights using softmax:
\[\text{Attn Weights} = \text{softmax}(\text{Scores})\]Apply dropout (if enabled):
\[\text{Attn Weights} = \text{Dropout}(\text{Attn Weights})\]Compute the output:
\[\text{Output} = \text{Attn Weights} \cdot \mathbf{V}\]
Examples¶
Basic Usage
>>> import lucid.nn as nn
>>> query = Tensor.randn(2, 4, 8, 16) # Batch=2, Heads=4, Seq_len=8, Dim=16
>>> key = Tensor.randn(2, 4, 8, 16)
>>> value = Tensor.randn(2, 4, 8, 16)
>>> attn = nn.ScaledDotProductAttention()
>>> output = attn(query, key, value)
>>> print(output.shape)
(2, 4, 8, 16)
Applying a Causal Mask
>>> attn = nn.ScaledDotProductAttention(is_causal=True)
>>> output = attn(query, key, value)
>>> print(output.shape)
(2, 4, 8, 16)
Note
This module is useful for implementing attention layers in transformers.
Supports dropout regularization for attention weights.
If is_causal=True, it ensures autoregressive behavior.