【如何训练一个中英翻译模型】LSTM机器翻译模型部署之ncnn(python)(五)

系列文章
【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码(一)
【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

目录

  • 一、事情准备
  • 二、模型转换
  • 三、ncnn模型加载与推理(python版)

一、事情准备

这篇是在【如何训练一个中译英翻译器】LSTM机器翻译模型部署之onnx(python)(四)的基础上进行的,要用到文件为:

input_words.txt
target_words.txt
config.json
encoder_model-sim.onnx
decoder_model-sim.onnx

其中的onnx就是用来转为ncnn模型的,这里借助了onnx这个中间商,所以前面我们需要先通过onnxsim对模型进行simplify,要不然在模型转换时会出现op不支持的情况(模型转换不仅有中间商这个例子,目前还可以通过pnnx直接将pytorch模型转为ncnn,感兴趣的小伙伴可以去折腾下)
老规矩,先给出工具:

onnx2ncnn:https://github.com/Tencent/ncnn
netron:https://netron.app

二、模型转换

这里进行onnx转ncnn,通过命令进行转换

onnx2ncnn onnxModel/encoder_model-sim.onnx ncnnModel/encoder_model.param ncnnModel/encoder_model.bin
onnx2ncnn onnxModel/decoder_model-sim.onnx ncnnModel/decoder_model.param ncnnModel/decoder_model.bin

转换成功可以看到:
在这里插入图片描述
转换之后可以对模型进行优化,但是奇怪的是,这里优化了不起作用,去不了MemoryData这些没用的op

ncnnoptimize ncnnModel/encoder_model.param ncnnModel/encoder_model.bin ncnnModel/encoder_model.param ncnnModel/encoder_model.bin 1
ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

三、ncnn模型加载与推理(python版)

跟onnx的推理比较类似,就是函数的调用方法有点不同,这里先用python实现,验证下是否没问题,方面后面部署到其它端,比如android。
主要包括:模型加载、推理模型搭建跟模型推理,但要注意的是这里的输入输出名称需要在param这个文件里面获取。

采用netron分别查看encoder与decoder的网络结构,获取输入输出名称:

encoder:
输入输出分别如图
在这里插入图片描述
decoder:

输入
在这里插入图片描述
输出:
在这里插入图片描述

推理代码如下,推理过程感觉没问题,但是推理输出结果相差很大(对比过第一层ncnn与onnx的推理结果了),可能问题出在模型转换环节的精度损失上,而且第二层模型转换后网络输出结果不一致了,很迷,还没找出原因,但是以下的推理是能运行通过,只不过输出结果有问题

import numpy as np
import ncnn


# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:
    input_words = f.readlines()
    input_characters = [line.rstrip('\n') for line in input_words]

# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:
    target_words = [line.strip() for line in f.readlines()]
    target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]

#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

# something readable.
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量

import json
with open('config/config.json', 'r') as file:
    loaded_data = json.load(file)

# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]





# Load the ncnn models for the encoder and decoder
encoderNet = ncnn.Net()
encoderNet.load_param("ncnnModel/encoder_model.param")
encoderNet.load_model("ncnnModel/encoder_model.bin")

decoderNet = ncnn.Net()
decoderNet.load_param("ncnnModel/decoder_model.param")
decoderNet.load_model("ncnnModel/decoder_model.bin")





def decode_sequence(input_seq):
    # Encode the input as state vectors.
    # print(input_seq)
    ex_encoder = encoderNet.create_extractor()
    ex_encoder.input("input_1", ncnn.Mat(input_seq))
    states_value = []

    _, LSTM_1 = ex_encoder.extract("lstm")
    _, LSTM_2 = ex_encoder.extract("lstm_1")


    states_value.append(LSTM_1)
    states_value.append(LSTM_2)


    # print(ncnn.Mat(input_seq))
    # print(vgdgd)
    
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, 849))

    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.
    # this target_seq you can treat as initial state



    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    ex_decoder = decoderNet.create_extractor()
    while not stop_condition:
        
        
        #print(ncnn.Mat(target_seq))
        
        print("---------")

        
        ex_decoder.input("input_2", ncnn.Mat(target_seq))
        ex_decoder.input("input_3", states_value[0])
        ex_decoder.input("input_4", states_value[1])
        _, output_tokens = ex_decoder.extract("dense")
        _, h = ex_decoder.extract("lstm_1")
        _, c = ex_decoder.extract("lstm_1_1")

        print(output_tokens)


        tk = []
        for i in range(849):
            tk.append(output_tokens[849*i])

        tk = np.array(tk)
        output_tokens = tk.reshape(1,1,849)

        print(output_tokens)



        # print(fdgd)
        
        print(h)
        print(c)
        
        
        # output_tokens = np.array(output_tokens)
        # output_tokens = output_tokens.reshape(1, 1, -1)


        # # h = np.array(h)
        # # c = np.array(c)
        # print(output_tokens.shape)
        # print(h.shape)
        # print(c.shape)
        
        
        #output_tokens, h, c = decoder_model.predict([target_seq] + states_value)

        # Sample a token
        # argmax: Returns the indices of the maximum values along an axis
        # just like find the most possible char
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        # find char using index
        sampled_char = reverse_target_char_index[sampled_token_index]
        # and append sentence
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        # append then ?
        # creating another new target_seq
        # and this time assume sampled_token_index to 1.0
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        print(sampled_token_index)

        # Update states
        # update states, frome the front parts
        
        states_value = [h, c]

    return decoded_sentence
    

