RelaxedOneHotCategorical
DistributionRelaxedOneHotCategorical(temperature: Tensor | float, probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)Concrete distribution — a continuous relaxation of OneHotCategorical.
RelaxedOneHotCategorical(temperature=τ, probs=p) defines a
distribution over the open -simplex whose samples are
differentiable surrogates for one-hot categorical samples. As
the distribution concentrates on the
vertices of the simplex (recovering one-hot samples); as
samples approach the uniform distribution over
the simplex.
This is the Gumbel-softmax distribution of Jang et al. (2017) and the
Concrete distribution of Maddison et al. (2017). The
straight-through estimator (hard=True in the sampler) rounds to
one-hot in the forward pass while using the soft sample for gradients.
Parameters
temperatureTensor | floatprobsTensor | None= None(..., K). Normalised internally.
Mutually exclusive with logits.logitsTensor | None= None(..., K).
Mutually exclusive with probs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Attributes
temperatureTensorprobsTensorprobs).logitsTensorlogits).Notes
Reparameterised sampling (Gumbel-softmax trick):
The result (open simplex) is differentiable w.r.t. and .
Log-PDF (Maddison et al. 2017, Eq. 1):
Examples
>>> import lucid
>>> from lucid.distributions import RelaxedOneHotCategorical
>>> dist = RelaxedOneHotCategorical(
... temperature=0.5,
... probs=lucid.tensor([0.1, 0.4, 0.5]),
... )
>>> samples = dist.rsample((50,))
>>> samples.shape # (50, 3) — lies on the open simplex
(50, 3)
>>> samples.sum(dim=-1) # each row sums to ~1Methods (3)
__init__
→None__init__(temperature: Tensor | float, probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)Construct a RelaxedOneHotCategorical (Concrete) distribution.
Parameters
temperatureTensor | floatprobsTensor | None= None(..., K) on the K-simplex.
Rows are automatically normalised to sum to 1. Mutually exclusive
with logits.logitsTensor | None= None(..., K). Mutually
exclusive with probs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Raises
ValueErrorprobs and logits are provided.rsample
→Tensorrsample(sample_shape: tuple[int, ...] = ())Draw a reparameterised sample via the Gumbel-softmax trick.
Delegates to lucid.nn.functional.gumbel_softmax with
hard=False, ensuring the output lies strictly inside the open
-simplex and gradients propagate through both logits and
temperature.
Parameters
sample_shapetuple[int, ...]= ()Returns
TensorSamples on the open simplex of shape
(*sample_shape, *batch_shape, K).
log_prob
→Tensorlog_prob(value: Tensor)Log-probability density of the Concrete distribution over the simplex.
Parameters
valueTensor(..., K).Returns
TensorLog-density values, shape (...,) (batch dimensions only).