class

Independent

extendsDistribution
Independent(base_distribution: Distribution, reinterpreted_batch_ndims: int, validate_args: bool | None = None)
source

Re-interpret rightmost batch dimensions of a base distribution as event dims.

Independent(base_distribution, reinterpreted_batch_ndims=n) takes the last nn dimensions of base_distribution.batch_shape and moves them into event_shape. Log-probabilities are summed over the re-interpreted axes, turning a product of independent marginals into the joint log-probability.

This is a pure structural wrapper — it does not change the sampler or the underlying computation. Its main use is to express

  • Diagonal multivariate Normals from batches of scalar Normals, or
  • Independent Bernoulli likelihoods over image pixels from a per-pixel Bernoulli batch.

Parameters

base_distributionDistribution
The underlying distribution. Its batch_shape must have at least reinterpreted_batch_ndims dimensions.
reinterpreted_batch_ndimsint
Number of rightmost batch dimensions to absorb into event_shape. Must satisfy 0 <= reinterpreted_batch_ndims <= len(batch_shape).
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Attributes

base_distDistribution
The wrapped base distribution.
reinterpreted_batch_ndimsint
Number of absorbed batch dimensions.

Notes

Given base with batch_shape = (B, D) and event_shape = (), wrapping with reinterpreted_batch_ndims=1 yields:

  • batch_shape = (B,)
  • event_shape = (D,)

The resulting log_prob(x) sums the DD scalar log-probabilities:

logp(x1,,xD)=i=1Dlogpi(xi)\log p(x_1, \ldots, x_D) = \sum_{i=1}^{D} \log p_i(x_i)

which is correct because the components are independent by construction.

Entropy is similarly summed:

H[X1,,XD]=i=1DH[Xi]H[X_1, \ldots, X_D] = \sum_{i=1}^{D} H[X_i]

(equality holds because independence implies H[X1,,XD]=iH[Xi]H[X_1, \ldots, X_D] = \sum_i H[X_i]).

The has_rsample property mirrors base_dist.has_rsample, so gradient flow is preserved when the base distribution supports it.

Examples

>>> import lucid
>>> from lucid.distributions import Independent
>>> from lucid.distributions import Normal
>>> # Diagonal Normal: batch of D=4 scalars → single 4-d event
>>> base = Normal(lucid.zeros(4), lucid.ones(4))
>>> dist = Independent(base, reinterpreted_batch_ndims=1)
>>> dist.batch_shape, dist.event_shape
((), (4,))
>>> x = dist.rsample()
>>> x.shape  # (4,)
(4,)
>>> dist.log_prob(x)  # scalar — sum of 4 log-probs

Methods (10)

dunder

__init__

None
__init__(base_distribution: Distribution, reinterpreted_batch_ndims: int, validate_args: bool | None = None)
source

Initialise an Independent distribution wrapper.

Parameters

base_distributionDistribution
The underlying distribution. Its batch_shape must have at least reinterpreted_batch_ndims dimensions.
reinterpreted_batch_ndimsint
Number of rightmost batch dimensions to move into event_shape. Must satisfy 0 <= reinterpreted_batch_ndims <= len(batch_shape).
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Raises

ValueError
If reinterpreted_batch_ndims exceeds the number of batch dimensions of base_distribution.
prop

has_rsample

bool
has_rsample: bool
source

Whether reparameterised sampling is supported.

Mirrors base_dist.has_rsample, so gradient flow is preserved when the base distribution supports it.

Returns

bool

True if and only if the base distribution supports reparameterised sampling.

prop

support

object
support: object
source

Support of the distribution — delegates to base_dist.support.

Returns

object

The support constraint of the underlying base distribution.

prop

mean

Tensor
mean: Tensor
source

Mean of the distribution — delegates to base_dist.mean.

Returns

Tensor

Mean with the same shape as base_dist.mean.

prop

mode

Tensor
mode: Tensor
source

Mode of the distribution — delegates to base_dist.mode.

Returns

Tensor

Mode with the same shape as base_dist.mode.

prop

variance

Tensor
variance: Tensor
source

Variance of the distribution — delegates to base_dist.variance.

Returns

Tensor

Variance with the same shape as base_dist.variance.

fn

rsample

Tensor
rsample(sample_shape: tuple[int, ...] = ())
source

Draw reparameterised samples — delegates to base_dist.rsample.

Parameters

sample_shapetuple[int, ...]= ()
Leading shape of the output sample batch.

Returns

Tensor

Samples of shape (*sample_shape, *batch_shape, *event_shape).

fn

sample

Tensor
sample(sample_shape: tuple[int, ...] = ())
source

Draw samples — delegates to base_dist.sample.

Parameters

sample_shapetuple[int, ...]= ()
Leading shape of the output sample batch.

Returns

Tensor

Samples of shape (*sample_shape, *batch_shape, *event_shape).

fn

log_prob

Tensor
log_prob(value: Tensor)
source

Log joint probability summed over re-interpreted event dimensions.

Computes the base distribution's log-probabilities and then sums over the rightmost reinterpreted_batch_ndims axes:

logp(x1,,xD)=i=1Dlogpi(xi)\log p(x_1, \ldots, x_D) = \sum_{i=1}^{D} \log p_i(x_i)

Parameters

valueTensor
Observations of shape (*batch_shape, *event_shape).

Returns

Tensor

Log joint probability values of shape batch_shape.

fn

entropy

Tensor
entropy()
source

Joint entropy summed over re-interpreted event dimensions.

Because the components are independent:

H[X1,,XD]=i=1DH[Xi]H[X_1, \ldots, X_D] = \sum_{i=1}^{D} H[X_i]

Returns

Tensor

Entropy values of shape batch_shape (nats).