JIT Compilation

The Lucid JIT (Just-In-Time) compilation system accelerates model execution by tracing a forward pass, capturing it as an intermediate representation (IR) graph, applying optimization passes, and replaying it through a specialized executor. This approach can yield significant speedups—typically around 4x on CPU and up to 10x on MLX (Apple Silicon GPU)—while preserving full compatibility with Lucid’s autograd engine for training.

Overview

Traditional eager execution evaluates each operation immediately as it is encountered in Python. While this offers maximum flexibility, it carries per-operation Python overhead. The JIT compiler eliminates this overhead by recording the computation graph once and replaying it efficiently on subsequent calls.

The compilation pipeline consists of three stages:

  1. Tracing — A real forward pass is executed while a TracingContext records every operation, its inputs, outputs, and metadata into an IR graph.

  2. Optimization — The captured IR graph is refined by a series of passes (e.g., dead-node elimination) that remove redundant computations.

  3. Execution — A ForwardExecutor replays the optimized graph using the current parameter values, optionally constructing a backward graph for training.

Important

Compilation is lazy: calling model.compile() returns a wrapper immediately, but the actual tracing and compilation happen on the first forward call, not at compile time. Subsequent calls reuse the cached compiled plan.

How It Works

Tracing

When a compiled model or function is called for the first time, Lucid activates a thread-local TracingContext. During this tracing pass, the normal forward computation runs as usual, but every operation decorated with @func_op automatically registers itself in the tracer.

The tracer records:

  • Input tensors and parameter tensors (registered with unique value IDs)

  • Each operation as an IRNode with its operator class, initialization arguments, tensor input/output IDs, non-tensor arguments, device placement, and gradient metadata

  • Output tensors of the forward pass

Note

Because tracing executes the real forward function, no proxy tensors or symbolic shapes are involved. This means every standard Lucid operation is automatically supported without special JIT annotations.

IR Graph Construction

After the traced forward pass completes, the TracingContext finalizes its state into an IRGraph containing:

  • input_ids — ordered list of external input value IDs

  • param_ids — set of parameter value IDs (kept separate so the executor can inject fresh parameter values on each call)

  • nodes — list of IRNode objects in execution order

  • values — mapping from value ID to IRValue metadata (shape, dtype, device)

  • output_ids — value IDs that form the function’s return values

  • constant_ids — values that are neither inputs, parameters, nor produced by any node (e.g., inline constants created during the forward pass)

Optimization Passes

The compiled IR graph is refined by a configurable pass pipeline before execution.

Pass

Applied In

Description

DeadNodeElimPass

Training & Inference

Removes operations whose outputs are not consumed by any downstream node or the final output. Uses a reverse liveness analysis starting from output IDs.

NoGradStripPass

Inference only

Sets has_gradient = False on all nodes, preventing the executor from constructing backward operations. This avoids unnecessary memory and compute overhead during inference.

Hint

The pass pipeline is selected automatically based on grad_enabled and module.training at the time of the first forward call. No manual configuration is needed.

Plan Caching

Each compiled plan is stored in a PlanCache keyed by a CacheKey that encodes:

  • Input tensor shapes, dtypes, and devices

  • Gradient state (lucid.grad_enabled())

  • Training mode (module.training for modules, always False for standalone functions)

Tip

If your model always receives the same input shape (e.g., fixed batch size), the plan is compiled once and reused for every subsequent call. Varying input shapes will create additional cache entries, up to a configurable maximum (default: 8) with FIFO eviction.

Execution

The ForwardExecutor replays the cached plan on each forward call:

  1. Builds a value_map by mapping input IDs to the actual input tensors, parameter IDs to the current parameter objects (reflecting any optimizer updates), and constant IDs to their stored live tensors.

  2. Iterates through the execution order, instantiating each operator from its saved class and initialization arguments, then calling the raw forward kernel (bypassing the Python-level @func_op wrapper for maximum speed).

  3. In training mode, attaches BackwardOperation objects to each output tensor, connecting them to their input tensors via weak references. This preserves the full autograd graph for loss.backward().

Note

Because the executor reads parameter values at execution time (not at trace time), optimizer updates between forward passes are automatically reflected. There is no need to recompile after optimizer.step().

Key Features

Automatic Plan Caching

Compiled plans are cached and reused across calls. The cache key captures input tensor shapes, dtypes, devices, gradient state, and training mode, ensuring that different execution contexts (e.g., train vs. eval, different batch sizes) receive appropriately compiled plans.

