class

OneHotCategorical

extendsDistribution
OneHotCategorical(probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)
source

Categorical distribution with one-hot encoded samples.

OneHotCategorical wraps a Categorical and returns samples as one-hot vectors of shape (..., K) instead of integer indices. It is particularly useful for:

  • REINFORCE-style gradient estimators where you need a discrete sample but want to use it in differentiable downstream computation.
  • Relaxations — replacing with RelaxedOneHotCategorical gives a differentiable approximation that converges to one-hot as temperature 0\to 0.

Parameters

probsTensor | None= None
Non-negative probability vector (..., K). Normalised internally. Mutually exclusive with logits.
logitsTensor | None= None
Unnormalised log-probabilities (..., K). Mutually exclusive with probs.
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Attributes

probsTensor
Normalised probability vector (present when constructed with probs).
logitsTensor
Unnormalised log-probability vector (present when constructed with logits).

Notes

Samples are integer-valued one-hot vectors in {0,1}K\{0, 1\}^K with exactly one 1 at the sampled category index.

Log-probability for a one-hot vector eke_k:

logP(X=ek)=jekjlogpj=logpk\log P(X = e_k) = \sum_{j} e_{kj} \log p_j = \log p_k

which is simply the log-probability of the selected category.

Entropy equals that of the underlying Categorical:

H[X]=kpklogpkH[X] = -\sum_{k} p_k \log p_k

The event_shape is (K,) whereas Categorical has event_shape = ().

Examples

>>> import lucid
>>> from lucid.distributions import OneHotCategorical
>>> dist = OneHotCategorical(probs=lucid.tensor([0.1, 0.5, 0.4]))
>>> sample = dist.sample()
>>> sample.shape  # (3,) — one-hot
(3,)
>>> sample.sum()  # always 1

Methods (5)

dunder

__init__

None
__init__(probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)
source

Initialise a OneHotCategorical distribution.

Parameters

probsTensor | None= None
Non-negative probability vector of shape (..., K). Normalised internally. Mutually exclusive with logits.
logitsTensor | None= None
Unnormalised log-probabilities of shape (..., K). Mutually exclusive with probs.
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Raises

ValueError
If both or neither of probs and logits are provided.
prop

support

Constraint
support: Constraint
source

Support of the distribution: the probability simplex.

Returns

Constraint

The simplex constraint, as each sample is a one-hot vector whose entries are non-negative and sum to 1.

fn

sample

Tensor
sample(sample_shape: tuple[int, ...] = ())
source

Draw one-hot encoded samples.

Internally samples category indices from the underlying Categorical distribution and converts them to one-hot vectors via one_hot.

Parameters

sample_shapetuple[int, ...]= ()
Leading shape of the output sample batch.

Returns

Tensor

Float tensor of shape (*sample_shape, *batch_shape, K) containing one-hot vectors (exactly one 1 per row).

fn

log_prob

Tensor
log_prob(value: Tensor)
source

Log-probability of a one-hot encoded sample.

For a one-hot vector eke_k (all zeros except a 1 at position kk):

logP(X=ek)=jekjlogpj=logpk\log P(X = e_k) = \sum_j e_{kj} \log p_j = \log p_k

Parameters

valueTensor
One-hot tensor of shape (..., K) with float dtype.

Returns

Tensor

Log-probabilities of shape batch_shape.

fn

entropy

Tensor
entropy()
source

Shannon entropy of the OneHotCategorical distribution.

Equal to the entropy of the underlying Categorical:

H[X]=k=0K1pklogpkH[X] = -\sum_{k=0}^{K-1} p_k \log p_k

Returns

Tensor

Entropy values of shape batch_shape (nats).