【Text2SQL】WikiSQL 数据集与 Seq2SQL 模型

论文:Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning

⭐⭐⭐⭐⭐

ICLR 2018

Dataset: github.com/salesforce/WikiSQL

Code:Seq2SQL 模型实现

一、论文速读

本文提出了 Text2SQL 方向的一个经典数据集 —— WikiSQL,同时提出了一个模型 Seq2SQL,用于把自然语言问句转为 SQL。

WikiSQL 数据集中的 SQL 形式较为简单,不包括排序(order by)、分组(group by)、子查询等其他复杂操作。根据这种简单的形式,本文的 Seq2SQL 模型针对一个 table 和一个 question,预测出 SELECT 部分、Aggregation 部分和 WHERE 部分,并将其构造成一个 SQL 语句。下图展示了一个示例:

在这里插入图片描述

Seq2SQL 基于 Augmented Pointer Network 来实现,下面先介绍一下这个网络结构,然后再介绍基于此来实现 Seq2SQL 模型。

二、Augmented Pointer Network(增广指针网络)

Augmented Pointer Network 能够从输入序列中选择 token 并逐个 token 生成输出序列。

对于一个 example,输入序列 x x x 是由"table 的列名"、“SQL 词汇表”、"question"三者用特殊分隔符拼接起来的序列:

在这里插入图片描述

比如在前面图片的示例中,列名 token 包括 “Pick”、“#”、“CFL” 等等组成,question token 包括 “How”、“many”、“CFL” 等等,SQL 词汇表包括 “SELECT”、“WHERE”、“COUNT”、“MIN” 等等。

这个网络首先对 input sequence x x x 做 word embedding,然后输入给两层的 Bi-LSTM 做编码得到 h e n c h^{enc} henc,其中 input 的第 i 个 token 的编码是 h t e n c h_t^{enc} htenc,这样每个 token 经过编码都变成了一个 vector。

解码器部分使用双层的单向 LSTM,每一步生成一个 token。具体生成方式是:使用上一步生成的 token y s − 1 y_{s-1} ys1 作为输入,输出一个 state g s g_s gs,然后拿 g s g_s gs 与 input sequence 的每个位置 t 的 h t h_t ht 做计算得到一个标量的注意力分数 α s , t p t r \alpha_{s,t}^{ptr} αs,tptr,选择分数最高的对应的输入 token 作为生成的下一个 token。其中注意力分数的计算公式如下:

20240518155338

三、Seq2SQL 模型

虽然可以直接训练 Augmented Pointer Network 让他生成 SQL 序列作为结果,但是这没有利用 SQL 本身固有的结构。本论文固定 SQL 的结构由三部分组成:SELECT、WHERE 和 Aggregation,并训练三个组件来分别生成这三部分:

在这里插入图片描述

3.1 Aggregation Classifier

他就是一个 classifier,最终输出一个 softmax 计算后的分布,从 NULLMAXMINCOUNTSUMAVG 中做分类,NULL 表示没有 aggregation 操作。其 loss L a g g L^{agg} Lagg 使用 cross entropy 来计算。

比如,“How many” 类型的 question 往往被分类为 COUNT

3.2 SELECT column prediction

SELECT column prediction 是一个匹配问题,这里使用指针网络的思想来解决:输入列名序列和 question 的拼接,输出与 question 最匹配的一个 column。

首先使用 LSTM 对每一列进行编码,column j j j 对应一个 vector e j c e_j^c ejc,然后对 input x x x 编码出一个 vector κ s e l \kappa^{sel} κsel,然后使用 MLP,计算 input representation κ s e l \kappa^{sel} κsel 与每一个 column j 的分数 α j s e l \alpha^{sel}_{j} αjsel,之后使用 softmax 对分数进行归一化:

  • 训练时,使用交叉熵损失 L s e l L^{sel} Lsel 来训练该模块
  • 预测时,选分数最大的 column 作为预测结果

对于输入 x x x 编码为 input representation 和计算分数的详细信息可以参考论文和代码实现

3.3 WHERE Clause

这里使用类似于 Augmented Pointer Network 的 pointer decoder 来训练这一模块。但是使用 cross entropy 有一个限制:两个 WHERE 条件可以被交换并产生相同结果。但两个顺序不同的 WHERE 会被 cross entropy 错误地惩罚,比如 year>18 and male=1male=1 and year>18 是等价的,但由于 cross entropy 是精确匹配 tokens,导致这个结果会被计算损失。

这里使用强化学习(RL)来训练, q ( y ) q(y) q(y) 是生成的查询, q g q_g qg 是真实查询,奖励函数的定义如下:

20240518171120

