lucid.split¶
- lucid.split(a: Tensor, /, size_or_sections: int | list[int] | tuple[int], axis: int = 0) tuple[Tensor, ...] ¶
The split function divides a tensor into multiple sub-tensors along a specified axis. It supports both equal-sized splits and custom-sized splits, making it useful for partitioning data in deep learning applications.
Function Signature¶
def split(
a: Tensor, size_or_sections: int | list[int] | tuple[int], axis: int = 0
) -> tuple[Tensor, ...]
Parameters¶
a (Tensor): The input tensor to be split.
size_or_sections (int | list[int] | tuple[int]): If an integer, the tensor is split into equal parts along the specified axis. If a list or tuple, it specifies the sizes of each split.
axis (int, optional): The axis along which to split the tensor. Default is 0.
Mathematical Expression¶
If size_or_sections is an integer k, the operation performs:
where each sub-tensor satisfies:
and s_i are determined based on the specified axis and split sizes.
Return Values¶
tuple[Tensor, …]: A tuple containing the resulting sub-tensors after the split operation.
Examples¶
from lucid import Tensor, split
x = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
split_tensors = split(x, 2, axis=1) # Splits into two tensors along axis 1
print(split_tensors[0].data) # Output: [[1, 2], [5, 6]]
print(split_tensors[1].data) # Output: [[3, 4], [7, 8]]
Note
If size_or_sections is an integer, the input tensor must be evenly divisible along the specified axis.
If size_or_sections is a list, the sum of its elements must match the size of the tensor along the given axis.