服务器利用PyTorch对多个GPU或单个GPU使用的几个方式
由于要在服务器的不同GPU中进行模型训练,这里记录一下改变使用GPU的一些方式。
一、单个GUP使用修改
Note:“cuda:0"或"cuda"都代表起始device_id为0,系统默认从0开始。可根据需要修改起始位置,如“cuda:1”等效"cuda:0"或"cuda”
# 任取一个,torch版本不同会有差别 torch.cuda.device(id) # id 是GPU编号 or torch.cuda.set_device(id) or torch.device(cuda)
还有一种方式是在终端指定运行GPU:
CUDA_VISIBLE_DEVICES=1 python main.py,表示只有第1块gpu可见,其他gpu不可用。第1块gpu编号已变成第0块,如果依然使用cuda:1会报invalid device ordinal;以下同效。
单GPU中保存训练模型(2选1)
state = {model: self.model.state_dict(), epoch: ite} torch.save(state, self.model.name()) or # 直接保存 torch.save(self.model.state_dict(), Mymodel.pth) # 当前目录
单GPU/CPU中加载 single-gpu 训练模型(3选1)
checkpoint = torch.load(self.model.name()) self.model.load_state_dict(checkpoint[model]) or # 直接加载 self.model.load_state_dict(torch.load(Mymodel.pth)) or # load gpu or cpu if torch.cuda.is_available(): # gpu self.model.load_state_dict(torch.load(Mymodel.pth)) else: # cpu 官方推荐CPU的加载方式 checkpoint = torch.load(self.model.name(),map_location=lambda storage, loc: storage) self.model.load_state_dict(checkpoint[model])
二、多个GPU使用修改
1. 在终端shell:CUDA_VISIBLE_DEVICES=0,1,3 python main.py
2.
# gpu_ids = [0, 1, 3] # 或 os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,3" # os.environ["CUDA_VISIBLE_DEVICES"] = ,.join(map(str, [0, 1, 3])) import os os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,3" # CUDA_VISIBLE_DEVICES 表当前可被python程序检测到的显卡 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 多GPU时可指定起始位置/编号 # 若不加if项,也不报错,但训练可能会变成单GPU if torch.cuda.device_count() > 1: # 查看当前电脑可用的gpu数量,或 if len(gpu_ids) > 1: print("Lets use", torch.cuda.device_count(), "GPUs!") # self.model = torch.nn.DataParallel(self.model, device_ids=gpu_ids) self.model = torch.nn.DataParallel(self.model) # 声明所有设备 net = self.model.to(device) # 从指定起始位置开始,将模型放到gpu或cpu上 images = self.images.to(device) # 模型和训练数据都放在主设备 labels = self.labels.to(device)
Note:使用多GPU训练,单用 model = torch.nn.DataParallel(model),默认所有存在的显卡都会被使用。
多GPU中保存训练模型(3选1)
if isinstance(self.model,torch.nn.DataParallel): # 判断是否并行 self.model = self.model.module state = {model: self.model.state_dict(), epoch: ite} torch.save(state, self.model.name()) # No-module or if isinstance(self.model, torch.nn.DataParallel): torch.save(self.model.module.stat_dict, Mymodel) # No-module else: torch.save(self.model.stat_dict, Mymodel) # No-module or # 直接保存 torch.save(self.model.state_dict(), Mymodel.pth) # is-module
下一篇:
Java在linux使用.sh启动程序