fn

tensor_split

list[Tensor]
tensor_split(x: Tensor, indices_or_sections: int | Sequence[int], dim: int = ...)
source

Split a tensor along dim, permitting unequal final-piece sizes.

More permissive than lucid.split: when indices_or_sections is an integer k and x.shape[dim] is not divisible by k, the first x.shape[dim] % k pieces receive one extra element each while the remaining pieces take the smaller floor size. This mirrors NumPy's np.array_split semantics, where split raises on non-divisible counts but tensor_split quietly returns ragged pieces.

Parameters

xTensor
Input tensor.
indices_or_sectionsint | Sequence[int]
  • int: number of (near-)equal-sized splits. The first axis_len % k pieces receive one extra element.
  • Sequence[int]: cut indices along dim.
dimint
Axis to split along. Defaults to 0.

Returns

list[Tensor]

Sub-tensors whose concatenation along dim reproduces x.

Notes

For integer k with n=shape(x)[dim]n = \text{shape}(x)[\text{dim}] and b=n/kb = \lfloor n / k \rfloor, r=nmodkr = n \bmod k, the piece sizes are

b+1,,b+1r terms,  b,,bkr terms.\underbrace{b + 1, \dots, b + 1}_{r\text{ terms}},\; \underbrace{b, \dots, b}_{k - r\text{ terms}}.

Contrast with lucid.split, which raises ValueError when n is not a multiple of k.

Examples

>>> import lucid
>>> x = lucid.arange(10)
>>> [s.shape for s in lucid.tensor_split(x, 3)]
[(4,), (3,), (3,)]
>>> [s.shape for s in lucid.tensor_split(x, [2, 5])]
[(2,), (3,), (5,)]