fn

trunc_normal_

Tensor
trunc_normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0)
source

Initialise tensor in-place with truncated normal samples.

Each entry is drawn from N(mean,std2)\mathcal{N}(\text{mean}, \text{std}^2) and rejected if it falls outside [a, b]; rejection sampling is repeated until tensor is fully populated. This is the preferred initialiser for transformer weights (e.g. ViT, BERT) where unbounded Gaussian tails would otherwise produce rare but disruptive activations.

Parameters

tensorTensor
Tensor to initialise in place; any shape is accepted.
meanfloat= 0.0
Mean of the underlying (untruncated) Gaussian. Default 0.0.
stdfloat= 1.0
Standard deviation of the underlying Gaussian. Default 1.0.
afloat= -2.0
Lower truncation bound. Default -2.0.
bfloat= 2.0
Upper truncation bound. Default 2.0.

Returns

Tensor

tensor (mutated) for chaining.

Notes

The conditional density is

p(xaxb)=1Z12πstdexp ⁣((xmean)22std2),x[a,b],p(x \mid a \le x \le b) = \frac{1}{Z}\, \frac{1}{\sqrt{2\pi}\,\text{std}}\, \exp\!\left(-\frac{(x - \text{mean})^2}{2\,\text{std}^2}\right), \quad x \in [a, b],

where Z=Φ((bmean)/std)Φ((amean)/std)Z = \Phi((b - \text{mean})/\text{std}) - \Phi((a - \text{mean})/\text{std}) is the normalising constant.

Examples

>>> import lucid
>>> from lucid.nn.init import trunc_normal_
>>> w = lucid.empty(64, 32)
>>> trunc_normal_(w, mean=0.0, std=0.02, a=-0.04, b=0.04)