fn
take_along_dim
→Tensortake_along_dim(x: Tensor, indices: Tensor, dim: int)Gather elements from x at positions indices along dim.
Advanced indexing primitive analogous to np.take_along_axis:
selects one element from x for every entry in indices,
broadcasting the remaining (non-dim) axes between the two
tensors. Thin wrapper around lucid.gather to align with
the reference-framework spelling.
Parameters
xTensorSource tensor.
indicesTensorInteger tensor of positions along
dim. Its shape must be
broadcast-compatible with x on every axis other than dim;
the size along dim controls the size of the output along
that axis.dimintAxis along which to gather. Negative values count from the end.
Returns
TensorTensor with the broadcast shape of x and indices (with
the size along dim taken from indices). Dtype matches
x.
Notes
For 1-D inputs the operation reduces to plain integer indexing. For higher-rank inputs, the result satisfies
Typical uses include "gather the top-k elements per row" patterns
after lucid.argsort / lucid.topk.
Examples
>>> import lucid
>>> x = lucid.tensor([[1.0, 2.0, 3.0],
... [4.0, 5.0, 6.0]])
>>> idx = lucid.tensor([[2, 0], [1, 2]])
>>> lucid.take_along_dim(x, idx, dim=1)
Tensor([[3., 1.],
[5., 6.]])