class

Distribution

Distribution(batch_shape: tuple[int, ...] = (), event_shape: tuple[int, ...] = (), validate_args: bool | None = None)
source

Abstract base for a probability distribution.

A distribution encodes a probability measure over a measurable space and exposes a standard interface for sampling, evaluating log-probabilities, and computing closed-form moments. Every concrete distribution in lucid.distributions inherits from this class.

Subclasses set:

  • arg_constraints — dict of param-name → Constraint, used by validate_args and to spell out the parameter domain.
  • supportConstraint for the random variable.
  • has_rsampleTrue for reparameterisable families.
  • batch_shape, event_shape.

Either rsample or sample (or both) must be overridden.

Parameters

batch_shapetuple[int, ...]= ()
Shape of independent, non-identical draws. For a scalar parameter this is (), for a vector parameter of length n it is (n,). Default is ().
event_shapetuple[int, ...]= ()
Shape of a single event (observation). Univariate distributions have event_shape = (). Multivariate distributions such as lucid.distributions.Dirichlet have a non-empty event_shape. Default is ().
validate_argsbool or None= None
If True, parameter constraints and sample support are validated at construction time and in log_prob. Useful during development; disable in production for speed. None inherits the class-level _validate_args flag. Default is None.

Attributes

arg_constraintsdict[str, Constraint]
Maps each constructor parameter name to the lucid.distributions.constraints.Constraint it must satisfy. Populated by each concrete subclass.
supportConstraint or None
Constraint describing the set on which the distribution has positive density (or probability mass). None means unconstrained.
has_rsamplebool
True when the distribution implements rsample — i.e., when the reparameterisation trick (Kingma & Welling, 2013) is available and gradients flow through sampled values.
has_enumerate_supportbool
True for finite discrete distributions that can enumerate every possible outcome, enabling exact marginalisation.

Notes

Shape semantics

Every tensor returned by sample or rsample has shape

sample_shape+batch_shape+event_shape\text{sample\_shape} + \text{batch\_shape} + \text{event\_shape}

where + denotes tuple concatenation. log_prob returns a tensor of shape sample_shape + batch_shape, having reduced over event_shape.

Reparameterisation

When has_rsample = True the sampler can be written as a deterministic transformation of a fixed-distribution noise variable ε\varepsilon:

X=gθ(ε),εp(ε)X = g_{\theta}(\varepsilon), \quad \varepsilon \sim p(\varepsilon)

This allows gradients θE[f(X)]\nabla_{\theta} \mathbb{E}[f(X)] to be estimated with low variance via the pathwise derivative, which is the backbone of the VAE objective (ELBO) and stochastic computation graphs in general.

Examples

>>> import lucid.distributions as dist
>>> d = dist.Normal(loc=0.0, scale=1.0)
>>> d.batch_shape
()
>>> d.event_shape
()
>>> x = d.rsample((100,))  # shape (100,)
>>> x.shape
(100,)

Methods (15)

dunder

__init__

None
__init__(batch_shape: tuple[int, ...] = (), event_shape: tuple[int, ...] = (), validate_args: bool | None = None)
source

Initialise batch/event shapes and optionally validate parameters.

Parameters

batch_shapetuple[int, ...]= ()
Shape of the batch of independent distributions.
event_shapetuple[int, ...]= ()
Shape of each individual event sample.
validate_argsbool or None= None
When True, _validate_params is called immediately so that out-of-constraint constructor arguments raise ValueError at construction time rather than silently producing NaN values later.
prop

batch_shape

tuple[int, ...]
batch_shape: tuple[int, ...]
source

Shape of the batch of independent (but not identically parameterised) distributions.

Returns

tuple[int, ...]

A tuple of integers. () for a single scalar distribution.

prop

event_shape

tuple[int, ...]
event_shape: tuple[int, ...]
source

Shape of a single observation drawn from the distribution.

Returns

tuple[int, ...]

() for univariate distributions. Non-empty for multivariate distributions such as lucid.distributions.Dirichlet.

prop

mean

Tensor
mean: Tensor
source

Expected value of the distribution.

Returns

Tensor

A tensor of shape batch_shape. Raises NotImplementedError if the distribution has no closed-form mean (e.g. lucid.distributions.Cauchy).

prop

mode

Tensor
mode: Tensor
source

Most likely value of the distribution (the argmax of the density).

Returns

Tensor

A tensor of shape batch_shape. Raises NotImplementedError if not implemented by the concrete subclass.

prop

variance

Tensor
variance: Tensor
source

Variance of the distribution.

Returns

Tensor

A tensor of shape batch_shape. The variance is the second central moment, Var[X]=E[(Xμ)2]\text{Var}[X] = \mathbb{E}[(X - \mu)^2]. Raises NotImplementedError if not provided by the concrete subclass.

prop

stddev

Tensor
stddev: Tensor
source

Standard deviation of the distribution.

Computed as σ=Var[X]\sigma = \sqrt{\text{Var}[X]}. Concrete subclasses may override this for numerical efficiency; the default implementation delegates to variance.

Returns

Tensor

A tensor of shape batch_shape.

fn

entropy

