Inference with C# BERT NLP Deep Learning and ONNX Runtime

目录

效果

测试一

测试二

测试三

模型信息

项目

代码

下载


Inference with C# BERT NLP Deep Learning and ONNX Runtime

效果

测试一

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :What is his name?

测试二

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :What will he bring home?

测试三

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :Where is Bob?

模型信息

Inputs
-------------------------
name:unique_ids_raw_output___9:0
tensor:Int64[-1]
name:segment_ids:0
tensor:Int64[-1, 256]
name:input_mask:0
tensor:Int64[-1, 256]
name:input_ids:0
tensor:Int64[-1, 256]
---------------------------------------------------------------

Outputs
-------------------------
name:unstack:1
tensor:Float[-1, 256]
name:unstack:0
tensor:Float[-1, 256]
name:unique_ids:0
tensor:Int64[-1]
---------------------------------------------------------------

项目

代码

using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;

namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{
    public struct BertInput
    {
        public long[] InputIds { get; set; }
        public long[] InputMask { get; set; }
        public long[] SegmentIds { get; set; }
        public long[] UniqueIds { get; set; }
    }

    public partial class Form1 : Form
    {
        public Form1()
        {
            InitializeComponent();
        }

        RunOptions runOptions;
        InferenceSession session;
        BertUncasedLargeTokenizer tokenizer;
        Stopwatch stopWatch = new Stopwatch();

        private void Form1_Load(object sender, EventArgs e)
        {
            string modelPath = "bertsquad-10.onnx";
            runOptions = new RunOptions();
            session = new InferenceSession(modelPath);
            tokenizer = new BertUncasedLargeTokenizer();
        }

        int MaxAnswerLength = 30;
        int bestN = 20;

        private void button1_Click(object sender, EventArgs e)
        {
            txt_answer.Text = "";
            Application.DoEvents();

            string question = txt_question.Text.Trim();
            string context = txt_context.Text.Trim();

            // Get the sentence tokens.
            var tokens = tokenizer.Tokenize(question, context);

            // Encode the sentence and pass in the count of the tokens in the sentence.
            var encoded = tokenizer.Encode(tokens.Count(), question, context);

            var padding = Enumerable
              .Repeat(0L, 256 - tokens.Count)
              .ToList();

            var bertInput = new BertInput()
            {
                InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
                InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
                SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
                UniqueIds = new long[] { 0 }
            };

            // Create input tensors over the input data.
            var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
                  new long[] { 1, bertInput.InputIds.Length });

            var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
                  new long[] { 1, bertInput.InputMask.Length });

            var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
                  new long[] { 1, bertInput.SegmentIds.Length });

            var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
                  new long[] { bertInput.UniqueIds.Length });

            var inputs = new Dictionary<string, OrtValue>
              {
                  { "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
                  { "segment_ids:0", segmentIdsOrtValue},
                  { "input_mask:0", inputMaskOrtValue },
                  { "input_ids:0", inputIdsOrtValue }
              };

            stopWatch.Restart();
            // Run session and send the input data in to get inference output. 
            var output = session.Run(runOptions, inputs, session.OutputNames);
            stopWatch.Stop();

            var startLogits = output[1].GetTensorDataAsSpan<float>();

            var endLogits = output[0].GetTensorDataAsSpan<float>();

            var uniqueIds = output[2].GetTensorDataAsSpan<long>();

            var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");

            var bestStartLogits = startLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestEndLogits = endLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestResultsWithScore = bestStartLogits
                .SelectMany(startLogit =>
                    bestEndLogits
                    .Select(endLogit =>
                        (
                            StartLogit: startLogit.Index,
                            EndLogit: endLogit.Index,
                            Score: startLogit.Logit + endLogit.Logit
                        )
                     )
                )
                .Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
                .Take(bestN);

            var (item, probability) = bestResultsWithScore
                .Softmax(o => o.Score)
                .OrderByDescending(o => o.Probability)
                .FirstOrDefault();

            int startIndex = item.StartLogit;
            int endIndex = item.EndLogit;

            var predictedTokens = tokens
                          .Skip(startIndex)
                          .Take(endIndex + 1 - startIndex)
                          .Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
                          .ToList();

            // Print the result.
            string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
                + "\r\nprobability:" + probability
                + $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";

            txt_answer.Text = answer;
            Console.WriteLine(answer);

        }

        private List<string> StitchSentenceBackTogether(List<string> tokens)
        {
            var currentToken = string.Empty;

            tokens.Reverse();

            var tokensStitched = new List<string>();

            foreach (var token in tokens)
            {
                if (!token.StartsWith("##"))
                {
                    currentToken = token + currentToken;
                    tokensStitched.Add(currentToken);
                    currentToken = string.Empty;
                }
                else
                {
                    currentToken = token.Replace("##", "") + currentToken;
                }
            }

            tokensStitched.Reverse();

            return tokensStitched;
        }
    }
}
 

