fn

dsplit

list[Tensor]
dsplit(x: Tensor, indices_or_sections: int | Sequence[int])
source

Split a tensor along its third (depth) axis.

NumPy-style depth-wise split: cuts the input into pieces along axis 2. Mirrors vsplit (axis 0) and hsplit (axis 1).

Parameters

xTensor
Input tensor with at least 3 dimensions.
indices_or_sectionsint | Sequence[int]
  • int: number of equal-sized splits along axis 2.
  • Sequence[int]: cut indices along axis 2.

Returns

list[Tensor]

Sub-tensors whose concatenation along axis 2 reproduces x.

Raises

ValueError
If x.ndim < 3.

Notes

For a tensor of shape (,D)(\dots, D) split into k near-equal pieces along the last (depth) axis, each piece has shape (,D/k)(\dots, \lceil D/k \rceil) or (,D/k)(\dots, \lfloor D/k \rfloor).

Examples

>>> import lucid
>>> x = lucid.arange(24).reshape(2, 3, 4)
>>> [s.shape for s in lucid.dsplit(x, 2)]
[(2, 3, 2), (2, 3, 2)]