AdamW
OptimizerAdamW(params: Iterable[Parameter] | Iterable[dict[str, object]], lr: float = 0.001, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.01, amsgrad: bool = False)Adam optimizer with decoupled weight decay regularisation.
AdamW fixes the weight-decay coupling present in standard Adam by
applying the decay directly to the parameters rather than adding it to
the gradient. The update rule is:
The final term is the decoupled weight decay; it is applied after the adaptive gradient step, not mixed into .
Parameters
paramsiterable of Parameter or iterable of dictlrfloat= 0.0011e-3).betastuple of float= (0.9, 0.999)(0.9, 0.999)).epsfloat= 1e-081e-8).weight_decayfloat= 0.011e-2).amsgradbool= FalseFalse).Attributes
param_groupslist of dict"params", "lr",
"beta1", "beta2", "eps", "weight_decay", and
"amsgrad".defaultsdictNotes
Decoupled weight decay makes the effective regularisation independent
of the learning rate, which simplifies hyperparameter tuning.
AdamW is the recommended default optimizer for transformer-based
models and generally outperforms Adam with L2 regularisation.
Examples
>>> import lucid.optim as optim
>>> optimizer = optim.AdamW(
... model.parameters(), lr=1e-4, weight_decay=1e-2
... )
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()Methods (2)
__init__
→None__init__(params: Iterable[Parameter] | Iterable[dict[str, object]], lr: float = 0.001, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.01, amsgrad: bool = False)Initialise the AdamW. See the class docstring for parameter semantics.
step
→Tensor or Nonestep(closure: _OptimizerClosure = None)Perform a single AdamW optimisation step.
Calls the engine-level AdamW update for each parameter group, which applies the adaptive gradient update followed by decoupled weight decay directly on the parameters.
Parameters
closurecallable= NoneReturns
Tensor or NoneThe loss returned by closure, or None if no closure
was provided.
Examples
>>> optimizer.zero_grad()
>>> loss = model(inputs)
>>> loss.backward()
>>> optimizer.step()