nvidia训练深度学习模型利器apex使用解读
一、背景
gpu显存不大,很多模型没法跑,不能用很大的batch size等导致loss没法降低。使用apex工具可以从中解脱出来。
二、apex介绍
apex是nvidia开源的,完美支持pytorch框架,用于改变数据格式来减小模型显存占用的工具。
其中最有价值的是amp(Automatic Mixed Precision),将模型的大部分操作都用float16数据类型替代,一些特别操作仍然使用float32.
并且用户仅仅通过三行代码即可完美将自己的训练代码迁移到该模型。
实验证明,使用float16作为大部分操作的数据类型,并没有降低参数,在一些实验中,反而由于可以增大batch size,带来精度上的提升,以及训练速度上的提升。
它号称能够在不降低性能的情况下,将模型训练的速度提升2~4倍,训练显存消耗减少为之前的一半。
三、apex配置
见:
四、代码实现
1、三行代码示例
from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level="O1") with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()
2、opt_level参数设置
只有一个opt_level需要用户自行配置
-
O0:纯FP32训练,可以作为accuracy的baseline O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM,卷积)还是FP32(Softmax)进行计算 O2:"几乎FP16"混合精度训练,不存在黑白名单,除了Batch Norm,几乎都是用FP16计算 O3:纯FP16训练,很不稳定,但是可以作为speed的baseline。
3、测试代码(不带amp)
import torch N, D_in, D_out = 64, 1024, 512 x = torch.randn(N, D_in, device=cuda) y = torch.randn(N, D_out, device=cuda) model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) for _ in range(1000): y_pred = model(x) loss = torch.nn.functional.mse_loss(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step()
4、测试代码(带amp)
import torch from apex import amp N, D_in, D_out = 64, 1024, 512 x = torch.randn(N, D_in, device=cuda) y = torch.randn(N, D_out, device=cuda) model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) model, optimizer = amp.initialize(model, optimizer, opt_level="O1") for _ in range(1000): y_pred = model(x) loss = torch.nn.functional.mse_loss(y_pred, y) optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()
5、swin-transformer算法添加amp实战
以上总共3处修改点。
测试结果:2g显存的batch_size从4增加到了6,1.5倍,似乎没有想象中那么多。
五、溢出问题
因为float16保存数据位数少了,能保存数据的上限和下限的绝对值也小了。
如果我们在处理分割类问题,需要用到一些涉及到求和的操作,如sigmoid,softmax,这些操作都涉及到求和。
分割问题特征图很大,求个sigmoid可能会导致数据溢出,得到错误的结果。
所以针对这些操作,仍然使用float32作为数据格式。
修改方式:仅需在模型定义中,在构造函数__init__中的某一个位置,加上下面这段:
from apex import amp class xxxNet(Module): def __init__(using_map=False) ... ... if using_amp: amp.register_float_function(torch, sigmoid) amp.register_float_function(torch, softmax)
用register_float_function指明后面的函数需要使用float类型,注意第二实参是string类型。
和register_float_function相似的注册函数还有:
-
amp.register_half_function(module, function_name) amp.register_float_function(module, function_name) amp.register_promote_function(module, function_name)
需要在使用amp.initialize之前使用注册函数,所以最号的位置就放在模型的构造函数中。