class

RelaxedOneHotCategorical

extendsDistribution
RelaxedOneHotCategorical(temperature: Tensor | float, probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)
source

Concrete distribution — a continuous relaxation of OneHotCategorical.

RelaxedOneHotCategorical(temperature=τ, probs=p) defines a distribution over the open KK-simplex whose samples are differentiable surrogates for one-hot categorical samples. As τ0\tau \to 0 the distribution concentrates on the KK vertices of the simplex (recovering one-hot samples); as τ\tau \to \infty samples approach the uniform distribution over the simplex.

This is the Gumbel-softmax distribution of Jang et al. (2017) and the Concrete distribution of Maddison et al. (2017). The straight-through estimator (hard=True in the sampler) rounds to one-hot in the forward pass while using the soft sample for gradients.

Parameters

temperatureTensor | float
Temperature parameter τ>0\tau > 0. Smaller values give sharper (more discrete) samples.
probsTensor | None= None
Probability vector(s) of shape (..., K). Normalised internally. Mutually exclusive with logits.
logitsTensor | None= None
Unnormalised log-probability vector(s) of shape (..., K). Mutually exclusive with probs.
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Attributes

temperatureTensor
Temperature τ\tau.
probsTensor
Normalised probability vector (present when constructed with probs).
logitsTensor
Unnormalised log-probability vector (present when constructed with logits).

Notes

Reparameterised sampling (Gumbel-softmax trick):

yk=exp ⁣((lk+gk)/τ)jexp ⁣((lj+gj)/τ),gkiidGumbel(0,1)y_k = \frac{\exp\!\bigl((l_k + g_k)/\tau\bigr)} {\sum_j \exp\!\bigl((l_j + g_j)/\tau\bigr)}, \quad g_k \overset{\text{iid}}{\sim} \operatorname{Gumbel}(0, 1)

The result yΔK1y \in \Delta^{K-1} (open simplex) is differentiable w.r.t. ll and τ\tau.

Log-PDF (Maddison et al. 2017, Eq. 1):

logp(y;l,τ)=logΓ(K)+(K1)logτ+k(lk(τ+1)logyk)Klogsumexpk(lkτlogyk)\log p(y; l, \tau) = \log\Gamma(K) + (K-1)\log\tau + \sum_k (l_k - (\tau+1)\log y_k) - K \operatorname{logsumexp}_k(l_k - \tau\log y_k)

Examples

>>> import lucid
>>> from lucid.distributions import RelaxedOneHotCategorical
>>> dist = RelaxedOneHotCategorical(
...     temperature=0.5,
...     probs=lucid.tensor([0.1, 0.4, 0.5]),
... )
>>> samples = dist.rsample((50,))
>>> samples.shape  # (50, 3) — lies on the open simplex
(50, 3)
>>> samples.sum(dim=-1)  # each row sums to ~1

Methods (3)

dunder

__init__

None
__init__(temperature: Tensor | float, probs: Tensor | None = None, logits: Tensor | None = None, validate_args: bool | None = None)
source

Construct a RelaxedOneHotCategorical (Concrete) distribution.

Parameters

temperatureTensor | float
Temperature λ>0\lambda > 0 controlling relaxation tightness. As λ0\lambda \to 0, samples concentrate on the vertices of the simplex recovering a hard one-hot Categorical; as λ\lambda \to \infty, samples concentrate near the simplex centroid (1/K,,1/K)(1/K, \ldots, 1/K).
probsTensor | None= None
Probability vector of shape (..., K) on the K-simplex. Rows are automatically normalised to sum to 1. Mutually exclusive with logits.
logitsTensor | None= None
Unnormalised log-probabilities of shape (..., K). Mutually exclusive with probs.
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Raises

ValueError
If both or neither of probs and logits are provided.
fn

rsample

Tensor
rsample(sample_shape: tuple[int, ...] = ())
source

Draw a reparameterised sample via the Gumbel-softmax trick.

Delegates to lucid.nn.functional.gumbel_softmax with hard=False, ensuring the output lies strictly inside the open KK-simplex and gradients propagate through both logits and temperature.

Parameters

sample_shapetuple[int, ...]= ()
Leading shape of the output sample batch.

Returns

Tensor

Samples on the open simplex of shape (*sample_shape, *batch_shape, K).

fn

log_prob

Tensor
log_prob(value: Tensor)
source

Log-probability density of the Concrete distribution over the simplex.

Parameters

valueTensor
Point(s) yy on the open KK-simplex, shape (..., K).

Returns

Tensor

Log-density values, shape (...,) (batch dimensions only).