fn

baddbmm

Tensor
baddbmm(input: Tensor, batch1: Tensor, batch2: Tensor, beta: float = ..., alpha: float = ...)
source

Batched GEMM with a batched accumulator.

Computes βinput+αbmm(batch1,batch2)\beta \cdot \text{input} + \alpha \cdot \operatorname{bmm}(\text{batch1}, \text{batch2}) where the batch axis is preserved (unlike addbmm which reduces it).

Parameters

inputTensor
Accumulator of shape (B, M, N).
batch1Tensor
Batched left matrices of shape (B, M, K).
batch2Tensor
Batched right matrices of shape (B, K, N).
betafloat
Scalar multiplier on input. Defaults to 1.0.
alphafloat
Scalar multiplier on bmm(batch1, batch2). Defaults to 1.0.

Returns

Tensor

Tensor of shape (B, M, N).

Notes

Per-batch definition:

out[k]=βinput[k]+α(batch1[k]batch2[k]).\text{out}[k] = \beta \cdot \text{input}[k] + \alpha \cdot (\text{batch1}[k] \cdot \text{batch2}[k]).

The implementation defers to lucid.bmm, which dispatches to Accelerate-batched GEMM on CPU and MLX on GPU.

Examples

>>> import lucid
>>> b1 = lucid.ones((3, 2, 4))
>>> b2 = lucid.ones((3, 4, 2))
>>> M = lucid.zeros((3, 2, 2))
>>> lucid.baddbmm(M, b1, b2).shape
(3, 2, 2)