nn.StaticKVCache¶
- class lucid.nn.StaticKVCache(max_cache_len: int, num_layers: int)¶
Overview¶
StaticKVCache preallocates fixed-size cache storage for each layer. This is useful when decode length bounds are known and predictable memory behavior is desired.
Class Signature¶
class lucid.nn.StaticKVCache(
max_cache_len: int,
num_layers: int,
)
Parameters¶
max_cache_len (int): Maximum sequence length storable per layer in the cache axis.
num_layers (int): Number of layer slots maintained by this cache instance.
Constructor¶
import lucid.nn as nn
cache = nn.StaticKVCache(
max_cache_len=2048, # per-layer max sequence length
num_layers=24, # number of Transformer layers
)
Basic append example¶
import lucid
import lucid.nn as nn
cache = nn.StaticKVCache(max_cache_len=16, num_layers=2)
k = lucid.randn(2, 8, 4, 64)
v = lucid.randn(2, 8, 4, 64)
cache.update(k, v, layer_idx=0)
print(cache.get_seq_length(0)) # 4
print(cache.get_max_cache_shape()) # 16
Position update example (0-D)¶
import lucid
import lucid.nn as nn
cache = nn.StaticKVCache(max_cache_len=16, num_layers=1)
k = lucid.randn(1, 8, 1, 64)
v = lucid.randn(1, 8, 1, 64)
pos = lucid.Tensor(10, dtype=lucid.Int32)
cache.update(k, v, layer_idx=0, cache_position=pos)
print(cache.get_seq_length(0)) # 11
Position update example (1-D)¶
import lucid
import lucid.nn as nn
cache = nn.StaticKVCache(max_cache_len=16, num_layers=1)
k = lucid.randn(1, 8, 3, 64)
v = lucid.randn(1, 8, 3, 64)
pos = lucid.Tensor([0, 4, 7], dtype=lucid.Int32)
cache.update(k, v, layer_idx=0, cache_position=pos)
print(cache.get_seq_length(0)) # 8
Batch-wise position update (2-D)¶
import lucid
import lucid.nn as nn
cache = nn.StaticKVCache(max_cache_len=32, num_layers=1)
k = lucid.randn(2, 8, 2, 64)
v = lucid.randn(2, 8, 2, 64)
pos = lucid.Tensor(
[[0, 5],
[1, 7]],
dtype=lucid.Int32,
)
cache.update(k, v, layer_idx=0, cache_position=pos)
print(cache.get_seq_length(0)) # 8
Out-of-bounds behavior¶
StaticKVCache raises ValueError when writes exceed max_cache_len.
import lucid
import lucid.nn as nn
cache = nn.StaticKVCache(max_cache_len=4, num_layers=1)
k = lucid.randn(1, 8, 5, 64)
v = lucid.randn(1, 8, 5, 64)
# ValueError: exceeded max_cache_len
cache.update(k, v, layer_idx=0)
Beam search utilities¶
import lucid
import lucid.nn as nn
cache = nn.StaticKVCache(max_cache_len=64, num_layers=4)
cache.update(lucid.randn(2, 8, 8, 64), lucid.randn(2, 8, 8, 64), layer_idx=0)
cache.batch_repeat_interleave(3) # B=2 -> B=6
alive = lucid.Tensor([4, 1, 5], dtype=lucid.Int32)
cache.reorder_cache(alive) # B=3
Cropping example¶
For static cache, crop keeps recent valid tokens and updates sequence length.
import lucid
import lucid.nn as nn
cache = nn.StaticKVCache(max_cache_len=32, num_layers=1)
cache.update(lucid.randn(1, 8, 20, 64), lucid.randn(1, 8, 20, 64), layer_idx=0)
cache.crop(8)
print(cache.get_seq_length(0)) # 8
Practical guidance¶
Choose max_cache_len based on your longest expected decode context.
Use reset() between independent requests.
Use crop() when implementing sliding-window style decoding.