class

DataLoader

DataLoader(dataset: Dataset, batch_size: int = 1, shuffle: bool | None = None, sampler: Sampler | None = None, batch_sampler: Sampler | None = None, num_workers: int = 0, collate_fn: Callable[..., object] | None = None, drop_last: bool = False, timeout: float = 0.0, worker_init_fn: Callable[..., object] | None = None, multiprocessing_context: object = None, generator: object = None, prefetch_factor: int | None = None, persistent_workers: bool = False)
source

Combine a dataset with a sampler to provide iteration over mini-batches.

Wraps a Dataset to provide batching, optional shuffling, parallel data loading via worker processes, and customisable collation. Iteration yields one collated batch per step until the underlying sampler is exhausted.

Parameters

datasetDataset
Dataset to load data from. May be either map-style (Dataset) or iterable-style (IterableDataset).
batch_sizeint= 1
Number of samples per batch. Ignored when batch_sampler is provided.
shufflebool= None
If True, the default sampler is RandomSampler; otherwise SequentialSampler. Mutually exclusive with sampler.
samplerSampler= None
Custom per-sample index sampler. Mutually exclusive with shuffle.
batch_samplerSampler= None
Custom batch sampler yielding lists of indices. Mutually exclusive with batch_size / shuffle / sampler / drop_last.
num_workersint= 0
Worker processes for parallel data loading. 0 runs single-process in the main thread; > 0 spawns a worker pool.
collate_fncallable= None
Merge a list of samples into a batch (default: default_collate).
drop_lastbool= False
If True, drop the trailing batch when the dataset length is not divisible by batch_size.
timeoutfloat= 0.0
Seconds to wait for a worker to deliver a batch before raising RuntimeError. 0 blocks indefinitely.
worker_init_fncallable= None
Called as worker_init_fn(worker_id) at the start of each worker process — useful for per-worker RNG seeding.
prefetch_factorint= None
Batches pre-loaded per worker (default 2 when num_workers > 0). Higher values trade memory for throughput.
persistent_workersbool= False
Keep worker processes alive between epochs to avoid repeated process-startup overhead. Requires num_workers > 0.
pin_memorybool= False
Accepted for API compatibility. Pinning is a no-op on Apple Silicon (unified memory architecture).
generatoroptional= None
RNG handle forwarded to the default RandomSampler.

Notes

Worker processes communicate via multiprocessing.Queue: each worker owns one index queue, all workers share a single result queue, and the main process reorders results back into sampler order before yielding. Sequence numbers ensure deterministic delivery regardless of completion order across workers.

Examples

>>> dl = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
>>> for batch in dl:
...     ...

Methods (3)

dunder

__init__

None
__init__(dataset: Dataset, batch_size: int = 1, shuffle: bool | None = None, sampler: Sampler | None = None, batch_sampler: Sampler | None = None, num_workers: int = 0, collate_fn: Callable[..., object] | None = None, drop_last: bool = False, timeout: float = 0.0, worker_init_fn: Callable[..., object] | None = None, multiprocessing_context: object = None, generator: object = None, prefetch_factor: int | None = None, persistent_workers: bool = False)
source

Configure a DataLoader; see the class docstring for parameter semantics.

Raises

ValueError
On any of the above mutual-exclusion / range violations.

Notes

sampler / batch_sampler / shuffle / batch_size / drop_last interact: passing batch_sampler precludes the other four; passing sampler precludes shuffle. When no sampler is supplied, a SequentialSampler (shuffle=False) or RandomSampler (shuffle=True) is constructed automatically. persistent_workers requires num_workers > 0.

dunder

__iter__

Iterator[Tensor | tuple[Tensor, ...]]
__iter__()
source

Yield collated mini-batches for one full pass over the dataset.

Dispatches to either the single-process iterator (num_workers == 0) or the multi-process iterator. When persistent_workers is enabled the multi-process worker pool survives between epochs; otherwise workers are spawned and joined per call.

dunder

__len__

int
__len__()
source

Return the number of batches per epoch (len(batch_sampler)).