class
BatchNorm2d
extends
_BatchNormBaseBatchNorm2d(num_features: int, eps: float = 1e-05, momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True, device: DeviceLike = None, dtype: DTypeLike = None)Batch normalization over a 4-D input (N, C, H, W).
Normalises each channel across the batch and spatial dimensions:
where and are computed over the axes for each channel .
During training, batch statistics are used and running statistics are updated via an exponential moving average:
During evaluation (model.eval()), the stored running statistics
running_mean and running_var are used instead.
Parameters
num_featuresintNumber of channels .
epsfloat= 1e-05Small constant added to the variance for numerical stability.
Default:
1e-5.momentumfloat or None= 0.1Exponential moving average factor for running statistics. If
None, uses cumulative moving average. Default: 0.1.affinebool= TrueIf
True, learns per-channel scale and shift
. Default: True.track_running_statsbool= TrueIf
True, maintains running_mean, running_var,
and num_batches_tracked. Default: True.deviceDeviceLike= NoneDevice for parameters and buffers. Default:
None.dtypeDTypeLike= NoneData type for parameters and buffers. Default:
None.Attributes
weightParameter or NoneLearnable scale of shape
(num_features,).
None when affine=False.biasParameter or NoneLearnable shift of shape
(num_features,).
None when affine=False.running_meanTensor or NoneRunning estimate of the per-channel mean, shape
(num_features,).
None when track_running_stats=False.running_varTensor or NoneRunning estimate of the per-channel variance, shape
(num_features,).
None when track_running_stats=False.num_batches_trackedTensor or NoneScalar counting the number of batches seen during training.
None when track_running_stats=False.Notes
- Input:
- Output: — same shape.
- BatchNorm2d is the most commonly used normalization layer in convolutional neural networks. It stabilises training by keeping activations in a well-scaled range after each convolutional block.
- At small batch sizes (e.g. ) the batch statistics
become noisy. Consider
GroupNormorInstanceNorm2din those settings.
Examples
>>> import lucid
>>> import lucid.nn as nn
>>> bn = nn.BatchNorm2d(64)
>>> x = lucid.randn(8, 64, 32, 32)
>>> out = bn(x) # normalised per channel
>>> out.shape
(8, 64, 32, 32)
>>> # Eval mode uses running statistics (no batch dependence)
>>> bn.eval()
>>> with lucid.no_grad():
... out = bn(x)