lucid.argmax¶
The argmax function returns the indices of the maximum values along a specified axis.
You can preserve the number of dimensions by setting keepdims=True.
Function Signature¶
def argmax(
a: Tensor,
axis: int | None = None,
keepdims: bool = False,
) -> Tensor
Parameters¶
a (Tensor): Input tensor to evaluate maximum indices from.
axis (int or None, optional): Axis along which to find the index of the maximum. If None, the input is flattened. Defaults to None.
keepdims (bool, optional): If True, retains reduced dimensions with size 1. Defaults to False.
Returns¶
Tensor (Int64): Indices of the maximum values along the specified axis.
\[\begin{split}\operatorname{shape}(\text{out}) \;=\;
\begin{cases}
(1, 1, \ldots) & \text{if keepdims=True} \\
\text{reduced shape} & \text{otherwise}
\end{cases}\end{split}\]
Note
argmax is gradient-free; back-propagation will not propagate through the returned indices.
Examples¶
Flattened maximum index
>>> x = lucid.Tensor([[3, 2], [5, 4]])
>>> lucid.argmax(x)
Tensor(2, grad=None)
Max index by row
>>> lucid.argmax(x, axis=1)
Tensor([0, 0], grad=None)