fn

take_along_dim

Tensor
take_along_dim(x: Tensor, indices: Tensor, dim: int)
source

Gather elements from x at positions indices along dim.

Advanced indexing primitive analogous to np.take_along_axis: selects one element from x for every entry in indices, broadcasting the remaining (non-dim) axes between the two tensors. Thin wrapper around lucid.gather to align with the reference-framework spelling.

Parameters

xTensor
Source tensor.
indicesTensor
Integer tensor of positions along dim. Its shape must be broadcast-compatible with x on every axis other than dim; the size along dim controls the size of the output along that axis.
dimint
Axis along which to gather. Negative values count from the end.

Returns

Tensor

Tensor with the broadcast shape of x and indices (with the size along dim taken from indices). Dtype matches x.

Notes

For 1-D inputs the operation reduces to plain integer indexing. For higher-rank inputs, the result satisfies

out[i0,,id1,j,id+1,]=x[i0,,id1,  indices[i0,,id1,j,id+1,],  id+1,].\text{out}[i_0, \dots, i_{d-1}, j, i_{d+1}, \dots] = x[i_0, \dots, i_{d-1},\; \text{indices}[i_0, \dots, i_{d-1}, j, i_{d+1}, \dots],\; i_{d+1}, \dots].

Typical uses include "gather the top-k elements per row" patterns after lucid.argsort / lucid.topk.

Examples

>>> import lucid
>>> x = lucid.tensor([[1.0, 2.0, 3.0],
...                   [4.0, 5.0, 6.0]])
>>> idx = lucid.tensor([[2, 0], [1, 2]])
>>> lucid.take_along_dim(x, idx, dim=1)
Tensor([[3., 1.],
        [5., 6.]])