备注
单击 此处 下载完整的示例代码
4. microTVM PyTorch 教程
该教程展示了如何使用 PyTorch 模型进行 microTVM 主机驱动的 AOT 编译。此教程可以在使用 C 运行时(CRT)的 x86 CPU 上执行。
注意: 此教程仅在使用 CRT 的 x86 CPU 上运行,不支持在 Zephyr 上运行,因为该模型不适用于我们当前支持的 Zephyr 单板。
安装 microTVM Python 依赖项
TVM 不包含用于 Python 串行通信包,因此在使用 microTVM 之前我们必须先安装一个。我们还需要TFLite来加载模型。
pip install pyserial==3.5 tflite==2.1
import pathlib
import torch
import torchvision
from torchvision import transforms
import numpy as np
from PIL import Image
import tvm
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.backend import Executor
import tvm.micro.testing
加载预训练 PyTorch 模型
首先,从 torchvision 中加载预训练的 MobileNetV2 模型。然后,下载一张猫的图像并进行预处理,以便用作模型的输入。
model = torchvision.models.quantization.mobilenet_v2(weights="DEFAULT", quantize=True)
model = model.eval()
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
# 预处理图片并转换为张量
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)
input_name = "input0"
shape_list = [(input_name, input_shape)]
relay_mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
输出:
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/ao/quantization/utils.py:310: UserWarning: must run observer before calling calculate_qparams. Returning default values.
warnings.warn(
Downloading: "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" to /workspace/.cache/torch/hub/checkpoints/mobilenet_v2_qnnpack_37f702c5.pth
0%| | 0.00/3.42M [00:00<?, ?B/s]
61%|###### | 2.09M/3.42M [00:00<00:00, 11.6MB/s]
100%|##########| 3.42M/3.42M [00:00<00:00, 18.5MB/s]
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/torch/_utils.py:314: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
device=storage.device,
/workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
return LooseVersion(torch_ver) > ver
/venv/apache-tvm-py3.8/lib/python3.8/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
other = LooseVersion(other)
定义目标、运行时与执行器
在本教程中,我们使用 AOT 主机驱动执行器。为了在 x86 机器上对嵌入式模拟环境编译模型,我们使用 C 运行时(CRT),并使用主机微型目标。使用该设置,TVM 为 C 运行时编译可以在 x86 CPU 机器上运行的模型,可以在物理微控制器上运行的相同流程。CRT 使用 src/runtime/crt/host/main.cc
中的 main()。要使用物理硬件,请将 board 替换为另一个物理微型目标,例如 nrf5340dk_nrf5340_cpuapp
或 mps2_an521
,并将平台类型更改为 Zephyr。在《为 Arduino 上的 microTVM 训练视觉模型》和《microTVM TFLite 教程》中,可以看到 更多目标示例。
target = tvm.micro.testing.get_target(platform="crt", board=None)
# 使用 C 运行时 (crt) 并通过设置 system-lib 为 True 打开静态链接
runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})
# 使用 AOT 执行器代替图或 vm 执行器。不要使用未包装的 API 或 C 风格调用
executor = Executor("aot")
编译模型
现在为目标编译模型:
with tvm.transform.PassContext(
opt_level=3,
config={"tir.disable_vectorize": True},
):
module = tvm.relay.build(
relay_mod, target=target, runtime=runtime, executor=executor, params=params
)