TorchScript介紹
TorchScript是PyTorch模型推理部署的中間表示,可以在高性能環(huán)境libtorch(C ++)中直接加載,實(shí)現(xiàn)模型推理,而無(wú)需Pytorch訓(xùn)練框架依賴。torch.jit是torchscript Python語(yǔ)言包支持,支持pytorch模型快速,高效,無(wú)縫對(duì)接到libtorch運(yùn)行時(shí),實(shí)現(xiàn)高效推理。它是Pytorch中除了訓(xùn)練部分之外,開發(fā)者最需要掌握的Pytorch框架開發(fā)技能之一。trace使用
Torchscript使用分為兩個(gè)部分分別是script跟trace,其中trace是跟蹤執(zhí)行步驟,記錄模型調(diào)用推理時(shí)執(zhí)行的每個(gè)步驟,代碼演示如下:
classMyCell(torch.nn.Module):
def__init__(self):
super(MyCell,self).__init__()
self.linear=torch.nn.Linear(4,4)
defforward(self,x,h):
new_h=torch.tanh(self.linear(x)+h)
returnnew_h,new_h
my_cell=MyCell()
x,h=torch.rand(3,4),torch.rand(3,4)
traced_cell=torch.jit.trace(my_cell,(x,h))
print(traced_cell)
traced_cell(x,h)
print(traced_cell.graph)
運(yùn)行結(jié)果如下:
MyCell(
original_name=MyCell
(linear):Linear(original_name=Linear)
)
跟蹤執(zhí)行結(jié)果
graph(%self.1:__torch__.MyCell,
%input:Float(3:4,4:1,requires_grad=0,device=cpu),
%h:Float(3:4,4:1,requires_grad=0,device=cpu)):
%19:__torch__.torch.nn.modules.linear.Linear=prim::GetAttr[name="linear"](%self.1)
%21:Tensor=prim::CallMethod[name="forward"](%19,%input)
%12:int=prim::Constant[value=1]()#D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py0
%13:Float(3:4,4:1,requires_grad=1,device=cpu)=aten::add(%21,%h,%12)#D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py0
%14:Float(3:4,4:1,requires_grad=1,device=cpu)=aten::tanh(%13)#D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py0
%15:(Float(3:4,4:1,requires_grad=1,device=cpu),Float(3:4,4:1,requires_grad=1,device=cpu))=prim::TupleConstruct(%14,%14)
return(%15)
script部分使用
script是導(dǎo)出模型為中間IR格式文件,支持高性能libtorch C++部署,我們以torchvision中Mask-RCNN導(dǎo)出中間格式IR為例,代碼演示如下:
importtorch
importtorchvisionastv
num_classes=3
model=tv.models.detection.maskrcnn_resnet50_fpn(
pretrained=False,progress=True,
num_classes=num_classes,
pretrained_backbone=True)
im=torch.zeros(1,3,*(1333,800)).to("cpu")
model.load_state_dict(torch.load("D:/gaobao_model.pth"))
model=model.to("cpu")
model.eval()
ts=torch.jit.script(model)
ts.save("gaobao.ts")
loaded_trace=torch.jit.load("gaobao.ts")
loaded_trace.eval()
withtorch.no_grad():
print(loaded_trace(list(im)))
最終得到torchscript文件,支持直接通過(guò)libtorch部署,其中通過(guò)torchscript C++部分加載的代碼如下:
#include//One-stopheader.
#include
#include
intmain(intargc,constchar*argv[]){
if(argc!=2){
std::cerr<"usage:example-app
" ;
return-1;
}
//DeserializetheScriptModulefromafileusingtorch::load().
torch::Modulemodule=torch::load(argv[1]);
std::vectorinputs;
inputs.push_back(torch::randn({4,8}));
inputs.push_back(torch::randn({8,5}));
torch::Tensoroutput=module.forward(std::move(inputs)).toTensor();
std::cout<std::endl;
}
上面代碼來(lái)自官方測(cè)試程序,特別說(shuō)明一下!
審核編輯 :李倩
-
開發(fā)
+關(guān)注
關(guān)注
0文章
373瀏覽量
41525 -
C++
+關(guān)注
關(guān)注
22文章
2119瀏覽量
75342 -
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13970
原文標(biāo)題:輕松學(xué)Pytorch之torchscript使用!
文章出處:【微信號(hào):CVSCHOOL,微信公眾號(hào):OpenCV學(xué)堂】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
元器件及單元電路介紹-610頁(yè)
【「AI Agent應(yīng)用與項(xiàng)目實(shí)戰(zhàn)」閱讀體驗(yàn)】書籍介紹
誰(shuí)能詳細(xì)介紹一下track-and-hold
RK3568內(nèi)置MCU開發(fā)介紹之二

技術(shù)分享 柵極驅(qū)動(dòng)器及其應(yīng)用介紹

評(píng)論