class

NAdam

extendsOptimizer
NAdam(params: Iterable[Parameter] | Iterable[dict[str, object]], lr: float = 0.002, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0)
source

Nesterov-accelerated Adaptive Moment Estimation optimizer.

NAdam incorporates Nesterov momentum into Adam by replacing the standard first-moment estimate m^t\hat{m}_t in the denominator with a one-step lookahead estimate. The update rule is:

mt=β1mt1+(1β1)gtvt=β2vt1+(1β2)gt2m^tNesterov=β1mt1β1t+1+(1β1)gt1β1tθt=θt1αv^t+ϵm^tNesterov\begin{aligned} m_t &= \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t &= \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{m}_t^{\text{Nesterov}} &= \frac{\beta_1 m_t}{1 - \beta_1^{t+1}} + \frac{(1 - \beta_1) g_t}{1 - \beta_1^t} \\ \theta_t &= \theta_{t-1} - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t^{\text{Nesterov}} \end{aligned}

Parameters

paramsiterable of Parameter or iterable of dict
Parameters to optimise, or a list of parameter-group dicts.
lrfloat= 0.002
Learning rate α\alpha (default: 2e-3).
betastuple of float= (0.9, 0.999)
Coefficients (β1,β2)(\beta_1, \beta_2) for the first- and second-moment estimates (default: (0.9, 0.999)).
epsfloat= 1e-08
Term ϵ\epsilon for numerical stability (default: 1e-8).
weight_decayfloat= 0
L2 regularisation coefficient (default: 0).

Attributes

param_groupslist of dict
Parameter groups with keys "params", "lr", "beta1", "beta2", "eps", and "weight_decay".
defaultsdict
Default hyperparameter values.

Notes

NAdam often converges faster than vanilla Adam because the Nesterov lookahead provides a more accurate gradient direction. It is particularly effective on recurrent networks and tasks with noisy gradients.

Examples

>>> import lucid.optim as optim
>>> optimizer = optim.NAdam(model.parameters(), lr=2e-3)
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()

Methods (2)

dunder

__init__

None
__init__(params: Iterable[Parameter] | Iterable[dict[str, object]], lr: float = 0.002, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0)
source

Initialise the NAdam. See the class docstring for parameter semantics.

fn

step

Tensor | None
step(closure: _OptimizerClosure = None)
source

Perform a single NAdam step.