class

RelaxedBernoulli

extendsDistribution
RelaxedBernoulli(temperature: Tensor | float, probs: Tensor | float | None = None, logits: Tensor | float | None = None, validate_args: bool | None = None)
source

Concrete (Gumbel-sigmoid) relaxation of the Bernoulli distribution.

RelaxedBernoulli(temperature=τ, probs=p) defines a continuous distribution over (0,1)(0, 1) whose samples are differentiable surrogates for Bernoulli samples. As τ0\tau \to 0 the distribution concentrates on {0,1}\{0, 1\}, recovering the discrete Bernoulli. As τ\tau \to \infty the distribution approaches Uniform(0,1)\operatorname{Uniform}(0, 1).

The Concrete distribution (Maddison et al. 2017) / Gumbel-softmax (Jang et al. 2017) trick enables gradient-based optimisation through discrete latent variables in variational autoencoders and related models.

Parameters

temperatureTensor | float
Temperature parameter τ>0\tau > 0. Controls the sharpness of the relaxation. Small values give near-discrete samples; large values give near-uniform samples.
probsTensor | float | None= None
Bernoulli success probability p(0,1)p \in (0, 1). Mutually exclusive with logits.
logitsTensor | float | None= None
Log-odds l=log(p/(1p))Rl = \log(p/(1-p)) \in \mathbb{R}. Mutually exclusive with probs.
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Attributes

temperatureTensor
Temperature τ\tau.
probsTensor
Success probability (present when constructed with probs).
logitsTensor
Log-odds (present when constructed with logits).

Notes

Reparameterised sampling (Gumbel-sigmoid trick):

y=σ ⁣(l+g1g2τ),g1,g2iidGumbel(0,1)y = \sigma\!\left(\frac{l + g_1 - g_2}{\tau}\right), \quad g_1, g_2 \overset{\text{iid}}{\sim} \operatorname{Gumbel}(0, 1)

where σ()\sigma(\cdot) is the sigmoid function. Gradients propagate through both ll and τ\tau.

Log-PDF (Maddison et al. 2017, Eq. 2):

logp(y;l,τ)=logτ+l(τ+1)log ⁣(eτlogit(y)l+1)τlogit(y)\log p(y; l, \tau) = \log\tau + l - (\tau+1)\log\!\left(e^{\tau\,\text{logit}(y) - l} + 1\right) - \tau\,\text{logit}(y)

where logit(y)=log(y/(1y))\text{logit}(y) = \log(y/(1-y)).

Examples

>>> import lucid
>>> from lucid.distributions import RelaxedBernoulli
>>> dist = RelaxedBernoulli(temperature=0.5, probs=0.7)
>>> samples = dist.rsample((100,))
>>> # Samples are in (0, 1)
>>> ((samples > 0) & (samples < 1)).all()

Methods (3)

dunder

__init__

None
__init__(temperature: Tensor | float, probs: Tensor | float | None = None, logits: Tensor | float | None = None, validate_args: bool | None = None)
source

Construct a RelaxedBernoulli (Binary Concrete) distribution.

Parameters

temperatureTensor | float
Temperature λ>0\lambda > 0 controlling the relaxation tightness. As λ0\lambda \to 0, samples concentrate on {0,1}\{0, 1\} recovering a hard Bernoulli; as λ\lambda \to \infty, samples concentrate around 1/21/2.
probsTensor | float | None= None
Probability p(0,1)p \in (0, 1) parameter. Mutually exclusive with logits.
logitsTensor | float | None= None
Log-odds logp1p\log\frac{p}{1-p} parameter. Mutually exclusive with probs.
validate_argsbool | None= None
If True, validate parameter constraints at construction time.

Raises

ValueError
If both or neither of probs and logits are provided.
fn

rsample

Tensor
rsample(sample_shape: tuple[int, ...] = ())
source

Draw a reparameterised sample via the Gumbel-sigmoid trick.

Parameters

sample_shapetuple[int, ...]= ()
Leading shape of the output sample batch.

Returns

Tensor

Samples y(0,1)y \in (0, 1) of shape (*sample_shape, *batch_shape), with gradients flowing through both logits and temperature.

fn

log_prob

Tensor
log_prob(value: Tensor)
source

Log-probability density of the Concrete/RelaxedBernoulli distribution.

Parameters

valueTensor
Point(s) y(0,1)y \in (0, 1) at which to evaluate the density.

Returns

Tensor

Log-density values of the same shape as value.