lucid.argmax

lucid.argmax(a: Tensor, axis: int | None = None, keepdims: bool = False) Tensor

The argmax function returns the indices of the maximum values along a specified axis.

You can preserve the number of dimensions by setting keepdims=True.

Function Signature

def argmax(
    a: Tensor,
    axis: int | None = None,
    keepdims: bool = False,
) -> Tensor

Parameters

  • a (Tensor): Input tensor to evaluate maximum indices from.

  • axis (int or None, optional): Axis along which to find the index of the maximum. If None, the input is flattened. Defaults to None.

  • keepdims (bool, optional): If True, retains reduced dimensions with size 1. Defaults to False.

Returns

  • Tensor (Int64): Indices of the maximum values along the specified axis.

\[\begin{split}\operatorname{shape}(\text{out}) \;=\; \begin{cases} (1, 1, \ldots) & \text{if keepdims=True} \\ \text{reduced shape} & \text{otherwise} \end{cases}\end{split}\]

Note

argmax is gradient-free; back-propagation will not propagate through the returned indices.

Examples

Flattened maximum index

>>> x = lucid.Tensor([[3, 2], [5, 4]])
>>> lucid.argmax(x)
Tensor(2, grad=None)

Max index by row

>>> lucid.argmax(x, axis=1)
Tensor([0, 0], grad=None)