RelaxedBernoulli
DistributionRelaxedBernoulli(temperature: Tensor | float, probs: Tensor | float | None = None, logits: Tensor | float | None = None, validate_args: bool | None = None)Concrete (Gumbel-sigmoid) relaxation of the Bernoulli distribution.
RelaxedBernoulli(temperature=τ, probs=p) defines a continuous
distribution over whose samples are differentiable
surrogates for Bernoulli samples. As the distribution
concentrates on , recovering the discrete Bernoulli.
As the distribution approaches
.
The Concrete distribution (Maddison et al. 2017) / Gumbel-softmax (Jang et al. 2017) trick enables gradient-based optimisation through discrete latent variables in variational autoencoders and related models.
Parameters
temperatureTensor | floatprobsTensor | float | None= Nonelogits.logitsTensor | float | None= Noneprobs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Attributes
temperatureTensorprobsTensorprobs).logitsTensorlogits).Notes
Reparameterised sampling (Gumbel-sigmoid trick):
where is the sigmoid function. Gradients propagate through both and .
Log-PDF (Maddison et al. 2017, Eq. 2):
where .
Examples
>>> import lucid
>>> from lucid.distributions import RelaxedBernoulli
>>> dist = RelaxedBernoulli(temperature=0.5, probs=0.7)
>>> samples = dist.rsample((100,))
>>> # Samples are in (0, 1)
>>> ((samples > 0) & (samples < 1)).all()Methods (3)
__init__
→None__init__(temperature: Tensor | float, probs: Tensor | float | None = None, logits: Tensor | float | None = None, validate_args: bool | None = None)Construct a RelaxedBernoulli (Binary Concrete) distribution.
Parameters
temperatureTensor | floatprobsTensor | float | None= Nonelogits.logitsTensor | float | None= Noneprobs.validate_argsbool | None= NoneTrue, validate parameter constraints at construction time.Raises
ValueErrorprobs and logits are provided.rsample
→Tensorrsample(sample_shape: tuple[int, ...] = ())Draw a reparameterised sample via the Gumbel-sigmoid trick.
Parameters
sample_shapetuple[int, ...]= ()Returns
TensorSamples of shape
(*sample_shape, *batch_shape), with gradients flowing through
both logits and temperature.
log_prob
→Tensorlog_prob(value: Tensor)Log-probability density of the Concrete/RelaxedBernoulli distribution.
Parameters
valueTensorReturns
TensorLog-density values of the same shape as value.