Pytorch|Pytorch Torch.utils.data.Sampler

Data Loading Order and Sampler For iterable-style datasets, data loading order is entirely controlled by the user-defined iterable. This allows easier implementations of chunk-reading and dynamic batch size (e.g., by yielding a batched sample at each time).
The rest of this section concerns the case with map-style datasets. torch.utils.data.Sampler classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets. E.g., in the common case with stochastic gradient decent (SGD), a Sampler could randomly permute a list of indices and yield each one at a time, or yield a small number of them for mini-batch SGD.
A sequential or shuffled sampler will be automatically constructed based on the shuffle argument to a DataLoader. Alternatively, users may use the sampler argument to specify a custom Sampler object that at each time yields the next index/key to fetch.
A custom Sampler that yields a list of batch indices at a time can be passed as the batch_sampler argument. Automatic batching can also be enabled via batch_size and drop_last arguments. See the next section for more details on this.
NOTE
【Pytorch|Pytorch Torch.utils.data.Sampler】Neither sampler nor batch_sampler is compatible with iterable-style datasets, since such datasets have no notion of a key or an index.
Sampler主要是结合Map-style的DataSet使用,所用的Sampler都有__iter__方法,返回一个取样索引的迭代器。可以选择性的实现__len__。

import torch.utils.data as data #shape (4,1,2,2) data_tensor=torch.Tensor(range(1,17)).reshape(4,1,2,2)sequential_sampler=data.SequentialSampler(data_tensor) #len=4random_sampler=data.RandomSampler(data_tensor,replacement=False, num_samples=None)#len=4for i in random_sampler: print(i) out: 0 1 3 2random_sampler2=data.RandomSampler(data_tensor,replacement=True, num_samples=5)#len=5,可以重复 subset_random_sampler=data.SubsetRandomSampler([0,1,3])weighted_random_sampler1=data.WeightedRandomSampler([0.1,0.2,0.1,0.1],#共四个样本,每个样本被采到的概率 num_samples=4, #要采的样本数,<=weights中非零元素的个数 replacement=False)#不能有重复 weighted_random_sampler2=data.WeightedRandomSampler([0.1,0.2,0.1,0.1],#共四个样本,每个样本被采到的概率 num_samples=6, #要采的样本数,可以大于weights中非零元素的个数 replacement=True)#可以有重复,默认为True,不然权值就没有意义了batch_sampler=data.BatchSampler(random_sampler,batch_size=2,drop_last=False) list(batch_sample)#[[2,0],[3,1]] len(batch_sample)#2

Sampler.py
import torch from torch._six import int_classes as _int_classesclass Sampler(object): r"""Base class for all Samplers.Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices of dataset elements, and a :meth:`__len__` method that returns the length of the returned iterators... note:: The :meth:`__len__` method isn't strictly required by :class:`~torch.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~torch.utils.data.DataLoader`. """def __init__(self, data_source): passdef __iter__(self): raise NotImplementedError# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # # Many times we have an abstract class representing a collection/iterable of # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally # implementing a `__len__` method. In such cases, we must make sure to not # provide a default implementation, because both straightforward default # implementations have their issues: # #+ `return NotImplemented`: #Calling `len(subclass_instance)` raises: #TypeError: 'NotImplementedType' object cannot be interpreted as an integer # #+ `raise NotImplementedError()`: #This prevents triggering some fallback behavior. E.g., the built-in #`list(X)` tries to call `len(X)` first, and executes a different code #path if the method is not found or `NotImplemented` is returned, while #raising an `NotImplementedError` will propagate and and make the call #fail where it could have use `__iter__` to complete the call. # # Thus, the only two sensible things to do are # #+ **not** provide a default `__len__`. # #+ raise a `TypeError` instead, which is what Python uses when users call #a method that is not defined on an object. #(@ssnl verifies that this works on at least Python 3.7.)class SequentialSampler(Sampler): r"""Samples elements sequentially, always in the same order.Arguments: data_source (Dataset): dataset to sample from """def __init__(self, data_source): self.data_source = data_sourcedef __iter__(self): return iter(range(len(self.data_source)))def __len__(self): return len(self.data_source)class RandomSampler(Sampler): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw.Arguments: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when `replacement` is ``True``. """def __init__(self, data_source, replacement=False, num_samples=None): self.data_source = data_source self.replacement = replacement self._num_samples = num_samplesif not isinstance(self.replacement, bool): raise ValueError("replacement should be a boolean value, but got " "replacement={}".format(self.replacement))if self._num_samples is not None and not replacement: raise ValueError("With replacement=False, num_samples should not be specified, " "since a random permute will be performed.")if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(self.num_samples))@property def num_samples(self): # dataset size might change at runtime if self._num_samples is None: return len(self.data_source) return self._num_samplesdef __iter__(self): n = len(self.data_source) if self.replacement: return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) return iter(torch.randperm(n).tolist())def __len__(self): return self.num_samplesclass SubsetRandomSampler(Sampler): r"""Samples elements randomly from a given list of indices, without replacement.Arguments: indices (sequence): a sequence of indices """def __init__(self, indices): self.indices = indicesdef __iter__(self): return (self.indices[i] for i in torch.randperm(len(self.indices)))def __len__(self): return len(self.indices)class WeightedRandomSampler(Sampler): r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).Args: weights (sequence): a sequence of weights, not necessary summing up to one num_samples (int): number of samples to draw replacement (bool): if ``True``, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.Example: >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [0, 0, 0, 1, 0] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2] """def __init__(self, weights, num_samples, replacement=True): if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ num_samples <= 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(num_samples)) if not isinstance(replacement, bool): raise ValueError("replacement should be a boolean value, but got " "replacement={}".format(replacement)) self.weights = torch.as_tensor(weights, dtype=torch.double) self.num_samples = num_samples self.replacement = replacementdef __iter__(self): return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())def __len__(self): return self.num_samplesclass BatchSampler(Sampler): r"""Wraps another sampler to yield a mini-batch of indices.Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size``Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """def __init__(self, sampler, batch_size, drop_last): if not isinstance(sampler, Sampler): raise ValueError("sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}" .format(sampler)) if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_lastdef __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batchdef __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size


    推荐阅读