class

Unflatten

extendsModule
Unflatten(dim: int, unflattened_size: tuple[int, ...])
source

Expand one dimension of a tensor into multiple dimensions.

Unflatten is the inverse of Flatten: it takes a single dimension of the input and splits it into the shape given by unflattened_size. The product of unflattened_size must equal the size of the target dimension.

If x.shape[dim]=isithenoutput.shape[dim:dim+k]=(s0,s1,,sk1)\text{If } x.\text{shape}[\texttt{dim}] = \prod_i s_i \quad \text{then} \quad \text{output}.\text{shape}[\texttt{dim}:\texttt{dim}+k] = (s_0, s_1, \dots, s_{k-1})

where k=len(unflattened_size)k = \text{len(unflattened\_size)}.

Parameters

dimint
The dimension to expand. Negative indices are supported.
unflattened_sizetuple[int, ...]
The target shape for the expanded dimension. The product of all elements must equal x.shape[dim].

Attributes

dimint
Stored value of the dim constructor argument.
unflattened_sizetuple[int, ...]
Stored target shape for the expanded dimension.

Notes

  • Input: (,d,)(\dots, d, \dots) where dd is at position dim.
  • Output: (,s0,s1,,sk1,)(\dots, s_0, s_1, \dots, s_{k-1}, \dots) where s0s1sk1=ds_0 \cdot s_1 \cdots s_{k-1} = d.
  • Internally implemented via a reshape — no data is copied.
  • A common pattern is Flatten in the encoder and Unflatten in the decoder to reconstruct spatial structure from a flat bottleneck.

Examples

**Reconstruct (C, H, W) spatial structure from a flat feature vector:**
>>> import lucid
>>> import lucid.nn as nn
>>>
>>> unflat = nn.Unflatten(dim=1, unflattened_size=(64, 4, 4))
>>> x = lucid.zeros(8, 1024)      # 64*4*4 = 1024
>>> unflat(x).shape
(8, 64, 4, 4)
**Paired Flatten / Unflatten round-trip (encoder–decoder bottleneck):**
>>> encoder = nn.Sequential(
...     nn.Conv2d(1, 16, 3, padding=1),
...     nn.ReLU(),
...     nn.Flatten(start_dim=1),   # (N, 16, H, W) -> (N, 16*H*W)
...     nn.Linear(16 * 28 * 28, 128),
... )
>>> decoder = nn.Sequential(
...     nn.Linear(128, 16 * 28 * 28),
...     nn.Unflatten(1, (16, 28, 28)),
...     nn.Conv2d(16, 1, 3, padding=1),
... )

Methods (3)

dunder

__init__

None
__init__(dim: int, unflattened_size: tuple[int, ...])
source

Initialise the Unflatten module. See the class docstring for parameter semantics.

fn

forward

Tensor
forward(x: Tensor)
source

Flatten (or unflatten) the specified dimensions of the input.

Parameters

inputTensor
Input tensor.

Returns

Tensor

Tensor with the configured dimensions flattened or unflattened.

fn

extra_repr

str
extra_repr()
source

Return a string representation of the layer's configuration.