fn

multi_head_attention_forward

(Tensor, Tensor or None)
multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int, num_heads: int, in_proj_weight: Tensor | None = None, in_proj_bias: Tensor | None = None, bias_k: Tensor | None = None, bias_v: Tensor | None = None, add_zero_attn: bool = False, dropout_p: float = 0.0, out_proj_weight: Tensor | None = None, out_proj_bias: Tensor | None = None, training: bool = True, key_padding_mask: Tensor | None = None, need_weights: bool = True, attn_mask: Tensor | None = None, use_separate_proj_weight: bool = False, q_proj_weight: Tensor | None = None, k_proj_weight: Tensor | None = None, v_proj_weight: Tensor | None = None, static_k: Tensor | None = None, static_v: Tensor | None = None, average_attn_weights: bool = True, is_causal: bool = False)
source

Stateless functional multi-head attention forward pass.

Performs the full Vaswani et al. (2017) multi-head attention computation as a single pure function — useful for porting code that builds attention layers without instantiating a stateful module.

The computation has four stages:

  1. Input projection. Linear projections produce Q=XqWQQ = X_q W_Q, K=XkWKK = X_k W_K, V=XvWVV = X_v W_V (when use_separate_proj_weight=False, these come from a single fused in_proj_weight).
  2. Head split. Reshape each of Q, K, V from (L, B, d_model) to (L, B·H, d_model/H) so each of the HH heads attends independently.
  3. Scaled dot-product attention.
Attn=softmax ⁣(QKdk+M)V\mathrm{Attn} = \mathrm{softmax}\!\left( \frac{Q K^\top}{\sqrt{d_k}} + M\right) V

where M is the union of the optional attn_mask, key_padding_mask, and the implicit causal mask when is_causal=True.

  1. Output projection. Concatenate heads and apply out=AttnconcatWO\mathrm{out} = \mathrm{Attn}_\text{concat} W_O.

Parameters

queryTensor
Inputs of shape (L, N, E) (target / source / source length × batch × embed_dim). Self-attention uses the same tensor for all three; cross-attention takes a separate query.
keyTensor
Inputs of shape (L, N, E) (target / source / source length × batch × embed_dim). Self-attention uses the same tensor for all three; cross-attention takes a separate query.
valueTensor
Inputs of shape (L, N, E) (target / source / source length × batch × embed_dim). Self-attention uses the same tensor for all three; cross-attention takes a separate query.
embed_dim_to_checkint
Must equal query.shape[-1]; sanity check.
num_headsint
Number of attention heads HH. Must divide embed_dim.
in_proj_weightTensor= None
Fused QKV projection. in_proj_weight has shape (3·embed_dim, embed_dim).
in_proj_biasTensor= None
Fused QKV projection. in_proj_weight has shape (3·embed_dim, embed_dim).
bias_kTensor= None
Optional learned bias vectors appended to K / V along the sequence axis.
bias_vTensor= None
Optional learned bias vectors appended to K / V along the sequence axis.
add_zero_attnbool= False
If True, append a row of zeros to K / V — not yet supported in Lucid (will raise NotImplementedError).
dropout_pfloat= 0.0
Dropout probability applied to attention weights during training.
out_proj_weightTensor= None
Output linear projection WOW_O.
out_proj_biasTensor= None
Output linear projection WOW_O.
trainingbool= True
If True, dropout is active.
key_padding_maskTensor= None
Boolean mask of shape (N, S) with True at padded positions to be ignored.
need_weightsbool= True
If True, also return the attention weight matrix.
attn_maskTensor= None
Additional additive mask, e.g. a causal triangular mask.
use_separate_proj_weightbool= False
Use q_proj_weight / k_proj_weight / v_proj_weight instead of a fused in_proj_weight. Not yet supported.
q_proj_weightTensor= None
Separate Q / K / V projection weights (unsupported).
k_proj_weightTensor= None
Separate Q / K / V projection weights (unsupported).
v_proj_weightTensor= None
Separate Q / K / V projection weights (unsupported).
static_kTensor= None
Precomputed K / V to attend over (unsupported).
static_vTensor= None
Precomputed K / V to attend over (unsupported).
average_attn_weightsbool= True
If True, the returned weight tensor is averaged across heads; otherwise per-head weights are returned.
is_causalbool= False
Apply an upper-triangular causal mask. Mutually exclusive with a user-supplied attn_mask.

Returns

(Tensor, Tensor or None)
  • The attention output of shape (L, N, E).
  • The attention weights of shape (N, L, S) (or (N, H, L, S) if average_attn_weights=False), or None if need_weights=False.

Notes

This functional is implemented by binding the supplied weight tensors onto a transient MultiheadAttention module so the forward behaviour is bit-identical to the module path.

Examples

>>> import lucid
>>> from lucid.nn.functional import multi_head_attention_forward
>>> L, N, E, H = 16, 2, 64, 8
>>> q = k = v = lucid.randn(L, N, E)
>>> Wi = lucid.randn(3 * E, E)
>>> Wo = lucid.randn(E, E)
>>> out, attn = multi_head_attention_forward(
...     q, k, v, embed_dim_to_check=E, num_heads=H,
...     in_proj_weight=Wi, out_proj_weight=Wo,
...     need_weights=False,
... )
>>> out.shape
(16, 2, 64)