nn.utils.apply_chunking_to_forward¶
- lucid.nn.utils.apply_chunking_to_forward(forward_fn: Callable[[...], Tensor], chunk_size: int, chunk_dim: int, *input_tensors: Tensor) Tensor¶
Function Signature¶
def apply_chunking_to_forward(
forward_fn: Callable[..., Tensor],
chunk_size: int,
chunk_dim: int,
*input_tensors: Tensor,
) -> Tensor
Overview¶
apply_chunking_to_forward splits input tensors into chunks along chunk_dim, applies forward_fn independently to each chunk, and concatenates outputs back along the same dimension.
When forward_fn is independent across chunk_dim, this yields the same result as applying forward_fn to full tensors directly, while reducing peak activation memory.
Parameters¶
forward_fn (Callable[…, Tensor]): Forward function to execute per chunk.
chunk_size (int): Size of each chunk along chunk_dim. If chunk_size == 0, chunking is disabled and forward_fn is called once.
chunk_dim (int): Dimension index used for chunking.
input_tensors (Tensor): One or more tensors passed into forward_fn. All tensors must have the same size at chunk_dim.
Return Value¶
Tensor: Concatenated output tensor across all chunks.
Validation Rules¶
input_tensors must be non-empty.
chunk_size must be non-negative.
chunk_dim must be a valid dimension for the inputs.
All inputs must share the same length along chunk_dim.
If chunking is enabled (chunk_size > 0), the size at chunk_dim must be divisible by chunk_size.
Examples¶
import lucid
import lucid.nn as nn
class TinyHead(nn.Module):
def __init__(self, hidden_size: int, chunk_size: int = 64):
super().__init__()
self.decoder = nn.Linear(hidden_size, hidden_size)
self.chunk_size = chunk_size
self.seq_dim = 1
def forward_chunk(self, hidden_states: lucid.Tensor) -> lucid.Tensor:
return self.decoder(hidden_states)
def forward(self, hidden_states: lucid.Tensor) -> lucid.Tensor:
return nn.utils.apply_chunking_to_forward(
self.forward_chunk,
self.chunk_size,
self.seq_dim,
hidden_states,
)