nn.ConstrainedConv3d¶
- class lucid.nn.ConstrainedConv3d(in_channels: int, out_channels: int, kernel_size: int | tuple[int, ...], stride: int | tuple[int, ...] = 1, padding: Literal['same', 'valid'] | int | tuple[int, ...] = 0, dilation: int | tuple[int, ...] = 1, groups: int = 1, bias: bool = True, *, constraint: Literal['none', 'nonneg', 'sum_to_one', 'zero_mean', 'nonneg_sum1', 'unit_l2', 'max_l2', 'fixed_center'] = 'none', enforce: Literal['forward', 'post_step'] = 'forward', eps: float = 1e-12, max_l2: float | None = None, center_value: float = -1.0, neighbor_sum: float = 1.0)¶
The ConstrainedConv3d module applies constrained convolution on volumetric inputs (or spatio-temporal tensors). It is useful when 3D kernels should satisfy explicit priors such as normalization, zero-mean, or bounded energy.
Class Signature¶
class lucid.nn.ConstrainedConv3d(
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: _PaddingStr | int | tuple[int, ...] = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
bias: bool = True,
*,
constraint: Literal[
"none", "nonneg", "sum_to_one", "zero_mean", "nonneg_sum1",
"unit_l2", "max_l2", "fixed_center"
] = "none",
enforce: Literal["forward", "post_step"] = "forward",
eps: float = 1e-12,
max_l2: float | None = None,
center_value: float = -1.0,
neighbor_sum: float = 1.0,
)
Parameters¶
in_channels (int): Number of channels in the input volume.
out_channels (int): Number of output channels.
kernel_size (int or tuple[int, …]): 3D kernel size.
stride (int or tuple[int, …], optional): Stride. Default is 1.
padding (_PaddingStr or int or tuple[int, …], optional): Padding. Supports “same” and “valid”. Default is 0.
dilation (int or tuple[int, …], optional): Dilation. Default is 1.
groups (int, optional): Grouped convolution factor. Default is 1.
bias (bool, optional): If True, adds learnable bias. Default is True.
constraint (str, optional): Kernel constraint mode.
enforce (str, optional): “forward” or “post_step”.
eps (float, optional): Stability constant.
max_l2 (float | None, optional): Required for “max_l2”.
center_value (float, optional): Center coefficient for “fixed_center”.
neighbor_sum (float, optional): Sum target of non-center coefficients.
Mathematical Formulation¶
For input \(x\) and kernel \(W\):
where constrained kernel is
for enforce=”forward”, and
is executed explicitly with project_() for enforce=”post_step”.
For input shape \((N, C_{in}, D_{in}, H_{in}, W_{in})\), output shape is \((N, C_{out}, D_{out}, H_{out}, W_{out})\) where
Constraint Modes¶
Each mode is applied to each \((K_D, K_H, K_W)\) kernel block:
none
\[\tilde{W} = W\]nonneg
\[\tilde{W} = \max(W, 0)\]sum_to_one
\[\tilde{W} = \frac{W}{\sum_{p,q,r} W_{p,q,r} + \varepsilon}\]zero_mean
\[\tilde{W} = W - \frac{1}{K_D K_H K_W}\sum_{p,q,r} W_{p,q,r}\]nonneg_sum1
\[\tilde{W} = \frac{\max(W, 0)}{\sum_{p,q,r} \max(W_{p,q,r},0) + \varepsilon}\]unit_l2
\[\tilde{W} = \frac{W}{\sqrt{\sum_{p,q,r} W_{p,q,r}^2 + \varepsilon}}\]max_l2
\[\tilde{W} = W \cdot \min\left(1, \frac{c}{\sqrt{\sum_{p,q,r} W_{p,q,r}^2 + \varepsilon}}\right)\]fixed_center (odd kernel sizes required)
Let center index be \((p_c, q_c, r_c)\):
\[\tilde{W}_{p_c,q_c,r_c} = v_c, \quad \sum_{(p,q,r)\neq(p_c,q_c,r_c)} \tilde{W}_{p,q,r} = s_n\]
Examples¶
1) Constrained volumetric convolution
>>> import lucid
>>> import lucid.nn as nn
>>> x = lucid.random.randn(2, 4, 16, 32, 32)
>>> conv = nn.ConstrainedConv3d(
... 4, 8, kernel_size=3, padding=1,
... constraint="unit_l2",
... )
>>> y = conv(x)
>>> y.shape
(2, 8, 16, 32, 32)
2) Spatio-temporal residual modeling with zero-mean kernels
>>> import lucid
>>> import lucid.nn as nn
>>> x = lucid.random.randn(1, 3, 8, 64, 64)
>>> conv = nn.ConstrainedConv3d(3, 6, kernel_size=3, padding=1, constraint="zero_mean")
>>> y = conv(x)
>>> y.shape
(1, 6, 8, 64, 64)
3) Hard projected max-L2 constrained training step
>>> import lucid
>>> import lucid.nn as nn
>>> import lucid.optim as optim
>>> conv = nn.ConstrainedConv3d(
... 6, 6, kernel_size=3, padding=1,
... constraint="max_l2", max_l2=0.8,
... enforce="post_step",
... )
>>> opt = optim.SGD(conv.parameters(), lr=1e-2)
>>> x = lucid.random.randn(2, 6, 6, 24, 24)
>>> loss = conv(x).mean()
>>> loss.backward()
>>> opt.step()
>>> conv.project_()
>>> opt.zero_grad()
4) Fixed-center 3D kernels
>>> import lucid
>>> import lucid.nn as nn
>>> x = lucid.random.randn(1, 1, 10, 20, 20)
>>> conv = nn.ConstrainedConv3d(
... 1, 4, kernel_size=5, padding=2,
... constraint="fixed_center",
... center_value=-1.0,
... neighbor_sum=1.0,
... )
>>> y = conv(x)
>>> y.shape
(1, 4, 10, 20, 20)
5) Grouped constrained Conv3d
>>> import lucid
>>> import lucid.nn as nn
>>> x = lucid.random.randn(2, 8, 8, 16, 16)
>>> conv = nn.ConstrainedConv3d(
... 8, 8, kernel_size=3, padding=1,
... groups=4,
... constraint="nonneg",
... )
>>> y = conv(x)
>>> y.shape
(2, 8, 8, 16, 16)