fn

tensordot

Tensor
tensordot(a: Tensor, b: Tensor, dims: _int | list[list[_int]] = 2)
source

Generalized tensor contraction over arbitrary axes.

Contracts a and b along the axes specified by dims, generalizing matrix multiplication to higher-rank tensors. dims may be:

  • an int N — contract the last N axes of a with the first N axes of b;
  • a pair of axis lists [[a_axes], [b_axes]] — contract the listed axes of a against the listed axes of b pairwise.

Parameters

aTensor
First operand.
bTensor
Second operand.
dimsint or list of list of int= 2
Specification of which axes to contract. Default is 2.

Returns

Tensor

Contracted result whose remaining axes are the un-contracted axes of a followed by those of b.

Notes

Mathematical definition for axis lists Ia,IbI_a, I_b:

outIˉa,Iˉb=Ia=Ibab\text{out}_{\bar I_a,\, \bar I_b} = \sum_{I_a = I_b} a_{\dots}\, b_{\dots}

where Iˉ\bar I denotes the axes not contracted. Equivalent to a suitable permutation followed by matmul and a reshape.

Examples

>>> import lucid
>>> a = lucid.zeros((3, 4, 5))
>>> b = lucid.zeros((4, 5, 2))
>>> lucid.tensordot(a, b, dims=2).shape
(3, 2)