fn
linearize
→tuplelinearize(func: Callable[..., Tensor | tuple[Tensor, ...]], primals: Tensor = ())Linearise func around primals for reuse across many tangents.
Returns the primal output together with a callable linear_fn that
applies the first-order Taylor expansion of func at primals.
Mathematically linear_fn(t) evaluates — the same quantity as jvp — but the linearisation cost
is paid once even if many tangent vectors are queried subsequently.
Parameters
funcCallableDifferentiable function.
*primalsTensor= ()Points at which to linearise
func.Returns
tuple(primals_out, linear_fn). Calling linear_fn(*tangents)
returns the JVP at primals against tangents.
Notes
Conceptually equivalent to a Taylor expansion truncated at first order:
The returned linear_fn is flagged for vmap isolation so that
composing vmap(linear_fn) slices tangents one at a time.
Examples
>>> import lucid
>>> from lucid.func import linearize
>>> f = lambda x: x ** 2
>>> x = lucid.tensor([1.0, 2.0, 3.0])
>>> y, lin = linearize(f, x)
>>> lin(lucid.ones_like(x)) # 2 * x