using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;

namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{
    public struct BertInput
    {
        public long[] InputIds { get; set; }
        public long[] InputMask { get; set; }
        public long[] SegmentIds { get; set; }
        public long[] UniqueIds { get; set; }
    }

    public partial class Form1 : Form
    {
        public Form1()
        {
            InitializeComponent();
        }

        RunOptions runOptions;
        InferenceSession session;
        BertUncasedLargeTokenizer tokenizer;
        Stopwatch stopWatch = new Stopwatch();

        private void Form1_Load(object sender, EventArgs e)
        {
            string modelPath = "bertsquad-10.onnx";
            runOptions = new RunOptions();
            session = new InferenceSession(modelPath);
            tokenizer = new BertUncasedLargeTokenizer();
        }

        int MaxAnswerLength = 30;
        int bestN = 20;

        private void button1_Click(object sender, EventArgs e)
        {
            txt_answer.Text = "";
            Application.DoEvents();

            string question = txt_question.Text.Trim();
            string context = txt_context.Text.Trim();

            // Get the sentence tokens.
            var tokens = tokenizer.Tokenize(question, context);

            // Encode the sentence and pass in the count of the tokens in the sentence.
            var encoded = tokenizer.Encode(tokens.Count(), question, context);

            var padding = Enumerable
              .Repeat(0L, 256 - tokens.Count)
              .ToList();

            var bertInput = new BertInput()
            {
                InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
                InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
                SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
                UniqueIds = new long[] { 0 }
            };

            // Create input tensors over the input data.
            var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
                  new long[] { 1, bertInput.InputIds.Length });

            var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
                  new long[] { 1, bertInput.InputMask.Length });

            var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
                  new long[] { 1, bertInput.SegmentIds.Length });

            var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
                  new long[] { bertInput.UniqueIds.Length });

            var inputs = new Dictionary<string, OrtValue>
              {
                  { "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
                  { "segment_ids:0", segmentIdsOrtValue},
                  { "input_mask:0", inputMaskOrtValue },
                  { "input_ids:0", inputIdsOrtValue }
              };

            stopWatch.Restart();
            // Run session and send the input data in to get inference output. 
            var output = session.Run(runOptions, inputs, session.OutputNames);
            stopWatch.Stop();

            var startLogits = output[1].GetTensorDataAsSpan<float>();

            var endLogits = output[0].GetTensorDataAsSpan<float>();

            var uniqueIds = output[2].GetTensorDataAsSpan<long>();

            var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");

            var bestStartLogits = startLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestEndLogits = endLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestResultsWithScore = bestStartLogits
                .SelectMany(startLogit =>
                    bestEndLogits
                    .Select(endLogit =>
                        (
                            StartLogit: startLogit.Index,
                            EndLogit: endLogit.Index,
                            Score: startLogit.Logit + endLogit.Logit
                        )
                     )
                )
                .Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
                .Take(bestN);

            var (item, probability) = bestResultsWithScore
                .Softmax(o => o.Score)
                .OrderByDescending(o => o.Probability)
                .FirstOrDefault();

            int startIndex = item.StartLogit;
            int endIndex = item.EndLogit;

            var predictedTokens = tokens
                          .Skip(startIndex)
                          .Take(endIndex + 1 - startIndex)
                          .Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
                          .ToList();

            // Print the result.
            string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
                + "\r\nprobability:" + probability
                + $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";

            txt_answer.Text = answer;
            Console.WriteLine(answer);

        }

        private List<string> StitchSentenceBackTogether(List<string> tokens)
        {
            var currentToken = string.Empty;

            tokens.Reverse();

            var tokensStitched = new List<string>();

            foreach (var token in tokens)
            {
                if (!token.StartsWith("##"))
                {
                    currentToken = token + currentToken;
                    tokensStitched.Add(currentToken);
                    currentToken = string.Empty;
                }
                else
                {
                    currentToken = token.Replace("##", "") + currentToken;
                }
            }

            tokensStitched.Reverse();

            return tokensStitched;
        }
    }
}

