MindTorch用户手册
1.简介
MindTorch是一款将PyTorch训练脚本高效迁移至MindSpore框架执行的实用工具,旨在不改变原生PyTorch用户的编程使用习惯下,使得PyTorch风格代码能在昇腾硬件上获得高效性能。用户只需要在PyTorch源代码主入口调用torch
系列相关的包导入部分(如torch、torchvision
等)之前调用from mindtorch.tools import mstorch_enable
,加上少量训练代码适配即可实现模型在昇腾硬件上的训练。
本教程旨在协助用户快速完成PyTorch脚本迁移工作,精度调优和性能调优可参考MindTorch调试调优指南。
2.模型迁移入门指南
将现有PyTorch原生代码利用MindTorch移植至MindSpore时,当前通常需要如下两个步骤,替换导入模块以及替换网络训练脚本:
Step 1: 替换导入模块
方式一:运行时自动替换(推荐)
用户只需要在PyTorch源代码主入口调用torch
系列相关的包导入部分之前调用from mindtorch.tools import mstorch_enable
,代码执行时torch同名的导入模块会自动被转换为mindtorch相应的模块(目前支持torch、torchvision、torchaudio
相关模块的自动转换)。
from mindtorch.tools import mstorch_enable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
# 上述代码等价于
# import mindtorch.torch as torch
# import mindtorch.torch.nn as nn
# import mindtorch.torch.nn.functional as F
# from mindtorch.torchvision import datasets, transforms
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(-1, 32*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
criterion = nn.CrossEntropyLoss()
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_data = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2, drop_last=True)
方式二:预先手动替换
替换代码中导入torch
相关包的代码,可以利用mindtorch/tools下提供的replace_import_package工具可快速完成工程代码中torch及torchvision相关导入包的替换。
bash replace_import_package.sh [Project Path]
Project Path
为需要进行替换的工程路经,默认为"./"。
或者,用户也可以逐文件手动的替换文件中的导入包部分代码,示例代码如下:
# 替换前
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torchvision import datasets, transforms
# 替换后
import mindtorch.torch as torch
import mindtorch.torch.nn as nn
import mindtorch.torch.nn.functional as F
from mindtorch.torchvision import datasets, transforms
MindTorch目前已支持大部分PyTorch和torchvision的原生态表达接口,用户只需要替换导入包即可完成模型定义和数据初始化。模型中所使用的高阶API支持状态可以从这里找到 Supported List。如果有一些必要的接口和功能缺失可以通过ISSUE 向我们反馈,我们会优先支持。
Step 2: 替换网络训练脚本
由于MindSpore的自动微分采用函数式表达,和PyTorch的微分接口存在差异,目前需要用户手动适配训练部分的少量代码,即将PyTorch版本的训练流程代码转换为MindSpore的函数式编程写法,从而使能MindSpore动静统一、自动并行等竞争力功能。详细内容可参考MindSpore使用文档。以下示例展示了如何将PyTorch训练流程转换为MindSpore函数式训练流程:
迁移前网络表达:
net = LeNet().to(config_args.device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
net.train()
# 数据迭代训练
for i in range(epochs):
for X, y in train_data:
X, y = X.to(config_args.device), y.to(config_args.device)
out = net(X)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("------>epoch:{}, loss:{:.6f}".format(i, loss))
替换为Mindspore函数式迭代训练表达,其中前向过程通常包含了模型网络接口调用
以及损失函数调用
,反向求导过程包含了反向梯度接口调用
以及优化器接口调用
部分,此外,MindSpore不需要调用loss.backward()
以及optimizer.zero_grad()
,具体示例如下:
import mindtorch.torch as torch
import mindspore as ms
net = LeNet().to(config_args.device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
# 定义前向过程
def forward_fn(data, label):
logits = net(data)
loss = criterion(logits, label)
return loss, logits
# 反向梯度定义
grad_fn = ms.ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
# 单步训练定义
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
optimizer(grads)
return loss
net.train()
# 数据迭代训练
for i in range(epochs):
for X, y in train_data:
X, y = X.to(config_args.device), y.to(config_args.device)
res = train_step(X, y)
print("------>epoch:{}, loss:{:.6f}".format(i, res.numpy()))
Step 2.1(可选): 使用MindSpore的优化器及学习率调整
当前MindTorch支持了SGD、Adam及AdamW, 如果需要使用其他的优化器, 还可以使用MindSpore的优化器进行替代, 比如可以使用RMSProp,Rprop等, 更多优化器可参考MindSpore官网关于优化器章节的讲述。下面例子可供参考。
迁移前代码:
import torch
net = Net()
optimizer = torch.optim.RMSprop(net.parameters(), lr=args.base_lr)
scheduler = torch.optim.StepLR(optimizer, args.step_size)
...
optimizer.step()
scheduler.step()
迁移后代码:
import mindspore
net = Net()
# 根据StepLR的学习率调整逻辑,生成一个list,list中的每个元素代表每个epoch所对应的lr。
def _step_lrs(total_epoch, step_size, base_lr, gamma):
_step = 0
_tmp_lr = base_lr
lrs= []
for i in range(total_epoch):
if i == 0 or i % step_size != 0:
lrs.append(_tmp_lr)
else:
_tmp_lr = _tmp_lr * gamma
lrs.append(_tmp_lr)
return lrs
lrs = _step_lrs(args.total_epoch, args.step_size, args.base_lr, args.gamma) # 生成全部epoch对应的学习率
optimizer = mindspore.nn.RMSProp(net.trainable_params(), learning_rate=lrs) # 定义优化器,并将所有epoch的学习率全部导入优化器中
...
optimizer(grads)
# 此时,不再需要scheduler.step()对学习率进行调整,因为在定义优化器的时候已经将所有epoch的lr传入,
# 后续,optimizer就会根据不同的epoch来选取对应的lr。
如果还需要分组学习率,可以参考下面例子。
迁移前代码:
import torch
net = Net()
conv_param = []
other_param = []
for name, param in net.named_parameters():
if 'conv' in name:
conv_param += [param]
else:
other_param += [param]
group_param = [{'params': conv_param, 'lr': args.base_lr},
{'params': other_param, 'lr': 0.01}]
optimizer = torch.optim.RMSprop(group_param)
scheduler = torch.optim.StepLR(optimizer, args.step_size)
...
optimizer.step()
scheduler.step()
迁移后代码:
import mindspore
net = Net()
conv_param = []
other_param = []
for name, param in net.named_parameters():
if 'conv' in name:
conv_param += [param]
else:
other_param += [param]
# 根据StepLR算法逻辑,生成全部epoch对应的学习率列表
def _step_lrs(total_epoch, step_size, base_lr, gamma):
_step = 0
_tmp_lr = base_lr
lrs= []
for i in range(total_epoch):
if i == 0 or i % step_size != 0:
lrs.append(_tmp_lr)
else:
_tmp_lr = _tmp_lr * gamma
lrs.append(_tmp_lr)
return lrs
lrs_conv = _step_lrs(args.total_epoch, args.step_size, args.base_lr, args.gamma) # 生成conv分组的学习率列表
lrs_linear = _step_lrs(args.total_epoch, args.step_size, 0.01, args.gamma) # 生成linear分组的学习率列表
# 将包含全部epoch的学习率列表传入对应分组的'lr'当中
group_param = [{'params': conv_param, 'lr': lrs_conv},
{'params': other_param, 'lr': lrs_linear}]
optimizer = mindspore.nn.RMSProp(group_param)
...
optimizer(grads)
# 不再需要scheduler.step(),optimizer会根据epoch选取分组中对应的学习率
Step 3: 模型保存与加载
import mindtorch.torch as torch
# 加载来自PyTorch原生脚本的预训练权重pth(内部依赖PyTorch库)
net.load_state_dict(torch.load('pytorch.pth'))
...
# 网络训练脚本
...
# 模型保存
torch.save(net.state_dict(),'msa.pth')
# 加载来自MindTorch迁移模型保存的pth进行finetune
net.load_state_dict(torch.load('msa.pth'))
我们支持PyTorch原生的模型保存与加载语法,允许用户保存网络权重或以字典形式保存其他数据;对于模型加载阶段,当前暂不支持加载网络模型结构。用户同样可以加载来自PyTorch原生的pth文件(该功能依赖环境安装PyTorch库),当前仅支持加载网络权重,不支持加载网络结构。基于MindTorch保存的pth文件不支持PyTorch原生脚本使用。
如果您想了解更多当前流程与PyTorch原生流程的区别可参考与PyTorch执行流程区别。
如果您想要运用静态图模式加速、分布式训练和混合精度等更高阶的训练方式加速训练可以参考3.进阶训练指南。如果在使用过程中遇到问题或无法对标的内容欢迎通过ISSUE 和我们反馈交流。当前存在部分接口暂时无法完全对标PyTorch(参考Supported List),针对这类接口我们正在积极优化中,您可以暂时参考4.手动适配指南进行适配处理(不影响网络的正常执行训练)。
更多迁移用例请参考MSAdaterModelZoo。
3.进阶训练指南
3.1 使用混合精度加速训练
混合精度训练是指在训练时,对神经网络不同的运算采用不同的数值精度的运算策略。对于conv、matmul等运算占比较大的神经网络,其训练速度通常会有较大的加速比。mindspore.amp模块提供了便捷的自动混合精度接口,用户可以在不同的硬件后端通过简单的接口调用获得训练加速。目前由于框架机制不同,用户需要将torch.cuda.amp.autocast
接口替换成mindspore.amp.auto_mixed_precision
接口,从而使能MindSpore的自动混合精度训练。
迁移前代码:
from torch.cuda.amp import autocast, GradScaler
model = Net().cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
scaler = GradScaler()
model.train()
for epoch in epochs:
for inputs, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
loss = scaler.scale(loss) # 损失缩放
loss.backward()
scaler.unscale_(optimizer) # 反向缩放梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # 梯度裁剪
scaler.step(optimizer) # 梯度更新
scaler.update() # 更新系数
...
迁移后代码:
import mindtorch.torch as torch
from mindtorch.torch.cuda.amp import GradScaler
from mindspore.amp import auto_mixed_precision
...
model = Net().cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
scaler = GradScaler()
model.train() # model的方法调用需放在混合精度模型转换前
model = auto_mixed_precision(model, 'O3') # Ascend环境推荐配置'O3',GPU环境推荐配置'O2'
def forward_fn(data, target):
logits = model(data)
logits = torch.cast_to_adapter_tensor(logits) # model为混合精度模型,需要对输出tensor进行类型转换
loss = criterion(logits, target)
loss = scaler.scale(loss) # 损失缩放
return loss
grad_fn = ms.ops.value_and_grad(forward_fn, None, optimizer.parameters)
def train_step(data, target):
loss, grads = grad_fn(data, target)
return loss, grads
for epoch in epochs:
for inputs, target in data:
loss, grads = train_step(input, target)
scaler.unscale_(optimizer, grads) # 反向缩放梯度
grads = ms.ops.clip_by_global_norm(grads, max_norm) # 梯度裁剪
scaler.step(optimizer, grads) # 梯度更新
scaler.update() # 更新系数
...
Step 1:调用auto_mixed_precision
自动生成混合精度模型,如果需要调用原始模型的方法请在混合精度模型生成前执行,如model.train()
;
Step 2(可选):如果后续有对网络输出Tensor的操作,需调用cast_to_adapter_tensor
手动将输出Tensor转换为MindTorch Tensor。
Step 3:调用GradScaler
对梯度进行缩放时,由于自动微分机制和接口区别,unscale_
和step
等接口需要把梯度grads作为入参传入。
更多细节请参考自动混合精度使用教程。
3.2 使用静态图模式加速训练
MindSpore框架的执行模式有两种:动态图(PyNative)模式和静态图(Graph)模式:
- 动态图模式下,程序按照代码的编写顺序执行,在执行正向过程中根据反向传播的原理,动态生成反向执行图。动态图模式方便编写和调试神经网络模型。
- 静态图模式下,程序在编译执行时先生成神经网络的图结构,然后再执行图中涉及的计算操作。静态图模式利用图优化等技术对执行图进行更大程度的优化,因此能获得较好的性能,但是执行图是从源码转换而来,因此在静态图下不是所有的Python语法都能支持。
更多详细信息请参考MindSpore动静统一机制介绍。
目前MSAdapte默认支持PyNative模式,请首先在PyNative模式下完成功能调试。如果想调用静态图模式进行训练加速,再尝试切换到Graph模式执行。下面介绍两种切换静态图的方式:
方式一:采用即时编译装饰器jit
,使能部分函数粒度表达模块以静态图模式执行。
import mindspore as ms
ms.set_context(jit_syntax_level=ms.STRICT)
@ms.jit
def mul(x, y):
return x * y
方式二:全局设置Graph模式,更适合基于Module表达。
import mindspore as ms
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_context(jit_syntax_level=ms.STRICT)
由于Graph模式下不是所有的Python语法都能支持,通过上面两种方式切换到Graph模式后部分网络可能会出现语法不支持情况,推荐调整日志级别export MSA_LOG=2
协助调试分析,根据报错信息对代码进行相应调整,当前主要体现在in-place类型操作和部分语法用法限制,具体可参考静态图语法支持。
3.3 使用分布式训练加速训练
分布式并行训练可以降低对内存、计算性能等硬件的需求,是进行训练的重要优化手段。目前MindTorch中对标torch.distributed
相关分布式接口还在开发中,如果用户想要使用分布式训练进行加速训练,需要将torch.distributed
相关接口替换成MindSpore提供的更简单易用的高阶API。MindTorch基于MindSpore分布式并行能力提供两种并行模式:
- 数据并行:对数据进行切分的并行模式,一般按照batch维度切分,将数据分配到各个计算单元中,进行模型计算。
- 自动并行:融合了数据并行、算子级模型并行的分布式并行模式,可以自动建立代价模型,找到训练时间较短的并行策略,为用户选择合适的并行模式。
相关机制请参考MindSpore原生分布式并行架构。
数据并行
import mindtorch.torch as torch
from mindtorch.torch.utils.data import DataLoader, DistributedSampler
from mindspore.communication import init
import mindspore as ms
...
init("hccl") # 初始化通信环境:"hccl"---Ascend, "nccl"---GPU, "mccl"---CPU
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL) # 配置数据并行模式
torch.manual_seed(1) # 设置随机种子,使得每张卡上权重初始化值一样,便于收敛
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
sampler = DistributedSampler(train_images) # 分布式数据处理
train_data = DataLoader(train_images, batch_size=32, num_workers=2, drop_last=True, sampler=sampler)
...
def forward_fn(data, label):
logits = net(data)
loss = criterion(logits, label)
return loss, logits
grad_fn = ms.ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
grad_reducer = nn.DistributedGradReducer(optimizer.parameters) # 定义分布式优化器
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
grads = grad_reducer(grads) # 梯度聚合
optimizer(grads)
return loss
net.train()
for i in range(epochs):
for inputs, target in train_data:
res = train_step(inputs, target)
...
自动并行
import mindtorch.torch as torch
from mindtorch.torch.utils.data import DataLoader, DistributedSampler
from mindspore.communication import init
import mindspore as ms
...
ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=True) # 自动并行仅支持静态图模式
init("hccl") # 初始化通信环境:"hccl"---Ascend, "nccl"---GPU, "mccl"---CPU
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, search_mode="recursive_programming") # 配置自动并行模式
torch.manual_seed(1) # 设置随机种子,使得每张卡上权重初始化值一样,便于收敛
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
sampler = DistributedSampler(train_images) # 分布式数据处理
train_data = DataLoader(train_images, batch_size=32, num_workers=2, drop_last=True, sampler=sampler)
...
def forward_fn(data, label):
logits = net(data)
loss = criterion(logits, label)
return loss, logits
grad_fn = ms.ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
@ms.jit
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
optimizer(grads)
return loss
net.train()
for i in range(epochs):
for inputs, target in train_data:
res = train_step(inputs, target)
...
自动并行功能目前在实验性阶段,仅支持部分场景。如果在使用过程中出现不支持的报错信息,可以通过ISSUE反馈。
分布式启动
通过OpenMPI的mpirun运行分布式脚本。下面以使用单机8卡的分布式训练为例,当执行该命令时, 脚本会在后台运行,日志文件会保存到当前目录下,不同卡上的日志会按rank_id分别保存在log_output/1/路径下对应的文件中。
mpirun -n 8 --output-filename log_output --merge-stderr-to-stdout python train.py > train.log 2>&1 &
多机多卡启动等更复杂的用法请参考MindSpore分布式训练样例。
4.手动适配指南
4.1 数据处理部分
通常情况下仅需将数据处理相关导入模块转换为mindtorch相应模块,即可实现PyTorch数据部分的迁移,示例如下:
from mindtorch.tools import mstorch_enable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 手动转换
# from mindtorch.torch.utils.data import DataLoader
# from mindtorch.torchvision import datasets, transforms
transform = transforms.Compose([transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616])
])
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
train_data = DataLoader(train_images, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
TorchVision接口支持: MindTorch torchvision是迁移自PyTorch官方实现的计算机视觉工具库,延用PyTorch官方API设计与使用习惯,内部计算调用MindSpore算子,实现与torchvision原始库同等功能。用户只需要将PyTorch源代码中import torchvision
替换为import mindtorch.torchvision
即可。torchvision支持状态可以从这里找到 TorchVision Supported List。
另外,如果遇到数据处理接口未完全适配的场景,可以暂时使用PyTorch原生的数据处理流程,将生成的数据PyTorch张量转为MindTorch支持的张量对象,请参考convert_tensor 工具使用教程实现。
4.2 模型构建部分
4.2.1 自定义module
from mindtorch.torch.nn import Module, Linear, Flatten
class MLP(Module):
def __init__(self):
super(MLP, self).__init__()
self.flatten = Flatten()
self.line1 = Linear(in_features=1024, out_features=64)
self.line2 = Linear(in_features=64, out_features=128, bias=False)
self.line3 = Linear(in_features=128, out_features=10)
def forward(self, inputs):
x = self.flatten(inputs)
x = self.line1(x)
x = self.line2(x)
x = self.line3(x)
return x
自定义Module写法和PyTorch原生写法一致,但需要注意下述问题:
- 自定义module时可能出现变量名已被使用场景,如
self.phase
,需要用户自行变更变量名;
- 自定义反向传播函数差异,反向函数需要满足MindSpore自定义反向函数格式要求,以下是适配案例,混合精度等更多信息可以参考自定义反向章节内容。
# PyTorch 写法
class GdnFunction(Function):
@staticmethod
def forward(ctx, x, gamma, beta):
# save variables for backprop
ctx.save_for_backward(x, gamma, beta)
...
return y
@staticmethod
def backward(ctx, grad_output):
x, gamma, beta = ctx.saved_variables
...
return grad_input, grad_gamma, grad_beta
# MindTorch 写法
class GdnFunction(nn.Module):
def __init__(self):
super(GdnFunction, self).__init__()
def forward(self, x, gamma, beta):
...
return y
def bprop(self, x, gamma, beta, out, grad_output):
x = torch.Tensor(x)
gamma = torch.Tensor(gamma)
beta = torch.Tensor(beta)
grad_output = torch.Tensor(grad_output)
...
return grad_input, grad_gamma, grad_beta
4.2.2 view类接口和inplace类接口适配
-
当前torch.view
操作实际等价于创建指定shape的新tensor,并不真实共享内存,需要用户自己保证tensor的赋值更新。(共享内存的view接口正在研发中,敬请期待!);
-
暂时无法对标inplace相关操作,当前此类并不真实共享内存,所以torch.xxx(*, out=output)
接口推荐写成output = torch.xxx(*)
形式,tensor_a.xxx_(*)
推荐写成tensor_b = tensor_a.xxx(*)
形式,则该接口在图模式下也可正常执行;
-
切片后的inplace算子不生效,需修改为如下写法:
# PyTorch 原生写法
boxes[i,:,0::4].clamp_(0, im_shape[i, 1]-1)
# MindTorch 推荐写法
a = boxes[i,:,0::4].clamp_(0, im_shape[i, 1]-1)
boxes[i, :, 0::4] = a
4.3 训练流程部分
4.3.1 指定执行硬件
PyTorch原生接口通过to
等接口将数据拷贝到指定硬件中执行,但是MindTorch暂不支持指定硬件执行,实际执行的硬件后端由conetxt指定。如果您的程序运行在云脑2,则默认执行昇腾硬件,如果想执行在其他硬件后端可以参考如下代码;
import mindspore as ms
ms.set_context(device_target="CPU") # 指定CPU执行
ms.set_context(device_target="Ascend", device_id=1) # 指定昇腾1号卡执行, device_id为可选参数,默认0号卡执行
如果未设置device_target
参数,则使用MindSpore包对应的后端设备。
4.3.2 网络训练流程
- 当调用
ms.ops.value_and_grad
接口时,如果has_aux
为True,不允许存在多层嵌套的输出(优化中),且求导位置必须为第一个输出;
torch.nn.utils.clip_grad_norm_
可替换为 ms.ops.clip_by_global_norm
等价实现梯度裁剪功能;
4.4 自定义操作
当内置接口和算子不满足使用需求时,可以利用框架提供的接口和机制实现自定义反向以及硬件算子。MindTorch基于MindSpore框架提供了不同于PyTorch的算子定义方式和接口。
4.4.1 自定义反向
在部分场景中,不仅需要自定义神经网络层的正向逻辑,也需要手动控制其反向的计算。MindTorch使用nn.Module.bprop
代替torch.autograd.Function.backward
来实现自定义反向。
PyTorch写法:
from torch.cuda.amp import custom_fwd, custom_bwd
class CustomNet(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a.mm(b)
@staticmethod
@custom_bwd
def backward(ctx, grad):
a, b = ctx.saved_tensors
return grad.mm(b.t()), a.t().mm(grad)
output = CustomNet.apply(input)
MindTorch写法:
import mindtorch.torch as torch
class CustomNet(torch.nn.Module):
def forward(self, a, b):
return a.mm(b)
def bprop(self, a, b, out, dout):
return dout.mm(b.t()), a.t().mm(dout)
custom_net = CustomNet()
custom_net = torch.amp.auto_mixed_precision(custom_net)
output = custom_net(input)
上述两种写法的差异点如下:
- MindTorch的正向和反向计算分别使用
forward
和bprop
代替forward
和backward
,并且无需使用staticmethod
。
- MindTorch目前不支持通过
ctx
保存正向计算的中间值,而是在反向计算中根据输入重新获取和计算。
bprop
方法有三类入参:
a
,b
:正向输入,当正向输入有多个时,需同样数量的入参;
out
:正向输出;
dout
:反向传播时,当前Module执行前的反向结果。
- 调用方式由
CustomNet()
代替CustomNet.apply
。
- 混合精度场景下,使用
auto_mixed_precision
接口代替custom_fwd
和custom_bwd
。
4.4.2 自定义算子
由于框架机制和硬件差异,开发者基于PyTorch自定义实现的C++算子需要手动迁移到MindSpore框架,详细使用教程请参考MindSpore自定义算子描述。
4.5 其他
- 网络中如果调用了MindSpore原生接口,则需要调用
mindtorch.torch.cast_to_adapter_tensor
接口将输出tensor转换为MindTorch tensor后方可继续调用PyTorch风格接口。除网络训练部分,不推荐混用MindTorch接口和MindSpore接口;
- MindTorch tensor暂不支持格式化输出,如
label = f"{class_names[labels[i]]}: {probs[i]:.2f}"
,可先转换为numpy后输出;
- 代码中调用
torch.autograd.Variable
接口,替换为torch.tensor
即可;
- 输出tensor如果要输入到opencv等其他组件进行处理时需要先转为numpy后再执行;
- 三方库适配:如果代码中使用的三方库不依赖PyTorch,可正常使用无需适配。如果使用的的三方库是基于PyTorch接口开发的,则需要将三方库相关功能代码迁移到MindTorch。目前我们也正在适配一些常用的三方库,例如einops,具体使用方式可参考third_party路径下内容。
5.MindTorch相关环境变量
环境变量 |
功能 |
类型 |
取值 |
MSA_LOG |
控制MindTorch日志的级别 (MindSpore的日志级别可通过GLOG_v配置,详情请参考日志环境变量。) 注意,这是一个实验性环境变量,后续可能修改或删除。 |
整型 |
0: DEBUG 1: INFO 2: WARNING 3: ERROR 4: CRITICAL 默认值:2,指定日志级别后,将会输出大于或等于该级别的日志信息;
|
ENABLE_FORK_UTILS |
用于指定多进程创建方式,默认使用spawn方式创建多进程,配置后采用fork方式创建多进程; 注意,这是一个实验性环境变量,后续可能修改或删除。 |
整型 |
0: spawn方式创建多进程; 1: fork方式创建多进程; 默认值:1 注意:Windows环境下只能使用spawn方式创建多进程。 |
FAQ
Q:设置mindspore.set_context(mode=context.GRAPH_MODE)后运行出现类似问题:
"Tensor.add_" is an in-place operation and "x.add_()" is not encouraged to use in MindSpore static graph mode. Please use "x = x.add()" or other API instead。
A:目前在设置GRAPH模式下不支持原地操作相关的接口,需要按照提示信息进行修改。需要注意的是,即使在PYNATIVE模式下,原地操作相关接口也是不鼓励使用的,因为目前在MindTorch不会带来内存收益,而且会给反向梯度计算带来不确定性。
Q:运行代码出现类似报错信息:
AttributeError: module 'mindtorch.torch' has no attribute 'xxx'。
A:首先确定'xxx'是否为torch 1.12版本支持的接口,PyTorch官网明确已废弃或者即将废弃的接口和参数,MindTorch不会兼容支持,请使用其他同等功能的接口代替。如果是PyTorch对应版本支持,而MindTorch中暂时没有,欢迎参与MindTorch项目贡献你的代码,也可以通过创建任务(New issue)反馈需求。
Q:为什么TensorDataset返回值为numpy.ndarray
类型?
A:为了加速数据处理流程以及避免在GPU/Ascend中SyncDeviceToHost失败,TensorDataset返回值会被转换为numpy.ndarray
类型。如果您结合DataLoader使用则无需关注返回值类型,如果您单独调用该接口则需要手动将输出转换为Tensor类型。
dataset = TensorDataset(all_input_ids)
for data in dataset:
data = torch.tensor(data)
Q:mindtorch.torch.__version__
对应版本是多少?具体怎么使用?
A:和MindTorch接口对标的策略相同,mindtorch.torch.__version__
当前对应于1.12.1
。使用方法上,在PyNative模式下, 保持与PyTorch一样的使用方法及功能;在Graph模式下,需要在图外(ms.jit
的作用域之外或者nn.module.forward
之外)使用,才能保证功能的正确性。
Q:图模式下不支持torch.dtype的is_floating_point/is_complex/is_signed方法?
E # 1 In file testing/ut/pytorch/torch/test_dtype.py:15
E if x_dtype.is_floating_point:
E ^
E
E ----------------------------------------------------
E - C++ Call Stack: (For framework developers)
E ----------------------------------------------------
E mindspore/ccsrc/pipeline/jit/ps/static_analysis/prim.cc:1590 GetEvaluatedValueForNameSpace
A:当前图模式下暂不支持取dtype属性方法,可以用以下写法替代:
from mindtorch.torch.common.dtype import all_float_type, all_signed_type, all_complex_type
# Replace x_dtype.is_floating_point
if x_dtype in all_float_type:
...
# Replace x_dtype.is_signed
if x_dtype in all_signed_type:
...
# Replace x_dtype.is_complex
if x_dtype in all_complex_type:
...
Q: torch.set_default_dtype
和torch.set_default_tensor_type
设置的dtype类型会影响哪些接口的输出类型?
A:目前对函数式接口arange
,bartlett_window
, empty
, empty_strided
, eye
, full
, hamming_window
, hann_window
, kaiser_window
, linspace
, logspace
, ones
, rand
, randn
, range
, zeros
的结果类型会产生影响,以及使用这些接口的其他数据,例如,nn.Conv2d
的weight;暂不支持对复数类型的影响。
Q: 项目更名为MindTorch后,之前已迁移至MSAdapter的网络脚本,怎么快速地迁移到MindTorch上?
A:项目更名之后,使用上主要的改变,是从import msadapter.pytorch
变更为import mindtorch.torch
,以及从import msadapter.torchvision
变更为import mindtorch.torchvision
。对于该变更,在mindtorch/tools目录下,提供了replace_import_msadapter_to_mindtorch.sh自动化脚本,可以一键化地将相关的import从msadapter切换成mindtorch。
例如有目录文件 /mynet, 执行以下命令即可进行自动替换:
bash replace_import_msadapter_to_mindtorch.sh /mynet
Q: 原始脚本中调用Apex相关接口实现混合精度报错,该怎么适配迁移到MindTorch上?
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
A:Apex是基于PyTorch开发的混合精度训练加速库,未适配其他框架,如果用户想实现混合精度加速,可参考3.1 使用混合精度加速训练 章节,调用mindspore.amp.auto_mixed_precision
接口实现混合精度训练。
Q: 接口迁移时报错没有某些关键字属性,出现类似报错信息:
TypeError: mean() got an unexpected keyword argument 'keepdims'
A:PyTorch部分接口的入参有别名的情况。如mean
接口,用户指定keepdim
与keepdims
可以实现等价效果;sort
接口,用户指定dim
或axis
均可以指定轴。针对这类场景当前需要用户修改使用PyTorch官方文档中呈现的关键字属性进行赋值。
Q: 接口迁移时报错入参个数与位置参数个数不一致,出现类似报错信息:
TypeError: addmm_() takes 3 positional arguments but 5 were given
A:PyTorch新版本中存在一些接口和老版本的入参个数不一致的现象,如add
, addcmul
, addmm
等接口,这类老版本的用法在新版本中已废弃。针对这类场景当前需要用户修改使用PyTorch官方文档(1.12.1
版本)中呈现的用法。
Q: 自定义Module执行时报错AttributeError:cells?
A:Module继承于mindspore.nn.Cell,有些方法名或属性名已被使用,因此在Module的子类中不能定义名为’cast’的方法,不能定义名为’phase’和’cells’的属性,否则会报错。如该报错示例,在自定义Module中重新定义了cells
属性导致报错,用户在自定义子类中使用其他名字即可正常执行。
Q:自定义Module执行时出现类似报错信息:
TypeError: 'NoneType' object does not support item assignment
A:请检查自定义Module中是否存在张量对象赋值为类属性的情况,如self.xxx=torch.Tensor(xxx)
,请确保super().__init__()
在此之前被调用。
Q:当使用运行时自动替换导入模块,如何使能一些原生PyTorch的模块功能?
A:可以调用from mindtorch.tools import pytorch_enable
接口临时使能原生PyTorch的模块功能,示例如下:
from mindtorch.tools import mstorch_enable
import torch
from mindtorch.tools import pytorch_enable
import torch as pytorch
from mindtorch.tools import mstorch_enable
pytorch.xxx # 调用pytroch原生功能模块
torch.xxx # 调用mindtorch相应模块
Q:程序并未在torch.cuda.set_device
指定的卡上执行?:
A:如4.3.1指定执行硬件所示,若mindtorch期望在特定后端(卡)上执行时,需通过mindspore.set_context
接口显示指定设备,且暂不支持计算过程中在不同设备间转移Tensor。因此下列PyTorch接口的指定执行设备的功能均暂不生效:
torch.cuda.set_device(device)
module.cuda(device)
tensor.cuda(device)
tensor.to(device)
如果期望实现异构执行功能,可调用Primitive.set_device
接口实现,详情请参考Host&Device异构。