class

StickBreakingTransform

extendsTransform
StickBreakingTransform()
source

Logistic stick-breaking bijection RK1ΔK1\mathbb{R}^{K-1} \to \Delta^{K-1}.

A true bijection between unconstrained (K1)(K-1)-vectors and the KK-simplex (i.e. one extra dimension is broken off as the residual stick). Unlike SoftmaxTransform it has the correct dimensionality and a tractable log-Jacobian determinant, making it the preferred choice for normalising flows over probability vectors and for unconstrained reparameterisations of lucid.distributions.Dirichlet priors. event_dim = 1.

Notes

Forward "stick breaking" (with zk=σ(xklog(Kk))z_k = \sigma(x_k - \log(K-k))):

yk=zkj<k(1zj),yK1=j=0K2(1zj)y_k = z_k \prod_{j < k}(1 - z_j), \quad y_{K-1} = \prod_{j=0}^{K-2}(1 - z_j)

The last component is the residual stick remaining after the first K1K - 1 breaks. Each yk>0y_k > 0 and kyk=1\sum_k y_k = 1 by construction.

Inverse (back-solve the stick lengths):

zk=yk1j<kyj,xk=logit(zk)+log(Kk)z_k = \frac{y_k}{1 - \sum_{j < k} y_j}, \qquad x_k = \mathrm{logit}(z_k) + \log(K - k)

Log Jacobian determinant:

logdetJ=k=0K2[logyk+log(1j<kyj)+log(1zk)]\log|\det J| = \sum_{k=0}^{K-2} \bigl[\log y_k + \log(1 - \textstyle\sum_{j < k} y_j) + \log(1 - z_k)\bigr]

where the log(Kk)-\log(K-k) shifts ensure the uniform Dirichlet corresponds to x=0\mathbf{x} = \mathbf{0}.

Examples

>>> import lucid
>>> from lucid.distributions.transforms import StickBreakingTransform
>>> T = StickBreakingTransform()
>>> y = T(lucid.tensor([0.0, 0.0]))  # maps to a Dirichlet(1,1,1) sample
>>> y.sum()
Tensor(1.0)

Methods (1)

fn

log_abs_det_jacobian

Tensor
log_abs_det_jacobian(x: Tensor, y: Tensor)
source

Standard simplex-to-RK1\mathbb{R}^{K-1} log-Jacobian.

logdetJ=k[logyk+log ⁣(1j<kyj)]\log|\det J| = \sum_k \bigl[\log y_k + \log\!\bigl(1 - \sum_{j<k} y_j\bigr)\bigr].