class

RMSNorm

extendsModule
RMSNorm(normalized_shape: int | list[int] | tuple[int, ...], eps: float = 1e-08, device: DeviceLike = None, dtype: DTypeLike = None)
source

Root Mean Square Layer Normalization.

Normalises the input by its root-mean-square value and applies a learnable per-element scale, but intentionally omits the mean subtraction step of standard layer norm:

y=xRMS(x)γ,RMS(x)=1di=1dxi2+εy = \frac{x}{\mathrm{RMS}(x)} \cdot \gamma, \qquad \mathrm{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \varepsilon}

where dd is the number of elements in normalized_shape.

Skipping mean subtraction reduces computation while retaining the re-scaling invariance property that makes normalization useful. This formulation is widely used in large language model architectures such as LLaMA.

Parameters

normalized_shapeint or list[int] or tuple[int, ...]
Shape of the trailing dimensions to normalize over. An integer d is equivalent to the single-element tuple (d,).
epsfloat= 1e-08
Small constant added inside the square root for numerical stability. Default: 1e-8.
deviceDeviceLike= None
Device on which to allocate weight. Default: None.
dtypeDTypeLike= None
Data type of weight. Default: None.

Attributes

weightParameter
Learnable per-element scale γ\gamma of shape normalized_shape, initialised to ones. Unlike LayerNorm, RMSNorm has no bias parameter by design.

Notes

  • Input: (,normalized_shape)(*, \text{normalized\_shape}).
  • Output: same shape as the input.
  • RMSNorm has no bias term; if a shift is needed, add a separate bias or use LayerNorm.
  • The default eps (1e-8) is intentionally smaller than that of LayerNorm (1e-5), since RMS values can be very small for zero-mean inputs.
  • The weight is initialised to all ones so the transformation starts as a pure normalisation.

Examples

Typical use in a transformer feed-forward block:
>>> import lucid
>>> import lucid.nn as nn
>>> norm = nn.RMSNorm(256)
>>> x = lucid.randn(4, 32, 256)   # (batch, seq_len, dim)
>>> out = norm(x)
>>> out.shape
(4, 32, 256)
Normalize over two trailing dimensions:
>>> norm2d = nn.RMSNorm((16, 16))
>>> x2d = lucid.randn(2, 8, 16, 16)
>>> out2d = norm2d(x2d)
>>> out2d.shape
(2, 8, 16, 16)

Methods (3)

dunder

__init__

None
__init__(normalized_shape: int | list[int] | tuple[int, ...], eps: float = 1e-08, device: DeviceLike = None, dtype: DTypeLike = None)
source

Initialise the RMSNorm module. See the class docstring for parameter semantics.

fn

forward

Tensor
forward(x: Tensor)
source

Apply normalisation to the input tensor.

Parameters

inputTensor
Input tensor whose shape is documented in the class docstring.

Returns

Tensor

Normalised tensor of the same shape as input.

fn

extra_repr

str
extra_repr()
source

Return a string representation of the layer's configuration.