fn

register_kl

Callable[[F], F]
register_kl(p_cls: type, q_cls: type)
source

Register a closed-form KL implementation for a distribution pair.

Decorator used to add an analytical KL(pq)\mathrm{KL}(p \,\|\, q) formula to the global dispatch registry, keyed on the exact (or ancestor) types of the two arguments. Once registered, kl_divergence will route calls matching (p_cls, q_cls) (or any subclass thereof, via MRO walk) to the decorated function.

Parameters

p_clstype
Class of the "left" distribution pp (the reference distribution in KL(pq)\mathrm{KL}(p \,\|\, q)).
q_clstype
Class of the "right" distribution qq.

Returns

Callable[[F], F]

A decorator that takes a two-argument callable fn(p, q) -> Tensor and registers it in the internal KL registry, returning fn unchanged so the decorated function remains directly callable.

Notes

Dispatch precedence inside kl_divergence:

  1. Exact class match (type(p), type(q)).
  2. Most-derived ancestor pair found by walking the Cartesian product of type(p).__mro__ and type(q).__mro__ in order.
  3. Monte Carlo fall-back (single rsample draw) when pp supports reparameterised sampling.
  4. NotImplementedError otherwise.

Closed-form formulas are strongly preferred over Monte Carlo: they are deterministic, zero-variance, and propagate gradients exactly.

Examples

>>> from lucid.distributions import Normal
>>> from lucid.distributions.kl import register_kl, kl_divergence
>>> @register_kl(Normal, Normal)
... def _kl_normal_normal(p, q):
...     return (q.scale.log() - p.scale.log()
...             + (p.variance + (p.loc - q.loc) ** 2) / (2.0 * q.variance)
...             - 0.5)
>>> p = Normal(0.0, 1.0); q = Normal(0.0, 2.0)
>>> kl_divergence(p, q)
Tensor(...)