fn
baddbmm
→Tensorbaddbmm(input: Tensor, batch1: Tensor, batch2: Tensor, beta: float = ..., alpha: float = ...)Batched GEMM with a batched accumulator.
Computes
where the batch axis is preserved (unlike addbmm which reduces
it).
Parameters
inputTensorAccumulator of shape
(B, M, N).batch1TensorBatched left matrices of shape
(B, M, K).batch2TensorBatched right matrices of shape
(B, K, N).betafloatScalar multiplier on
input. Defaults to 1.0.alphafloatScalar multiplier on
bmm(batch1, batch2). Defaults to 1.0.Returns
TensorTensor of shape (B, M, N).
Notes
Per-batch definition:
The implementation defers to lucid.bmm, which dispatches to
Accelerate-batched GEMM on CPU and MLX on GPU.
Examples
>>> import lucid
>>> b1 = lucid.ones((3, 2, 4))
>>> b2 = lucid.ones((3, 4, 2))
>>> M = lucid.zeros((3, 2, 2))
>>> lucid.baddbmm(M, b1, b2).shape
(3, 2, 2)