class

WeightedRandomSampler

extendsSampler
WeightedRandomSampler(weights: list[float], num_samples: int, replacement: bool = True, generator: object = None)
source

Yield indices drawn proportionally to user-supplied weights.

Each index i is selected with probability proportional to weights[i]. Internally the weights are normalised by their sum so they need not form a true probability distribution on input.

Parameters

weightslist of float
Non-negative weight per index. The effective probability of index i is pi=wi/jwjp_i = w_i / \sum_j w_j.
num_samplesint
Number of indices to draw per epoch.
replacementbool= True
If True (default), draws are independent with replacement — the same index may appear multiple times. If False, draws use weighted reservoir-style selection (via random.choices); note that the underlying call here still permits repeats, so callers wanting strict-uniqueness should provide num_samples <= len(weights) and validate downstream.
generatoroptional= None
Seed-like object forwarded to random.Random for reproducibility.

Notes

Each index ii is selected with probability

P(i)=wijwj,P(i) = \frac{w_i}{\sum_j w_j},

so the user-supplied weights need not be normalised. The classic use case is class-imbalance correction: set w_i = 1 / class_count[label_i] so under-represented classes are upsampled to roughly uniform frequency. Sampling with replacement is the default — it preserves the target marginal exactly and is the only fully consistent option when num_samples exceeds the number of nonzero-weight indices.

Examples

>>> # 3 classes, counts [900, 90, 10]; upweight rare classes
>>> weights = [1/900]*900 + [1/90]*90 + [1/10]*10
>>> sampler = WeightedRandomSampler(weights, num_samples=1000)
>>> for idx in sampler:
...     x, y = my_dataset[idx]

Methods (3)

dunder

__init__

None
__init__(weights: list[float], num_samples: int, replacement: bool = True, generator: object = None)
source

Store sampling configuration.

Parameters

weightslist of float
See class docstring.
num_samplesint
See class docstring.
replacementbool= True
See class docstring.
generatoroptional= None
See class docstring.
dunder

__iter__

Iterator[int]
__iter__()
source

Yield num_samples indices proportional to the weights.

With replacement=True indices are produced by inverse-CDF sampling against the normalised weight vector. With replacement=False random.choices is used with the raw weights.

dunder

__len__

int
__len__()
source

Return num_samples.