WeightedRandomSampler
SamplerWeightedRandomSampler(weights: list[float], num_samples: int, replacement: bool = True, generator: object = None)Yield indices drawn proportionally to user-supplied weights.
Each index i is selected with probability proportional to
weights[i]. Internally the weights are normalised by their sum so
they need not form a true probability distribution on input.
Parameters
weightslist of floati is .num_samplesintreplacementbool= TrueTrue (default), draws are independent with replacement —
the same index may appear multiple times. If False, draws use
weighted reservoir-style selection (via random.choices); note
that the underlying call here still permits repeats, so callers
wanting strict-uniqueness should provide num_samples <= len(weights) and validate downstream.generatoroptional= Nonerandom.Random for reproducibility.Notes
Each index is selected with probability
so the user-supplied weights need not be normalised. The classic
use case is class-imbalance correction: set w_i = 1 / class_count[label_i] so under-represented classes are upsampled
to roughly uniform frequency. Sampling with replacement is the
default — it preserves the target marginal exactly and is the only
fully consistent option when num_samples exceeds the number of
nonzero-weight indices.
Examples
>>> # 3 classes, counts [900, 90, 10]; upweight rare classes
>>> weights = [1/900]*900 + [1/90]*90 + [1/10]*10
>>> sampler = WeightedRandomSampler(weights, num_samples=1000)
>>> for idx in sampler:
... x, y = my_dataset[idx]Methods (3)
__init__
→None__init__(weights: list[float], num_samples: int, replacement: bool = True, generator: object = None)Store sampling configuration.
Parameters
weightslist of floatnum_samplesintreplacementbool= Truegeneratoroptional= None__iter__
→Iterator[int]__iter__()Yield num_samples indices proportional to the weights.
With replacement=True indices are produced by inverse-CDF
sampling against the normalised weight vector. With
replacement=False random.choices is used with the raw
weights.
__len__
→int__len__()Return num_samples.