class

Dirichlet

extendsExponentialFamily
Dirichlet(concentration: Tensor, validate_args: bool | None = None)
source

Dirichlet distribution on the KK-simplex.

Multivariate generalisation of the lucid.distributions.Beta distribution: a distribution over probability vectors xΔK1\mathbf{x} \in \Delta^{K-1} with xi0x_i \geq 0 and ixi=1\sum_i x_i = 1. It is the conjugate prior of the lucid.distributions.Categorical / lucid.distributions.Multinomial likelihoods and a foundational building block in topic models (LDA), Bayesian mixture models, and population genetics.

The last dimension of concentration is the simplex/event dimension; all preceding dimensions form the batch shape.

Parameters

concentrationTensor
Concentration vector α\boldsymbol{\alpha} with all entries αi>0\alpha_i > 0. Shape (..., K).
validate_argsbool= None
If True, validate parameter constraints at construction time.

Notes

Probability density on the KK-simplex (xi>0,  ixi=1x_i > 0,\;\sum_i x_i = 1):

p(x;α)=1B(α)i=1Kxiαi1,B(α)=iΓ(αi)Γ(α0),α0=iαip(\mathbf{x}; \boldsymbol{\alpha}) = \frac{1}{B(\boldsymbol{\alpha})} \prod_{i=1}^{K} x_i^{\alpha_i - 1}, \qquad B(\boldsymbol{\alpha}) = \frac{\prod_i \Gamma(\alpha_i)}{\Gamma(\alpha_0)}, \quad \alpha_0 = \sum_i \alpha_i

Moments (with α0=iαi\alpha_0 = \sum_i \alpha_i, μi=αi/α0\mu_i = \alpha_i/\alpha_0):

E[Xi]=μi,Var[Xi]=μi(1μi)α0+1,Cov[Xi,Xj]=μiμjα0+1    (ij)\mathbb{E}[X_i] = \mu_i, \qquad \mathrm{Var}[X_i] = \frac{\mu_i (1 - \mu_i)}{\alpha_0 + 1}, \qquad \mathrm{Cov}[X_i, X_j] = -\frac{\mu_i \mu_j}{\alpha_0 + 1} \;\;(i \neq j)

Special cases:

  • K=2K = 2Beta(α1,α2)\mathrm{Beta}(\alpha_1, \alpha_2) (after dropping one redundant coordinate).
  • α=1\boldsymbol{\alpha} = \mathbf{1} → uniform over the simplex.
  • α0\alpha_0 \to \infty with fixed μ\boldsymbol{\mu} → mass concentrates at μ\boldsymbol{\mu}.

Conjugacy: observing categorical counts n=(n1,,nK)\mathbf{n} = (n_1, \ldots, n_K) updates Dirichlet(α)Dirichlet(α+n)\mathrm{Dirichlet}(\boldsymbol{\alpha}) \to \mathrm{Dirichlet}(\boldsymbol{\alpha} + \mathbf{n}).

Sampling uses the normalised-Gamma method: draw independent GiGamma(αi,1)G_i \sim \mathrm{Gamma}(\alpha_i, 1) and set Xi=Gi/jGjX_i = G_i / \sum_j G_j. Samples are detached.

Examples

>>> import lucid
>>> from lucid.distributions import Dirichlet
>>> d = Dirichlet(lucid.tensor([1.0, 2.0, 3.0]))
>>> d.mean  # α / Σ α
Tensor([0.1667, 0.3333, 0.5000])
>>> d.sample((4,))
Tensor([...])

Methods (6)

dunder

__init__

None
__init__(concentration: Tensor, validate_args: bool | None = None)
source

Construct a Dirichlet distribution.

Parameters

concentrationTensor
Concentration vector α\boldsymbol{\alpha} with all entries >0> 0. The last dimension is the event (simplex) dimension KK; all preceding dimensions form the batch shape.
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Notes

