Independent
DistributionIndependent(base_distribution: Distribution, reinterpreted_batch_ndims: int, validate_args: bool | None = None)Re-interpret rightmost batch dimensions of a base distribution as event dims.
Independent(base_distribution, reinterpreted_batch_ndims=n) takes
the last dimensions of base_distribution.batch_shape and
moves them into event_shape. Log-probabilities are summed over the
re-interpreted axes, turning a product of independent marginals into the
joint log-probability.
This is a pure structural wrapper — it does not change the sampler or the underlying computation. Its main use is to express
- Diagonal multivariate Normals from batches of scalar Normals, or
- Independent Bernoulli likelihoods over image pixels from a per-pixel Bernoulli batch.
Parameters
base_distributionDistributionbatch_shape must have at least
reinterpreted_batch_ndims dimensions.reinterpreted_batch_ndimsintevent_shape.
Must satisfy 0 <= reinterpreted_batch_ndims <= len(batch_shape).validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Attributes
base_distDistributionreinterpreted_batch_ndimsintNotes
Given base with batch_shape = (B, D) and event_shape = (),
wrapping with reinterpreted_batch_ndims=1 yields:
batch_shape = (B,)event_shape = (D,)
The resulting log_prob(x) sums the scalar log-probabilities:
which is correct because the components are independent by construction.
Entropy is similarly summed:
(equality holds because independence implies ).
The has_rsample property mirrors base_dist.has_rsample, so
gradient flow is preserved when the base distribution supports it.
Examples
>>> import lucid
>>> from lucid.distributions import Independent
>>> from lucid.distributions import Normal
>>> # Diagonal Normal: batch of D=4 scalars → single 4-d event
>>> base = Normal(lucid.zeros(4), lucid.ones(4))
>>> dist = Independent(base, reinterpreted_batch_ndims=1)
>>> dist.batch_shape, dist.event_shape
((), (4,))
>>> x = dist.rsample()
>>> x.shape # (4,)
(4,)
>>> dist.log_prob(x) # scalar — sum of 4 log-probsMethods (10)
__init__
→None__init__(base_distribution: Distribution, reinterpreted_batch_ndims: int, validate_args: bool | None = None)Initialise an Independent distribution wrapper.
Parameters
base_distributionDistributionbatch_shape must have at
least reinterpreted_batch_ndims dimensions.reinterpreted_batch_ndimsintevent_shape.
Must satisfy
0 <= reinterpreted_batch_ndims <= len(batch_shape).validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Raises
ValueErrorreinterpreted_batch_ndims exceeds the number of batch
dimensions of base_distribution.has_rsample
→boolhas_rsample: boolWhether reparameterised sampling is supported.
Mirrors base_dist.has_rsample, so gradient flow is preserved
when the base distribution supports it.
Returns
boolTrue if and only if the base distribution supports
reparameterised sampling.
support
→objectsupport: objectSupport of the distribution — delegates to base_dist.support.
Returns
objectThe support constraint of the underlying base distribution.
mean
→Tensormean: TensorMean of the distribution — delegates to base_dist.mean.
Returns
TensorMean with the same shape as base_dist.mean.
mode
→Tensormode: TensorMode of the distribution — delegates to base_dist.mode.
Returns
TensorMode with the same shape as base_dist.mode.
variance
→Tensorvariance: TensorVariance of the distribution — delegates to base_dist.variance.
Returns
TensorVariance with the same shape as base_dist.variance.
rsample
→Tensorrsample(sample_shape: tuple[int, ...] = ())Draw reparameterised samples — delegates to base_dist.rsample.
Parameters
sample_shapetuple[int, ...]= ()Returns
TensorSamples of shape (*sample_shape, *batch_shape, *event_shape).
sample
→Tensorsample(sample_shape: tuple[int, ...] = ())Draw samples — delegates to base_dist.sample.
Parameters
sample_shapetuple[int, ...]= ()Returns
TensorSamples of shape (*sample_shape, *batch_shape, *event_shape).
log_prob
→Tensorlog_prob(value: Tensor)Log joint probability summed over re-interpreted event dimensions.
Computes the base distribution's log-probabilities and then sums
over the rightmost reinterpreted_batch_ndims axes:
Parameters
valueTensor(*batch_shape, *event_shape).Returns
TensorLog joint probability values of shape batch_shape.
entropy
→Tensorentropy()Joint entropy summed over re-interpreted event dimensions.
Because the components are independent:
Returns
TensorEntropy values of shape batch_shape (nats).