import numpy as np

input_text = "Call me."
encoder_input_data = np.zeros(
    (1,max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
for t, char in enumerate(input_text):
    print(char)
    # 3D vector only z-index has char its value equals 1.0
    encoder_input_data[0,t, input_token_index[char]] = 1.


input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

decoder的模型输出为849*849,感觉怪怪的,然后我们把模型的输入固定下来看看是不是模型的问题。
打开decoder_model.param,把输入层固定下来,0=w 1=h 2=c,那么:
input_2:0=849 1=1 2=1
input_3:0=256 1=1
input_4:0=256 1=1

运行以下命令进行优化

ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

结果如下:
在这里插入图片描述
打开网络来看一下:
可以看到输出确实是849849(红色框),那就是模型转换有问题了
在这里插入图片描述
仔细看,能够看到有两个shape(蓝色框)分别为849跟849
1,这两个不同维度的网络进行BinaryOP之后,就变成849849了,那么,我们把Reshape这个网络去掉试试(不把前面InnerProduct的输入维度有849reshape为8491),下面来看手术刀怎么操作。

我们需要在没经过固定维度并ncnnoptimize的模型上操作(也就是没经过上面0=w 1=h 2=c修改的模型上操作)
根据名字我们找到Reshape那一层:
在这里插入图片描述
然后找到与reshape那一层相连接的上一层(红色框)与下一层(蓝色框)
在这里插入图片描述
通过红色框与蓝色框里面的名字我们找到了上层与下层分别为InnerProduct与BinaryOp
在这里插入图片描述
这时候,把InnerProduct与BinaryOp接上,把Reshape删掉
在这里插入图片描述
再改一下最上面的层数,把19改为18,因为我们删掉了一层
在这里插入图片描述保存之后再次执行

ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

执行后可以看到网络层数跟blob数都更新了
在这里插入图片描述

这时候改一下固定一下输入层数,并运行ncnnoptimize,再打开netron看一下网络结构,可以看到输出维度正常了
在这里插入图片描述
但是通过推理结果还是不对,没找到原因,推理代码如下:

import numpy as np
import ncnn


# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:
    input_words = f.readlines()
    input_characters = [line.rstrip('\n') for line in input_words]

# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:
    target_words = [line.strip() for line in f.readlines()]
    target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]

#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

# something readable.
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量

import json
with open('config/config.json', 'r') as file:
    loaded_data = json.load(file)

# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]





# Load the ncnn models for the encoder and decoder
encoderNet = ncnn.Net()
encoderNet.load_param("ncnnModel/encoder_model.param")
encoderNet.load_model("ncnnModel/encoder_model.bin")

decoderNet = ncnn.Net()
decoderNet.load_param("ncnnModel/decoder_model.param")
decoderNet.load_model("ncnnModel/decoder_model.bin")





def decode_sequence(input_seq):
    # Encode the input as state vectors.
    # print(input_seq)
    ex_encoder = encoderNet.create_extractor()
    ex_encoder.input("input_1", ncnn.Mat(input_seq))
    states_value = []

    _, LSTM_1 = ex_encoder.extract("lstm")
    _, LSTM_2 = ex_encoder.extract("lstm_1")


    states_value.append(LSTM_1)
    states_value.append(LSTM_2)


    # print(ncnn.Mat(input_seq))
    # print(vgdgd)
    
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, 849))

    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.
    # this target_seq you can treat as initial state



    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    ex_decoder = decoderNet.create_extractor()
    while not stop_condition:
        
        
        #print(ncnn.Mat(target_seq))
        
        print("---------")

        
        ex_decoder.input("input_2", ncnn.Mat(target_seq))
        ex_decoder.input("input_3", states_value[0])
        ex_decoder.input("input_4", states_value[1])
        _, output_tokens = ex_decoder.extract("dense")
        _, h = ex_decoder.extract("lstm_1")
        _, c = ex_decoder.extract("lstm_1_1")

        print(output_tokens)


        # print(ghfhf)


        # tk = []
        # for i in range(849):
        #     tk.append(output_tokens[849*i])

        # tk = np.array(tk)
        # output_tokens = tk.reshape(1,1,849)

        # print(output_tokens)



        # print(fdgd)
        
        print(h)
        print(c)
        
        
        output_tokens = np.array(output_tokens)
        output_tokens = output_tokens.reshape(1, 1, -1)


        # # h = np.array(h)
        # # c = np.array(c)
        # print(output_tokens.shape)
        # print(h.shape)
        # print(c.shape)
        
        
        #output_tokens, h, c = decoder_model.predict([target_seq] + states_value)

        # Sample a token
        # argmax: Returns the indices of the maximum values along an axis
        # just like find the most possible char
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        # find char using index
        sampled_char = reverse_target_char_index[sampled_token_index]
        # and append sentence
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        # append then ?
        # creating another new target_seq
        # and this time assume sampled_token_index to 1.0
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        print(sampled_token_index)

        # Update states
        # update states, frome the front parts
        
        states_value = [h, c]

    return decoded_sentence
    

