fn

where

Tensor
where(cond: Tensor, x: Tensor, y: Tensor)
source

Elementwise selection between x and y based on cond.

For each element position the result is x where cond is True and y otherwise. All three inputs broadcast against each other.

Parameters

condTensor
Boolean tensor.
xTensor
Values used where cond is True.
yTensor
Values used where cond is False.

Returns

Tensor

Tensor of the broadcast shape.

Notes

Mathematically,

outi  =  {xiif condiyiotherwise.\text{out}_i \;=\; \begin{cases} x_i & \text{if } \text{cond}_i \\ y_i & \text{otherwise} . \end{cases}

Differentiable through both x and y — the gradient is gated by cond.

Examples

>>> import lucid
>>> a = lucid.tensor([1.0, 2.0, 3.0])
>>> b = lucid.tensor([-1.0, -2.0, -3.0])
>>> lucid.where(a > 1.5, a, b)
Tensor([-1.,  2.,  3.])