lucid.chunk

lucid.chunk(a: Tensor, /, chunks: int, axis: int = 0) tuple[Tensor, ...]

The chunk function splits a tensor into multiple sub-tensors along a specified axis. If the tensor cannot be evenly divided, the last chunk will be smaller.

Function Signature

def chunk(input_: Tensor, chunks: int, axis: int = 0) -> tuple[Tensor, ...]:

Parameters

  • input_ (Tensor): The input tensor to be split.

  • chunks (int): The number of chunks to split the tensor into. Must be greater than 0.

  • axis (int, optional): The dimension along which to split the tensor. Default is 0.

Returns

  • tuple[Tensor, …]: A tuple containing chunks number of tensors, where each tensor is a portion of input_ along the specified axis.

Gradient Computation

Each returned tensor retains a corresponding compute_grad function, which ensures that gradients are correctly mapped back to their original positions in input_.

Example

Splitting a tensor into two equal parts along the first axis:

>>> import lucid
>>> x = lucid.ones(4, 6, requires_grad=True)
>>> y1, y2 = x.chunk(2, axis=0)
>>> y1 *= 2
>>> y2 *= 3
>>> y1.backward()
>>> y2.backward()
>>> print(x.grad)
[[2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3. 3.]]

Note

  • If chunks is larger than the size of the axis being split, some returned tensors may be empty.

  • This function ensures proper gradient propagation, preserving the backpropagation mechanism in lucid.