class
RMSNorm
extends
ModuleRMSNorm(normalized_shape: int | list[int] | tuple[int, ...], eps: float = 1e-08, device: DeviceLike = None, dtype: DTypeLike = None)Root Mean Square Layer Normalization.
Normalises the input by its root-mean-square value and applies a learnable per-element scale, but intentionally omits the mean subtraction step of standard layer norm:
where is the number of elements in normalized_shape.
Skipping mean subtraction reduces computation while retaining the re-scaling invariance property that makes normalization useful. This formulation is widely used in large language model architectures such as LLaMA.
Parameters
normalized_shapeint or list[int] or tuple[int, ...]Shape of the trailing dimensions to normalize over. An integer
d is equivalent to the single-element tuple (d,).epsfloat= 1e-08Small constant added inside the square root for numerical
stability. Default:
1e-8.deviceDeviceLike= NoneDevice on which to allocate
weight. Default: None.dtypeDTypeLike= NoneData type of
weight. Default: None.Attributes
weightParameterLearnable per-element scale of shape
normalized_shape, initialised to ones. Unlike
LayerNorm, RMSNorm has no bias parameter by design.Notes
- Input: .
- Output: same shape as the input.
- RMSNorm has no bias term; if a shift is needed, add a separate
bias or use
LayerNorm. - The default
eps(1e-8) is intentionally smaller than that ofLayerNorm(1e-5), since RMS values can be very small for zero-mean inputs. - The weight is initialised to all ones so the transformation starts as a pure normalisation.
Examples
Typical use in a transformer feed-forward block:
>>> import lucid
>>> import lucid.nn as nn
>>> norm = nn.RMSNorm(256)
>>> x = lucid.randn(4, 32, 256) # (batch, seq_len, dim)
>>> out = norm(x)
>>> out.shape
(4, 32, 256)
Normalize over two trailing dimensions:
>>> norm2d = nn.RMSNorm((16, 16))
>>> x2d = lucid.randn(2, 8, 16, 16)
>>> out2d = norm2d(x2d)
>>> out2d.shape
(2, 8, 16, 16)Methods (3)
dunder
__init__
→None__init__(normalized_shape: int | list[int] | tuple[int, ...], eps: float = 1e-08, device: DeviceLike = None, dtype: DTypeLike = None)Initialise the RMSNorm module. See the class docstring for parameter semantics.
fn
forward
→Tensorforward(x: Tensor)Apply normalisation to the input tensor.
Parameters
inputTensorInput tensor whose shape is documented in the class docstring.
Returns
TensorNormalised tensor of the same shape as input.
fn
extra_repr
→strextra_repr()Return a string representation of the layer's configuration.