Dirichlet
ExponentialFamilyDirichlet(concentration: Tensor, validate_args: bool | None = None)Dirichlet distribution on the -simplex.
Multivariate generalisation of the lucid.distributions.Beta
distribution: a distribution over probability vectors
with and
. It is the conjugate prior of the
lucid.distributions.Categorical /
lucid.distributions.Multinomial likelihoods and a foundational
building block in topic models (LDA), Bayesian mixture models, and
population genetics.
The last dimension of concentration is the simplex/event
dimension; all preceding dimensions form the batch shape.
Parameters
concentrationTensor(..., K).validate_argsbool= NoneTrue, validate parameter constraints at construction time.Notes
Probability density on the -simplex ():
Moments (with , ):
Special cases:
- → (after dropping one redundant coordinate).
- → uniform over the simplex.
- with fixed → mass concentrates at .
Conjugacy: observing categorical counts updates .
Sampling uses the normalised-Gamma method: draw independent and set . Samples are detached.
Examples
>>> import lucid
>>> from lucid.distributions import Dirichlet
>>> d = Dirichlet(lucid.tensor([1.0, 2.0, 3.0]))
>>> d.mean # α / Σ α
Tensor([0.1667, 0.3333, 0.5000])
>>> d.sample((4,))
Tensor([...])Methods (6)
__init__
→None__init__(concentration: Tensor, validate_args: bool | None = None)Construct a Dirichlet distribution.
Parameters
concentrationTensorvalidate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Notes
The Dirichlet distribution with concentration has PDF over the -simplex:
where .
Sampling uses the normalised-Gamma trick: draw independent then return .
Examples
>>> import lucid
>>> from lucid.distributions import Dirichlet
>>> d = Dirichlet(lucid.tensor([1.0, 2.0, 3.0]))
>>> d.mean # proportional to concentration
Tensor([0.1667, 0.3333, 0.5000])mean
→Tensormean: TensorExpected value of the Dirichlet distribution.
Each component of the mean equals the normalised concentration:
Returns
TensorMean vector on the simplex, shape batch_shape + event_shape.
Examples
>>> Dirichlet(lucid.tensor([2.0, 2.0])).mean
Tensor([0.5, 0.5])variance
→Tensorvariance: TensorVariance of the Dirichlet distribution (component-wise).
Returns
TensorVariance vector, shape batch_shape + event_shape.
sample
→Tensorsample(sample_shape: tuple[int, ...] = ())Draw samples from the Dirichlet distribution.
Uses the normalised-Gamma method:
The result lies on the probability simplex and is detached.
Parameters
sample_shapetuple[int, ...]= ()().Returns
TensorSimplex-valued samples of shape
sample_shape + batch_shape + event_shape.
Examples
>>> d = Dirichlet(lucid.tensor([1.0, 1.0, 1.0]))
>>> x = d.sample((100,))
>>> x.sum(dim=-1) # all oneslog_prob
→Tensorlog_prob(value: Tensor)Log-density of value under the Dirichlet distribution.
Parameters
valueTensorReturns
TensorLog-densities, shape batch_shape.
entropy
→Tensorentropy()Shannon entropy of the Dirichlet distribution (in nats).
where , is the number of categories, and is the digamma function.
Returns
TensorEntropy in nats, shape batch_shape.