fn

scaled_dot_product_attention

Tensor
scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None)
source

Scaled dot-product attention — the core of every Transformer block.

Computes the attention-weighted aggregation of value vectors using query-key dot products as similarity scores:

Attention(Q,K,V)=softmax ⁣(QKdk+M)V\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}} + M\right) V

where MM is an optional additive mask (-inf to disallow attention at a position, 0 to allow). The 1/dk1/\sqrt{d_k} factor keeps the softmax in a usable temperature regime as the head dimension dkd_k grows — without it, large dot products saturate the softmax into a near one-hot distribution and gradients vanish.

Parameters

queryTensor
Shape (B, H, T, E). B batch, H heads, T query positions, E head dimension dkd_k.
keyTensor
Shape (B, H, S, E).
valueTensor
Shape (B, H, S, V). V may differ from E.
attn_maskTensor= None
Additive mask broadcast-compatible with (B, H, T, S). Use large negative values (or -inf) at positions to mask out. Mutually exclusive with is_causal.
dropout_pfloat= 0.0
Dropout probability applied to attention weights during training. Default 0.0.
is_causalbool= False
If True, apply an upper-triangular causal mask so each query position only attends to keys at the same or earlier positions (autoregressive decoder self-attention).
scalefloat= None
Override the default 1/dk1/\sqrt{d_k} scale factor.

Returns

Tensor

Attention output of shape (B, H, T, V).

Notes

Introduced in Attention Is All You Need (Vaswani et al., 2017). The implementation uses the log-sum-exp form of softmax for numerical stability under aggressive masking, and fuses the scale into the score matrix prior to softmax. Causal masking enables efficient autoregressive decoding when combined with a key/value cache.

Examples

>>> import lucid
>>> from lucid.nn.functional import scaled_dot_product_attention
>>> q = lucid.randn(2, 8, 16, 64)          # (B, H, T, E)
>>> k = lucid.randn(2, 8, 16, 64)
>>> v = lucid.randn(2, 8, 16, 64)
>>> out = scaled_dot_product_attention(q, k, v, is_causal=True)
>>> out.shape
(2, 8, 16, 64)