Categorical
DistributionCategorical(probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)Categorical distribution — a discrete distribution over K labelled outcomes.
Categorical(probs=p) or Categorical(logits=l) defines a distribution
over the integer set where is the
number of categories. Exactly one of probs or logits must be given.
Parameters
probsTensor | None= None(..., K). Rows are automatically normalised to sum to 1.
Mutually exclusive with logits.logitsTensor | None= None(..., K).
The distribution uses internally to convert to
normalised probabilities. Mutually exclusive with probs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Attributes
probsTensor(..., K); present when
constructed with probs).logitsTensor(..., K); present when
constructed with logits).Notes
PMF:
Parameterisations are related by:
Entropy:
Mean is not well-defined for a general Categorical (the labels have no
canonical metric), so mean returns a NaN tensor of the batch shape.
Sampling uses the Gumbel-max trick: add i.i.d. noise to the log-probabilities and take the argmax. This is equivalent to ancestral sampling and avoids cumulative-sum + binary-search.
The batch dimensions of the input correspond to independent distributions.
For example, probs of shape (B, K) yields a batch of
Categorical distributions.
Examples
>>> import lucid
>>> from lucid.distributions import Categorical
>>> # Uniform over 4 categories
>>> dist = Categorical(probs=lucid.tensor([0.25, 0.25, 0.25, 0.25]))
>>> samples = dist.sample((10,))
>>> # Batch of 2 distributions
>>> dist_b = Categorical(logits=lucid.zeros(2, 5))
>>> dist_b.batch_shape, dist_b.event_shape
((2,), ())Methods (6)
__init__
→None__init__(probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)Initialise a Categorical distribution.
Parameters
probsTensor | None= None(..., K). Rows are
automatically normalised to sum to 1. Mutually exclusive with
logits.logitsTensor | None= None(..., K). Converted
to probabilities via softmax internally. Mutually exclusive with
probs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Raises
ValueErrorprobs and logits are provided.support
→Constraintsupport: ConstraintSupport of the distribution: integer interval .
Returns
ConstraintAn integer_interval constraint from 0 to K - 1.
mean
→Tensormean: TensorMean of the Categorical distribution (undefined — returns NaN).
The Categorical distribution assigns labels with no inherent ordering
or metric, so the mean is not well-defined. This property returns a
NaN tensor of the batch shape to match expected behaviour.
Returns
TensorTensor of float('nan') values with shape batch_shape.
sample
→Tensorsample(sample_shape: tuple[int, ...] = ())Draw samples from the Categorical distribution.
Uses the Gumbel-max trick: add i.i.d. noise to the log-probabilities and take the argmax, which is equivalent to ancestral sampling but avoids cumulative-sum and binary search.
Parameters
sample_shapetuple[int, ...]= ()Returns
TensorInteger tensor of shape (*sample_shape, *batch_shape) with
values in . The result is detached
(no gradients flow through discrete samples).
log_prob
→Tensorlog_prob(value: Tensor)Log-probability of the given category indices.
For a one-hot index , the log-probability is:
Parameters
valueTensorbatch_shape. Values must be in .Returns
TensorLog-probabilities of shape batch_shape.
entropy
→Tensorentropy()Shannon entropy of the Categorical distribution.
Returns
TensorEntropy values of shape batch_shape (nats).