class

ChannelShuffle

extendsModule
ChannelShuffle(groups: int)
source

Shuffle channels across groups to enable cross-group information flow.

ChannelShuffle implements the channel shuffle operation introduced in ShuffleNet (Zhang et al., 2018). When depth-wise group convolutions are stacked, each group's output depends only on its own input channels; channel shuffling breaks this isolation by interleaving outputs from different groups before the next layer.

The operation reshapes the channel dimension into (groups,C/groups)(\text{groups}, C / \text{groups}), transposes those two axes, and then flattens back to CC channels:

reshape:(N,C,H,W)    (N,g,C/g,H,W)\text{reshape:} \quad (N, C, H, W) \;\to\; (N,\, g,\, C/g,\, H,\, W) transpose axes 1 and 2:(N,g,C/g,H,W)    (N,C/g,g,H,W)\text{transpose axes 1 and 2:} \quad (N,\, g,\, C/g,\, H,\, W) \;\to\; (N,\, C/g,\, g,\, H,\, W) flatten back:(N,C/g,g,H,W)    (N,C,H,W)\text{flatten back:} \quad (N,\, C/g,\, g,\, H,\, W) \;\to\; (N,\, C,\, H,\, W)

Parameters

groupsint
Number of channel groups gg. The channel count CC must be divisible by groups.

Attributes

groupsint
Stored value of the groups constructor argument.

Notes

  • Input: (N,C,H,W)(N, C, H, W) where CC is divisible by groups.
  • Output: (N,C,H,W)(N, C, H, W) — same shape as input.
  • When groups == 1 the output is identical to the input (identity).
  • The operation has no learnable parameters and is implemented as a sequence of reshape + transpose + reshape operations on the C++ engine.
  • In ShuffleNet, channel shuffle is placed between two consecutive group-wise point-wise convolutions to allow information exchange across groups without the cost of a full dense convolution.

Examples

**ShuffleNet-style block with group convolutions:**
>>> import lucid
>>> import lucid.nn as nn
>>>
>>> g = 4
>>> block = nn.Sequential(
...     nn.Conv2d(64, 64, kernel_size=1, groups=g),   # group-wise PW
...     nn.BatchNorm2d(64),
...     nn.ReLU(),
...     nn.ChannelShuffle(groups=g),                  # shuffle across groups
...     nn.Conv2d(64, 64, kernel_size=3, padding=1, groups=64),  # DW
...     nn.BatchNorm2d(64),
...     nn.Conv2d(64, 64, kernel_size=1, groups=g),   # group-wise PW
...     nn.BatchNorm2d(64),
... )
>>> x = lucid.zeros(2, 64, 28, 28)
>>> block(x).shape
(2, 64, 28, 28)
**Verify shuffle is a permutation (no information lost):**
>>> import lucid
>>> import lucid.nn as nn
>>>
>>> x  = lucid.randn(1, 8, 4, 4)
>>> cs = nn.ChannelShuffle(groups=4)
>>> y  = cs(x)
>>> y.shape == x.shape
True

Methods (3)

dunder

__init__

None
__init__(groups: int)
source

Initialise the ChannelShuffle module. See the class docstring for parameter semantics.

fn

forward

Tensor
forward(x: Tensor)
source

Upsample the input tensor.

Parameters

inputTensor
Input tensor of shape (N,C,)(N, C, *).

Returns

Tensor

Upsampled output tensor.

fn

extra_repr

str
extra_repr()
source

Return a string representation of the layer's configuration.