并根据此奖励函数计算出 loss L w h e L^{whe} Lwhe

3.4 Seq2SQL 的训练

设置一个混合损失函数 L = L a g g + L s e l + L w h e L = L^{agg} + L^{sel} + L^{whe} L=Lagg+Lsel+Lwhe,并使用梯度下降来最小化该 loss 从而训练模型。

四、WikiSQL 数据集

该文更重要的一个贡献是提供了一个 WikiSQL 数据集,包含 80654 条样本和 24241 个 schema。这些数据被随机划分为 train、dev 和 test 三个 split。

下面是一个 example:

20240518173309

解释如下:

  • phase: the phase in which the dataset was collected. We collected WikiSQL in two phases.
  • question: the natural language question written by the worker.
  • table_id: the ID of the table to which this question is addressed.
  • sql: the SQL query corresponding to the question. This has the following subfields:
    • sel: the numerical index of the column that is being selected. You can find the actual column from the table.
    • agg: the numerical index of the aggregation operator that is being used. You can find the actual operator from Query.agg_ops in lib/query.py.
    • conds: a list of triplets (column_index, operator_index, condition) where:
      • column_index: the numerical index of the condition column that is being used. You can find the actual column from the table.
      • operator_index: the numerical index of the condition operator that is being used. You can find the actual operator from Query.cond_ops in lib/query.py.
      • condition: the comparison value for the condition, in either string or float type.

同时还给出了每个 table 的 schema 和数据部分。

五、评估指标

  • N N N:数据集的样本总数
  • N e x N_{ex} Nex:运行生成的 SQL 后,得到正确结果的样本数
  • N l f N_{lf} Nlf:生成的 SQL 与 ground-truth SQL 字符串完全精确匹配的样本数

由此提出两个指标:

  • A C C e x = N e x / N ACC_{ex} = N_{ex} / N ACCex=Nex/N执行精度指标,如果生成的 SQL 与 ground-truth SQL 的执行结果相同,那就算作正确。存在一个缺点:如果构造一个错误的 SQL 但执行结果正确,依然被算作正确
  • A C C l f = N l f / N ACC_{lf} = N_{lf} / N ACClf=Nlf/N逻辑形式的精确指标,如果生成的 SQL 与 ground-truth SQL 完全匹配,才被算作正确。存在一个缺点:两个等价但写法不同的 SQL 会被算作错误

六、总结

这篇论文给出了一个 WikiSQL 数据集,并提出了 Text2SQL 的一个解决方案以及评价指标。

但是很明显,该方案存在不少缺点,之后的方案会继续改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/641139.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Aligned Layer:trustless应用的通用验证层

1. 引言 Aligned Layer近期获得了2000万美金的A轮融资,Aligned Layer: 致力于成为 去中心化 zero-knowledge (ZK) proof 验证层。旨在为不同证明系统,提供快速、经济且可扩展的零知识证明验证。作为以太坊生态系统及其他生态系统的通用sour…

【pyspark速成专家】4_Spark之RDD编程2

目录 四,常用PairRDD的转换操作 五,缓存操作 四,常用PairRDD的转换操作 PairRDD指的是数据为长度为2的tuple类似(k,v)结构的数据类型的RDD,其每个数据的第一个元素被当做key,第二个元素被当做value. reduceByKey #reduceByKey…

罗德里格斯公式(旋转矩阵)推导

文章目录 1. 推导2. 性质3. 参考 1. 推导 r r r为旋转轴, θ \theta θ为旋转角度。 先将旋转轴单位化 u r ∣ ∣ r ∣ ∣ u\frac{r}{||r||} u∣∣r∣∣r​ 旋转可以被分为垂直和旋转两个方向, 我们求沿轴方向的分量其实就是在求 p p p向量在 u u u方…

02-Linux【基础篇】

一、Linux的目录结构 1.基本介绍 Linux的文件系统采用层级式的树状目录结构,在此结构中的最上层是根目录"/",然后在此目录下再创建其他的目录 深刻理解Linux树状文件目录是非常重要的 记住一句经典的话:在Linux世界里&#xff…

C语言内存函数(与上篇字符函数及字符串函数一起食用效果更佳哦~)

顾名思义,内存函数就是针对内存块(即一块内存)来处理的。 因此本篇所讲的四种内存函数: memcpy(内存拷贝)memmove(内存移动)memset(内存设置)memcmp&#x…

批量复制文件智能删除已复制,轻松管理文件新体验!让您的文件整理更高效无忧

在信息爆炸的时代,文件管理无疑成为我们日常生活和工作中不可或缺的一部分。面对堆积如山的文件,我们时常陷入无尽的复制、粘贴、删除循环中,不仅耗时耗力,还容易出错。但今天,我要向您推荐一款颠覆传统的文件管理工具…

