fastapi 调用ollama之下的sqlcoder模式进行对话操作数据库

from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
import ollama
import mysql.connector
from mysql.connector.cursor import MySQLCursor
import json

app = FastAPI()

# 数据库连接配置
DB_CONFIG = {
    "database": "web",        # 您的数据库名,用于存储业务数据
    "user": "root",          # 数据库用户名,需要有读写权限
    "password": "XXXXXX",    # 数据库密码,建议使用强密码
    "host": "127.0.0.1",    # 数据库主机地址,本地开发环境使用localhost
    "port": "3306"          # MySQL 默认端口,可根据实际配置修改
}

# 数据库连接函数
def get_db_connection():
    try:
        conn = mysql.connector.connect(**DB_CONFIG)
        return conn
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"数据库连接失败: {str(e)}")

class SQLRequest(BaseModel):
    question: str

def get_table_relationships():
    """动态获取表之间的关联关系"""
    conn = get_db_connection()
    cur = conn.cursor()
    try:
        # 获取当前数据库名
        cur.execute("SELECT DATABASE()")
        db_name = cur.fetchone()[0]
        
        # 获取外键关系
        cur.execute("""
            SELECT 
                TABLE_NAME,
                COLUMN_NAME,
                REFERENCED_TABLE_NAME,
                REFERENCED_COLUMN_NAME
            FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
            WHERE TABLE_SCHEMA = %s
                AND REFERENCED_TABLE_NAME IS NOT NULL
            ORDER BY TABLE_NAME, COLUMN_NAME
        """, (db_name,))
        
        relationships = []
        for row in rows:
            table_name, column_name, ref_table, ref_column = row
            relationships.append(
                f"-- {table_name}.{column_name} can be joined with {ref_table}.{ref_column}"
            )
        
        return "\n".join(relationships) if relationships else "-- No foreign key relationships found"
        
    finally:
        cur.close()
        conn.close()

def get_database_schema():
    """获取MySQL数据库表结构,以CREATE TABLE格式返回"""
    conn = get_db_connection()
    cur = conn.cursor()
    try:
        # 获取当前数据库名
        cur.execute("SELECT DATABASE()")
        db_name = cur.fetchone()[0]
        
        # 获取所有表的结构信息
        cur.execute("""
            SELECT 
                t.TABLE_NAME,
                c.COLUMN_NAME,
                c.COLUMN_TYPE,
                c.IS_NULLABLE,
                c.COLUMN_KEY,
                c.COLUMN_COMMENT
            FROM INFORMATION_SCHEMA.TABLES t
            JOIN INFORMATION_SCHEMA.COLUMNS c 
                ON t.TABLE_NAME = c.TABLE_NAME
            WHERE t.TABLE_SCHEMA = %s
                AND t.TABLE_TYPE = 'BASE TABLE'
            ORDER BY t.TABLE_NAME, c.ORDINAL_POSITION
        """, (db_name,))
        
        rows = cur.fetchall()
        
        schema = []
        current_table = None
        table_columns = []
        
        for row in rows:
            table_name, column_name, column_type, nullable, key, comment = row
            
            if current_table != table_name:
                if current_table is not None:
                    schema.append(f"CREATE TABLE {current_table} (\n" + 
                                ",\n".join(table_columns) + 
                                "\n);\n")
                current_table = table_name
                table_columns = []
            
            # 构建列定义
            column_def = f"  {column_name} {column_type.upper()}"
            if key == "PRI":
                column_def += " PRIMARY KEY"
            elif nullable == "NO":
                column_def += " NOT NULL"
                
            if comment:
                column_def += f" -- {comment}"
                
            table_columns.append(column_def)
        
        # 添加最后一个表
        if current_table is not None:
            schema.append(f"CREATE TABLE {current_table} (\n" + 
                        ",\n".join(table_columns) + 
                        "\n);\n")
            
        return "\n".join(schema)
    finally:
        cur.close()
        conn.close()

