fn
trunc_normal_
→Tensortrunc_normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0)Initialise tensor in-place with truncated normal samples.
Each entry is drawn from and rejected if it falls outside [a, b];
rejection sampling is repeated until tensor is fully populated.
This is the preferred initialiser for transformer weights (e.g.
ViT, BERT) where unbounded Gaussian tails would otherwise produce
rare but disruptive activations.
Parameters
tensorTensorTensor to initialise in place; any shape is accepted.
meanfloat= 0.0Mean of the underlying (untruncated) Gaussian. Default
0.0.stdfloat= 1.0Standard deviation of the underlying Gaussian. Default
1.0.afloat= -2.0Lower truncation bound. Default
-2.0.bfloat= 2.0Upper truncation bound. Default
2.0.Returns
Tensortensor (mutated) for chaining.
Notes
The conditional density is
where is the normalising constant.
Examples
>>> import lucid
>>> from lucid.nn.init import trunc_normal_
>>> w = lucid.empty(64, 32)
>>> trunc_normal_(w, mean=0.0, std=0.02, a=-0.04, b=0.04)