fn

fuse_linear_bn_eval

Module
fuse_linear_bn_eval(linear: object, bn: object)
source

Fold a BatchNorm1d into the preceding Linear weights (inference-only).

The 1-D analogue of fuse_conv_bn_eval — absorbs the BN's eval-time affine transform into a Linear layer's weight and bias. The fused linear computes BN(Linear(x)) exactly while using one fewer kernel.

Parameters

linearLinear
Linear layer to absorb the BN into.
bnBatchNorm1d
BatchNorm whose feature dimension matches linear.out_features. Must be in eval mode (or otherwise be using its frozen statistics).

Returns

Module

Deep copy of linear with fused parameters. Originals are not mutated.

Raises

TypeError
If linear is not lucid.nn.modules.linear.Linear or bn is not lucid.nn.modules.normalization.BatchNorm1d.

Notes

Math is identical to the conv case — the Linear's weight matrix is scaled row-wise by γ/σ2+ϵ\gamma / \sqrt{\sigma^2 + \epsilon} and the bias absorbs the mean shift and BN bias.

Examples

>>> from lucid.nn.utils import fuse_linear_bn_eval
>>> linear = nn.Linear(128, 64); bn = nn.BatchNorm1d(64)
>>> linear.eval(); bn.eval()
>>> fused = fuse_linear_bn_eval(linear, bn)