fn
fuse_linear_bn_weights
→(Tensor, Tensor)fuse_linear_bn_weights(linear_w: Tensor, linear_b: Tensor | None, bn_rm: Tensor, bn_rv: Tensor, bn_eps: float, bn_w: Tensor | None, bn_b: Tensor | None)Low-level form of fuse_linear_bn_eval operating on raw tensors.
Internally identical to fuse_conv_bn_weights — a Linear's
weight is just a 2-D conv weight from the BN's perspective.
Exposed under its own name to keep call sites self-documenting in
graph-rewrite passes.
Parameters
linear_wTensorLinear weight matrix, shape
(out_features, in_features).linear_bTensor or NoneLinear bias, shape
(out_features,), or None.bn_rmTensorBatchNorm1d running mean.
bn_rvTensorBatchNorm1d running variance.
bn_epsfloatBatchNorm numerical-stability epsilon.
bn_wTensor or NoneBatchNorm affine .
bn_bTensor or NoneBatchNorm affine .
Returns
(Tensor, Tensor)Fused (weight, bias) on the same dtype / device as
linear_w.
Notes
Math:
Examples
>>> from lucid.nn.utils import fuse_linear_bn_weights
>>> W, b = fuse_linear_bn_weights(
... linear.weight, linear.bias,
... bn.running_mean, bn.running_var, bn.eps,
... bn.weight, bn.bias,
... )