class

GradScaler

GradScaler(init_scale: float = 2.0 ** 16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True)
source

Dynamic loss-scaling helper for mixed-precision training.

Mixed-precision training keeps activations and weights in fp16 to halve memory bandwidth and exploit fp16-fast hardware paths, but fp16's narrow dynamic range causes small gradients to underflow to zero — the network stops learning. GradScaler works around this by multiplying the loss by a large constant ss before backpropagation:

L~=sL,L~θ=sLθ.\tilde{L} = s \cdot L, \qquad \frac{\partial \tilde{L}}{\partial \theta} = s \cdot \frac{\partial L}{\partial \theta}.

The scaled gradients sit comfortably inside fp16's representable range; before the optimizer step they are unscaled by 1/s1/s in fp32 so the update is mathematically equivalent to ordinary training.

The scale itself is adapted dynamically. After every step the unscaled gradients are checked for inf / NaN:

  • Overflow detected — the step is skipped and ss is multiplied by backoff_factor (typically 0.5).
  • No overflow for growth_interval consecutive stepsss is multiplied by growth_factor (typically 2.0).

This produces a sawtooth schedule that tracks the largest scale the current gradient distribution can tolerate.

Parameters

init_scalefloat= 2**16
Initial loss scaling factor applied by scale.
growth_factorfloat= 2.0
Multiplier applied to the scale after growth_interval consecutive non-overflowing steps. Must be > 1.0.
backoff_factorfloat= 0.5
Multiplier applied when an inf / NaN gradient is detected. Must be in (0, 1).
growth_intervalint= 2000
Number of overflow-free steps required before the scale grows.
enabledbool= True
When False the scaler degenerates into a transparent pass-through — scale returns its input unchanged, step calls the optimizer directly, and update is a no-op.

Notes

The canonical training-loop pattern is scale-loss, then step, then update:

  1. scale multiplies the loss by ss before backward() so the gradients land safely inside fp16 range.
  2. step unscales the gradients, checks for inf / NaN, and either runs optimizer.step() or skips the update.
  3. update adjusts ss according to the growth / backoff schedule for the next iteration.

Examples

>>> scaler = GradScaler()
>>> for x, y in dataloader:
...     with autocast():
...         out = model(x)
...         loss = loss_fn(out, y)
...     scaler.scale(loss).backward()
...     scaler.step(optimizer)
...     scaler.update()

Methods (8)

dunder

__init__

None
__init__(init_scale: float = 2.0 ** 16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True)
source

Initialize the scaler state.

Parameters

init_scalefloat= 2**16
Initial loss scaling factor applied by scale.
growth_factorfloat= 2.0
Multiplier applied to the scale after growth_interval consecutive non-overflowing steps.
backoff_factorfloat= 0.5
Multiplier applied when an inf/NaN gradient is detected.
growth_intervalint= 2000
Number of overflow-free steps required before the scale grows.
enabledbool= True
When False the scaler is a transparent pass-through.
fn

scale

Tensor | list[Tensor]
scale(outputs: Tensor | list[Tensor])
source

Multiply outputs by the current scale factor.

Args: outputs: A Tensor or list of Tensors to scale.

Returns: Scaled Tensor(s) — same structure as input.

fn

unscale_

None
unscale_(optimizer: Optimizer)
source

Divide gradients by the current scale in-place.

Should be called before gradient clipping.

Args: optimizer: The optimizer whose parameters' grads will be unscaled.

fn

step

Tensor | None
step(optimizer: Optimizer, args: object = (), kwargs: object = {})
source

Unscale gradients and call optimizer.step() if no inf/nan detected.

If inf/nan is detected in gradients, skip the optimizer step.

Args: optimizer: The optimizer to step.

Returns: The return value of optimizer.step(), or None if step was skipped.

fn

update

None
update(new_scale: float | None = None)
source

Update the scale factor.

If a scale is provided, it is set directly. Otherwise, the scale is grown if no overflow was found for growth_interval steps, or reduced if overflow was found.

Args: new_scale: Explicit new scale value (optional).

fn

get_scale

float
get_scale()
source

Return the current scale factor.

fn

state_dict

dict[str, float]
state_dict()
source

Return serializable state dict.

fn

load_state_dict

None
load_state_dict(state_dict: dict[str, float])
source

Load state from a dict.