Tensor
entropy()
source

Shannon differential (or discrete) entropy.

Defined as

H[X]=Exp[logp(x)]H[X] = -\mathbb{E}_{x \sim p}[\log p(x)]

for a continuous distribution and similarly for discrete ones. Measured in nats (natural logarithm base).

Returns

Tensor

A tensor of shape batch_shape. Raises NotImplementedError if not implemented by the concrete subclass.

Notes

Entropy quantifies the average amount of "surprise" or uncertainty in a single draw. Higher entropy means the distribution is more spread out / less predictable.

fn

sample

Tensor
sample(sample_shape: tuple[int, ...] = ())
source

Draw independent, identically distributed samples.

The default implementation calls rsample and detaches the result from the autograd graph, so gradients do not flow through the returned tensor. Discrete distributions that cannot be reparameterised must override this method directly instead.

Parameters

sample_shapetuple[int, ...]= ()
Number of independent samples to draw. The returned tensor has shape sample_shape + batch_shape + event_shape. Default is (), which returns a single sample with shape batch_shape + event_shape.

Returns

Tensor

A detached tensor (no gradient) of shape sample_shape + batch_shape + event_shape.

fn

rsample

Tensor
rsample(sample_shape: tuple[int, ...] = ())
source

Reparameterised sample — gradients flow through the noise.

Unlike sample, rsample expresses the stochastic node as a deterministic transformation of a parameter-free noise variable:

X=gθ(ε),εp0(ε)X = g_{\theta}(\varepsilon), \quad \varepsilon \sim p_0(\varepsilon)

This factorisation allows the gradient θEX[f(X)]\nabla_{\theta} \mathbb{E}_{X}[f(X)] to be estimated cheaply via the pathwise (re-parameterisation) estimator, which has much lower variance than the REINFORCE estimator.

Concrete distributions must override either this or sample. Only distributions that admit a differentiable sampler set has_rsample = True.

Parameters

sample_shapetuple[int, ...]= ()
Number of independent samples to draw. The returned tensor has shape sample_shape + batch_shape + event_shape.

Returns

Tensor

A tensor attached to the autograd graph through the distribution parameters.

Raises

NotImplementedError
If the distribution does not support reparameterised sampling (has_rsample = False). Use sample instead in that case.
fn

log_prob

Tensor
log_prob(value: Tensor)
source

Log-probability (log-density) evaluated at value.

For a continuous distribution with density p(x)p(x) this returns logp(x)\log p(x). For a discrete distribution with probability mass function P(X=k)P(X = k) this returns logP(X=k)\log P(X = k).

Working in log-space is numerically preferable to evaluating the density directly: products of probabilities become sums of log-probabilities, avoiding underflow for long sequences.

Parameters

valueTensor
Point(s) at which to evaluate the log-density. Must be broadcastable with batch_shape + event_shape.

Returns

Tensor

Log-probability tensor of shape broadcast(value.shape, batch_shape + event_shape)[:-len(event_shape)]. In the scalar / univariate case this simplifies to broadcast(value.shape, batch_shape).

Raises

NotImplementedError
If not implemented by the concrete subclass.
fn

cdf

Tensor
cdf(value: Tensor)
source

Cumulative distribution function (CDF) evaluated at value.

Returns the probability that a random variable XX drawn from this distribution is less than or equal to value:

F(x)=P(Xx)=xp(t)dtF(x) = P(X \leq x) = \int_{-\infty}^{x} p(t)\, dt

Parameters

valueTensor
Point(s) at which to evaluate the CDF.

Returns

Tensor

Values in [0,1][0, 1] with the same shape as broadcast(value, batch_shape).

Raises

NotImplementedError
If not implemented by the concrete subclass.
fn

icdf

Tensor
icdf(value: Tensor)
source

Inverse CDF (quantile function / percent-point function).

Given a probability p[0,1]p \in [0, 1] returns the smallest xx such that F(x)pF(x) \geq p:

Q(p)=F1(p)=inf{xR:F(x)p}Q(p) = F^{-1}(p) = \inf\{x \in \mathbb{R} : F(x) \geq p\}

The quantile function is particularly useful for inverse-CDF (Smirnov transform) sampling: if UUniform(0,1)U \sim \text{Uniform}(0,1) then Q(U)pQ(U) \sim p.

Parameters

valueTensor
Probability values in [0,1][0, 1].

Returns

Tensor

Quantiles with the same shape as broadcast(value, batch_shape).

Raises

NotImplementedError
If not implemented by the concrete subclass.
fn

prob

Tensor
prob(value: Tensor)
source

Probability density (or mass) at value.

Computed as exp(logp(x))\exp(\log p(x)) via log_prob. For numerical stability prefer working with log_prob directly; use this method only when the raw density value is needed.

Parameters

valueTensor
Point(s) at which to evaluate the density/mass.

Returns

Tensor

Non-negative density or probability-mass values.

dunder

__repr__

str
__repr__()
source

Concise string representation showing parameter shapes.

Returns

str

E.g. "Normal(loc=(3,), scale=(3,))" or "Bernoulli(probs=0.3)" for a scalar.