class

EmbeddingBag

extendsModule
EmbeddingBag(num_embeddings: int, embedding_dim: int, max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, padding_idx: int | None = None, device: DeviceLike = None, dtype: DTypeLike = None)
source

Embedding lookup table with per-bag reduction.

Computes a summary embedding for each bag (a variable-length set of token indices) without materialising the full per-token embedding tensor. Given an embedding matrix WRV×DW \in \mathbb{R}^{V \times D} and a bag of indices B={i1,i2,,ik}\mathcal{B} = \{i_1, i_2, \ldots, i_k\}, the output embedding for that bag is:

y={jBW[j]mode=’sum’1BjBW[j]mode=’mean’maxjBW[j]mode=’max’y = \begin{cases} \displaystyle\sum_{j \in \mathcal{B}} W[j] & \text{mode} = \texttt{'sum'} \\[6pt] \displaystyle\frac{1}{|\mathcal{B}|}\sum_{j \in \mathcal{B}} W[j] & \text{mode} = \texttt{'mean'} \\[6pt] \displaystyle\max_{j \in \mathcal{B}} W[j] & \text{mode} = \texttt{'max'} \end{cases}

where the max\max is taken element-wise across the embedding dimension.

2-D input (fixed-length bags). When offsets is None the input x must have shape (B, L)B bags each containing exactly L indices. The reduction is applied over the L axis, yielding an output of shape (B, D).

1-D input with offsets (variable-length bags). When offsets is provided, x is a flat 1-D integer tensor of all indices concatenated, and offsets is a 1-D integer tensor of length B marking the start position of each bag. Bag bb consists of indices x[offsets[b] : offsets[b+1]] (the last bag runs to the end of x).

Parameters

num_embeddingsint
Size of the embedding dictionary (vocabulary size VV).
embedding_dimint
Dimensionality of each embedding vector (DD).
max_normfloat or None= None
If provided, rows with LpL_p-norm exceeding this value are renormalised before the lookup (see Embedding). Default: None.
norm_typefloat= 2.0
The pp for max_norm renormalisation. Default: 2.0.
scale_grad_by_freqbool= False
Accepted for API compatibility; not yet implemented. Default: False.
mode(sum, mean, max)= 'sum'
Reduction applied over each bag. Default: 'mean'.
sparsebool= False
Accepted for API compatibility; sparse gradient emission is not yet supported. Default: False.
padding_idxint or None= None
Indices equal to padding_idx contribute zero to the reduction and do not receive gradient updates. Default: None.
deviceDeviceLike= None
Device for the weight tensor.
dtypeDTypeLike= None
Data type for the weight tensor.

Attributes

weightParameter, shape ``(num_embeddings, embedding_dim)``
The embedding matrix WW, initialised from N(0,1)\mathcal{N}(0, 1).

Notes

  • x (2-D path): (B, L) integer indices → output (B, D).
  • x (1-D + offsets path): (total_indices,) integer indices, offsets (B,) → output (B, D).

EmbeddingBag is more memory-efficient than computing Embedding(x).sum(dim=1) because it fuses the lookup and reduction into a single kernel call, avoiding the intermediate (B, L, D) tensor. This is especially beneficial for large vocabularies and long bags.

'max' mode is not differentiable with respect to ties; the gradient is propagated only through the index that achieved the maximum.

Embedding : Per-token embedding lookup without reduction.

Examples

Fixed-length bags (2-D input), mean pooling:
>>> import lucid, lucid.nn as nn
>>> emb_bag = nn.EmbeddingBag(num_embeddings=20, embedding_dim=8, mode='mean')
>>> idx = lucid.tensor([[1, 3, 5], [0, 2, 4]], dtype=lucid.int64)  # (B=2, L=3)
>>> y = emb_bag(idx)
>>> y.shape    # (B=2, D=8)
(2, 8)
Variable-length bags via offsets (sum pooling):
>>> emb_sum = nn.EmbeddingBag(10, 4, mode='sum')
>>> flat_idx = lucid.tensor([0, 1, 2, 3, 4, 5], dtype=lucid.int64)  # 6 total
>>> offsets = lucid.tensor([0, 2, 5], dtype=lucid.int64)  # bags: [0,1], [2,3,4], [5]
>>> y2 = emb_sum(flat_idx, offsets)
>>> y2.shape    # (B=3, D=4)
(3, 4)

Methods (3)

dunder

__init__

None
__init__(num_embeddings: int, embedding_dim: int, max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, padding_idx: int | None = None, device: DeviceLike = None, dtype: DTypeLike = None)
source

Initialise the EmbeddingBag module. See the class docstring for parameter semantics.

fn

forward

Tensor
forward(x: Tensor, offsets: Tensor | None = None)
source

Look up embeddings for the given indices.

Parameters

inputTensor
Tensor of integer indices.

Returns

Tensor

Tensor of embedding vectors of shape (*input.shape, embedding_dim).

fn

extra_repr

str
extra_repr()
source

Return a string representation of the layer's configuration.