EmbeddingBag
ModuleEmbeddingBag(num_embeddings: int, embedding_dim: int, max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, padding_idx: int | None = None, device: DeviceLike = None, dtype: DTypeLike = None)Embedding lookup table with per-bag reduction.
Computes a summary embedding for each bag (a variable-length set of token indices) without materialising the full per-token embedding tensor. Given an embedding matrix and a bag of indices , the output embedding for that bag is:
where the is taken element-wise across the embedding dimension.
2-D input (fixed-length bags). When offsets is None the
input x must have shape (B, L) — B bags each containing
exactly L indices. The reduction is applied over the L
axis, yielding an output of shape (B, D).
1-D input with offsets (variable-length bags). When offsets
is provided, x is a flat 1-D integer tensor of all indices
concatenated, and offsets is a 1-D integer tensor of length
B marking the start position of each bag. Bag
consists of indices x[offsets[b] : offsets[b+1]] (the last bag
runs to the end of x).
Parameters
num_embeddingsintembedding_dimintmax_normfloat or None= NoneEmbedding).
Default: None.norm_typefloat= 2.0max_norm renormalisation. Default: 2.0.scale_grad_by_freqbool= FalseFalse.mode(sum, mean, max)= 'sum''mean'.sparsebool= FalseFalse.padding_idxint or None= Nonepadding_idx contribute zero to the
reduction and do not receive gradient updates.
Default: None.deviceDeviceLike= NonedtypeDTypeLike= NoneAttributes
weightParameter, shape ``(num_embeddings, embedding_dim)``Notes
- x (2-D path):
(B, L)integer indices → output(B, D). - x (1-D + offsets path):
(total_indices,)integer indices,offsets(B,)→ output(B, D).
EmbeddingBag is more memory-efficient than computing
Embedding(x).sum(dim=1) because it fuses the lookup and
reduction into a single kernel call, avoiding the intermediate
(B, L, D) tensor. This is especially beneficial for large
vocabularies and long bags.
'max' mode is not differentiable with respect to ties; the
gradient is propagated only through the index that achieved the
maximum.
Embedding : Per-token embedding lookup without reduction.
Examples
Fixed-length bags (2-D input), mean pooling:
>>> import lucid, lucid.nn as nn
>>> emb_bag = nn.EmbeddingBag(num_embeddings=20, embedding_dim=8, mode='mean')
>>> idx = lucid.tensor([[1, 3, 5], [0, 2, 4]], dtype=lucid.int64) # (B=2, L=3)
>>> y = emb_bag(idx)
>>> y.shape # (B=2, D=8)
(2, 8)
Variable-length bags via offsets (sum pooling):
>>> emb_sum = nn.EmbeddingBag(10, 4, mode='sum')
>>> flat_idx = lucid.tensor([0, 1, 2, 3, 4, 5], dtype=lucid.int64) # 6 total
>>> offsets = lucid.tensor([0, 2, 5], dtype=lucid.int64) # bags: [0,1], [2,3,4], [5]
>>> y2 = emb_sum(flat_idx, offsets)
>>> y2.shape # (B=3, D=4)
(3, 4)Methods (3)
__init__
→None__init__(num_embeddings: int, embedding_dim: int, max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, padding_idx: int | None = None, device: DeviceLike = None, dtype: DTypeLike = None)Initialise the EmbeddingBag module. See the class docstring for parameter semantics.
forward
→Tensorforward(x: Tensor, offsets: Tensor | None = None)Look up embeddings for the given indices.
Parameters
inputTensorReturns
TensorTensor of embedding vectors of shape (*input.shape, embedding_dim).
extra_repr
→strextra_repr()Return a string representation of the layer's configuration.