fn

vmap

Callable
vmap(func: Callable[..., Tensor | tuple[Tensor, ...]], in_dims: int | tuple[int | None, ...] = 0, out_dims: int | tuple[int, ...] = 0, randomness: str = 'error', chunk_size: int | None = None, strategy: str = 'auto')
source

Vectorise func over a batch axis to produce a batched function.

Lifts a function operating on a single example into one operating on a whole batch in a single dispatch, without writing an explicit Python loop. Mathematically, if f:RdRef : \mathbb{R}^d \to \mathbb{R}^e, then vmap(f) realises F:RB×dRB×eF : \mathbb{R}^{B \times d} \to \mathbb{R}^{B \times e} such that F(x)[b] = f(x[b]) for every batch index bb. The transform composes with grad, jacrev, jvp, and vjp to yield per-sample gradients, batched Jacobians, and higher-order constructs.

Parameters

funcCallable
Function to vectorise. Must accept and return tensors (or tuples of tensors).
in_dimsint or tuple of (int or None)= 0
Batch axis for each input. An int applies to all positional arguments; a tuple gives per-argument control. Use None to broadcast an argument unchanged across the batch. Default 0.
out_dimsint or tuple of int= 0
Where the batch axis appears in the output(s). Default 0.
randomnessstr= 'error'
"error" (default) forbids random ops inside func; "different" and "same" allow them with shared RNG state.
chunk_sizeint= None
If set, process the batch in chunks of this size to cap peak memory. Applies in both vectorised and isolated strategies.
strategystr= 'auto'
"auto" (default) picks isolated mode for transforms that materialise per-output backward passes (jacrev/jacfwd/hessian) and falls back to vectorised mode otherwise. "vectorized" always moves the batch axis to the front and calls func once. "isolated" always loops per-element in Python.

Returns

Callable

A new function with the batched semantics described above.

Notes

In vectorised mode there is exactly one underlying engine dispatch: on GPU this becomes a single Metal kernel launch across all batch elements via MLX; on CPU it becomes an Accelerate BLAS / vDSP call over the fully batched tensor. Reductions inside func must specify dim — an unqualified .sum() would also collapse the batch axis, which is rarely desired. In-place ops inside the vectorised function are unsupported.

Examples

Per-sample gradients:
>>> import lucid
>>> from lucid.func import grad, vmap
>>> f = lambda x: (x ** 2).sum()
>>> X = lucid.randn(32, 4)
>>> per_sample = vmap(grad(f))(X)  # shape (32, 4)