fn

bmm

Tensor
bmm(input: Tensor, other: Tensor | Scalar)
source

Batched matrix multiplication.

Computes a matmul over a single explicit batch dimension. Both inputs must be 3-D with matching leading batch size: shapes (B, n, k) and (B, k, m). No broadcasting is performed on the batch axis; use matmul for full broadcasting semantics.

Parameters

inputTensor
Left operand of shape (B, n, k).
otherTensor or scalar
Right operand of shape (B, k, m).

Returns

Tensor

Stack of B matrix products with shape (B, n, m).

Notes

Mathematical definition for each batch index b:

Cb,ij=r=1kAb,irBb,rj\mathbf{C}_{b,ij} = \sum_{r=1}^{k} \mathbf{A}_{b,ir}\,\mathbf{B}_{b,rj}

Autograd handles each batch independently. Common in attention layers where the batch axis carries heads / sequences.

Examples

>>> import lucid
>>> a = lucid.zeros((4, 3, 5))
>>> b = lucid.zeros((4, 5, 2))
>>> lucid.bmm(a, b).shape
(4, 3, 2)