English | 简体中文
-
ResNet部署实现来自Torchvision的代码,和基于ImageNet2012的预训练模型。
导入Torchvision,加载预训练模型,并进行模型转换,具体转换步骤如下。
import torch
import torchvision.models as models
model = models.resnet50(pretrained=True)
batch_size = 1 #批处理大小
input_shape = (3, 224, 224) #输入数据,改成自己的输入shape
# #set the model to inference mode
model.eval()
x = torch.randn(batch_size, *input_shape) # 生成张量
export_onnx_file = "ResNet50.onnx" # 目的ONNX文件名
torch.onnx.export(model,
x,
export_onnx_file,
opset_version=12,
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
"output":{0:"batch_size"}})
为了方便开发者的测试,下面提供了ResNet导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库)
模型 | 大小 | 精度 |
---|---|---|
ResNet-18 | 45MB | |
ResNet-34 | 84MB | |
ResNet-50 | 98MB | |
ResNet-101 | 170MB |
- 本版本文档和代码基于Torchvision v0.12.0 编写