fn

multi_dot

Tensor
multi_dot(tensors: list[Tensor])
source

Multiply a sequence of matrices as a single chained product.

Computes the product of a list of matrices

A1A2An,A_1 \, A_2 \, \cdots \, A_n,

associating left-to-right. Optimal parenthesization can substantially reduce flops for chains with widely varying inner dimensions; the current implementation associates left-to-right (the most common case is already locally optimal).

Parameters

tensorslist of Tensor
Sequence of at least one matrix. Inner dimensions must agree (tensors[i].shape[-1] == tensors[i+1].shape[-2]).

Returns

Tensor

The chained matrix product.

Notes

For long chains, choosing the optimal split can reduce work from O(ididi+1di+2)O(\sum_i d_i d_{i+1} d_{i+2}) (left-to-right) to a significantly smaller bound found via dynamic programming.

Examples

>>> import lucid
>>> from lucid.linalg import multi_dot
>>> A = lucid.tensor([[1.0, 2.0]])
>>> B = lucid.tensor([[3.0], [4.0]])
>>> C = lucid.tensor([[5.0]])
>>> multi_dot([A, B, C])
Tensor([[55.0000]])