optim.RAdam¶
- class lucid.optim.RAdam(params: Iterable[Parameter], lr: float = 0.001, betas: tuple[int | float | complex, int | float | complex] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.0)¶
The RAdam class implements the Rectified Adam (RAdam) optimization algorithm. RAdam addresses the slow convergence problem in Adam during early training steps by rectifying the variance of the adaptive learning rate. By doing so, RAdam combines the fast convergence of Adam with the stability and generalization properties of SGD.
Class Signature¶
class RAdam(optim.Optimizer):
def __init__(
self,
params: Iterable[nn.Parameter],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
) -> None
Parameters¶
Learning Rate (`lr`): Controls the step size during parameter updates. A higher learning rate can speed up convergence but may lead to instability.
Betas (`betas`): A tuple of two coefficients controlling the exponential decay of moment estimates. The first value (beta1) controls the decay rate for the first moment (mean of gradients), while the second value (beta2) controls the decay rate for the second moment (variance of gradients).
Epsilon (`eps`): A small constant added to the denominator to prevent division by zero. This improves numerical stability during training.
Weight Decay (`weight_decay`): This coefficient controls the L2 penalty applied to the model parameters. Unlike traditional L2 regularization, weight decay in RAdam is decoupled from the gradient calculation and applied directly to the parameter values.
Algorithm¶
The RAdam optimization algorithm updates each parameter according to the following formulas:
Where:
\(\nabla_t\) is the gradient of the loss with respect to the parameter at iteration \(t\).
\(m_t\) and \(v_t\) are the exponentially decaying first and second moment estimates.
\(\hat{m}_t\) and \(\hat{v}_t\) are bias-corrected estimates of the first and second moments.
\(\rho_t\) represents the variance rectification term.
\(\theta_t\) is the parameter being updated at iteration \(t\).
\(\text{lr}\) is the learning rate.
The variance rectification factor \(r_t\) ensures that the step size is adjusted dynamically for small batch sizes or early steps, improving the convergence properties.
Examples¶
Using the RAdam Optimizer
import lucid.optim as optim
import lucid.nn as nn
# Define a simple model
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter([1.0, 2.0, 3.0])
def forward(self, x):
return x * self.param
# Initialize model and RAdam optimizer
model = MyModel()
optimizer = optim.RAdam(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01,
)
# Training loop
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = compute_loss(output, target)
loss.backward()
optimizer.step()
Inspecting Optimizer State
Use the state_dict() and load_state_dict() methods to save and load the optimizer state.
# Save state
optimizer_state = optimizer.state_dict()
# Load state
optimizer.load_state_dict(optimizer_state)