GradScaler
GradScaler(init_scale: float = 2.0 ** 16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True)Dynamic loss-scaling helper for mixed-precision training.
Mixed-precision training keeps activations and weights in fp16 to
halve memory bandwidth and exploit fp16-fast hardware paths, but
fp16's narrow dynamic range causes small gradients to underflow to
zero — the network stops learning. GradScaler works
around this by multiplying the loss by a large constant
before backpropagation:
The scaled gradients sit comfortably inside fp16's representable range; before the optimizer step they are unscaled by in fp32 so the update is mathematically equivalent to ordinary training.
The scale itself is adapted dynamically. After every step the
unscaled gradients are checked for inf / NaN:
- Overflow detected — the step is skipped and is
multiplied by
backoff_factor(typically0.5). - No overflow for
growth_intervalconsecutive steps — is multiplied bygrowth_factor(typically2.0).
This produces a sawtooth schedule that tracks the largest scale the current gradient distribution can tolerate.
Parameters
init_scalefloat= 2**16scale.growth_factorfloat= 2.0growth_interval
consecutive non-overflowing steps. Must be > 1.0.backoff_factorfloat= 0.5inf / NaN gradient is
detected. Must be in (0, 1).growth_intervalint= 2000enabledbool= TrueFalse the scaler degenerates into a transparent
pass-through — scale returns its input unchanged,
step calls the optimizer directly, and update
is a no-op.Notes
The canonical training-loop pattern is scale-loss, then step, then update:
scalemultiplies the loss by beforebackward()so the gradients land safely inside fp16 range.stepunscales the gradients, checks forinf/NaN, and either runsoptimizer.step()or skips the update.updateadjusts according to the growth / backoff schedule for the next iteration.
Examples
>>> scaler = GradScaler()
>>> for x, y in dataloader:
... with autocast():
... out = model(x)
... loss = loss_fn(out, y)
... scaler.scale(loss).backward()
... scaler.step(optimizer)
... scaler.update()Methods (8)
__init__
→None__init__(init_scale: float = 2.0 ** 16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True)Initialize the scaler state.
Parameters
init_scalefloat= 2**16scale.growth_factorfloat= 2.0growth_interval
consecutive non-overflowing steps.backoff_factorfloat= 0.5growth_intervalint= 2000enabledbool= TrueFalse the scaler is a transparent pass-through.scale
→Tensor | list[Tensor]scale(outputs: Tensor | list[Tensor])Multiply outputs by the current scale factor.
Args: outputs: A Tensor or list of Tensors to scale.
Returns: Scaled Tensor(s) — same structure as input.
unscale_
→Noneunscale_(optimizer: Optimizer)Divide gradients by the current scale in-place.
Should be called before gradient clipping.
Args: optimizer: The optimizer whose parameters' grads will be unscaled.
step
→Tensor | Nonestep(optimizer: Optimizer, args: object = (), kwargs: object = {})Unscale gradients and call optimizer.step() if no inf/nan detected.
If inf/nan is detected in gradients, skip the optimizer step.
Args: optimizer: The optimizer to step.
Returns: The return value of optimizer.step(), or None if step was skipped.
update
→Noneupdate(new_scale: float | None = None)Update the scale factor.
If a scale is provided, it is set directly. Otherwise, the scale is grown if no overflow was found for growth_interval steps, or reduced if overflow was found.
Args: new_scale: Explicit new scale value (optional).
get_scale
→floatget_scale()Return the current scale factor.
state_dict
→dict[str, float]state_dict()Return serializable state dict.
load_state_dict
→Noneload_state_dict(state_dict: dict[str, float])Load state from a dict.