class
CatTransform
extends
TransformCatTransform(transforms: list[Transform], dim: int = 0, lengths: list[int] | None = None)Apply different transforms to contiguous partitions along an axis.
Splits the input along dim into contiguous chunks of size
lengths[i] (or equal partitions if lengths is None),
applies the -th transform to the -th chunk, and
concatenates the results back along the same axis. Differs from
StackTransform in that each partition can have a different
length, not just a single index.
Parameters
transformslist[Transform]One transform per partition.
dimint= 0Concatenation dimension. Default
0.lengthslist[int]= NoneLength of each partition. If
None the axis size must be
divisible by len(transforms) and equal partitions are used.Raises
ValueErrorIf
transforms is empty, or if the axis size is not divisible
by len(transforms) when lengths is None.Notes
Forward (with the -th partition):
Inverse: split, invert, re-concatenate.
Log Jacobian determinant: per-partition Jacobians concatenated back
along dim:
Examples
>>> import lucid
>>> from lucid.distributions.transforms import ExpTransform, AffineTransform, CatTransform
>>> T = CatTransform([ExpTransform(), AffineTransform(0.0, 2.0)],
... dim=0, lengths=[2, 3])
>>> T(lucid.tensor([0.0, 0.0, 1.0, 2.0, 3.0])).shape
(5,)Methods (2)
dunder
__init__
→None__init__(transforms: list[Transform], dim: int = 0, lengths: list[int] | None = None)Store the per-partition transforms, concat axis, and partition lengths.
Raises
ValueErrorIf
transforms is empty.fn
log_abs_det_jacobian
→Tensorlog_abs_det_jacobian(x: Tensor, y: Tensor)Per-partition Jacobians concatenated back along dim.