fn

embedding_bag

Tensor
embedding_bag(x: Tensor, weight: Tensor, offsets: Tensor | None = None, max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, per_sample_weights: Tensor | None = None, include_last_offset: bool = False, padding_idx: int | None = None)
source

Aggregate embeddings into per-bag pooled vectors.

Conceptually equivalent to looking up each index with embedding and then reducing across the bag axis, but fused into a single op that avoids materialising the per-token embedding matrix — essential for very large vocabularies in recommendation and NLP models.

Given a flat index sequence x partitioned into bags by offsets, the i-th output row is

out[i]=reducejbagiW[x[j]]\mathrm{out}[i] = \mathrm{reduce}_{j \in \mathrm{bag}_i} W[\, x[j]\,]

where reduce is one of sum, mean, or max.

Parameters

xTensor
Either a 1-D index tensor (use with offsets) or a 2-D (num_bags, seq_len) tensor of indices where each row is a bag of equal length.
weightTensor
Embedding table of shape (num_embeddings, embedding_dim).
offsetsTensor= None
Required when x is 1-D. Integer tensor whose i-th element is the starting index of bag i within x.
max_normfloat= None
Renormalise embedding rows with LpL_p norm exceeding max_norm before lookup.
norm_typefloat= 2.0
pp exponent for max_norm. Default 2.0.
scale_grad_by_freqbool= False
Scale gradients of each embedding row by inverse mini-batch frequency.
modestr= 'mean'
Bag reduction: "sum", "mean" (default), or "max".
sparsebool= False
Request a sparse gradient (accepted for compatibility).
per_sample_weightsTensor= None
Optional per-element weights applied before reduction. Same shape as x (only valid for mode="sum" in most reference implementations).
include_last_offsetbool= False
If True, offsets has length num_bags + 1 and its last entry is the total number of indices in x.
padding_idxint= None
Embedding row to mask out (its lookup result contributes zero).

Returns

Tensor

Pooled output of shape (num_bags, embedding_dim).

Notes

Compared with embedding + manual reduction, embedding_bag saves a full materialisation of the per-token table and fuses the reduction into a single scatter-add (or scatter-max) pass.

Examples

>>> import lucid
>>> from lucid.nn.functional import embedding_bag
>>> w = lucid.randn(10, 4)
>>> ids = lucid.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=lucid.int64)
>>> off = lucid.tensor([0, 4], dtype=lucid.int64)
>>> out = embedding_bag(ids, w, offsets=off, mode="mean")
>>> out.shape
(2, 4)