不均衡样本的sampler构建 Imbalanced Dataset Sampler

from fastNLP.io import SST2Pipe
from fastNLP import DataSetIter
from torchsampler import ImbalancedDatasetSampler
pipe = SST2Pipe()
databundle = pipe.process_from_file()
vocab = databundle.vocabs[words]
print(databundle)
print(databundle.datasets[train][0])
print(databundle.vocabs[words])

train_data = databundle.get_dataset(train)
train_data, test_data = train_data.split(0.015)
dev_data = databundle.get_dataset(dev)
print(len(train_data),len(dev_data),len(test_data))
tmp_data = dev_data[:10]

def callback_get_label(dataset,idx):
    label = dataset[idx][target]
    return label

#对于数据集需要定义callback_get_label函数来确定每一个样本的标签值
sampler=ImbalancedDatasetSampler(tmp_data,callback_get_label=callback_get_label)
batch = DataSetIter(batch_size=3, dataset=tmp_data,
                    sampler=sampler)
for batch_x, batch_y in batch:
    print("batch_x: ", batch_x)
    print("batch_y: ", batch_y)
class IMB(torchvision.datasets.MNIST):
    def __init__(self,transform=None, target_transform=None):
        train_dataset = torchvision.datasets.MNIST(., train=True, download=True, transform=train_transform)
        train_labels = np.delete(train_loader.dataset.train_labels, idx_to_del, axis=0)
        train_data = np.delete(train_loader.dataset.train_data, idx_to_del, axis=0)
        self.data, self.targets = train_data, train_labels
        self.transform=transform
        self.target_transform=target_transform

imbalanced_train_dataset=IMB(transform=train_transform)

imbalanced_train_loader = torch.utils.data.DataLoader(
    imbalanced_train_dataset, batch_size=args.batch_size, shuffle=True, **kwar

pytorch版本和mxnet版本的Imbalanced Dataset Sampler。目前自己写的mxnet版本效率极低,推荐pytorch版本,不影响sampler是在pytorch或mxnet中的使用。

def callback_get_label(dataset,idx): label = dataset[idx][1] return label import torch class ImbalancedDatasetSampler(Sampler): """Samples elements randomly from a given list of indices for imbalanced dataset Arguments: indices (list, optional): a list of indices num_samples (int, optional): number of samples to draw callback_get_label func: a callback-like function which takes two arguments - dataset and index """

def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None): """ torch版本__iter__与mxnet版本__iter__ 不均衡的抽样思想:计算每个类别的概率,将概率值赋值给每个样本生成概率列表,长度与样本大小一致, 然后以多项分布的方式进行抽样获取样本索引值。 """ # if indices is not provided, # all elements in the dataset will be considered self.indices = list(range(len(dataset))) if indices is None else indices

# define custom callback self.callback_get_label = callback_get_label

# if num_samples is not provided, # draw `len(indices)` samples in each iteration self.num_samples = len(self.indices) if num_samples is None else num_samples # distribution of classes in the dataset label_to_count = {} for idx in self.indices: label = self._get_label(dataset, idx) if label in label_to_count: label_to_count[label] += 1 else: label_to_count[label] = 1 # weight for each sample weights = [1.0 / label_to_count[self._get_label(dataset, idx)] for idx in self.indices]

self.weights = torch.DoubleTensor(weights) #self.weights = np.array(weights)

def _get_label(self, dataset, idx): if self.callback_get_label: return self.callback_get_label(dataset, idx) else: raise NotImplementedError # def __iter__(self): # return (self.indices[i] for i in torch.multinomial( # self.weights, self.num_samples, replacement=True))

def __iter__(self): #print(self.mxmulti(self.weights)) return (self.indices[i] for i in self.mxmulti(self.weights)) @classmethod def mxmulti(cls,weights): probs = np.array(weights).astype(float) probs=probs/probs.sum() sample_times = np.random.multinomial(len(probs), np.array(probs),) sample_list=[] for i, t in enumerate(sample_times): t = t.item() if t > 0: sample_list.extend([i]*t) import random random.shuffle(sample_list) return sample_list def __len__(self): return self.num_samples

经验分享 程序员 微信小程序 职场和发展