nn.functional.scaled_dot_product_attention¶
- lucid.nn.functional.scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: int | float | complex | None = None) Tensor ¶
The scaled_dot_product_attention function computes scaled dot-product attention, a fundamental operation in transformer-based models.
Function Signature¶
def scaled_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: _Scalar | None = None,
) -> Tensor
Parameters¶
query (Tensor): The query tensor of shape (N, H, L, D), where: - N: Batch size - H: Number of attention heads - L: Sequence length - D: Embedding dimension per head
key (Tensor): The key tensor of shape (N, H, S, D), where S is the source sequence length.
value (Tensor): The value tensor of shape (N, H, S, D), matching the key tensor.
attn_mask (Tensor | None, optional): A mask tensor of shape (N, H, L, S), used to mask out certain positions. Default: None.
dropout_p (float, optional): Dropout probability applied to attention weights. Default: 0.0.
is_causal (bool, optional): If True, applies a causal mask to prevent attending to future positions. Default: False.
scale (_Scalar | None, optional): Scaling factor applied to the dot-product before the softmax operation. If None, the scale is set to 1 / sqrt(D). Default: None.
Returns¶
Tensor: The output tensor of shape (N, H, L, D), containing the weighted sum of values.
Attention Mechanism¶
The function performs the following operations:
Compute the scaled dot-product scores:
\[\text{Scores} = \frac{\mathbf{Q} \mathbf{K}^\top}{\text{scale}}\]Apply the attention mask if provided:
\[\text{Scores} = \text{Scores} + \text{attn_mask}\]Compute the attention weights using softmax:
\[\text{Attn Weights} = \text{softmax}(\text{Scores})\]Apply dropout (if enabled):
\[\text{Attn Weights} = \text{Dropout}(\text{Attn Weights})\]Compute the output:
\[\text{Output} = \text{Attn Weights} \cdot \mathbf{V}\]
Examples¶
Basic Attention Computation
>>> import lucid.nn.functional as F
>>> query = lucid.random.randn(2, 4, 8, 16) # Batch=2, Heads=4, Seq_len=8, Dim=16
>>> key = lucid.random.randn(2, 4, 8, 16)
>>> value = lucid.random.randn(2, 4, 8, 16)
...
>>> output = F.scaled_dot_product_attention(query, key, value)
>>> print(output.shape)
(2, 4, 8, 16)
Note
The function supports multi-head attention computations.
The optional causal mask ensures autoregressive behavior in transformers.
The dropout probability is applied to attention weights before computing the final output.