fn

batch_norm

Tensor
batch_norm(x: Tensor, running_mean: Tensor | None, running_var: Tensor | None, weight: Tensor | None = None, bias: Tensor | None = None, training: bool = False, momentum: float = 0.1, eps: float = 1e-05)
source

Batch normalization (Ioffe & Szegedy, 2015).

Normalises each channel using statistics computed across the batch and all spatial axes, then applies a learnable per-channel affine transform. Acts as a strong regulariser and an enabler of higher learning rates by reducing "internal covariate shift".

Parameters

xTensor
Input of shape (N, C, *) with 2–5 dimensions: (N, C), (N, C, L), (N, C, H, W), or (N, C, D, H, W).
running_meanTensor or None
Running mean buffer of shape (C,); consulted in eval mode.
running_varTensor or None
Running variance buffer of shape (C,); consulted in eval mode.
weightTensor= None
Per-channel scale γ\gamma of shape (C,). Defaults to ones (no scaling).
biasTensor= None
Per-channel shift β\beta of shape (C,). Defaults to zeros (no shift).
trainingbool= False
When True, statistics come from the current batch; when False and running buffers are supplied, those are used.
momentumfloat= 0.1
Exponential-moving-average coefficient for the running buffers.
epsfloat= 1e-05
Numerical safety added inside the square root.

Returns

Tensor

Same shape as x.

Notes

Math (over batch B\mathcal{B} and spatial dims):

μB=1BiBxiσB2=1BiB(xiμB)2x^i=xiμBσB2+ϵyi=γx^i+β\begin{aligned} \mu_{\mathcal{B}} &= \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} x_i \\ \sigma_{\mathcal{B}}^2 &= \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} (x_i - \mu_{\mathcal{B}})^2 \\ \hat{x}_i &= \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}} \\ y_i &= \gamma\,\hat{x}_i + \beta \end{aligned}

Dispatches internally by ndim to the matching engine kernel: ndim==2 is treated as (N, C, 1); ndim==3 uses batch_norm1d; ndim==4 uses the 2-D op; ndim==5 uses batch_norm3d. Eval mode bypasses statistics computation entirely and uses running_mean / running_var directly.

Examples

>>> import lucid
>>> from lucid.nn.functional import batch_norm
>>> x = lucid.randn(8, 16, 32, 32)
>>> rm = lucid.zeros(16); rv = lucid.ones(16)
>>> y = batch_norm(x, rm, rv, training=True)
>>> y.shape
(8, 16, 32, 32)