fn

fuse_conv_bn_weights

(Tensor, Tensor)
fuse_conv_bn_weights(conv_w: Tensor, conv_b: Tensor | None, bn_rm: Tensor, bn_rv: Tensor, bn_eps: float, bn_w: Tensor | None, bn_b: Tensor | None)
source

Low-level form of fuse_conv_bn_eval operating on raw tensors.

Takes the relevant Conv and BN tensors as plain arguments and returns the fused (weight, bias) pair. Useful for build tools that walk a serialised graph (ONNX, ahead-of-time compilation) and need to perform the fusion without instantiating Module objects.

Parameters

conv_wTensor
Convolution weight, shape (out_channels, in_channels, *kernel).
conv_bTensor or None
Convolution bias, shape (out_channels,), or None if the conv has no bias.
bn_rmTensor
BatchNorm running mean, shape (out_channels,).
bn_rvTensor
BatchNorm running variance, shape (out_channels,).
bn_epsfloat
BatchNorm numerical-stability epsilon.
bn_wTensor or None
BatchNorm affine weight γ\gamma, or None if affine=False.
bn_bTensor or None
BatchNorm affine bias β\beta, or None.

Returns

(Tensor, Tensor)

Fused (weight, bias) on the same dtype / device as conv_w.

Notes

Identical math to fuse_conv_bn_eval:

Wfused=Wγ/σ2+ϵ,bfused=γ(bμ)/σ2+ϵ+β.\mathbf{W}_{\text{fused}} = \mathbf{W} \cdot \gamma / \sqrt{\sigma^2 + \epsilon}, \qquad b_{\text{fused}} = \gamma (b - \mu)/\sqrt{\sigma^2 + \epsilon} + \beta.

Examples

>>> from lucid.nn.utils import fuse_conv_bn_weights
>>> W, b = fuse_conv_bn_weights(
...     conv.weight, conv.bias,
...     bn.running_mean, bn.running_var, bn.eps,
...     bn.weight, bn.bias,
... )