fn

fuse_linear_bn_weights

(Tensor, Tensor)
fuse_linear_bn_weights(linear_w: Tensor, linear_b: Tensor | None, bn_rm: Tensor, bn_rv: Tensor, bn_eps: float, bn_w: Tensor | None, bn_b: Tensor | None)
source

Low-level form of fuse_linear_bn_eval operating on raw tensors.

Internally identical to fuse_conv_bn_weights — a Linear's weight is just a 2-D conv weight from the BN's perspective. Exposed under its own name to keep call sites self-documenting in graph-rewrite passes.

Parameters

linear_wTensor
Linear weight matrix, shape (out_features, in_features).
linear_bTensor or None
Linear bias, shape (out_features,), or None.
bn_rmTensor
BatchNorm1d running mean.
bn_rvTensor
BatchNorm1d running variance.
bn_epsfloat
BatchNorm numerical-stability epsilon.
bn_wTensor or None
BatchNorm affine γ\gamma.
bn_bTensor or None
BatchNorm affine β\beta.

Returns

(Tensor, Tensor)

Fused (weight, bias) on the same dtype / device as linear_w.

Notes

Math:

Wfused=Wγ/σ2+ϵ,bfused=γ(bμ)/σ2+ϵ+β.\mathbf{W}_{\text{fused}} = \mathbf{W} \cdot \gamma / \sqrt{\sigma^2 + \epsilon}, \qquad b_{\text{fused}} = \gamma (b - \mu)/\sqrt{\sigma^2 + \epsilon} + \beta.

Examples

>>> from lucid.nn.utils import fuse_linear_bn_weights
>>> W, b = fuse_linear_bn_weights(
...     linear.weight, linear.bias,
...     bn.running_mean, bn.running_var, bn.eps,
...     bn.weight, bn.bias,
... )