如何导出rot90算子至onnx
- 1 背景描述
- 2 等价替换
- 2.1 rot90替换(NCHW)
- 2.2 rot180替换(NCHW)
- 2.3 rot270替换(NCHW)
- 3 rot导出ONNX
1 背景描述
在部署模型时,如果某些模型中或者前后处理中含有rot90
算子,但又希望一起和模型导出onnx时,可能会遇到如下错误(当前使用环境pytorch2.0.1
,opset_version
为17):
import torch
import torch.nn as nn
class RotModel(nn.Module):
def forward(self, x: torch.Tensor):
x = torch.rot90(x, k=1, dims=(2, 3))
return x
def main():
print("pytorch version:", torch.__version__)
model = RotModel()
with torch.inference_mode():
x = torch.randn(size=(1, 3, 224, 224))
torch.onnx.export(model,
args=(x,),
f="rot90_counterclockwise.onnx",
opset_version=17)
if __name__ == '__main__':
main()
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::rot90’ to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
简单的说就是不支持导出该算子,包括在onnx支持的算子文档中也找不到rot90
算子,onnx官方github链接:
https://github.com/onnx/onnx
2 等价替换
导不出咋办,那就想想旋转矩阵的原理,以及如何使用现有支持的算子替换。
2.1 rot90替换(NCHW)
废话不多说,rot90度(以逆时针为例)可以使用翻转和转置实现。具体代码如下,使用torch自带的rot90与自己实现的对比,通过torch.equal()
来对比两个Tensor是否一致,结果一致,不信自己试试。
import torch
def self_rot90_counterclockwise(x: torch.Tensor):
x = x.flip(dims=[3]).permute([0, 1, 3, 2])
return x
def main():
print("pytorch version:", torch.__version__)
with torch.inference_mode():
x = torch.randn(size=(1, 3, 224, 224))
y0 = torch.rot90(x, k=1, dims=[2, 3])
y1 = self_rot90_counterclockwise(x)
print(torch.equal(y0, y1))
if __name__ == '__main__':
main()
2.2 rot180替换(NCHW)
rot180度(以逆时针为例)可以使用翻转实现。具体代码如下:
import torch
def self_rot180_counterclockwise(x: torch.Tensor):
x = x.flip(dims=[2, 3])
return x
def main():
print("pytorch version:", torch.__version__)
with torch.inference_mode():
x = torch.randn(size=(1, 3, 224, 224))
y0 = torch.rot90(x, k=2, dims=[2, 3])
y1 = self_rot180_counterclockwise(x)
print(torch.equal(y0, y1))
if __name__ == '__main__':
main()
2.3 rot270替换(NCHW)
rot270度(以逆时针为例)可以使用翻转和转置实现。具体代码如下:
import torch
def self_rot270_counterclockwise(x: torch.Tensor):
x = x.flip(dims=[2]).permute([0, 1, 3, 2])
return x
def main():
print("pytorch version:", torch.__version__)
with torch.inference_mode():
x = torch.randn(size=(1, 3, 224, 224))
y0 = torch.rot90(x, k=3, dims=[2, 3])
y1 = self_rot270_counterclockwise(x)
print(torch.equal(y0, y1))
if __name__ == '__main__':
main()
3 rot导出ONNX
这里以rot90度(以逆时针为例)结合刚刚的等价实现来导出ONNX:
import torch
import torch.nn as nn
class RotModel(nn.Module):
def forward(self, x: torch.Tensor):
# x = torch.rot90(x, k=1, dims=(2, 3))
x = x.flip(dims=[3]).permute([0, 1, 3, 2])
return x
def main():
print("pytorch version:", torch.__version__)
model = RotModel()
with torch.inference_mode():
x = torch.randn(size=(1, 3, 224, 224))
torch.onnx.export(model,
args=(x,),
f="rot90_counterclockwise.onnx",
opset_version=17)
if __name__ == '__main__':
main()
使用netron
打开生成的rot90_counterclockwise.onnx
文件,如下所示: