MultiheadAttention
ModuleMultiheadAttention(embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, kdim: int | None = None, vdim: int | None = None, batch_first: bool = False, device: DeviceLike = None, dtype: DTypeLike = None)Multi-head scaled dot-product attention.
Implements the multi-head attention mechanism introduced in "Attention Is All You Need" (Vaswani et al., 2017). Each head independently computes scaled dot-product attention over a learned linear projection of the query, key and value inputs; the per-head outputs are then concatenated and projected once more to produce the final result.
Scaled dot-product attention for a single head:
where is the per-head key dimension (head_dim).
The scaling prevents the dot-products from
growing so large that the softmax function is pushed into regions
with extremely small gradients.
Multi-head attention uses parallel heads:
where , , , and are learned projection matrices.
When kdim and vdim both equal embed_dim, the three
input projections are stored as a single fused weight
in_proj_weight of shape (3 * embed_dim, embed_dim) and
split at runtime, which is more cache-friendly on Apple Silicon.
Parameters
embed_dimintnum_heads.num_headsinthead_dim = embed_dim // num_heads.dropoutfloat= 0.00.0.biasbool= TrueTrue, learnable bias terms are added to all input and
output projection layers. Default: True.add_bias_kvbool= FalseTrue, learnable bias rows bias_k and bias_v
(each of shape (1, 1, embed_dim)) are appended to the
key and value sequences along the sequence dimension before
the attention computation. Useful for cross-attention
scenarios where extra context tokens are desired.
Default: False.add_zero_attnbool= FalseTrue, a zero-valued row is appended to the key and
value sequences. This can stabilise training in early steps
by providing an "attend to nothing" option. Default: False.kdimint or None= NoneNone (default),
falls back to embed_dim and a fused in_proj_weight
is used.vdimint or None= NoneNone
(default), falls back to embed_dim.batch_firstbool= FalseFalse(default):(seq_len, batch, embed_dim)— the classic sequence-first convention.True:(batch, seq_len, embed_dim)— more intuitive for most modern use-cases.
deviceDeviceLike= NoneNone defaults to
the current default device.dtypeDTypeLike= NoneNone defaults to the
current default floating-point type.Attributes
embed_dimintnum_headsinthead_dimintembed_dim // num_heads.kdimintvdimintdropoutfloatbatch_firstbool(batch, seq, feature).in_proj_weightParameter or None(3 * embed_dim, embed_dim) projection weight used
when kdim == vdim == embed_dim. Sliced at runtime into
Q, K, V sub-weights. None when using separate weights.q_proj_weightParameter or None(embed_dim, embed_dim).
Non-None only when kdim or vdim differs from
embed_dim.k_proj_weightParameter or None(embed_dim, kdim).v_proj_weightParameter or None(embed_dim, vdim).in_proj_biasParameter or None(3 * embed_dim,).
None when bias=False.out_proj_weightParameter(embed_dim, embed_dim).out_proj_biasParameter or None(embed_dim,).
None when bias=False.bias_kParameter or None(1, 1, embed_dim).
Non-None when add_bias_kv=True.bias_vParameter or None(1, 1, embed_dim).
Non-None when add_bias_kv=True.add_zero_attnboolNotes
The shapes below use the following notation:
- — batch size
- — target (query) sequence length
- — source (key / value) sequence length
- —
embed_dim
When batch_first=False (default):
query:key: where =kdimvalue: where =vdim- Output
attn_output: - Output
attn_weights: whenneed_weights=Trueandaverage_attn_weights=True; whenaverage_attn_weights=False.
When batch_first=True:
query:key/value:- Output
attn_output:
Why scale by ? As grows, the dot-products accumulate over more dimensions and their magnitude grows like under the assumption of unit-variance inputs. Without the scale factor the softmax would saturate, producing near-one-hot distributions and vanishingly small gradients. Dividing by restores roughly unit variance before the softmax.
Causal masking (is_causal=True):
An upper-triangular mask is added to the
score matrix so that position cannot attend to any
position . This implements the autoregressive
constraint needed for language model decoding.
Fused vs. separate projections:
When kdim == vdim == embed_dim, the Q/K/V projections
share a single (3E, E) weight matrix. This layout
allows a single linear call plus a cheap split_at
on the result, which amortises kernel-launch overhead and
improves cache locality on the MLX / Accelerate backends.
Checkpoint compatibility:
State-dicts from the reference framework store the output
projection under the key out_proj.weight / out_proj.bias
(a sub-module named out_proj). Lucid's _load_from_state_dict
hook transparently remaps those keys to the flat
out_proj_weight / out_proj_bias attributes used here,
so pre-trained weights can be loaded directly.
Examples
**Basic self-attention** (sequence-first layout):
>>> import lucid
>>> import lucid.nn as nn
>>> mha = nn.MultiheadAttention(embed_dim=64, num_heads=8)
>>> # Sequence-first: (seq_len, batch, embed_dim)
>>> x = lucid.randn(10, 2, 64) # 10 tokens, batch=2
>>> out, weights = mha(x, x, x)
>>> out.shape
(10, 2, 64)
>>> weights.shape # averaged over heads
(2, 10, 10)
**Cross-attention with batch_first layout and causal mask**:
>>> mha = nn.MultiheadAttention(embed_dim=64, num_heads=8,
... batch_first=True)
>>> q = lucid.randn(2, 6, 64) # batch=2, 6 query tokens
>>> kv = lucid.randn(2, 10, 64) # 10 key/value tokens
>>> out, _ = mha(q, kv, kv, need_weights=False)
>>> out.shape
(2, 6, 64)
**Cross-modal attention with different key/value dimensions**:
>>> mha = nn.MultiheadAttention(embed_dim=128, num_heads=4,
... kdim=64, vdim=64)
>>> q = lucid.randn(5, 1, 128)
>>> kv = lucid.randn(7, 1, 64)
>>> out, weights = mha(q, kv, kv)
>>> out.shape
(5, 1, 128)Methods (3)
__init__
→None__init__(embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, kdim: int | None = None, vdim: int | None = None, batch_first: bool = False, device: DeviceLike = None, dtype: DTypeLike = None)Initialise the MultiheadAttention module. See the class docstring for parameter semantics.
forward
→Tensorforward(query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Tensor | None = None, need_weights: bool = True, attn_mask: Tensor | None = None, average_attn_weights: bool = True, is_causal: bool = False)Run the forward pass of the module.
Parameters
queryTensorkeyTensorvalueTensorkey_padding_maskTensor= Noneneed_weightsTensor= Trueattn_maskTensor= Noneaverage_attn_weightsTensor= Trueis_causalTensor= FalseReturns
TensorOutput tensor; refer to the class docstring for the exact shape.
extra_repr
→strextra_repr()Return a string representation of the layer's configuration.