VAEConfig¶
- class lucid.models.VAEConfig(encoders: list[lucid.nn.module.Module], decoders: list[lucid.nn.module.Module], priors: list[lucid.nn.module.Module] | None = None, reconstruction_loss: Literal['mse', 'bce'] = 'mse', kl_weight: float = 1.0, beta_schedule: Callable[[int], float] | None = None, hierarchical_kl: bool = True, depth: int | None = None)¶
VAEConfig stores the encoder stack, decoder stack, optional hierarchical
priors, and KL scheduling options used by lucid.models.VAE.
Class Signature¶
@dataclass
class VAEConfig:
encoders: list[nn.Module]
decoders: list[nn.Module]
priors: list[nn.Module] | None = None
reconstruction_loss: Literal["mse", "bce"] = "mse"
kl_weight: float = 1.0
beta_schedule: Callable[[int], float] | None = None
hierarchical_kl: bool = True
depth: int | None = None
Parameters¶
encoders (list[nn.Module]): Encoder modules that each emit concatenated mean and log-variance tensors.
decoders (list[nn.Module]): Decoder modules applied from the deepest latent back to the reconstruction.
priors (list[nn.Module] | None): Optional hierarchical prior modules. The current implementation expects exactly depth - 1 modules when provided.
reconstruction_loss (Literal[“mse”, “bce”]): Reconstruction loss mode.
kl_weight (float): Base KL multiplier.
beta_schedule (Callable[[int], float] | None): Optional function that overrides KL weight per training step.
hierarchical_kl (bool): Whether KL weight is distributed over latent levels.
depth (int | None): Number of latent levels. Defaults to the encoder count.
Validation¶
encoders and decoders must be non-empty and contain only nn.Module.
depth must be positive and must match both encoder and decoder counts.
priors, when provided, must contain only nn.Module and must have length depth - 1.
reconstruction_loss must be “mse” or “bce”.
kl_weight must be non-negative.
beta_schedule must be callable or None.
hierarchical_kl must be a boolean.
Usage¶
import lucid.models as models
import lucid.nn as nn
config = models.VAEConfig(
encoders=[nn.Sequential(nn.Linear(784, 128))],
decoders=[nn.Sequential(nn.Linear(64, 784))],
reconstruction_loss="bce",
)
model = models.VAE(config)