移动云——让每个人都能享受云技术的魅力

一、引言 云技术的起源可以追溯到20世纪60年代和70年代,随着科技的发展,现在早就和所有人息息相关。在云技术的浪潮中,有这么一家厂商通过自己的努力,深耕云计算市场,不仅有各种各样的产品为开发者提供服务&#xff0…

基于vue3速学angular

因为工作原因,需要接手新的项目,新的项目是angular框架的,自学下和vue3的区别,写篇博客记录下: 参考:https://zhuanlan.zhihu.com/p/546843290?utm_id0 1.结构上: vue3:一个vue文件&#xff…

【simple-admin】simple-admin-core 首次服务启动 如何配置mysql数据库表 | 如何docker启动core

一、下载启动S-A 1、下载源码 https://github.com/suyuan32/simple-admin-core.git git clone https://github.com/suyuan32/simple-admin-core.git2、修改etc下yaml配置 需要对RPC和API 分别2个文件夹下的etc下的yaml进行修改 替换成我们的数据库 3、初始化数据库 核心代…

JavaEE之线程(9) _定时器的实现代码

前言 定时器也是软件开发中的一个重要组件. 类似于一个 “闹钟”。 达到一个设定的时间之后,就执行某个指定好的代码,比如: 在受上述场景中,当客户端发出去请求之后, 就要等待响应,如果服务器迟迟没有响应&…

YOLOv10最全使用教程(含ONNX和TensorRT推理)

论文题目:YOLOv10: Real-Time End-to-End Object Detection 研究单位:清华大学 论文链接:http://arxiv.org/abs/2405.14458 代码链接:https://github.com/THU-MIG/yolov10 作者提供的模型性能评价图,如下:…

基于灰狼优化算法优化RBF(GWO-RBF)的数据回归预测(多输入多输出)

代码原理及流程 基于灰狼优化算法优化多输入多输出(MIMO)的RBF神经网络的数据回归预测,可以采取以下步骤: 1. 数据准备:准备包含多个输入特征和多个输出目标的数据集,确保数据已经经过预处理和归一化。 …

盲人社区生活支持体系:织就一张温暖的网

在当今社会,构建一个全面、包容的盲人社区生活支持体系成为了推动社会进步、保障残障人士权益的重要议题。随着科技的不断革新,一款名为“蝙蝠避障”的辅助软件走进了盲人的日常生活,它如同一位无形的向导,通过实时避障与拍照识别…

Failed to load property source from location ‘classpath:/application.yml‘

总帮客户二开 经常会遇到这种问题 网上的解决办法都是刷新maven (对我没用!) 我的解决办法是修改编码!

PLSQL连接Linux Oracle21c

PLSQL连接Linux Oracle21c 一、安装PLsql 下载官网 https://www.allroundautomations.com/registered-plsqldev/ 二、Oracle Instant Client下载 使用plsql连接oracle的时候是需要本地先安装oracle客户端,英文名就是Oracle Instant Client。 官方下载地址&…

初识java——javaSE (6)接口的实现——比较器与深拷贝,浅拷贝

文章目录 前言一 比较器1.1 关于两个对象的比较1.2 Comparable接口:1.3 Arrays.sort方法的实现1.4 比较器的实现Comparator接口 二 深拷贝与浅拷贝2.1 浅拷贝:Cloneable接口:clone方法:实现拷贝:浅拷贝: 2.…

【C++】list容器

目录 一.list容器介绍 二.C中list的基本组成 三.list容器相关接口的模拟实现 1.push_back() 2.迭代器的begin()和end() 3.insert() 4.erase() 5.pop_front() 6.pop_back() 7.size() 8.empty() 9.析构~list()和清除数据clear() 10.拷贝构造 11.赋值运算 四.模拟…

分享几张漂亮的linux kde主题

分享几张漂亮的linux kde主题:在系统设置的全局主题内下载。

SpringBoot——整合Redis

目录 Redis 创建Commodity表 启动MySQL和Redis 新建一个SpringBoot项目 pom.xml application.properties Commodity实体类 ComMapper接口 ComService业务层接口 ComServiceImpl业务接口的实现类 ComController控制器 RedisConfig配置类 SpringbootRdisApplication启…

c++|多态

c|多态 1 多态的概念2 多态的定义及其实现2.1 满足多态的条件2.2 虚函数2.3 虚函数的重写2.4 析构函数适合加virtural吗2.4 C11 override 和 final2.5 三个概念的对比 3 多态的原理4 抽象类4.1 概念4.2 纯虚函数 1 多态的概念 多态的概念:通俗来说,就是…