下载

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/207506.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

传统算法:使用 Pygame 实现插入排序

使用 Pygame 模块实现了插入排序的动画演示。首先,它生成一个包含随机整数的数组,并通过 Pygame 在屏幕上绘制这个数组的条形图。接着,通过插入排序算法对数组进行排序,动画效果可视化每一步的排序过程。在排序的过程中,程序将当前元素插入到已排序的部分,通过适度的延迟…

每日一练2023.12.1——输出GPLT【PTA】

题目链接&#xff1a;L1-023 输出GPLT 题目要求&#xff1a; 给定一个长度不超过10000的、仅由英文字母构成的字符串。请将字符重新调整顺序&#xff0c;按GPLTGPLT....这样的顺序输出&#xff0c;并忽略其它字符。当然&#xff0c;四种字符&#xff08;不区分大小写&#x…

《opencv实用探索·七》一文看懂图像卷积运算

1、图像卷积使用场景 图像卷积是图像处理中的一种常用的算法&#xff0c;它是一种基本的滤波技术&#xff0c;通过卷积核&#xff08;也称为滤波器&#xff09;对图像进行操作&#xff0c;使用场景如下&#xff1a; 模糊&#xff08;Blur&#xff09;&#xff1a; 使用加权平…

C++入门篇(零) C++入门篇概述

目录 一、C概述 1. 什么是C 2. C的发展史 3. C的工作领域 4. C关键字(C98) 二、C入门篇导论 一、C概述 1. 什么是C C是基于C语言而产生的计算机程序设计语言&#xff0c;支持多重编程模式&#xff0c;包括过程化程序设计、数据抽象、面向对象程序设计、泛型程序设计和设计模式…

Maven无法拉取依赖/构建失败操作步骤(基本都能解决)

首先检查配置文件&#xff0c;确认配置文件没有问题(也可以直接用同事的配置文件(记得修改文件里的本地仓库地址)) 1.file->Invalidate Caches清除缓存重启(简单粗暴&#xff0c;但最有效) 2.刷新maven以及mvn clean&#xff0c;多刷几次&#xff0c;看看还有没有报红的依赖…

Python 中 AttributeError: Int object Has No Attribute 错误

int 数据类型是最基本和最原始的数据类型之一&#xff0c;它不仅在 Python 中&#xff0c;而且在其他几种编程语言中都用于存储和表示整数。 只要没有小数点&#xff0c;int 数据类型就可以存储任何正整数或负整数。 本篇文章重点介绍并提供了一种解决方案&#xff0c;以应对我…

基于Netty的网络调用实现

作为一个分布式消息队列&#xff0c;通信的质量至关重要。基于TCP协议和Socket实现一个高效、稳定的通信程序并不容易&#xff0c;有很多大大小小的“坑”等待着经验不足的开发者。RocketMQ选择不重复发明轮子&#xff0c;基于Netty库来实现底层的通信功能。 1 Netty介绍 Net…

TCP报文解析

1.端口号 标记同一台计算机上的不同进程 源端口&#xff1a;占2个字节&#xff0c;源端口和IP的作用是标记报文的返回地址。 目的端口&#xff1a;占2个字节&#xff0c;指明接收方计算机上的应用程序接口。 TCP报头中的源端口号和目的端口号同IP报头中的源IP和目的IP唯一确定一…

马蹄集第34周

1.战神的对称谜题 不知道为什么超时&#xff01; def main():s input()result 0for i in range(len(s)):l i - 1r i 1while l > 0 and r < len(s) and s[l] s[r]:result max(result, r - l 1)l - 1r 1l ir i 1while l > 0 and r < len(s) and s[l] s…

二分查找与搜索树高频问题

