class

MultiheadAttention

extendsModule
MultiheadAttention(embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, kdim: int | None = None, vdim: int | None = None, batch_first: bool = False, device: DeviceLike = None, dtype: DTypeLike = None)
source

Multi-head scaled dot-product attention.

Implements the multi-head attention mechanism introduced in "Attention Is All You Need" (Vaswani et al., 2017). Each head independently computes scaled dot-product attention over a learned linear projection of the query, key and value inputs; the per-head outputs are then concatenated and projected once more to produce the final result.

Scaled dot-product attention for a single head:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right) V

where dkd_k is the per-head key dimension (head_dim). The 1/dk1/\sqrt{d_k} scaling prevents the dot-products from growing so large that the softmax function is pushed into regions with extremely small gradients.

Multi-head attention uses hh parallel heads:

headi=Attention(QWiQ,  KWiK,  VWiV)\text{head}_i = \text{Attention}(Q W_i^Q,\; K W_i^K,\; V W_i^V) MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\, W^O

where WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}, WiKRdmodel×dkW_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}, WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}, and WORhdv×dmodelW^O \in \mathbb{R}^{h d_v \times d_{\text{model}}} are learned projection matrices.

When kdim and vdim both equal embed_dim, the three input projections are stored as a single fused weight in_proj_weight of shape (3 * embed_dim, embed_dim) and split at runtime, which is more cache-friendly on Apple Silicon.

Parameters

embed_dimint
Total dimension of the model, dmodeld_{\text{model}}. Must be divisible by num_heads.
num_headsint
Number of parallel attention heads hh. Each head operates on a subspace of dimension head_dim = embed_dim // num_heads.
dropoutfloat= 0.0
Dropout probability applied to the attention weight matrix during training. Default: 0.0.
biasbool= True
If True, learnable bias terms are added to all input and output projection layers. Default: True.
add_bias_kvbool= False
If True, learnable bias rows bias_k and bias_v (each of shape (1, 1, embed_dim)) are appended to the key and value sequences along the sequence dimension before the attention computation. Useful for cross-attention scenarios where extra context tokens are desired. Default: False.
add_zero_attnbool= False
If True, a zero-valued row is appended to the key and value sequences. This can stabilise training in early steps by providing an "attend to nothing" option. Default: False.
kdimint or None= None
Feature dimension of the key input. When None (default), falls back to embed_dim and a fused in_proj_weight is used.
vdimint or None= None
Feature dimension of the value input. When None (default), falls back to embed_dim.
batch_firstbool= False
Controls the expected layout of all input and output tensors.
  • False (default): (seq_len, batch, embed_dim) — the classic sequence-first convention.
  • True: (batch, seq_len, embed_dim) — more intuitive for most modern use-cases.
deviceDeviceLike= None
Device on which to allocate parameters. None defaults to the current default device.
dtypeDTypeLike= None
Data type for all parameters. None defaults to the current default floating-point type.

Attributes

embed_dimint
Total model dimension passed at construction.
num_headsint
Number of attention heads.
head_dimint
Per-head dimension: embed_dim // num_heads.
kdimint
Effective key feature dimension.
vdimint
Effective value feature dimension.
dropoutfloat
Attention weight dropout probability.
batch_firstbool
Whether inputs are (batch, seq, feature).
in_proj_weightParameter or None
Fused (3 * embed_dim, embed_dim) projection weight used when kdim == vdim == embed_dim. Sliced at runtime into Q, K, V sub-weights. None when using separate weights.
q_proj_weightParameter or None
Separate query projection weight (embed_dim, embed_dim). Non-None only when kdim or vdim differs from embed_dim.
k_proj_weightParameter or None
Separate key projection weight (embed_dim, kdim).
v_proj_weightParameter or None
Separate value projection weight (embed_dim, vdim).
in_proj_biasParameter or None
Bias for the fused input projection (3 * embed_dim,). None when bias=False.
out_proj_weightParameter
Output projection weight (embed_dim, embed_dim).
out_proj_biasParameter or None
Output projection bias (embed_dim,). None when bias=False.
bias_kParameter or None
Learnable key bias row (1, 1, embed_dim). Non-None when add_bias_kv=True.
bias_vParameter or None
Learnable value bias row (1, 1, embed_dim). Non-None when add_bias_kv=True.
add_zero_attnbool
Whether a zero row is appended to K and V.

Notes

