OneHotCategorical
DistributionOneHotCategorical(probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)Categorical distribution with one-hot encoded samples.
OneHotCategorical wraps a Categorical and returns samples
as one-hot vectors of shape (..., K) instead of integer indices.
It is particularly useful for:
- REINFORCE-style gradient estimators where you need a discrete sample but want to use it in differentiable downstream computation.
- Relaxations — replacing with
RelaxedOneHotCategoricalgives a differentiable approximation that converges to one-hot as temperature .
Parameters
probsTensor | None= None(..., K). Normalised internally.
Mutually exclusive with logits.logitsTensor | None= None(..., K). Mutually exclusive with
probs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Attributes
probsTensorprobs).logitsTensorlogits).Notes
Samples are integer-valued one-hot vectors in with exactly one 1 at the sampled category index.
Log-probability for a one-hot vector :
which is simply the log-probability of the selected category.
Entropy equals that of the underlying Categorical:
The event_shape is (K,) whereas Categorical has
event_shape = ().
Examples
>>> import lucid
>>> from lucid.distributions import OneHotCategorical
>>> dist = OneHotCategorical(probs=lucid.tensor([0.1, 0.5, 0.4]))
>>> sample = dist.sample()
>>> sample.shape # (3,) — one-hot
(3,)
>>> sample.sum() # always 1Methods (5)
__init__
→None__init__(probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)Initialise a OneHotCategorical distribution.
Parameters
probsTensor | None= None(..., K). Normalised
internally. Mutually exclusive with logits.logitsTensor | None= None(..., K). Mutually
exclusive with probs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Raises
ValueErrorprobs and logits are provided.support
→Constraintsupport: ConstraintSupport of the distribution: the probability simplex.
Returns
ConstraintThe simplex constraint, as each sample is a one-hot vector
whose entries are non-negative and sum to 1.
sample
→Tensorsample(sample_shape: tuple[int, ...] = ())Draw one-hot encoded samples.
Internally samples category indices from the underlying
Categorical distribution and converts them to one-hot
vectors via one_hot.
Parameters
sample_shapetuple[int, ...]= ()Returns
TensorFloat tensor of shape (*sample_shape, *batch_shape, K)
containing one-hot vectors (exactly one 1 per row).
log_prob
→Tensorlog_prob(value: Tensor)Log-probability of a one-hot encoded sample.
For a one-hot vector (all zeros except a 1 at position ):
Parameters
valueTensor(..., K) with float dtype.Returns
TensorLog-probabilities of shape batch_shape.
entropy
→Tensorentropy()Shannon entropy of the OneHotCategorical distribution.
Equal to the entropy of the underlying Categorical:
Returns
TensorEntropy values of shape batch_shape (nats).