不均衡样本的sampler构建 Imbalanced Dataset Sampler
from fastNLP.io import SST2Pipe from fastNLP import DataSetIter from torchsampler import ImbalancedDatasetSampler pipe = SST2Pipe() databundle = pipe.process_from_file() vocab = databundle.vocabs[words] print(databundle) print(databundle.datasets[train][0]) print(databundle.vocabs[words]) train_data = databundle.get_dataset(train) train_data, test_data = train_data.split(0.015) dev_data = databundle.get_dataset(dev) print(len(train_data),len(dev_data),len(test_data)) tmp_data = dev_data[:10] def callback_get_label(dataset,idx): label = dataset[idx][target] return label #对于数据集需要定义callback_get_label函数来确定每一个样本的标签值 sampler=ImbalancedDatasetSampler(tmp_data,callback_get_label=callback_get_label) batch = DataSetIter(batch_size=3, dataset=tmp_data, sampler=sampler) for batch_x, batch_y in batch: print("batch_x: ", batch_x) print("batch_y: ", batch_y)
class IMB(torchvision.datasets.MNIST): def __init__(self,transform=None, target_transform=None): train_dataset = torchvision.datasets.MNIST(., train=True, download=True, transform=train_transform) train_labels = np.delete(train_loader.dataset.train_labels, idx_to_del, axis=0) train_data = np.delete(train_loader.dataset.train_data, idx_to_del, axis=0) self.data, self.targets = train_data, train_labels self.transform=transform self.target_transform=target_transform imbalanced_train_dataset=IMB(transform=train_transform) imbalanced_train_loader = torch.utils.data.DataLoader( imbalanced_train_dataset, batch_size=args.batch_size, shuffle=True, **kwar
pytorch版本和mxnet版本的Imbalanced Dataset Sampler。目前自己写的mxnet版本效率极低,推荐pytorch版本,不影响sampler是在pytorch或mxnet中的使用。
def callback_get_label(dataset,idx): label = dataset[idx][1] return label import torch class ImbalancedDatasetSampler(Sampler): """Samples elements randomly from a given list of indices for imbalanced dataset Arguments: indices (list, optional): a list of indices num_samples (int, optional): number of samples to draw callback_get_label func: a callback-like function which takes two arguments - dataset and index """
def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None): """ torch版本__iter__与mxnet版本__iter__ 不均衡的抽样思想:计算每个类别的概率,将概率值赋值给每个样本生成概率列表,长度与样本大小一致, 然后以多项分布的方式进行抽样获取样本索引值。 """ # if indices is not provided, # all elements in the dataset will be considered self.indices = list(range(len(dataset))) if indices is None else indices
# define custom callback self.callback_get_label = callback_get_label
# if num_samples is not provided, # draw `len(indices)` samples in each iteration self.num_samples = len(self.indices) if num_samples is None else num_samples # distribution of classes in the dataset label_to_count = {} for idx in self.indices: label = self._get_label(dataset, idx) if label in label_to_count: label_to_count[label] += 1 else: label_to_count[label] = 1 # weight for each sample weights = [1.0 / label_to_count[self._get_label(dataset, idx)] for idx in self.indices]
self.weights = torch.DoubleTensor(weights) #self.weights = np.array(weights)
def _get_label(self, dataset, idx): if self.callback_get_label: return self.callback_get_label(dataset, idx) else: raise NotImplementedError # def __iter__(self): # return (self.indices[i] for i in torch.multinomial( # self.weights, self.num_samples, replacement=True))
def __iter__(self): #print(self.mxmulti(self.weights)) return (self.indices[i] for i in self.mxmulti(self.weights)) @classmethod def mxmulti(cls,weights): probs = np.array(weights).astype(float) probs=probs/probs.sum() sample_times = np.random.multinomial(len(probs), np.array(probs),) sample_list=[] for i, t in enumerate(sample_times): t = t.item() if t > 0: sample_list.extend([i]*t) import random random.shuffle(sample_list) return sample_list def __len__(self): return self.num_samples