fn
multi_head_attention_forward
→(Tensor, Tensor or None)multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int, num_heads: int, in_proj_weight: Tensor | None = None, in_proj_bias: Tensor | None = None, bias_k: Tensor | None = None, bias_v: Tensor | None = None, add_zero_attn: bool = False, dropout_p: float = 0.0, out_proj_weight: Tensor | None = None, out_proj_bias: Tensor | None = None, training: bool = True, key_padding_mask: Tensor | None = None, need_weights: bool = True, attn_mask: Tensor | None = None, use_separate_proj_weight: bool = False, q_proj_weight: Tensor | None = None, k_proj_weight: Tensor | None = None, v_proj_weight: Tensor | None = None, static_k: Tensor | None = None, static_v: Tensor | None = None, average_attn_weights: bool = True, is_causal: bool = False)Stateless functional multi-head attention forward pass.
Performs the full Vaswani et al. (2017) multi-head attention computation as a single pure function — useful for porting code that builds attention layers without instantiating a stateful module.
The computation has four stages:
- Input projection. Linear projections produce
, ,
(when
use_separate_proj_weight=False, these come from a single fusedin_proj_weight). - Head split. Reshape each of
Q,K,Vfrom(L, B, d_model)to(L, B·H, d_model/H)so each of the heads attends independently. - Scaled dot-product attention.
where M is the union of the optional attn_mask,
key_padding_mask, and the implicit causal mask when
is_causal=True.
- Output projection. Concatenate heads and apply .
Parameters
queryTensorInputs of shape
(L, N, E) (target / source / source length
× batch × embed_dim). Self-attention uses the same tensor
for all three; cross-attention takes a separate query.keyTensorInputs of shape
(L, N, E) (target / source / source length
× batch × embed_dim). Self-attention uses the same tensor
for all three; cross-attention takes a separate query.valueTensorInputs of shape
(L, N, E) (target / source / source length
× batch × embed_dim). Self-attention uses the same tensor
for all three; cross-attention takes a separate query.embed_dim_to_checkintMust equal
query.shape[-1]; sanity check.num_headsintNumber of attention heads . Must divide
embed_dim.in_proj_weightTensor= NoneFused QKV projection.
in_proj_weight has shape
(3·embed_dim, embed_dim).in_proj_biasTensor= NoneFused QKV projection.
in_proj_weight has shape
(3·embed_dim, embed_dim).bias_kTensor= NoneOptional learned bias vectors appended to
K / V along
the sequence axis.bias_vTensor= NoneOptional learned bias vectors appended to
K / V along
the sequence axis.add_zero_attnbool= FalseIf
True, append a row of zeros to K / V — not yet
supported in Lucid (will raise NotImplementedError).dropout_pfloat= 0.0Dropout probability applied to attention weights during
training.
out_proj_weightTensor= NoneOutput linear projection .
out_proj_biasTensor= NoneOutput linear projection .
trainingbool= TrueIf
True, dropout is active.key_padding_maskTensor= NoneBoolean mask of shape
(N, S) with True at padded
positions to be ignored.need_weightsbool= TrueIf
True, also return the attention weight matrix.attn_maskTensor= NoneAdditional additive mask, e.g. a causal triangular mask.
use_separate_proj_weightbool= FalseUse
q_proj_weight / k_proj_weight / v_proj_weight
instead of a fused in_proj_weight. Not yet supported.q_proj_weightTensor= NoneSeparate Q / K / V projection weights (unsupported).
k_proj_weightTensor= NoneSeparate Q / K / V projection weights (unsupported).
v_proj_weightTensor= NoneSeparate Q / K / V projection weights (unsupported).
static_kTensor= NonePrecomputed K / V to attend over (unsupported).
static_vTensor= NonePrecomputed K / V to attend over (unsupported).
average_attn_weightsbool= TrueIf
True, the returned weight tensor is averaged across
heads; otherwise per-head weights are returned.is_causalbool= FalseApply an upper-triangular causal mask. Mutually exclusive
with a user-supplied
attn_mask.Returns
(Tensor, Tensor or None)- The attention output of shape
(L, N, E). - The attention weights of shape
(N, L, S)(or(N, H, L, S)ifaverage_attn_weights=False), orNoneifneed_weights=False.
Notes
This functional is implemented by binding the supplied weight
tensors onto a transient MultiheadAttention module so the
forward behaviour is bit-identical to the module path.
Examples
>>> import lucid
>>> from lucid.nn.functional import multi_head_attention_forward
>>> L, N, E, H = 16, 2, 64, 8
>>> q = k = v = lucid.randn(L, N, E)
>>> Wi = lucid.randn(3 * E, E)
>>> Wo = lucid.randn(E, E)
>>> out, attn = multi_head_attention_forward(
... q, k, v, embed_dim_to_check=E, num_heads=H,
... in_proj_weight=Wi, out_proj_weight=Wo,
... need_weights=False,
... )
>>> out.shape
(16, 2, 64)