The shapes below use the following notation:

  • NN — batch size
  • LL — target (query) sequence length
  • SS — source (key / value) sequence length
  • EEembed_dim

When batch_first=False (default):

  • query: (L,N,E)(L, N, E)
  • key: (S,N,Ek)(S, N, E_k) where EkE_k = kdim
  • value: (S,N,Ev)(S, N, E_v) where EvE_v = vdim
  • Output attn_output: (L,N,E)(L, N, E)
  • Output attn_weights: (N,L,S)(N, L, S) when need_weights=True and average_attn_weights=True; (N,h,L,S)(N, h, L, S) when average_attn_weights=False.

When batch_first=True:

  • query: (N,L,E)(N, L, E)
  • key / value: (N,S,Ek/v)(N, S, E_{k/v})
  • Output attn_output: (N,L,E)(N, L, E)

Why scale by 1/dk1/\sqrt{d_k}? As dkd_k grows, the dot-products QKQK^\top accumulate over more dimensions and their magnitude grows like dk\sqrt{d_k} under the assumption of unit-variance inputs. Without the scale factor the softmax would saturate, producing near-one-hot distributions and vanishingly small gradients. Dividing by dk\sqrt{d_k} restores roughly unit variance before the softmax.

Causal masking (is_causal=True): An upper-triangular -\infty mask is added to the score matrix so that position ii cannot attend to any position j>ij > i. This implements the autoregressive constraint needed for language model decoding.

Fused vs. separate projections: When kdim == vdim == embed_dim, the Q/K/V projections share a single (3E, E) weight matrix. This layout allows a single linear call plus a cheap split_at on the result, which amortises kernel-launch overhead and improves cache locality on the MLX / Accelerate backends.

Checkpoint compatibility: State-dicts from the reference framework store the output projection under the key out_proj.weight / out_proj.bias (a sub-module named out_proj). Lucid's _load_from_state_dict hook transparently remaps those keys to the flat out_proj_weight / out_proj_bias attributes used here, so pre-trained weights can be loaded directly.

Examples

**Basic self-attention** (sequence-first layout):
>>> import lucid
>>> import lucid.nn as nn
>>> mha = nn.MultiheadAttention(embed_dim=64, num_heads=8)
>>> # Sequence-first: (seq_len, batch, embed_dim)
>>> x = lucid.randn(10, 2, 64)          # 10 tokens, batch=2
>>> out, weights = mha(x, x, x)
>>> out.shape
(10, 2, 64)
>>> weights.shape                        # averaged over heads
(2, 10, 10)
**Cross-attention with batch_first layout and causal mask**:
>>> mha = nn.MultiheadAttention(embed_dim=64, num_heads=8,
...                             batch_first=True)
>>> q = lucid.randn(2, 6, 64)           # batch=2, 6 query tokens
>>> kv = lucid.randn(2, 10, 64)         # 10 key/value tokens
>>> out, _ = mha(q, kv, kv, need_weights=False)
>>> out.shape
(2, 6, 64)
**Cross-modal attention with different key/value dimensions**:
>>> mha = nn.MultiheadAttention(embed_dim=128, num_heads=4,
...                             kdim=64, vdim=64)
>>> q = lucid.randn(5, 1, 128)
>>> kv = lucid.randn(7, 1, 64)
>>> out, weights = mha(q, kv, kv)
>>> out.shape
(5, 1, 128)

Methods (3)

dunder

__init__

None
__init__(embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, kdim: int | None = None, vdim: int | None = None, batch_first: bool = False, device: DeviceLike = None, dtype: DTypeLike = None)
source

Initialise the MultiheadAttention module. See the class docstring for parameter semantics.

fn

forward

Tensor
forward(query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Tensor | None = None, need_weights: bool = True, attn_mask: Tensor | None = None, average_attn_weights: bool = True, is_causal: bool = False)
source

Run the forward pass of the module.

Parameters

queryTensor
See the class docstring.
keyTensor
See the class docstring.
valueTensor
See the class docstring.
key_padding_maskTensor= None
See the class docstring.
need_weightsTensor= True
See the class docstring.
attn_maskTensor= None
See the class docstring.
average_attn_weightsTensor= True
See the class docstring.
is_causalTensor= False
See the class docstring.

Returns

Tensor

Output tensor; refer to the class docstring for the exact shape.

fn

extra_repr

str
extra_repr()
source

Return a string representation of the layer's configuration.