fn
register_kl
→Callable[[F], F]register_kl(p_cls: type, q_cls: type)Register a closed-form KL implementation for a distribution pair.
Decorator used to add an analytical
formula to the global dispatch
registry, keyed on the exact (or ancestor) types of the two
arguments. Once registered, kl_divergence will route calls
matching (p_cls, q_cls) (or any subclass thereof, via MRO walk)
to the decorated function.
Parameters
p_clstypeClass of the "left" distribution (the reference
distribution in ).
q_clstypeClass of the "right" distribution .
Returns
Callable[[F], F]A decorator that takes a two-argument callable
fn(p, q) -> Tensor and registers it in the internal KL
registry, returning fn unchanged so the decorated function
remains directly callable.
Notes
Dispatch precedence inside kl_divergence:
- Exact class match
(type(p), type(q)). - Most-derived ancestor pair found by walking the Cartesian product
of
type(p).__mro__andtype(q).__mro__in order. - Monte Carlo fall-back (single
rsampledraw) when supports reparameterised sampling. NotImplementedErrorotherwise.
Closed-form formulas are strongly preferred over Monte Carlo: they are deterministic, zero-variance, and propagate gradients exactly.
Examples
>>> from lucid.distributions import Normal
>>> from lucid.distributions.kl import register_kl, kl_divergence
>>> @register_kl(Normal, Normal)
... def _kl_normal_normal(p, q):
... return (q.scale.log() - p.scale.log()
... + (p.variance + (p.loc - q.loc) ** 2) / (2.0 * q.variance)
... - 0.5)
>>> p = Normal(0.0, 1.0); q = Normal(0.0, 2.0)
>>> kl_divergence(p, q)
Tensor(...)