lucid.where

lucid.where(condition: Tensor, a: Tensor, b: Tensor, /) Tensor

The where function performs element-wise selection based on a condition tensor, returning elements chosen from either a or b based on whether the condition is true.

Function Signature

def where(condition: Tensor, a: Tensor, b: Tensor) -> Tensor

Parameters

  • condition (Tensor): A boolean tensor indicating which elements to choose from a (if True) or b (if False).

  • a (Tensor): Values selected at positions where condition is True.

  • b (Tensor): Values selected at positions where condition is False.

Returns

  • Tensor: A new tensor composed of values from a and b selected according to condition.

Gradient Calculation

\[\begin{split}\frac{\partial \text{out}}{\partial a} = \text{where}(\text{cond}, \text{grad}, 0) \\ \frac{\partial \text{out}}{\partial b} = \text{where}(\neg \text{cond}, \text{grad}, 0) \\ \frac{\partial \text{out}}{\partial \text{cond}} = 0\end{split}\]

Note

The gradient for condition is always zero since it is not differentiable. Gradients for a and b are masked accordingly.

Example

>>> import lucid
>>> cond = lucid.tensor([[True, False], [False, True]])
>>> a = lucid.tensor([[1, 2], [3, 4]], requires_grad=True)
>>> b = lucid.tensor([[10, 20], [30, 40]], requires_grad=True)
>>> out = lucid.where(cond, a, b)
>>> print(out)
Tensor([[ 1, 20], [30,  4]], grad=None)