fn

vsplit

list[Tensor]
vsplit(x: Tensor, indices_or_sections: int | Sequence[int])
source

Split a tensor along its first (vertical) axis.

NumPy-style vertical split: cuts the input into pieces along axis 0. With an integer k, the tensor is divided into k near-equal pieces; with a sequence of indices, the cuts occur at those positions.

Parameters

xTensor
Input tensor with at least 1 dimension.
indices_or_sectionsint | Sequence[int]
  • int: number of equal-sized splits. If the axis length is not divisible, the first axis_len % k pieces get one extra element.
  • Sequence[int]: cut indices along axis 0.

Returns

list[Tensor]

Sub-tensors whose stacking along axis 0 reproduces x.

Raises

ValueError
If x.ndim < 1.

Notes

For a tensor of shape (N,)(N, \dots) split into k near-equal pieces, each piece has shape (N/k,)(\lceil N/k \rceil, \dots) or (N/k,)(\lfloor N/k \rfloor, \dots).

Examples

>>> import lucid
>>> x = lucid.arange(12).reshape(4, 3)
>>> [s.shape for s in lucid.vsplit(x, 2)]
[(2, 3), (2, 3)]