class

TransformedDistribution

extendsDistribution
TransformedDistribution(base_distribution: Distribution, transforms: Transform | list[Transform], validate_args: bool | None = None)
source

Pushforward of a base distribution through a (composite) bijector.

Constructs a new Distribution whose samples are obtained by pushing samples from base_distribution through the supplied chain of Transform instances, with log_prob accounting for the Jacobian correction via the change-of-variable formula. This is the canonical way to build normalising flows in Lucid: stack any number of bijections on top of a tractable base (typically a Normal) to obtain expressive densities while retaining exact log-likelihood evaluation and reparameterised sampling.

Parameters

base_distributionDistribution
Underlying distribution whose samples will be pushed through transforms.
transformsTransform or list[Transform]
A single transform or an ordered list applied left-to-right. Internally wrapped in a list.
validate_argsbool= None
Forwarded to Distribution.

Notes

Sampling (with T=TnT1T = T_n \circ \cdots \circ T_1 the composite bijector):

Y=T(X),XpbaseY = T(X), \quad X \sim p_{\text{base}}

Reparameterised sampling is available iff the base distribution supports it (has_rsample is forwarded).

Density (change of variables):

logpY(y)=logpX(T1(y))i=1nlog ⁣detTi(xi1)xi1\log p_Y(\mathbf{y}) = \log p_X(T^{-1}(\mathbf{y})) - \sum_{i=1}^{n} \log\!\left|\det \frac{\partial T_i(\mathbf{x}_{i-1})}{\partial \mathbf{x}_{i-1}}\right|

where x0=T1(y)\mathbf{x}_0 = T^{-1}(\mathbf{y}) and xi=Ti(xi1)\mathbf{x}_i = T_i(\mathbf{x}_{i-1}). The implementation walks the transforms in reverse, inverting one step at a time and accumulating the Jacobian correction.

Examples

>>> import lucid
>>> from lucid.distributions import Normal
>>> from lucid.distributions.transforms import (
...     ExpTransform, TransformedDistribution)
>>> # LogNormal = ExpTransform(Normal(0, 1))
>>> log_normal = TransformedDistribution(Normal(loc=0.0, scale=1.0), [ExpTransform()])
>>> log_normal.rsample((4,))
Tensor([...])
>>> log_normal.log_prob(lucid.tensor(1.0))
Tensor(...)

Methods (5)

dunder

__init__

None
__init__(base_distribution: Distribution, transforms: Transform | list[Transform], validate_args: bool | None = None)
source

Construct a transformed distribution.

Parameters

base_distributionDistribution
Underlying distribution whose samples will be pushed through transforms.
transformsTransform | list[Transform]
A single transform or an ordered list applied left-to-right.
validate_argsbool | None= None
Forwarded to Distribution.
prop

has_rsample

bool
has_rsample: bool
source

Whether reparameterised sampling is available — inherited from the base distribution.

fn

rsample

Tensor
rsample(sample_shape: tuple[int, ...] = ())
source

Push a reparameterised base sample through the transform chain.

fn

sample

Tensor
sample(sample_shape: tuple[int, ...] = ())
source

Push a (non-reparameterised) base sample through the transform chain.

fn

log_prob

Tensor
log_prob(value: Tensor)
source

Evaluate the log-density of value under the transformed distribution.

Uses the change-of-variable formula:

logp(y)=logpbase(T1(y))ilog ⁣detTi/xi\log p(y) = \log p_{\text{base}}(T^{-1}(y)) - \sum_i \log\!\bigl|\det \partial T_i / \partial x_i\bigr|

with the chain unwound by walking the transforms in reverse.