fn
fuse_conv_bn_weights
→(Tensor, Tensor)fuse_conv_bn_weights(conv_w: Tensor, conv_b: Tensor | None, bn_rm: Tensor, bn_rv: Tensor, bn_eps: float, bn_w: Tensor | None, bn_b: Tensor | None)Low-level form of fuse_conv_bn_eval operating on raw tensors.
Takes the relevant Conv and BN tensors as plain arguments and
returns the fused (weight, bias) pair. Useful for build tools
that walk a serialised graph (ONNX, ahead-of-time compilation) and
need to perform the fusion without instantiating Module objects.
Parameters
conv_wTensorConvolution weight, shape
(out_channels, in_channels, *kernel).conv_bTensor or NoneConvolution bias, shape
(out_channels,), or None if the
conv has no bias.bn_rmTensorBatchNorm running mean, shape
(out_channels,).bn_rvTensorBatchNorm running variance, shape
(out_channels,).bn_epsfloatBatchNorm numerical-stability epsilon.
bn_wTensor or NoneBatchNorm affine weight , or
None if
affine=False.bn_bTensor or NoneBatchNorm affine bias , or
None.Returns
(Tensor, Tensor)Fused (weight, bias) on the same dtype / device as conv_w.
Notes
Identical math to fuse_conv_bn_eval:
Examples
>>> from lucid.nn.utils import fuse_conv_bn_weights
>>> W, b = fuse_conv_bn_weights(
... conv.weight, conv.bias,
... bn.running_mean, bn.running_var, bn.eps,
... bn.weight, bn.bias,
... )