fn

linearize

tuple
linearize(func: Callable[..., Tensor | tuple[Tensor, ...]], primals: Tensor = ())
source

Linearise func around primals for reuse across many tangents.

Returns the primal output together with a callable linear_fn that applies the first-order Taylor expansion of func at primals. Mathematically linear_fn(t) evaluates Jf(primals)tJ_f(\mathrm{primals}) \, t — the same quantity as jvp — but the linearisation cost is paid once even if many tangent vectors are queried subsequently.

Parameters

funcCallable
Differentiable function.
*primalsTensor= ()
Points at which to linearise func.

Returns

tuple

(primals_out, linear_fn). Calling linear_fn(*tangents) returns the JVP at primals against tangents.

Notes

Conceptually equivalent to a Taylor expansion truncated at first order:

f(x+t)f(x)+Jf(x)t.f(x + t) \approx f(x) + J_f(x) \, t.

The returned linear_fn is flagged for vmap isolation so that composing vmap(linear_fn) slices tangents one at a time.

Examples

>>> import lucid
>>> from lucid.func import linearize
>>> f = lambda x: x ** 2
>>> x = lucid.tensor([1.0, 2.0, 3.0])
>>> y, lin = linearize(f, x)
>>> lin(lucid.ones_like(x))  # 2 * x