fn

kl_divergence

Tensor
kl_divergence(p: Distribution, q: Distribution)
source

Compute the Kullback–Leibler divergence KL(pq)\mathrm{KL}(p \,\|\, q).

The KL divergence is the expected log-ratio of two probability measures pp and qq defined on the same sample space:

KL(pq)=Exp ⁣[logp(x)q(x)]=p(x)logp(x)q(x)dx\mathrm{KL}(p \,\|\, q) = \mathbb{E}_{x \sim p}\!\left[\log \frac{p(x)}{q(x)}\right] = \int p(x) \log \frac{p(x)}{q(x)} \, dx

It is non-negative (KL(pq)0\mathrm{KL}(p \,\|\, q) \geq 0, with equality iff p=qp = q almost everywhere), is not symmetric in general (KL(pq)KL(qp)\mathrm{KL}(p \,\|\, q) \neq \mathrm{KL}(q \,\|\, p)), and does not satisfy the triangle inequality — so it is a divergence, not a metric.

Dispatch walks the registry built by register_kl: first an exact class match, then the MRO of type(p) × type(q) for the most-derived registered ancestor pair. Falls back to a single-sample Monte Carlo estimate when p.has_rsamplep.\mathrm{has\_rsample} is True and no analytical formula is registered.

Parameters

pDistribution
Left-hand distribution.
qDistribution
Right-hand distribution (must share the support / event shape of p).

Returns

Tensor

Non-negative tensor of shape batch_shape giving the per-batch KL divergence in nats.

Raises

NotImplementedError
When no closed-form pair is registered and pp does not support reparameterised sampling, so MC fall-back is unavailable.

Notes

KL divergence appears throughout machine learning:

  • Variational inference: the ELBO equals logp(x)KL(q(z)p(zx))\log p(\mathbf{x}) - \mathrm{KL}(q(\mathbf{z})\,\|\, p(\mathbf{z} \mid \mathbf{x})).
  • VAE training: the encoder regulariser is KL(q(zx)p(z))\mathrm{KL}(q(\mathbf{z} \mid \mathbf{x})\,\|\, p(\mathbf{z})), available analytically for Normal-vs-Normal pairs.
  • Maximum likelihood: minimising negative log-likelihood is equivalent to minimising KL(pdatapθ)\mathrm{KL}(p_{\text{data}} \,\|\, p_{\theta}).
  • Mode-seeking vs mode-covering: KL(pq)\mathrm{KL}(p\,\|\,q) is mode-covering in qq; KL(qp)\mathrm{KL}(q\,\|\,p) is mode-seeking — important for variational approximations.

Examples

>>> import lucid
>>> from lucid.distributions import Normal
>>> from lucid.distributions.kl import kl_divergence
>>> p = Normal(loc=0.0, scale=1.0)
>>> q = Normal(loc=1.0, scale=2.0)
>>> kl_divergence(p, q)
Tensor(...)