fn

unflatten

Tensor
unflatten(input: Tensor, dim: int, sizes: Sequence[int])
source

Inverse of flatten — split a single dim into a tuple of dims.

Replaces input's dim-th axis with the sequence sizes, whose product must equal input.size(dim). At most one entry of sizes may be -1 (inferred).

Parameters

inputTensor
Source tensor.
dimint
Dimension to unflatten.
sizessequence of int
Target sizes for the new dims.

Returns

Tensor

Tensor with one extra dimension (or more) where dim used to be.

Notes

Returns a view when input is contiguous along dim; otherwise a copy is produced via reshape.

Examples

>>> import lucid
>>> x = lucid.zeros(2, 12)
>>> lucid.unflatten(x, dim=1, sizes=(3, 4)).shape
(2, 3, 4)