import numpy as np

input_text = "Call me."
encoder_input_data = np.zeros(
    (1,max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
for t, char in enumerate(input_text):
    print(char)
    # 3D vector only z-index has char its value equals 1.0
    encoder_input_data[0,t, input_token_index[char]] = 1.


input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)



参考文献:https://github.com/Tencent/ncnn/issues/2586

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/51242.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Unity光照相关知识和实践 (烘焙光照,环境光设置,全局光照)

简介 本文将会通过一个简单的场景搭建,介绍如何使用烘焙光照以及相关的注意事项。另外还介绍了Unity内全局光照(GI)的知识和GI实际在游戏内的表现效果。 Unity关于光照相关的参考文档地址:https://docs.unity.cn/cn/current/Man…

Linux CentOS快速安装VNC并开启服务

以下是在 CentOS 上安装并开启 VNC 服务的步骤: 安装 VNC 服务器软件包。运行以下命令: sudo yum install tigervnc-server 输出 $ sudo yum install tigervnc-server Loaded plugins: fastestmirror, langpacks Repository epel is missing name i…

计算机论文中名词翻译和解释笔记

看论文中一些英文的简写不知道中文啥意思,或者一个名词不知道啥意思。 于是自己做了一个个人总结。 持续更新 目录 SoftmaxDeep Learning(深度学习)循环神经网络(Recurrent Neural Network简称 RNN)损失函数/代价函数(Loss Function)基于手绘草图的三维模型检索(Ske…

【笔记】PyTorch DDP 与 Ring-AllReduce

转载请注明出处:小锋学长生活大爆炸[xfxuezhang.cn] 文内若有错误,欢迎指出! 今天我想跟大家分享的是一篇虽然有点老,但是很经典的文章,这是一个在分布式训练中会用到的一项技术, 实际上叫ringallreduce。 …

用html+javascript打造公文一键排版系统8:主送机关排版

公文一般在标题和正文之间还有主送机关,相关规定为: 主送机关 编排于标题下空一行位置,居左顶格,回行时仍顶格,最后一个机关名称后标全角冒号。如主送机关名称过多导致公文首页不能显示正文时,应当将主送机…

redis的并发安全问题:redis的事务VSLua脚本

redis为什么会发生并发安全问题? 在redis中,处理的数据都在内存中,数据操作效率极高,单线程的情况下,qps轻松破10w。反而在使用多线程时,为了保证线程安全,采用了一些同步机制,以及多…

20.3 HTML 表格

1. table表格 table标签是HTML中用来创建表格的元素. table标签通常包含以下子标签: - th标签: 表示表格的表头单元格(table header), 用于描述列的标题. - tr标签: 表示表格的行(table row). - td标签: 表示表格的单元格(table data), 通常位于tr标签内, 用于放置单元格中的…

C语言枚举与联合体详解

本篇文章带来枚举与联合体相关知识详细讲解! 如果您觉得文章不错,期待你的一键三连哦,你的鼓励是我创作的动力之源,让我们一起加油,一起奔跑,让我们顶峰相见!!! 目录 一…

InnoDB引擎底层逻辑讲解——架构之内存架构

1.InnoDB引擎架构 下图为InnoDB架构图,左侧为内存结构,右侧为磁盘结构。 2.InnoDB内存架构讲解 2.1 Buffer Pool缓冲池 2.2 Change Buffer更改缓冲区 2.3 Adaptive Hash Index自适应hash索引 查看自适应hash索引是否开启: show variable…

Modbus TCP/IP之异常响应

文章目录 一、异常响应二、异常码分析2.1 异常码0x012.2 异常码0x022.3 异常码0x032.4 异常码0x062.5 异常码0x04、0x05等 一、异常响应 对于查询报文,存在以下四种处理反馈: 正常接收,正常处理,返回正常响应报文;因为…

部署问题集合(十八)Windows环境下使用两个Tomcat

下载Tomcat Tomcat镜像下载地址:https://mirrors.cnnic.cn/apache/tomcat/进入如下地址:zip的是压缩版,exe是安装版 修改第二个Tomcat配置文件 第一步:编辑conf/server.xml文件,修改三个端口,有些版本改…

【Rust日报】2023-07-28 使用 Cargo-PGO 优化 Rust 程序

使用 Cargo-PGO 优化 Rust 程序 去年,作者致力于改进用于构建 Rust 编译器的配置文件引导优化 (PGO) 工作流程。在这样做的过程中,虽然 PGO 对于 Rust 工作得很好,但它并不像希望的那样易于使用和发现。这促使我创建了 cars-pgo,这…

【雕爷学编程】Arduino动手做(175)---机智云ESP8266开发板模块2

37款传感器与执行器的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&am…

【Vue3】递归组件

1. 递归组件mock数据 App.vue <template><div><Tree :data"data"></Tree></div> </template><script setup lang"ts"> import { reactive } from vue; import Tree from ./components/Tree.vue; interface Tr…

CentOS 8 上安装 Nginx

Nginx是一款高性能的开源Web服务器和反向代理服务器&#xff0c;以其轻量级和高效能而广受欢迎。在本教程中&#xff0c;我们将学习在 CentOS 8 操作系统上安装和配置 Nginx。 步骤 1&#xff1a;更新系统 在安装任何软件之前&#xff0c;让我们先更新系统的软件包列表和已安…

读发布!设计与部署稳定的分布式系统(第2版)笔记26_安全性上

1. 安全问题 1.1. 系统违规并不总是涉及数据获取&#xff0c;有时会出现植入假数据&#xff0c;例如假身份或假运输文件 1.2. 必须在整个开发过程中持续地把安全内建到系统里&#xff0c;而不是把安全像胡椒面那样在出锅前才撒到系统上 2. OWASP 2.1. Open Web Application…

Godot 4 源码分析 - 动态导入图片文件

用Godot 4尝试编一个电子书软件&#xff0c;初步效果已经出来&#xff0c;并且通过管道通信接口可以获取、设置属性、调用函数&#xff0c;貌似能处理各种事宜了。 其实不然&#xff0c;外因通过内因起作用&#xff0c;如果没把里面搞明白&#xff0c;功能没有开放出来&#x…

【SpringCloud Alibaba】(六)使用 Sentinel 实现服务限流与容错

今天&#xff0c;我们就使用 Sentinel 实现接口的限流&#xff0c;并使用 Feign 整合 Sentinel 实现服务容错的功能&#xff0c;让我们体验下微服务使用了服务容错功能的效果。 因为内容仅仅围绕着 SpringCloud Alibaba技术栈展开&#xff0c;所以&#xff0c;这里我们使用的服…

详解Mybatis之分页插件【PageHelper】

编译软件&#xff1a;IntelliJ IDEA 2019.2.4 x64 操作系统&#xff1a;win10 x64 位 家庭版 Maven版本&#xff1a;apache-maven-3.6.3 Mybatis版本&#xff1a;3.5.6 文章目录 一. 什么是分页&#xff1f;二. 为什么使用分页&#xff1f;三. 如何设计一个Page类&#xff08;分…

【玩转python系列】【小白必看】使用Python爬虫技术获取代理IP并保存到文件中

文章目录 前言导入依赖库打开文件准备写入数据循环爬取多个页面完整代码运行效果结束语 前言 这篇文章介绍了如何使用 Python 爬虫技术获取代理IP并保存到文件中。通过使用第三方库 requests 发送HTTP请求&#xff0c;并使用 lxml 库解析HTML&#xff0c;我们可以从多个网页上…