class
ChannelShuffle
extends
ModuleChannelShuffle(groups: int)Shuffle channels across groups to enable cross-group information flow.
ChannelShuffle implements the channel shuffle operation introduced in
ShuffleNet (Zhang et al., 2018). When depth-wise group convolutions are
stacked, each group's output depends only on its own input channels;
channel shuffling breaks this isolation by interleaving outputs from
different groups before the next layer.
The operation reshapes the channel dimension into , transposes those two axes, and then flattens back to channels:
Parameters
groupsintNumber of channel groups . The channel count
must be divisible by
groups.Attributes
groupsintStored value of the
groups constructor argument.Notes
- Input: where is divisible by
groups. - Output: — same shape as input.
- When
groups == 1the output is identical to the input (identity). - The operation has no learnable parameters and is implemented as a sequence of reshape + transpose + reshape operations on the C++ engine.
- In ShuffleNet, channel shuffle is placed between two consecutive group-wise point-wise convolutions to allow information exchange across groups without the cost of a full dense convolution.
Examples
**ShuffleNet-style block with group convolutions:**
>>> import lucid
>>> import lucid.nn as nn
>>>
>>> g = 4
>>> block = nn.Sequential(
... nn.Conv2d(64, 64, kernel_size=1, groups=g), # group-wise PW
... nn.BatchNorm2d(64),
... nn.ReLU(),
... nn.ChannelShuffle(groups=g), # shuffle across groups
... nn.Conv2d(64, 64, kernel_size=3, padding=1, groups=64), # DW
... nn.BatchNorm2d(64),
... nn.Conv2d(64, 64, kernel_size=1, groups=g), # group-wise PW
... nn.BatchNorm2d(64),
... )
>>> x = lucid.zeros(2, 64, 28, 28)
>>> block(x).shape
(2, 64, 28, 28)
**Verify shuffle is a permutation (no information lost):**
>>> import lucid
>>> import lucid.nn as nn
>>>
>>> x = lucid.randn(1, 8, 4, 4)
>>> cs = nn.ChannelShuffle(groups=4)
>>> y = cs(x)
>>> y.shape == x.shape
TrueMethods (3)
dunder
__init__
→None__init__(groups: int)Initialise the ChannelShuffle module. See the class docstring for parameter semantics.
fn
forward
→Tensorforward(x: Tensor)Upsample the input tensor.
Parameters
inputTensorInput tensor of shape .
Returns
TensorUpsampled output tensor.
fn
extra_repr
→strextra_repr()Return a string representation of the layer's configuration.