nn.EncoderDecoderCache

class lucid.nn.EncoderDecoderCache(self_attention_cache: KVCache | None = None, cross_attention_cache: KVCache | None = None)

Overview

EncoderDecoderCache is a container cache for encoder-decoder style models. It bundles two KV cache instances:

  • self_attention_cache for decoder self-attention

  • cross_attention_cache for decoder cross-attention

The class routes cache reads and writes to one of the two internal caches using is_cross_attention.

Class Signature

class lucid.nn.EncoderDecoderCache(
    self_attention_cache: nn.KVCache | None = None,
    cross_attention_cache: nn.KVCache | None = None,
)

If either cache is not provided, DynamicKVCache is created by default.

Key methods

EncoderDecoderCache.update(key: Tensor, value: Tensor, layer_idx: int, cache_position: Tensor | None = None, is_cross_attention: bool = False) tuple[Tensor, Tensor]
EncoderDecoderCache.get(layer_idx: int, is_cross_attention: bool = False) tuple[Tensor, Tensor] | None
EncoderDecoderCache.get_seq_length(layer_idx: int = 0, is_cross_attention: bool = False) int
EncoderDecoderCache.reset() None
EncoderDecoderCache.batch_select_indices(indices: Tensor) None
EncoderDecoderCache.batch_repeat_interleave(repeats: int) None
EncoderDecoderCache.crop(max_length: int) None

is_updated

EncoderDecoderCache exposes is_updated: dict[int, bool] to track whether cross-attention cache has been updated per layer.

Minimal example

import lucid
import lucid.nn as nn

cache = nn.EncoderDecoderCache(
    self_attention_cache=nn.DynamicKVCache(),
    cross_attention_cache=nn.DynamicKVCache(),
)

# decoder self-attention cache update
k_self = lucid.randn(2, 8, 1, 64)
v_self = lucid.randn(2, 8, 1, 64)
cache.update(k_self, v_self, layer_idx=0, is_cross_attention=False)

# decoder cross-attention cache update
k_cross = lucid.randn(2, 8, 10, 64)
v_cross = lucid.randn(2, 8, 10, 64)
cache.update(k_cross, v_cross, layer_idx=0, is_cross_attention=True)

print(cache.get_seq_length(0, is_cross_attention=False))  # 1
print(cache.get_seq_length(0, is_cross_attention=True))   # 10