JIT Compilation¶
The Lucid JIT (Just-In-Time) compilation system accelerates model execution by tracing a forward pass, capturing it as an intermediate representation (IR) graph, applying optimization passes, and replaying it through a specialized executor. This approach can yield significant speedups—typically around 4x on CPU and up to 10x on MLX (Apple Silicon GPU)—while preserving full compatibility with Lucid’s autograd engine for training.
Overview¶
Traditional eager execution evaluates each operation immediately as it is encountered in Python. While this offers maximum flexibility, it carries per-operation Python overhead. The JIT compiler eliminates this overhead by recording the computation graph once and replaying it efficiently on subsequent calls.
The compilation pipeline consists of three stages:
Tracing — A real forward pass is executed while a TracingContext records every operation, its inputs, outputs, and metadata into an IR graph.
Optimization — The captured IR graph is refined by a series of passes (e.g., dead-node elimination) that remove redundant computations.
Execution — A ForwardExecutor replays the optimized graph using the current parameter values, optionally constructing a backward graph for training.
Important
Compilation is lazy: calling model.compile() returns a wrapper immediately, but the actual tracing and compilation happen on the first forward call, not at compile time. Subsequent calls reuse the cached compiled plan.
How It Works¶
Tracing¶
When a compiled model or function is called for the first time, Lucid activates a thread-local TracingContext. During this tracing pass, the normal forward computation runs as usual, but every operation decorated with @func_op automatically registers itself in the tracer.
The tracer records:
Input tensors and parameter tensors (registered with unique value IDs)
Each operation as an IRNode with its operator class, initialization arguments, tensor input/output IDs, non-tensor arguments, device placement, and gradient metadata
Output tensors of the forward pass
Note
Because tracing executes the real forward function, no proxy tensors or symbolic shapes are involved. This means every standard Lucid operation is automatically supported without special JIT annotations.
IR Graph Construction¶
After the traced forward pass completes, the TracingContext finalizes its state into an IRGraph containing:
input_ids — ordered list of external input value IDs
param_ids — set of parameter value IDs (kept separate so the executor can inject fresh parameter values on each call)
nodes — list of IRNode objects in execution order
values — mapping from value ID to IRValue metadata (shape, dtype, device)
output_ids — value IDs that form the function’s return values
constant_ids — values that are neither inputs, parameters, nor produced by any node (e.g., inline constants created during the forward pass)
Optimization Passes¶
The compiled IR graph is refined by a configurable pass pipeline before execution.
Pass |
Applied In |
Description |
|---|---|---|
DeadNodeElimPass |
Training & Inference |
Removes operations whose outputs are not consumed by any downstream node or the final output. Uses a reverse liveness analysis starting from output IDs. |
NoGradStripPass |
Inference only |
Sets has_gradient = False on all nodes, preventing the executor from constructing backward operations. This avoids unnecessary memory and compute overhead during inference. |
Hint
The pass pipeline is selected automatically based on grad_enabled and module.training at the time of the first forward call. No manual configuration is needed.
Plan Caching¶
Each compiled plan is stored in a PlanCache keyed by a CacheKey that encodes:
Input tensor shapes, dtypes, and devices
Gradient state (lucid.grad_enabled())
Training mode (module.training for modules, always False for standalone functions)
Tip
If your model always receives the same input shape (e.g., fixed batch size), the plan is compiled once and reused for every subsequent call. Varying input shapes will create additional cache entries, up to a configurable maximum (default: 8) with FIFO eviction.
Execution¶
The ForwardExecutor replays the cached plan on each forward call:
Builds a value_map by mapping input IDs to the actual input tensors, parameter IDs to the current parameter objects (reflecting any optimizer updates), and constant IDs to their stored live tensors.
Iterates through the execution order, instantiating each operator from its saved class and initialization arguments, then calling the raw forward kernel (bypassing the Python-level @func_op wrapper for maximum speed).
In training mode, attaches BackwardOperation objects to each output tensor, connecting them to their input tensors via weak references. This preserves the full autograd graph for loss.backward().
Note
Because the executor reads parameter values at execution time (not at trace time), optimizer updates between forward passes are automatically reflected. There is no need to recompile after optimizer.step().
Key Features¶
Automatic Plan Caching
Compiled plans are cached and reused across calls. The cache key captures input tensor shapes, dtypes, devices, gradient state, and training mode, ensuring that different execution contexts (e.g., train vs. eval, different batch sizes) receive appropriately compiled plans.
Training & Inference Modes
The JIT system fully supports training workflows. In training mode, only dead-node elimination is applied, and the executor constructs the backward graph so that loss.backward() propagates gradients back to all parameters. In inference mode, an additional NoGradStripPass disables all gradient tracking for faster execution.
Module Hook Preservation
JITModule respects both forward_pre_hooks and forward_hooks registered on the original module. Pre-hooks run before the compiled execution, and post-hooks run on the executor’s output, matching the behavior of non-compiled modules.
Dead Code Elimination
Operations whose outputs are never used by the final output are automatically pruned from the execution plan. This is especially beneficial for models with auxiliary branches that are unused during inference.
Usage Examples¶
Compiling an nn.Module
The simplest way to compile a model is via the compile() method:
>>> import lucid
>>> import lucid.nn as nn
>>> from lucid.models import lenet_5
>>> model = lenet_5()
>>> model = model.compile() # Returns a JITModule wrapper
>>> x = lucid.randn(1, 1, 32, 32)
>>> output = model(x) # First call triggers tracing + compilation
Subsequent calls with the same input shape reuse the cached plan.
Compiling a Standalone Function
You can also compile a plain function using lucid.compile:
>>> import lucid
>>> def my_func(x, y):
... return lucid.sum(x * y + x)
>>> compiled_func = lucid.compile(my_func)
>>> result = compiled_func(lucid.randn(3, 3), lucid.randn(3, 3))
Full Training Loop
Compiled models integrate seamlessly into standard training loops:
>>> model = lenet_5()
>>> model.to("gpu")
>>> model = model.compile()
>>> model.train()
>>> optimizer = lucid.optim.Adam(model.parameters(), lr=1e-3)
>>> for x, y in train_loader:
... x = x.to("gpu")
... y = y.to("gpu")
...
... logits = model(x)
... loss = F.cross_entropy(logits, y)
...
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
The JIT system automatically handles gradient graph construction in training mode and uses the updated parameter values on each forward call.
Inference with no_grad
For maximum inference speed, combine compilation with lucid.no_grad():
>>> model = model.compile()
>>> model.eval()
>>> with lucid.no_grad():
... predictions = model(test_input)
This applies both DeadNodeElimPass and NoGradStripPass, eliminating all gradient-related overhead.
Cache Invalidation
If your model’s structure changes (e.g., after pruning or adding layers), you can manually clear the compiled plan cache:
>>> model.invalidate_cache()
>>> output = model(x) # Re-traces and recompiles
Limitations¶
Warning
Data-dependent control flow is captured only once. The JIT traces a single execution path through the model. If your forward pass contains if / else branches that depend on tensor values (not shapes), only the branch taken during the first call will be compiled. Subsequent calls will always follow that same branch, regardless of input values.
# This pattern is NOT safe for JIT compilation:
def forward(self, x):
if x.sum() > 0: # Data-dependent branch
return self.path_a(x)
return self.path_b(x)
Caution
Dynamic input shapes cause recompilation. Each unique combination of input shapes, dtypes, and devices produces a separate cache entry. If your inputs vary in shape every call (e.g., variable-length sequences without padding), the cache will frequently miss and the tracing overhead will negate any speedup.
Caution
First-call overhead. The very first call to a compiled model or function incurs the cost of tracing the full forward pass plus compilation. This one-time cost is amortized over all subsequent calls.
Integration with Lucid¶
The JIT system is designed to be a drop-in acceleration layer that works transparently with the rest of the Lucid ecosystem:
`lucid.nn.Module` — Any module can be compiled via model.compile() or lucid.compile(model). All registered parameters, buffers, and submodule hierarchies are preserved.
`lucid.optim` — Optimizers hold references to the original parameter objects. Since the JIT executor reads parameters at execution time, optimizer updates (param.data -= …) are reflected on the next forward call.
`lucid.autograd` — In training mode, the executor constructs BackwardOperation objects that integrate with Lucid’s autograd engine. Calling loss.backward() after a compiled forward pass works identically to the non-compiled case.
Conclusion¶
The Lucid JIT compilation system provides a simple, high-level interface for accelerating model execution with minimal code changes. By calling model.compile(), users gain automatic tracing, graph optimization, plan caching, and full training support—all without sacrificing the flexibility of Lucid’s eager execution model.
Attention
For detailed API documentation of the compile function, including parameter descriptions, return types, and advanced usage, see the lucid.compile reference page.