Multinomial
DistributionMultinomial(total_count: Tensor | int = 1, probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)Multinomial distribution over categories.
Multivariate generalisation of the lucid.distributions.Bernoulli
/Binomial: the joint distribution of the counts
obtained when
independent draws are made from a Categorical with probabilities
. The are non-negative integers
summing to .
The event dimension is the last axis of probs / logits; all
preceding axes form the batch shape.
Parameters
total_countint or Tensor= 11 (a
one-hot Categorical sample).probsTensor= None(..., K) —
normalised internally to sum to 1 along the last axis. Mutually
exclusive with logits.logitsTensor= None(..., K), converted
via softmax. Mutually exclusive with probs.validate_argsbool= NoneTrue, validate parameter constraints at construction time.Notes
Probability mass function (joint over the count vector with ):
Moments (per category):
Special cases:
- →
- → one-hot
lucid.distributions.Categorical
Conjugate prior: lucid.distributions.Dirichlet —
observing counts updates
Dirichlet(α) → Dirichlet(α + k).
Sampling is non-reparameterised (has_rsample = False) and
implemented by summing one-hot Categorical draws.
Examples
>>> import lucid
>>> from lucid.distributions import Multinomial
>>> d = Multinomial(total_count=10, probs=lucid.tensor([0.2, 0.3, 0.5]))
>>> d.mean # n * p
Tensor([2., 3., 5.])
>>> d.sample((4,))
Tensor([...])Methods (7)
__init__
→None__init__(total_count: Tensor | int = 1, probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)Initialise a Multinomial distribution.
Parameters
total_countTensor | int= 11 (reduces to a one-hot Categorical).probsTensor | None= None(..., K).
Normalised to sum to 1 internally. Mutually exclusive with
logits.logitsTensor | None= None(..., K).
Converted to probabilities via softmax. Mutually exclusive with
probs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.support
→Constraintsupport: ConstraintSupport of the distribution: non-negative integers.
Returns
Constraintnonnegative_integer — the multinomial event count vector lives
in , subject to the additional
constraint that the counts sum to total_count.
total_count
→Tensortotal_count: TensorTotal number of trials per Multinomial draw.
Returns
TensorThe integer-valued n parameter broadcast over the batch shape.
Each independent Multinomial sums to this many trials across the
K categories.
mean
→Tensormean: Tensorn · p (element-wise).
variance
→Tensorvariance: Tensorn · p · (1 − p) (element-wise).
log_prob
→Tensorlog_prob(value: Tensor)log C(n; k₁,…,kK) + Σ kᵢ log pᵢ.
sample
→Tensorsample(sample_shape: tuple[int, ...] = ())Draw samples by summing one-hot Categorical draws.