MAML-Pytorch代码阅读笔记
正在更新中。。。。。。 参考博客: 知乎:
代码: 论文地址:
MAML-pytorch:
首先看下网络结构: 可以看到是一个标准的四层卷积神经网络,size均为3*3*32,后接relu、batchnorm以及maxpooling,最后将卷及结果扁平化。
config = [ (conv2d, [32, 3, 3, 3, 1, 0]), (relu, [True]), (bn, [32]), (max_pool2d, [2, 2, 0]), (conv2d, [32, 32, 3, 3, 1, 0]), (relu, [True]), (bn, [32]), (max_pool2d, [2, 2, 0]), (conv2d, [32, 32, 3, 3, 1, 0]), (relu, [True]), (bn, [32]), (max_pool2d, [2, 2, 0]), (conv2d, [32, 32, 3, 3, 1, 0]), (relu, [True]), (bn, [32]), (max_pool2d, [2, 1, 0]), (flatten, []), (linear, [args.n_way, 32 * 5 * 5]) ] maml = Meta(args, config).to(device)
for k in range(1, self.update_step): # 1. run the i-th task and compute loss for k=1~K-1 logits = self.net(x_spt[i], fast_weights, bn_training=True) loss = F.cross_entropy(logits, y_spt[i]) # 2. compute grad on theta_pi grad = torch.autograd.grad(loss, fast_weights) # 3. theta_pi = theta_pi – train_lr * grad fast_weights = list(map(lambda p: p[1] – self.update_lr * p[0], zip(grad, fast_weights))) logits_q = self.net(x_qry[i], fast_weights, bn_training=True) # loss_q will be overwritten and just keep the loss_q on last update step. loss_q = F.cross_entropy(logits_q, y_qry[i]) losses_q[k + 1] += loss_q with torch.no_grad(): pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy corrects[k + 1] = corrects[k + 1] + correct
之后计算这个task基于query set(x_qry)的loss并求和,在下述代码中完成第二次参数的梯度更新:
# end of all tasks # sum over all losses on query set across all tasks loss_q = losses_q[-1] / task_num # optimize theta parameters self.meta_optim.zero_grad() loss_q.backward() self.meta_optim.step()
mini = MiniImagenet(/home/i/tmp/MAML-Pytorch/miniimagenet/, mode=train, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=10000, resize=args.imgsz) mini_test = MiniImagenet(/home/i/tmp/MAML-Pytorch/miniimagenet/, mode=test, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=100, resize=args.imgsz)
接下来是整个训练过程的流程,在每一步中,都会在meta-train set上进行maml的核心算法,然后每过500个epoch,会将参数放到meta-test set上进行测试,然后对测试集中的每个task做fine-tune。
for epoch in range(args.epoch//10000): # 在训练集中取一个batch的task进行训练 db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True) for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs = maml(x_spt, y_spt, x_qry, y_qry) if step % 30 == 0: print(‘step:’, step, ‘ training acc:’, accs) if step % 500 == 0: # 评估阶段 db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True) accs_all_test = [] for x_spt, y_spt, x_qry, y_qry in db_test: x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device) accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry) accs_all_test.append(accs) accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print(‘Test acc:’, accs)
下一篇:
原生js中的数组使用的几种方法