fn

gumbel_softmax

Tensor
gumbel_softmax(logits: Tensor, tau: float = 1.0, hard: bool = False, dim: int = -1)
source

Gumbel-Softmax — differentiable relaxation of categorical sampling.

Draws Gumbel noise, adds it to the logits, and softmaxes by temperature τ\tau. Lets gradients flow through what would otherwise be a non-differentiable categorical sample — central to differentiable discrete-action RL, VQ-VAE codebook training, and Concrete distributions (Jang et al. / Maddison et al. 2017).

Parameters

logitsTensor
Unnormalised log-probabilities of any shape; the last (or dim-th) axis indexes the categorical alternatives.
taufloat= 1.0
Temperature τ>0\tau > 0. Smaller τ\tau makes the relaxation closer to a one-hot, at the cost of larger gradient variance. Default 1.0.
hardbool= False
If True the forward pass returns a true one-hot vector while the backward pass uses the soft gradient (straight-through estimator). Default False.
dimint= -1
Dimension over which to take the softmax. Default -1.

Returns

Tensor

Sample with the same shape as logits. Rows sum to 1 along dim; in hard mode each row is one-hot.

Notes

giGumbel(0,1),yi=exp((logpi+gi)/τ)jexp((logpj+gj)/τ)g_i \sim \text{Gumbel}(0, 1), \qquad y_i = \frac{\exp((\log p_i + g_i)/\tau)}{\sum_j \exp((\log p_j + g_j)/\tau)}

Gumbel samples are drawn as log(logU)-\log(-\log U) with UUniform(0,1)U \sim \mathrm{Uniform}(0,1), clamped away from the boundaries to avoid log0\log 0. In hard mode the result is yhardysoft.detach()+ysofty_{\text{hard}} - y_{\text{soft}}.\text{detach}() + y_{\text{soft}}, which evaluates to yhardy_{\text{hard}} in the forward and to ysofty_{\text{soft}} in the backward.

Examples

>>> import lucid
>>> from lucid.nn.functional import gumbel_softmax
>>> logits = lucid.tensor([[1.0, 2.0, 3.0]])
>>> y = gumbel_softmax(logits, tau=0.5, hard=True)
>>> y.sum().item()
1.0