VAE

Autoencoder Variational Autoencoder

class lucid.models.VAE(config: VAEConfig)

The VAE class implements a generalized Variational Autoencoder supporting standard and hierarchical variants. It is compatible with various encoder-decoder architectures and prior configurations.

This model is modular, allowing users to define multiple layers of latent variables and optionally plug in learnable prior modules for deep hierarchies. It is designed to support \(\beta\)-scheduling and flexible loss configurations.

        %%{init: {"flowchart":{"curve":"monotoneX","nodeSpacing":50,"rankSpacing":50}} }%%
flowchart LR
  linkStyle default stroke-width:2.0px
  subgraph sg_m0["<span style='font-size:20px;font-weight:700'>VAE</span>"]
  style sg_m0 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
    subgraph sg_m1["encoders"]
      direction TB;
    style sg_m1 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m2["Sequential"]
        direction TB;
      style sg_m2 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m3["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,784) → (1,512)</span>"];
        m4["ReLU"];
        m5["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,512) → (1,128)</span>"];
      end
    end
    subgraph sg_m6["decoders"]
      direction TB;
    style sg_m6 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
      subgraph sg_m7["Sequential"]
        direction TB;
      style sg_m7 fill:#000000,fill-opacity:0.05,stroke:#000000,stroke-opacity:0.75,stroke-width:1px
        m8["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,64) → (1,512)</span>"];
        m9["ReLU"];
        m10["Linear<br/><span style='font-size:11px;color:#2b6cb0;font-weight:400'>(1,512) → (1,784)</span>"];
      end
    end
  end
  input["Input<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,784)</span>"];
  output["Output<br/><span style='font-size:11px;color:#a67c00;font-weight:400'>(1,784)x4</span>"];
  style input fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style output fill:#fff3cd,stroke:#a67c00,stroke-width:1px;
  style m3 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m4 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
  style m5 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m8 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  style m9 fill:#faf5ff,stroke:#6b46c1,stroke-width:1px;
  style m10 fill:#ebf8ff,stroke:#2b6cb0,stroke-width:1px;
  input --> m3;
  m10 --> output;
  m3 --> m4;
  m4 --> m5;
  m5 --> m8;
  m8 --> m9;
  m9 --> m10;
    

Class Signature

class VAE(config: VAEConfig)

Parameters

  • config (VAEConfig): Configuration object that stores the encoder stack, decoder stack, optional hierarchical priors, reconstruction loss mode, and KL weighting strategy.

Configuration

  • encoders (list[nn.Module]): Encoder modules. Each encoder must output (N, 2 * D) so it can be split into mean and log-variance tensors.

  • decoders (list[nn.Module]): Decoder modules applied in reverse order from the deepest latent sample.

  • priors (list[nn.Module] | None): Optional hierarchical prior modules. When provided, the current implementation expects exactly depth - 1 modules.

  • reconstruction_loss (Literal[“mse”, “bce”]): Reconstruction loss mode.

  • kl_weight (float): Base \(\beta\) multiplier for KL divergence.

  • beta_schedule (Callable[[int], float] | None): Optional schedule that overrides \(\beta\) per training step.

  • hierarchical_kl (bool): If True, distributes KL weighting across latent levels.

  • depth (int | None): Latent depth. If omitted, it defaults to the encoder count.

Returns

Use the forward method to obtain outputs:

recon, mus, logvars, zs = model(x)
  • recon (Tensor): Reconstructed tensor in original input shape.

  • mus (list[Tensor]): Means for each latent layer.

  • logvars (list[Tensor]): Log-variances for each latent layer.

  • zs (list[Tensor]): Sampled latent vectors after reparameterization.

Loss Components

Use the method get_loss() to compute total VAE loss:

loss, recon_loss, kl = model.get_loss(x, recon, mus, logvars, zs)

Where:

  • recon_loss: is the MSE or BCE loss.

  • kl: total KL divergence (may be hierarchical).

  • loss: full objective: \(\mathcal{L} = \text{recon} + \beta \cdot \text{KL}\)

Methods

VAE.reparameterize(mu: Tensor, logvar: Tensor) Tensor
VAE.encode(x: Tensor) tuple[list[Tensor], ...]
VAE.decode(zs: list[Tensor]) Tensor
VAE.current_beta() float

Module Output Requirements

The encoder modules in encoders list must return a Tensor of shape (N, 2 * D), which will be split into:

\[\mu, \log \sigma^2 = \text{split}(h, \text{axis}=1)\]

Similarly, priors modules (if provided) should accept a latent tensor from the deeper layer and output the same (N, 2 * D) format to define hierarchical priors.

Decoder modules in decoders list must accept a Tensor z and sequentially transform it back to the original input shape.

Examples

import lucid
import lucid.nn as nn
import lucid.nn.functional as F
from lucid.models import VAE

encoder = nn.Sequential(
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Linear(512, 128),
)

decoder = nn.Sequential(
    nn.Linear(64, 512),
    nn.ReLU(),
    nn.Linear(512, 784),
)

from lucid.models import VAEConfig

vae = VAE(
    VAEConfig(
        encoders=[encoder],
        decoders=[decoder],
    )
)

x = lucid.randn(32, 784)
recon, mus, logvars, zs = vae(x)
loss, recon_loss, kl = vae.get_loss(x, recon, mus, logvars, zs)

Note

This VAE class encompasses the standard VAE, \(\beta\)-VAE, and Hierarchical VAE (HAVE).

Tip

The reparameterization is implemented as:

\[z = \mu + \epsilon \cdot \exp(\log \sigma^2 / 2), \quad \epsilon \sim \mathcal{N}(0, I)\]

Warning

Make sure the final decoder output is in the same shape and scale as the input, especially when using BCE loss, which expects values in \([0, 1]\).