The Dirichlet distribution with concentration α\boldsymbol{\alpha} has PDF over the KK-simplex:

p(x;α)=1B(α)i=1Kxiαi1p(\mathbf{x}; \boldsymbol{\alpha}) = \frac{1}{B(\boldsymbol{\alpha})} \prod_{i=1}^{K} x_i^{\alpha_i - 1}

where B(α)=iΓ(αi)/Γ(iαi)B(\boldsymbol{\alpha}) = \prod_i \Gamma(\alpha_i) / \Gamma(\sum_i \alpha_i).

Sampling uses the normalised-Gamma trick: draw independent GiGamma(αi,1)G_i \sim \text{Gamma}(\alpha_i, 1) then return x=G/iGi\mathbf{x} = \mathbf{G} / \sum_i G_i.

Examples

>>> import lucid
>>> from lucid.distributions import Dirichlet
>>> d = Dirichlet(lucid.tensor([1.0, 2.0, 3.0]))
>>> d.mean  # proportional to concentration
Tensor([0.1667, 0.3333, 0.5000])
prop

mean

Tensor
mean: Tensor
source

Expected value of the Dirichlet distribution.

Each component of the mean equals the normalised concentration:

E[Xi]=αijαjE[X_i] = \frac{\alpha_i}{\sum_j \alpha_j}

Returns

Tensor

Mean vector on the simplex, shape batch_shape + event_shape.

Examples

>>> Dirichlet(lucid.tensor([2.0, 2.0])).mean
Tensor([0.5, 0.5])
prop

variance

Tensor
variance: Tensor
source

Variance of the Dirichlet distribution (component-wise).

Var[Xi]=μi(1μi)α0+1,α0=jαj,  μi=αi/α0\operatorname{Var}[X_i] = \frac{\mu_i (1 - \mu_i)}{\alpha_0 + 1}, \quad \alpha_0 = \sum_j \alpha_j,\; \mu_i = \alpha_i / \alpha_0

Returns

Tensor

Variance vector, shape batch_shape + event_shape.

fn

sample

Tensor
sample(sample_shape: tuple[int, ...] = ())
source

Draw samples from the Dirichlet distribution.

Uses the normalised-Gamma method:

x=gigi,giGamma(αi,1)\mathbf{x} = \frac{\mathbf{g}}{\sum_i g_i}, \quad g_i \sim \text{Gamma}(\alpha_i, 1)

The result lies on the probability simplex and is detached.

Parameters

sample_shapetuple[int, ...]= ()
Leading shape dimensions for the sample batch. Default is ().

Returns

Tensor

Simplex-valued samples of shape sample_shape + batch_shape + event_shape.

Examples

>>> d = Dirichlet(lucid.tensor([1.0, 1.0, 1.0]))
>>> x = d.sample((100,))
>>> x.sum(dim=-1)  # all ones
fn

log_prob

Tensor
log_prob(value: Tensor)
source

Log-density of value under the Dirichlet distribution.

logp(x;α)=i(αi1)logxilogB(α)\log p(\mathbf{x}; \boldsymbol{\alpha}) = \sum_i (\alpha_i - 1) \log x_i - \log B(\boldsymbol{\alpha})

Parameters

valueTensor
Simplex-valued observations, last dimension is KK.

Returns

Tensor

Log-densities, shape batch_shape.

fn

entropy

Tensor
entropy()
source

Shannon entropy of the Dirichlet distribution (in nats).

H=logB(α)+(α0K)ψ(α0)i(αi1)ψ(αi)H = \log B(\boldsymbol{\alpha}) + (\alpha_0 - K) \psi(\alpha_0) - \sum_i (\alpha_i - 1) \psi(\alpha_i)

where α0=iαi\alpha_0 = \sum_i \alpha_i, KK is the number of categories, and ψ\psi is the digamma function.

Returns

Tensor

Entropy in nats, shape batch_shape.