torch.jit.trace与torch.jit.script的区别
C++落地部署python深度学习模型
·
文章目录
术语
- Tochscript:狭义概念导出图形的表示/格式;广义概念为导出模型的方法;
- (Torch)Scriptable:可以用torch.jit.script导出模型
- Traceable:可以用torch.jit.trace导出模型
什么时候用torch.jit.trace(结论:首选)
- torch.jit.trace一种导出方法;它运行具有某些张量输入的模型,并“跟踪/记录”所有执行到图形中的操作。
- 在模型内部的数据类型只有张量,且没有for if while等控制流,选择torch.jit.trace
- 支持python的预处理和动态行为;
- torch.jit.trace编译function并返回一个可执行文件,该可执行文件将使用即时编译进行优化。
- 大项目优先选择torch.jit.trace,特别是是图像检测和分割的算法;
优点
- 不会损害代码质量;
- 2.它的主要限制可以通过与torch.jit.script混合来解决
什么时候用torch.jit.script(结论:必要时)
- 定义:一种模型导出方法,其实编译python的模型源码,得到可执行的图;
- 在模型内部的数据类型只有张量,且没有for if while等控制流,也可以选择torch.jit.script
- 不支持python的预处理和动态行为;
- 必须做一下类型标注;
- torch.jit.script在编译function或 nn.Module 脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码。
错误举例
import torch
from torch import nn
class MyModule(nn.Module):
def __init__(self, return_b=False):
super().__init__()
self.return_b = return_b
def forward(self, x):
a = x + 2
if self.return_b: #属于静态控制
b = x + 3
return a, b
return a
model = MyModule(return_b=True)
# Will work 成功
traced = torch.jit.trace(model, (torch.randn(10, ), ))
# Will fail 失败
scripted = torch.jit.script(model)
- 总结:控制流是静态的,torch.jit.trace将正常工作
动态控制
- if x[0] == 4: x += 1 is a dynamic control flow.
model: nn.Sequential = ...
for m in model: # 动态控制
x = m(x)
输入和输出有丰富类型的模型需要格外注意
outputs = model(inputs) # inputs/outputs are rich structure
# torch.jit.trace(model, inputs) # FAIL! unsupported format
adapter = TracingAdapter(model, inputs)
traced = torch.jit.trace(adapter, adapter.flattened_inputs) # Can now trace the model
# Traced model can only produce flattened outputs (tuple of tensors):
flattened_outputs = traced(*adapter.flattened_inputs)
# Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
new_outputs = adapter.outputs_schema(flattened_outputs)
QA
-
- JIT要求python的代码要是低级的;详情 因为更多动态高级的python语法,jit不支持.具体哪些支持哪些没支持官方也没有详细的列表; JIT should not force users to write ugly code #48108
-
- 错误示例:动态控制流:对于动态控制流torch.jit.trace只会编译一个分支,在其他分支处理的时候会报错;
def f(x):
return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
m = torch.jit.trace(f, torch.tensor(3))
print(m.code) # 可以打印出trace的情况
#--------------------------------------------
def f(x: Tensor) -> Tensor:
return torch.sqrt(x)
-
- 错误示例:将变量视为常量
import torch
a, b = torch.rand(1), torch.rand(2)
print(a,b)
def f1(x): return torch.arange(x.shape[0])
def f2(x): return torch.arange(len(x))
result = torch.jit.trace(f1, a)(b)
print(result)
result =torch.jit.trace(f2, a)(b) # TracerWarning
print(result) #
print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
-
-
错误示例:获取设备
解决错误的方法
-
- 严格消除警告信息,才C++运行的时候会报错
-
- 局部单元测试
- 单元测试一样要做在导出模型后,这样避免在应用模型的时候(C++运行)出错;
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
-
- 避免非必要的动态控制,例如:
if x.numel() > 0:
output = self.layers(x)
else:
output = torch.zeros((0, C, H, W)) # Create empty outputs

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)