关卡名 逢试必考的二分查找 我会了✔️ 内容 1.山脉数组的峰顶索引 ✔️ 2.旋转数字的最小数字 ✔️ 3.寻找缺失数字 ✔️ 4.优化求平方根 ✔️ 5.中序与搜索树原理 ✔️ 6.二叉搜索树中搜索特定值 ✔️ 7.验证二叉搜索树 ✔️ 基于二分查找思想&#xff0c;可以拓展出很…

【PUSDN】WebStorm中报错Switch language version to React JSX

简述 WebStorm中报错Switch language version to React JSX 可能本页面的写法是其他语法。所以可以不用管。 测试项目&#xff1a;ant design vue pro 前情提示 系统&#xff1a; 一说 同步更新最新版、完整版请移步PUSDN Powered By PUSDN - 平行宇宙软件开发者网www.pusdn…

算法学习—排序

排序算法 一、选择排序 1.算法简介 选择排序是一个简单直观的排序方法&#xff0c;它的工作原理很简单&#xff0c;首先从未排序序列中找到最大的元素&#xff0c;放到已排序序列的末尾&#xff0c;重复上述步骤&#xff0c;直到所有元素排序完毕。 2.算法描述 1&#xff…

C语言-预处理与库

预处理、动态库、静态库 1. 声明与定义分离 一个源文件对应一个头文件 注意&#xff1a; 头文件名以 .h 作为后缀头文件名要与对应的原文件名 一致 例&#xff1a; 源文件&#xff1a;01_code.c #include <stdio.h> int num01 10; int num02 20; void add(int a, in…

uniapp 使用web-view外接三方

来源 前阵子有个需求是需要在原有的项目上加入一个电子签名的功能&#xff0c;为了兼容性和复用性后面解决方法是将这个电子签名写在一个新的项目中&#xff0c;然后原有的项目使用web-view接入这个电子签名项目&#xff1b; 最近又有一个需求&#xff0c;是需要接入第三方的…

蓝桥杯每日一题2023.11.30

题目描述 九数组分数 - 蓝桥云课 (lanqiao.cn) 题目分析 此题目实际上是使用dfs进行数字确定&#xff0c;每次循环中将当前数字与剩下的数字进行交换 eg.1与2、3、4、、、进行交换 2与3、4、、、进行交换 填空位置将其恢复原来位置即可&#xff0c;也就直接将其交换回去即可…

Linux(CentOS7.5):新增硬盘分区纪实

一、服务器概述 1、既有一块系统硬盘&#xff0c;新增一块100G硬盘。 2、要求&#xff0c;将新插入硬盘分为&#xff1a;20G、30G、50G。 二、操作步骤 1、确认新硬盘是否插入成功&#xff1a; fdisk -l# 红色框出来的&#xff0c;为识别出来的新硬盘信息 # 黄色框出来的&#…

Linux:锁定部分重要文件,防止误操作

一、情景描述 比如root用户或者拥有root权限的用户&#xff0c;登陆系统后&#xff0c;通过useradd指令&#xff0c;新增一个用户。 而我们业务限制&#xff0c;只能某一个人才有权限新增用户。 那么&#xff0c;这个时候&#xff0c;我们就用chattr来锁定/etc/passwd文件&…

一些ab命令

1.ab简介 ab是apache自带的压力测试工具&#xff0c;是apachebench命令的缩写。ab非常实用&#xff0c;它不仅可以对apache服务器进行网站访问压力测试&#xff0c;也可以对或其它类型的服务器如nginx、tomcat、IIS等进行压力测试。 ab的原理&#xff1a;ab命令会创建多个并发…

力扣 790. 多米诺和托米诺平铺(一维dp)

题目描述&#xff1a; 有两种形状的瓷砖&#xff1a;一种是 2 x 1 的多米诺形&#xff0c;另一种是形如 "L" 的托米诺形。两种形状都可以旋转。 给定整数 n &#xff0c;返回可以平铺 2 x n 的面板的方法的数量。返回对 109 7 取模 的值。 平铺指的是每个正方形都…

df新增一列数据,并指定列名

方法1&#xff1a;直接指定df列名赋值为list即可 skill_info_df[age] age_list ps:list的长度要和df对齐 方法二 使用insert&#xff1a;