lucid.compile¶
- lucid.compile(target: Callable, *, max_cache_entries: int = 8) JITFunction¶
- lucid.compile(target: Any, *, max_cache_entries: int = 8) JITModule
The lucid.compile function is the primary entry point for JIT compilation in Lucid. It accepts either an nn.Module instance or a plain callable and returns a compiled wrapper that traces the computation graph on the first call, caches the optimized plan, and replays it on subsequent calls for accelerated execution.
Function Signature¶
def compile(
target: nn.Module | Callable,
*,
max_cache_entries: int = 8,
) -> JITModule | JITFunction
Parameters¶
target (nn.Module | Callable): The compilation target. If an nn.Module is provided, a JITModule wrapper is returned that respects training mode, hooks, and parameter management. If a plain callable (function or lambda) is provided, a JITFunction wrapper is returned.
max_cache_entries (int, default=8): Maximum number of compiled plans to keep in the internal cache. Each unique combination of input shapes, dtypes, devices, gradient state, and training mode creates a separate cache entry. When the cache is full, the oldest entry is evicted (FIFO policy).
Returns¶
JITModule — when target is an nn.Module. Wraps the module and compiles its forward() method. Delegates attribute access (e.g., model.training, model.parameters()) to the underlying module transparently.
JITFunction — when target is a callable. Wraps the function for compiled execution without any module-specific behavior (no training mode, no parameter tracking, no hooks).
Raises¶
TypeError — if target is neither an nn.Module nor a callable.
>>> lucid.compile(42)
TypeError: lucid.compile() expects an nn.Module or a callable, got <class 'int'>
JITModule vs JITFunction¶
The return type of lucid.compile depends on what is passed in. The two wrappers differ in several important ways:
Feature |
JITModule |
JITFunction |
|---|---|---|
Target |
nn.Module instance |
Any callable |
Training mode |
Respects module.training; separate plans for train/eval |
Always training_mode=False |
Parameters |
Registers and tracks all module.parameters(); rebuilds param_map on every call so optimizer updates are reflected |
No parameter awareness |
Hooks |
Runs forward_pre_hooks before execution and forward_hooks after execution |
No hook support |
Attribute access |
__getattr__ delegates to the wrapped module (e.g., model.training, model.eval()) |
Standard function wrapper (functools.update_wrapper) |
Cache key |
Includes input shapes + grad_enabled + module.training |
Includes input shapes + grad_enabled only |
Note
For most use cases, JITModule is what you want. JITFunction is useful for compiling standalone utility functions that do not involve learnable parameters or training state.
Examples¶
Compiling a Module
There are two equivalent ways to compile an nn.Module:
>>> import lucid
>>> import lucid.nn as nn
>>> model = nn.Sequential(
... nn.Linear(784, 256),
... nn.ReLU(),
... nn.Linear(256, 10),
... )
# Method 1: Using the Module's convenience method
>>> compiled_model = model.compile()
# Method 2: Using lucid.compile directly
>>> compiled_model = lucid.compile(model)
Both produce an identical JITModule wrapping the original model.
Compiling a Standalone Function
>>> import lucid
>>> def normalize(x):
... mean = lucid.mean(x, axis=-1, keepdims=True)
... std = lucid.std(x, axis=-1, keepdims=True)
... return (x - mean) / (std + 1e-5)
>>> fast_normalize = lucid.compile(normalize)
>>> x = lucid.randn(32, 128)
>>> result = fast_normalize(x) # First call: trace + compile
>>> result = fast_normalize(x) # Cached: fast execution
Custom Cache Size
If your model receives inputs of many different shapes, increase the cache size to avoid repeated recompilation:
>>> model = model.compile(max_cache_entries=32)
Conversely, if memory is constrained and input shapes are fixed, you can reduce it:
>>> model = model.compile(max_cache_entries=2)
Training with a Compiled Model
A compiled model integrates seamlessly into the standard training loop. The key requirement is that model.train() is called before the first forward pass so the compiled plan includes backward graph construction:
>>> import lucid
>>> import lucid.nn as nn
>>> import lucid.nn.functional as F
>>> import lucid.optim as optim
>>> model = MyModel()
>>> model.to("gpu")
>>> model = model.compile()
>>> model.train()
>>> optimizer = optim.Adam(model.parameters(), lr=1e-3)
>>> for epoch in range(num_epochs):
... model.train()
... for x, y in train_loader:
... logits = model(x)
... loss = F.cross_entropy(logits, y)
...
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
...
... # Evaluation uses a separate cached plan
... model.eval()
... with lucid.no_grad():
... test_logits = model(test_x)
Important
The optimizer references the same parameter objects as the compiled model. When optimizer.step() updates param.data, the next compiled forward pass automatically uses the updated values—no recompilation is needed.
Switching Between Train and Eval
Calling model.train() and model.eval() on a JITModule delegates to the underlying module. Because training_mode is part of the cache key, separate compiled plans are maintained for each mode:
>>> model = model.compile()
>>> model.train()
>>> logits = model(x) # Compiled with backward graph
>>> model.eval()
>>> with lucid.no_grad():
... logits = model(x) # Compiled without backward graph (faster)
Each mode has its own cached plan, so switching is cost-free after the initial compilation of each mode.
Cache Invalidation
If the model’s computation graph has changed (e.g., layers were added, removed, or reconfigured after compilation), the cached plans become stale. Invalidate the cache to force recompilation:
>>> model.invalidate_cache()
>>> output = model(x) # Re-traces and recompiles
This clears all cached plans for this JITModule or JITFunction.
Notes¶
Note
Module.compile(**kwargs) is a convenience shorthand defined on nn.Module that simply calls lucid.compile(self, **kwargs). Both forms are fully equivalent.
Tip
Compilation is lazy—the tracing and IR construction happen on the first forward call, not when compile() is invoked. This means model.compile() itself is essentially free. The one-time tracing cost is incurred when the model first processes real data.
Warning
The JIT compiler traces a single execution path through the model. Data-dependent control flow (e.g., if x.sum() > 0) is captured as a static branch. If your model contains such patterns, the traced path may not match all future inputs. Avoid data-dependent branches in compiled models, or restructure them to be shape-dependent instead.
Note
Separate plans are cached for different contexts:
Different input shapes (e.g., batch size 32 vs. 64)
Train vs. eval mode (module.training is part of the cache key)
Gradient enabled vs. disabled (lucid.grad_enabled() is part of the cache key)
This ensures that the executor always uses an appropriately optimized plan.