绝大数情况,我们使用网上公开数据集,跑通一个深度学习算法模型,获得了较高的准确率,这样仅停留在仿真阶段,并未考虑算法实际的部署。以下将以故障诊断为例,将深度学习算法模型部署到Android系统,并实时进行故障诊断,此处输入为一维振动信号。
对于目标检测而言,可以考虑使用硬件加速接口以及对网络模型进行剪枝、量化、蒸馏等优化技术,从而加速模型推理。
将深度学习部署到移动端有以下几个重要的优势:
-
实时推理: 在移动设备上部署深度学习模型可以实现实时推理,使移动应用程序能够即时响应用户的请求和输入。这对于需要即时反馈的应用程序(如语音识别、图像识别、自然语言处理等)非常重要。
-
隐私保护: 在移动设备上执行推理可以保护用户的隐私数据,因为数据不需要传输到远程服务器进行处理。这对于处理敏感数据(如人脸识别、语音识别等)的应用程序非常重要。
-
离线功能: 移动设备上部署的深度学习模型可以在没有互联网连接的情况下工作,从而使应用程序具有离线功能。这对于在网络连接不稳定或不可用的环境下使用应用程序的用户非常有用。
-
减少延迟: 将深度学习模型部署到移动端可以减少与远程服务器进行通信的延迟,从而提高应用程序的响应速度和性能。
-
降低带宽消耗: 在移动设备上执行推理可以减少与远程服务器之间的数据传输量,从而降低带宽消耗和通信成本。
-
定制化需求: 将深度学习模型部署到移动端可以根据特定应用程序的需求进行定制化开发,从而提供更好的用户体验和功能。
1.仅使用PyCharm+AndroidStudio
PyCharm相关
将网络模型转换为TochScript模型并保存,便于后续部署
- android_test.py
import torch.utils.data.distributed
# 定义转化后的模型名称
model_ori_pt = 'model_ori.pt'
'加载pytorch模型'
model_ori = torch.load('androidTest.pth')
# 打印模型类型
print(type(model_ori))
# 将模型移动到 CPU 上
device = torch.device('cpu')
model_ori = model_ori.to(device)
# 设置模型为评估模式
model_ori.eval()
# 定义输入信号的大小
input_tensor = torch.rand(1, 1, 2048)
# 转化模型并存储
mobile_ori = torch.jit.trace(model_ori, input_tensor)
mobile_ori.save(model_ori_pt)
- androidTest.pth
torch.save(net, f"./results/androidTest.pth") # 保存整个模型的结构信息
- models.py为网络模型的源码
AndroidStudio相关
在gradle中添加依赖
implementation 'org.pytorch:pytorch_android:1.12.1'
implementation 'org.pytorch:pytorch_android_torchvision:1.12.1'
将生成的model_ori.pt拷贝到安卓目录
在MainActivity中编写代码
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Module module_ori = null;
try {
module_ori = Module.load(assetFilePath(this, "model_ori.pt"));
} catch (IOException e) {
e.printStackTrace();
}
//生成随机信号数组
float[] signal = new float[2048];
Random random = new Random();
for (int i = 0; i < 2048; i++) {
signal[i] = random.nextFloat(); // 在[0, 1)范围内生成随机浮点数
}
// 将一维信号转换为张量
Tensor inputTensor = Tensor.fromBlob(signal, new long[]{1, 1,signal.length});
Tensor outputTensor = module_ori.forward(IValue.from(inputTensor)).toTensor();
final float[] scores = outputTensor.getDataAsFloatArray();
System.out.println("------------------------------------");
System.out.println(Arrays.toString(scores));
}
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
} catch (IOException e) {
e.printStackTrace();
}
return file.getAbsolutePath();
}
}
}