Training & Inference Modes

The JIT system fully supports training workflows. In training mode, only dead-node elimination is applied, and the executor constructs the backward graph so that loss.backward() propagates gradients back to all parameters. In inference mode, an additional NoGradStripPass disables all gradient tracking for faster execution.

Module Hook Preservation

JITModule respects both forward_pre_hooks and forward_hooks registered on the original module. Pre-hooks run before the compiled execution, and post-hooks run on the executor’s output, matching the behavior of non-compiled modules.

Dead Code Elimination

Operations whose outputs are never used by the final output are automatically pruned from the execution plan. This is especially beneficial for models with auxiliary branches that are unused during inference.

Usage Examples

Compiling an nn.Module

The simplest way to compile a model is via the compile() method:

>>> import lucid
>>> import lucid.nn as nn
>>> from lucid.models import lenet_5

>>> model = lenet_5()
>>> model = model.compile()  # Returns a JITModule wrapper

>>> x = lucid.randn(1, 1, 32, 32)
>>> output = model(x)  # First call triggers tracing + compilation

Subsequent calls with the same input shape reuse the cached plan.

Compiling a Standalone Function

You can also compile a plain function using lucid.compile:

>>> import lucid

>>> def my_func(x, y):
...     return lucid.sum(x * y + x)

>>> compiled_func = lucid.compile(my_func)
>>> result = compiled_func(lucid.randn(3, 3), lucid.randn(3, 3))

Full Training Loop

Compiled models integrate seamlessly into standard training loops:

>>> model = lenet_5()
>>> model.to("gpu")
>>> model = model.compile()
>>> model.train()

>>> optimizer = lucid.optim.Adam(model.parameters(), lr=1e-3)

>>> for x, y in train_loader:
...     x = x.to("gpu")
...     y = y.to("gpu")
...
...     logits = model(x)
...     loss = F.cross_entropy(logits, y)
...
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()

The JIT system automatically handles gradient graph construction in training mode and uses the updated parameter values on each forward call.

Inference with no_grad

For maximum inference speed, combine compilation with lucid.no_grad():

>>> model = model.compile()
>>> model.eval()

>>> with lucid.no_grad():
...     predictions = model(test_input)

This applies both DeadNodeElimPass and NoGradStripPass, eliminating all gradient-related overhead.

Cache Invalidation

If your model’s structure changes (e.g., after pruning or adding layers), you can manually clear the compiled plan cache:

>>> model.invalidate_cache()
>>> output = model(x)  # Re-traces and recompiles

Limitations

Warning

Data-dependent control flow is captured only once. The JIT traces a single execution path through the model. If your forward pass contains if / else branches that depend on tensor values (not shapes), only the branch taken during the first call will be compiled. Subsequent calls will always follow that same branch, regardless of input values.

# This pattern is NOT safe for JIT compilation:
def forward(self, x):
    if x.sum() > 0:     # Data-dependent branch
        return self.path_a(x)
    return self.path_b(x)

Caution

Dynamic input shapes cause recompilation. Each unique combination of input shapes, dtypes, and devices produces a separate cache entry. If your inputs vary in shape every call (e.g., variable-length sequences without padding), the cache will frequently miss and the tracing overhead will negate any speedup.

Caution

First-call overhead. The very first call to a compiled model or function incurs the cost of tracing the full forward pass plus compilation. This one-time cost is amortized over all subsequent calls.

Integration with Lucid

The JIT system is designed to be a drop-in acceleration layer that works transparently with the rest of the Lucid ecosystem:

  • `lucid.nn.Module` — Any module can be compiled via model.compile() or lucid.compile(model). All registered parameters, buffers, and submodule hierarchies are preserved.

  • `lucid.optim` — Optimizers hold references to the original parameter objects. Since the JIT executor reads parameters at execution time, optimizer updates (param.data -= …) are reflected on the next forward call.

  • `lucid.autograd` — In training mode, the executor constructs BackwardOperation objects that integrate with Lucid’s autograd engine. Calling loss.backward() after a compiled forward pass works identically to the non-compiled case.

Conclusion

The Lucid JIT compilation system provides a simple, high-level interface for accelerating model execution with minimal code changes. By calling model.compile(), users gain automatic tracing, graph optimization, plan caching, and full training support—all without sacrificing the flexibility of Lucid’s eager execution model.

Attention

For detailed API documentation of the compile function, including parameter descriptions, return types, and advanced usage, see the lucid.compile reference page.