当前位置:
首页 > Python基础教程 >
-
Pytorch多GPU训练过程
这篇文章主要介绍了Pytorch多GPU训练过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
1 导入库
import torch#深度学习的pytoch平台
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
2 指定GPU
2.1 单GPU声明
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2.2 多GPU声明
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5' #指定GPU编号
device = torch.device("cuda") #创建GPU对象
3 数据放到GPU
x_train = Variable(train,requires_grad=True).to(device=device,dtype=torch.float32) #把训练变量放到GPU
4 把模型网络放到GPU 【重要】
net = DNN(layers)
net = nn.DataParallel(net)
net.to(device=device)
重要:nn.DataParallel
net = nn.DataParallel(net)
net.to(device=device)
1.使用 nn.DataParallel 打包模型
2.然后用 nn.DataParallel 的 model.to(device) 把模型传送到多块GPU中进行运算
torch.nn.DataParallel(DP)
DataParallel(DP)中的参数:
module即表示你定义的模型
device_ids表示你训练时用到的gpu device
output_device这个参数表示输出结果的device,默认就是在第一块卡上,因此第一块卡的显存会占用的比其他卡要更多一些。
当调用nn.DataParallel的时候,input数据是并行的,但是output loss却不是这样的,每次都会在output_device上相加计算
===> 这就造成了第一块GPU的负载远远大于剩余其他的显卡。
DP的优势是实现简单,不涉及多进程,核心在于使用nn.DataParallel将模型wrap一下,代码其他地方不需要做任何更改。
例子:
5 其他:多GPU并行
加个判断:
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
model = Model(input_size, output_size) # 实例化模型对象
if torch.cuda.device_count() > 1: # 检查电脑是否有多块GPU
print(f"Let's use {torch.cuda.device_count()} GPUs!")
model = nn.DataParallel(model) # 将模型对象转变为多GPU并行运算的模型
model.to(device) # 把并行的模型移动到GPU上
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持
原文链接:https://blog.csdn.net/weixin_43338969/article/details/128834863
栏目列表
最新更新
vbscript基础篇 - vbs数组Array的定义与使用方
vbscript基础篇 - vbs变量定义与使用方法
vbs能调用的系统对象小结
vbscript网页模拟登录效果代码
VBScript 根据IE窗口的标题输出ESC
杀死指定进程名称的小VBS
通过vbs修改以点结尾的文件的属性为隐藏
查询电脑开关机时间的vbs代码
VBA中的Timer函数用法
ComboBox 控件的用法教程
SQL SERVER中递归
2个场景实例讲解GaussDB(DWS)基表统计信息估
常用的 SQL Server 关键字及其含义
动手分析SQL Server中的事务中使用的锁
openGauss内核分析:SQL by pass & 经典执行
一招教你如何高效批量导入与更新数据
天天写SQL,这些神奇的特性你知道吗?
openGauss内核分析:执行计划生成
[IM002]Navicat ODBC驱动器管理器 未发现数据
初入Sql Server 之 存储过程的简单使用
uniapp/H5 获取手机桌面壁纸 (静态壁纸)
[前端] DNS解析与优化
为什么在js中需要添加addEventListener()?
JS模块化系统
js通过Object.defineProperty() 定义和控制对象
这是目前我见过最好的跨域解决方案!
减少回流与重绘
减少回流与重绘
如何使用KrpanoToolJS在浏览器切图
performance.now() 与 Date.now() 对比