fn
gumbel_softmax
→Tensorgumbel_softmax(logits: Tensor, tau: float = 1.0, hard: bool = False, dim: int = -1)Gumbel-Softmax — differentiable relaxation of categorical sampling.
Draws Gumbel noise, adds it to the logits, and softmaxes by temperature . Lets gradients flow through what would otherwise be a non-differentiable categorical sample — central to differentiable discrete-action RL, VQ-VAE codebook training, and Concrete distributions (Jang et al. / Maddison et al. 2017).
Parameters
logitsTensorUnnormalised log-probabilities of any shape; the last (or
dim-th) axis indexes the categorical alternatives.taufloat= 1.0Temperature . Smaller makes the
relaxation closer to a one-hot, at the cost of larger gradient
variance. Default
1.0.hardbool= FalseIf
True the forward pass returns a true one-hot vector while
the backward pass uses the soft gradient (straight-through
estimator). Default False.dimint= -1Dimension over which to take the softmax. Default
-1.Returns
TensorSample with the same shape as logits. Rows sum to 1
along dim; in hard mode each row is one-hot.
Notes
Gumbel samples are drawn as with , clamped away from the boundaries to avoid . In hard mode the result is , which evaluates to in the forward and to in the backward.
Examples
>>> import lucid
>>> from lucid.nn.functional import gumbel_softmax
>>> logits = lucid.tensor([[1.0, 2.0, 3.0]])
>>> y = gumbel_softmax(logits, tau=0.5, hard=True)
>>> y.sum().item()
1.0