vmap
→Callablevmap(func: Callable[..., Tensor | tuple[Tensor, ...]], in_dims: int | tuple[int | None, ...] = 0, out_dims: int | tuple[int, ...] = 0, randomness: str = 'error', chunk_size: int | None = None, strategy: str = 'auto')Vectorise func over a batch axis to produce a batched function.
Lifts a function operating on a single example into one operating on a
whole batch in a single dispatch, without writing an explicit Python
loop. Mathematically, if , then
vmap(f) realises such that F(x)[b] = f(x[b]) for every
batch index . The transform composes with grad,
jacrev, jvp, and vjp to yield per-sample
gradients, batched Jacobians, and higher-order constructs.
Parameters
funcCallablein_dimsint or tuple of (int or None)= 0int applies to all positional
arguments; a tuple gives per-argument control. Use None to
broadcast an argument unchanged across the batch. Default 0.out_dimsint or tuple of int= 00.randomnessstr= 'error'"error" (default) forbids random ops inside func;
"different" and "same" allow them with shared RNG state.chunk_sizeint= Nonestrategystr= 'auto'"auto" (default) picks isolated mode for transforms that
materialise per-output backward passes (jacrev/jacfwd/hessian) and
falls back to vectorised mode otherwise. "vectorized" always
moves the batch axis to the front and calls func once.
"isolated" always loops per-element in Python.Returns
CallableA new function with the batched semantics described above.
Notes
In vectorised mode there is exactly one underlying engine dispatch:
on GPU this becomes a single Metal kernel launch across all batch
elements via MLX; on CPU it becomes an Accelerate BLAS / vDSP call
over the fully batched tensor. Reductions inside func must
specify dim — an unqualified .sum() would also collapse the
batch axis, which is rarely desired. In-place ops inside the
vectorised function are unsupported.
Examples
Per-sample gradients:
>>> import lucid
>>> from lucid.func import grad, vmap
>>> f = lambda x: (x ** 2).sum()
>>> X = lucid.randn(32, 4)
>>> per_sample = vmap(grad(f))(X) # shape (32, 4)