Neural Caches¶
The lucid.nn.cache module provides generic cache abstractions used by
neural components. It defines common lifecycle and batch-manipulation utilities
for cache state management, and currently includes Transformer KV cache
implementations.
Why neural cache matters¶
Neural models often maintain reusable intermediate state across steps. Without cache, each step recomputes historical state from scratch. With cache, only new state is appended or updated.
Common cache workflow¶
Create a cache object before decoding.
Pass the cache into attention-enabled model forward calls.
Update cache with each decoding step.
Reorder/select cache during beam search.
Optionally crop cache for memory control.
Reset cache when sequence generation is finished.
Quick start¶
import lucid
import lucid.nn as nn
cache = nn.DynamicKVCache()
# key/value shape example: (batch, heads, seq_len, head_dim)
key = lucid.randn(2, 8, 1, 64)
value = lucid.randn(2, 8, 1, 64)
# append into layer 0
k_all, v_all = cache.update(key, value, layer_idx=0)
print(k_all.shape, v_all.shape) # (2, 8, 1, 64), (2, 8, 1, 64)
# next token append
key2 = lucid.randn(2, 8, 1, 64)
value2 = lucid.randn(2, 8, 1, 64)
k_all, v_all = cache.update(key2, value2, layer_idx=0)
print(cache.get_seq_length(0)) # 2
Cache types¶
lucid.nn.CacheAbstract base interface for all cache types.lucid.nn.KVCacheAbstract KV-specific interface built on top oflucid.nn.Cache.lucid.nn.EncoderDecoderCacheContainer cache that routes updates/reads to decoder self-attention cache or cross-attention cache.lucid.nn.DynamicKVCacheGrows sequence length dynamically (append/expand on demand).lucid.nn.StaticKVCacheUses fixed preallocated storage per layer (max_cache_len).
Choosing Dynamic vs Static¶
Use DynamicKVCache when: - sequence length is highly variable, - you want simpler bring-up and debugging.
Use StaticKVCache when: - you know an upper bound on decode length, - stable memory behavior is important.
Supported utility operations¶
All cache classes support:
update(…) for layer KV update.
get(layer_idx) and get_seq_length(layer_idx).
reorder_cache(beam_idx) for beam search.
batch_select_indices(indices) for general batch reindexing.
batch_repeat_interleave(repeats) for beam expansion.
crop(max_length) for keeping recent tokens only.
reset() for clearing states.
cache_position modes¶
Both dynamic and static caches support:
None: append mode.
0-D tensor: write a single token to a single position.
1-D tensor: write tokens to positions (length must match seq_len).
2-D tensor: batch-wise position mapping (batch, seq_len).
Beam search snippets¶
import lucid.nn as nn
import lucid
cache = nn.DynamicKVCache()
# ... prefill/update cache first ...
num_beams = 4
cache.batch_repeat_interleave(num_beams) # B -> B*num_beams
# after beam pruning
beam_idx = lucid.Tensor([3, 1, 0, 2], dtype=lucid.Int32)
cache.reorder_cache(beam_idx)
Memory control snippet¶
# keep only most recent 1024 tokens
cache.crop(1024)
# clear everything between requests
cache.reset()