def get_chinese_table_mapping():
    """动态生成表名的中文映射"""
    conn = get_db_connection()
    cur = conn.cursor()
    try:
        # 获取所有表的注释信息
        cur.execute("""
            SELECT 
                t.TABLE_NAME,
                t.TABLE_COMMENT
            FROM information_schema.TABLES t
            WHERE t.TABLE_SCHEMA = DATABASE()
            ORDER BY t.TABLE_NAME
        """)
        
        mappings = []
        for table_name, table_comment in cur.fetchall():
            # 生成表的中文名称
            chinese_name = table_name
            if table_name.startswith('web_'):
                chinese_name = table_name.replace('web_', '').replace('_', '')
            if table_comment:
                chinese_name = table_comment.split('--')[0].strip()
                # 如果中文名称以"表"结尾,则去掉"表"if chinese_name.endswith('表'):
                    chinese_name = chinese_name[:-1]
            
            mappings.append(f'           - "{chinese_name}" -> {table_name} table')
        
        return "\n".join(mappings)
    finally:
        cur.close()
        conn.close()

@app.post("/query")
async def query_database(request: Request):
    try:
        # 获取请求体数据并确保正确处理中文
        body = await request.body()
        try:
            request_data = json.loads(body.decode('utf-8'))
        except UnicodeDecodeError:
            request_data = json.loads(body.decode('gbk'))
        
        question = request_data.get('question')
        print(f"收到问题: {question}")  # 调试日志
        
        if not question:
            raise HTTPException(status_code=400, detail="缺少 question 参数")
            
        # 获取数据库结构
        db_schema = get_database_schema()
        #print(f"数据库结构: {db_schema}")  # 调试日志
        
        # 获取中文映射并打印
        chinese_mapping = get_chinese_table_mapping()
        #print(f"表映射关系:\n{chinese_mapping}")  # 添加这行来打印映射
        
        # 修改 prompt 使用更严格的指导
        prompt = f"""
        ### Instructions:
        Convert Chinese question to MySQL query. Follow these rules strictly:
        1. ONLY return a valid SELECT SQL query
        2. Use EXACT table names from the mapping below
        3. DO NOT use any table that's not in the mapping
        4. For Chinese terms, use these exact mappings:
{chinese_mapping}

        ### Examples:
        Question: 所有装修记录
        SQL: SELECT * FROM web_decoration ORDER BY id;

        Question: 查询装修
        SQL: SELECT * FROM web_decoration ORDER BY id;

        ### Database Schema:
        {db_schema}

        ### Question:
        {question}

        ### SQL Query:
        """
        
        # 获取 SQL 查询
        response = ollama.chat(model='sqlcoder:latest', 
                             messages=[{'role': 'user', 'content': prompt}])
        sql_query = response['message']['content'].strip()
        
        print(f"生成的SQL: {sql_query}")  # 调试日志
        
        # 验证 SQL 查询
        if "装修" in question and "web_decoration" not in sql_query:
            sql_query = "SELECT * FROM web_decoration ORDER BY id"
            
        if not sql_query.upper().startswith('SELECT'):
            raise HTTPException(status_code=400, detail="无效的SQL查询格式")
            
        conn = get_db_connection()
        cur = conn.cursor(dictionary=True)
        
        try:
            cur.execute(sql_query)
            results = cur.fetchall()
            return {
                "sql_query": sql_query,
                "results": results
            }
        except mysql.connector.Error as e:
            raise HTTPException(status_code=400, detail=f"SQL 执行错误: {str(e)}")
        finally:
            cur.close()
            conn.close()
            
    except json.JSONDecodeError as e:
        raise HTTPException(status_code=400, detail=f"JSON 解析错误: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"处理请求时生错误: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    port = 8666
    print(f"Starting server on port {port}")
    uvicorn.run(
        "main:app",
        host="0.0.0.0", 
        port=port,
        reload=True
    )

我来为您解释这段代码的主要功能和结构:

主要功能

这是一个使用 FastAPI 框架开发的 API 服务,主要用于将中文问题转换为 SQL 查询语句并执行。

核心组件说明

  1. 数据库配置与连接
DB_CONFIG = {
    "database": "web",
    "user": "root",
    "password": "XXXXXX",
    "host": "127.0.0.1",
    "port": "3306"
}

提供了 MySQL 数据库的基本连接配置。

  1. 主要工具函数
  • get_table_relationships(): 获取数据库表之间的外键关系
  • get_database_schema(): 获取数据库表结构
  • get_chinese_table_mapping(): 生成表名的中文映射关系
  1. 核心 API 端点
@app.post("/query")

这个端点接收中文问题,主要处理流程:

  • 接收并解析用户的中文问题
  • 获取数据库结构和表映射
  • 使用 ollama 模型将中文转换为 SQL 查询
  • 执行 SQL 查询并返回结果
  1. 智能转换功能
    使用 ollamasqlcoder 模型将中文问题转换为 SQL 查询,包含:
  • 严格的表名映射
  • SQL 查询验证
  • 错误处理机制

特点

  1. 支持中文输入处理
  2. 自动获取数据库结构
  3. 动态生成中文表名映射
  4. 完善的错误处理机制
  5. 支持热重载的开发模式

使用示例

可以通过 POST 请求访问 /query 端点:

{
    "question": "查询所有装修记录"
}

服务会返回:

{
    "sql_query": "SELECT * FROM web_decoration ORDER BY id",
    "results": [...]
}

安全特性

  1. 数据库连接错误处理
  2. SQL 注入防护
  3. 请求体编码自适应(支持 UTF-8 和 GBK)
  4. 查询结果的安全封装

查看效果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/917733.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于微信小程序的乡村研学游平台设计与实现,LW+源码+讲解

摘 要 信息数据从传统到当代,是一直在变革当中,突如其来的互联网让传统的信息管理看到了革命性的曙光,因为传统信息管理从时效性,还是安全性,还是可操作性等各个方面来讲,遇到了互联网时代才发现能补上自…

基于Java Springboot城市交通管理系统

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术:Html、Css、Js、Vue、Element-ui 数据库:MySQL 后端技术:Java、Spring Boot、MyBatis 三、运行环境 开发工具:IDEA/eclipse 数…

手机直连卫星NTN通信初步研究

目录 1、手机直连卫星之序幕 2、卫星NTN及其网络架构 2.1 NTN 2.2 NTN网络架构 3、NTN的3GPP标准化进程 3.1 NTN需要适应的特性 3.2 NTN频段 3.3 NTN的3GPP标准化进程概况 3.4 NTN的3GPP标准化进程的详情 3.4.1 NR-NTN 3.4.1.1 NTN 的无线相关 SI/WI 3.4.1.2…

基本数据类型和包装类型的区别、缓存池、自动拆箱装箱(面试题)

目录 1. 八种基本类型及对应包装类型 2. 基本类型和包装类型 区别 3. 自动拆箱装箱 3.1 自动装箱 3.2 自动拆箱 3.3 缓存池 4. 高频面试案例分析 1. 八种基本类型及对应包装类型 基本数据类型类型描述范围(指数形式)位数包装类型byte整型&#x…

Aria2-CVE-2023-39141漏洞分析

前言: 在偶然一次的渗透靶机的时候,上网查询Aria2的历史漏洞,发现了这个漏洞,但是网上并没有对应的漏洞解释,于是我就就源代码进行分析,发现这是一个非常简单的漏洞,于是发这篇文章跟大家分享一…

androidstudio入门到放弃配置

b站视频讲解传送门 android_studio安装包:https://developer.android.google.cn/studio?hlzh-cn 下载安装 开始创建hello-world 1.删除缓存 文件 下载gradle文件压缩:gradle-8.9用自己创建项目时自动生成的版本即可,不用和我一样 https://…

河道无人机雷达测流监测系统由哪几部分组成?

在现代水利管理中,河道无人机雷达监测系统正逐渐成为一种重要的工具,为河道的安全和管理提供了强大的技术支持。那么,这个先进的监测系统究竟由哪几部分组成呢? 河道无人机雷达监测系统工作原理 雷达传感器通过发射电磁波或激光束…

Mac上详细配置java开发环境和软件(更新中)

文章目录 概要JDK的配置JDK下载安装配置JDK环境变量文件 Idea的安装Mysql安装和配置Navicat Premium16.1安装安装Vscode安装和配置Maven配置本地仓库配置阿里云私服Idea集成Maven 概要 这里使用的是M3型片 14.6版本的Mac 用到的资源放在网盘 链接: https://pan.baidu.com/s/17…

CKA认证 | Day3 K8s管理应用生命周期(上)

第四章 应用程序生命周期管理(上) 1、在Kubernetes中部署应用流程 1.1 使用Deployment部署Java应用 在 Kubernetes 中,Deployment 是一种控制器,用于管理 Pod 的部署和更新。以下是使用 Deployment 部署 Java 应用的步骤&#x…

ffmpeg编程入门

文章目录 ffmpeg流程常用的音视频术语常用概念复用器编解码器ffmpeg的整体结构注册组件相关封装格式相关函数的调用流程 相关的ffpmeg数据结构简介数据结构之间的关系 ffmpeg流程 图中的函数 以及结构体都是ffmpeg自带提供的 ffmpeg打开的时候 和其他io操作差不多 有一个类似句…

函数指针示例

目录&#xff1a; 代码&#xff1a; main.c #include <stdio.h> #include <stdlib.h>int Max(int x, int y); int Min(int x, int y);int main(int argc, char**argv) {int x,y;scanf("%d",&x);scanf("%d",&y);int select;printf(&q…

间接采购管理:主要挑战与实战策略

间接采购支出会悄然消耗掉企业的现金流&#xff0c;即使是管理完善的公司也难以避免。这是因为间接支出不直接关联特定客户、产品或项目&#xff0c;使采购人员难以跟踪。但正确管理间接支出能为企业带来显著收益——前提是要有合适的工具。本文将分享管理间接支出的关键信息与…

TCP(下):三次握手四次挥手 动态控制

欢迎浏览高耳机的博客 希望我们彼此都有更好的收获 感谢三连支持! TCP(上)&#xff1a;成熟可靠的传输层协议-CSDN博客 &#x1f95d;在上篇博客中&#xff0c;我们针对TCP的特性,报文结构,连接过程以及相对于其他协议的区别进行了探讨&#xff0c;提供了初步的理解和概览。本…

ASP.NET 部署到IIS,访问其它服务器的共享文件 密码设定

asp.net 修改上面的 IIS需要在 配置文件 添加如下内容 》》》web.config <system.web><!--<identity impersonate"true"/>--><identity impersonate"true" userName"您的账号" password"您的密码" /><co…

python实现十进制转换二进制,tkinter界面

目录 需求 效果 代码实现 代码解释 需求 python实现十进制转换二进制 效果 代码实现 import tkinter as tk from tkinter import messageboxdef convert_to_binary():try:# 获取输入框中的十进制数decimal_number int(entry.get())# 转换为二进制binary_number bin(de…

现代密码学|古典密码学例题讲解|AES数学基础(GF(2^8)有限域上的运算问题)| AES加密算法

文章目录 古典密码凯撒密码和移位变换仿射变换例题多表代换例题 AES数学基础&#xff08;GF&#xff08;2^8&#xff09;有限域上的运算问题&#xff09;多项式表示法 | 加法 | 乘法X乘法模x的四次方1的乘法 AES加密算法初始变换字节代换行移位列混合轮密钥加子密钥&#xff08…

ubuntu使用DeepSpeech进行语音识别(包含交叉编译)

文章目录 前言一、DeepSpeech编译二、DeepSpeech使用示例三、核心代码分析1.创建模型核心代码2.识别过程核心代码 四、交叉编译1.交叉编译2.使用 总结 前言 由于工作需要语音识别的功能&#xff0c;环境是在linux arm版上&#xff0c;所以想先在ubuntu上跑起来看一看&#xff…

阿里云引领智算集群网络架构的新一轮变革

阿里云引领智算集群网络架构的新一轮变革 云布道师 11 月 8 日~ 10 日在江苏张家港召开的 CCF ChinaNet&#xff08;即中国网络大会&#xff09;上&#xff0c;众多院士、教授和业界技术领袖齐聚一堂&#xff0c;畅谈网络未来的发展方向&#xff0c;聚焦智算集群网络的创新变…

PyQt5 加载UI界面与资源文件

步骤一: 使用 Qt Designer 创建 XXX.ui文件 步骤二: 使用 Qt Designer 创建 资源文件 步骤三: Python文件中创建相关类, 使用 uic.loadUi(mainwidget.ui, self ) 加载UI文件 import sys from PyQt5 import QtCore, QtWidgets, uic from PyQt5.QtCore import Qt f…

7.高可用集群架构Keepalived双主热备原理

一. 高可用集群架构Keepalived双主热备原理 (1)主机+备机keepalived配置(192.168.1.171) ! Configuration File for keepalivedglobal_defs {# 路由id:当前安装keepalived节点主机的标识符,全局唯一router_id keep_101 } #计算机节点(主机配置) vrrp_instance VI_1 {</