nn.EncoderDecoderCache¶
- class lucid.nn.EncoderDecoderCache(self_attention_cache: KVCache | None = None, cross_attention_cache: KVCache | None = None)¶
Overview¶
EncoderDecoderCache is a container cache for encoder-decoder style models. It bundles two KV cache instances:
self_attention_cache for decoder self-attention
cross_attention_cache for decoder cross-attention
The class routes cache reads and writes to one of the two internal caches using is_cross_attention.
Class Signature¶
class lucid.nn.EncoderDecoderCache(
self_attention_cache: nn.KVCache | None = None,
cross_attention_cache: nn.KVCache | None = None,
)
If either cache is not provided, DynamicKVCache is created by default.
Key methods¶
- EncoderDecoderCache.update(key: Tensor, value: Tensor, layer_idx: int, cache_position: Tensor | None = None, is_cross_attention: bool = False) tuple[Tensor, Tensor]¶
- EncoderDecoderCache.get(layer_idx: int, is_cross_attention: bool = False) tuple[Tensor, Tensor] | None¶
- EncoderDecoderCache.get_seq_length(layer_idx: int = 0, is_cross_attention: bool = False) int¶
- EncoderDecoderCache.reset() None¶
- EncoderDecoderCache.batch_repeat_interleave(repeats: int) None¶
- EncoderDecoderCache.crop(max_length: int) None¶
is_updated¶
EncoderDecoderCache exposes is_updated: dict[int, bool] to track whether cross-attention cache has been updated per layer.
Minimal example¶
import lucid
import lucid.nn as nn
cache = nn.EncoderDecoderCache(
self_attention_cache=nn.DynamicKVCache(),
cross_attention_cache=nn.DynamicKVCache(),
)
# decoder self-attention cache update
k_self = lucid.randn(2, 8, 1, 64)
v_self = lucid.randn(2, 8, 1, 64)
cache.update(k_self, v_self, layer_idx=0, is_cross_attention=False)
# decoder cross-attention cache update
k_cross = lucid.randn(2, 8, 10, 64)
v_cross = lucid.randn(2, 8, 10, 64)
cache.update(k_cross, v_cross, layer_idx=0, is_cross_attention=True)
print(cache.get_seq_length(0, is_cross_attention=False)) # 1
print(cache.get_seq_length(0, is_cross_attention=True)) # 10