class

TensorDataset

extendsDataset
TensorDataset(tensors: Tensor = ())
source

Dataset wrapping one or more Tensors, indexed along their first axis.

Each sample is the tuple (t1[i], t2[i], ...) where t1, t2, ... are the wrapped tensors. All tensors must agree in their first dimension (the sample axis); subsequent dimensions are independent.

Parameters

*tensorsTensor= ()
One or more tensors of identical leading-dimension size. The dataset length equals tensors[0].shape[0].

Raises

ValueError
If no tensors are provided or the leading dimensions disagree.

Notes

All wrapped tensors must share the same length along axis 0 — that shared length defines __len__. The underlying tensors are held by reference rather than copied, so any mutation visible on the source tensors is also visible through the dataset. This keeps construction O(1) but means the caller is responsible for not invalidating the buffers (e.g. by resizing) during iteration.

Examples

>>> X = lucid.randn(100, 4)
>>> y = lucid.randint(0, 3, (100,))
>>> ds = TensorDataset(X, y)
>>> x_i, y_i = ds[0]

Methods (3)

dunder

__init__

None
__init__(tensors: Tensor = ())
source

Initialise the instance. See the class docstring for parameter semantics.

dunder

__getitem__

tuple of Tensor
__getitem__(index: int)
source

Return (t[index] for t in self.tensors) as a tuple.

Parameters

indexint
Sample index along the leading dimension.

Returns

tuple of Tensor

One element per wrapped tensor, in registration order.

dunder

__len__

int
__len__()
source

Return the leading-dimension size shared by all wrapped tensors.