流程:
1、分别识别问题及提供的资料文件中的公司名实体,有公司名的走语义检索,无公司名的走结构化召回
2、结构化召回:Qwen根据问题生成sql,执行sql获取结果数值,把结果数值与问题给到Qwen生成最终结果
3、语义检索:根据1中识别的公司名+比赛提供的数据文件集合找到对应的招股说明书文件、把该文件切分成段N个文本段、使用Qwen为每个文本段生成向量集合A、把问题生成向量B、使用余弦相似度比较2类向量并排序得到top5,把top5合并成一个文本T,把问题与文本T生成提示词给到送给Qwen生成结果
后续优化方向包括不限于:
提升召回率:包括结构化召回与语义召回
提升准确率:主要是语义召回:可以优化提示词+对问题及检索的文本进行归一化
模型微调:sql生成及向量生成这块可以使用微调以后的模型
模型切换:现在使用的是Qwen2.5 7B,可以尝试使用参数更大模型或金融相关的专业模型
得分:综合:78.49
结构化召回:89.05
语义:62.65
排名:31/3502
说明:
本文源码下载:https://download.csdn.net/download/love254443233/90106437
参考的baseline代码=大模型说的队(源码FinQwen)Tongyi-EconML/FinQwen: FinQwen: 致力于构建一个开放、稳定、高质量的金融大模型项目,基于大模型搭建金融场景智能问答系统,利用开源开放来促进「AI+金融」。https://github.com/Tongyi-EconML/FinQwen
关键源码:
提取实体:
import csv
import pandas as pd
import numpy as np
import re
import copy
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
from modelscope import GenerationConfig
model_dir = '/data/nfs/baozhi/models/Qwen-7B-Chat'
# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
new_question_file_dir = 'intermediate/A01_question_classify.csv'
new_question_file = pd.read_csv(new_question_file_dir,delimiter = ",",header = 0)
company_file_dir = 'files/AF0_pdf_to_company.csv'
company_file = pd.read_csv(company_file_dir,delimiter = ",",header = 0)
company_data_csv_list = list()
company_index_list = list()
company_name_list = list()
for cyc in range(len(company_file)):
company_name_list.append(company_file[cyc:cyc+1]['公司名称'][cyc])
company_data_csv_list.append(company_file[cyc:cyc+1]['csv文件名'][cyc])
temp_index_cp = tokenizer(company_file[cyc:cyc+1]['公司名称'][cyc])
temp_index_cp = temp_index_cp['input_ids']
company_index_list.append(temp_index_cp)
g = open('intermediate/A02_question_classify_entity.csv', 'w', newline='', encoding = 'utf-8-sig')
csvwriter = csv.writer(g)
csvwriter.writerow(['问题id','问题','分类','对应实体','csv文件名'])
for cyc in range(len(new_question_file)):
tempw_id = new_question_file[cyc:cyc+1]['问题id'][cyc]
tempw_q = new_question_file[cyc:cyc+1]['问题'][cyc]
tempw_q_class = new_question_file[cyc:cyc+1]['分类'][cyc]
tempw_entity = 'N_A'
tempw_csv_name = 'N_A'
if new_question_file[cyc:cyc+1]['分类'][cyc] == 'Text':
temp_index_q = tokenizer(new_question_file[cyc:cyc+1]['问题'][cyc])
temp_index_q = temp_index_q['input_ids']
q_cp_similarity_list = list()
for cyc2 in range(len(company_file)):
temp_index_cp = company_index_list[cyc2]
temp_simi = len(set(temp_index_cp) &set(temp_index_q))/ (len(set(temp_index_cp))+len(set(temp_index_q)))
q_cp_similarity_list.append(temp_simi)
t = copy.deepcopy(q_cp_similarity_list)
max_number = []
max_index = []
for _ in range(1):
number = max(t)
index = t.index(number)
t[index] = 0
max_number.append(number)
max_index.append(index)
t = []
tempw_entity = company_name_list[max_index[0]]
tempw_csv_name = company_data_csv_list[max_index[0]]
csvwriter.writerow([str(tempw_id),str(tempw_q),tempw_q_class,tempw_entity,tempw_csv_name])
elif new_question_file[cyc:cyc+1]['分类'][cyc] == 'SQL':
csvwriter.writerow([str(tempw_id),str(tempw_q),tempw_q_class,tempw_entity,tempw_csv_name])
else:
find_its_name_flag = 0
for cyc_name in range(len(company_name_list)):
if company_name_list[cyc_name] in tempw_q:
tempw_entity = company_name_list[cyc_name]
tempw_csv_name = company_data_csv_list[cyc_name]
csvwriter.writerow([str(tempw_id),str(tempw_q),tempw_q_class,tempw_entity,tempw_csv_name])
find_its_name_flag = 1
break
if find_its_name_flag == 0:
csvwriter.writerow([str(tempw_id),str(tempw_q),tempw_q_class,tempw_entity,tempw_csv_name])
g.close()
print('A02_finished')
exit()
生成sql:
import csv
import pandas as pd
import numpy as np
import sqlite3
import re
import copy
from langchain_community.utilities import SQLDatabase
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
from modelscope import GenerationConfig
table_name_list = ['基金基本信息','基金股票持仓明细','基金债券持仓明细','基金可转债持仓明细','基金日行情表','A股票日行情表','港股票日行情表','A股公司行业划分表','基金规模变动表','基金份额持有人结构']
table_info_dict = {}
n = 5
deny_list = ['0','1','2','3','4','5','6','7','8','9',',','?','。',
'一','二','三','四','五','六','七','八','九','零','十',
'的','小','请','.','?','有多少','帮我','我想','知道',
'是多少','保留','是什么','-','(',')','(',')',':',
'哪个','统计','且','和','来','请问','记得','有','它们']
# url='sqlite:model_train/other/FinQwen-main/solutions/4_大模型说的队/app/tcdata/bobi.db'
# url="sqlite:data/nfs/baozhi/my_model_train/other/FinQwen-main/bs_challenge_financial_14b_dataset/dataset/bobi.db"
# db0 = SQLDatabase.from_uri(url, sample_rows_in_table_info=0)
# dbd0 = db0.table_info
#
# db2 = SQLDatabase.from_uri(url, sample_rows_in_table_info=2)
# dbd2 = db2.table_info
# list1 = dbd2.split('CREATE TABLE')
# for cyc_piece in range(len(list1)):
# list1[cyc_piece] = 'CREATE TABLE' + list1[cyc_piece]
# for piece in list1:
# for word in table_name_list:
# if word in piece:
# table_info_dict[word] = piece
question_csv_file_dir = "intermediate/A01_question_classify.csv"
question_csv_file = pd.read_csv(question_csv_file_dir,delimiter = ",",header = 0)
model_dir = '/data/nfs/baozhi/models/Qwen-7B-Chat'
# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="cuda:0", trust_remote_code=True, bf16=True).eval()
model.generation_config = GenerationConfig.from_pretrained(model_dir,
trust_remote_code=True,
temperature = 0.0001,
top_p = 1,
do_sample = False,
seed = 1234)
print('B01_model_loaded')
deny_token_list = list()
for word in deny_list:
temp_tokens = tokenizer(word)
temp_tokens = temp_tokens['input_ids']
deny_token_list = deny_token_list + temp_tokens
def get_prompt_v33(question,index_list):
Examples = '以下是一些例子:'
for index in index_list:
Examples = Examples + "问题:" + example_question_list[index] + '\n'
Examples = Examples + "SQL:" + example_sql_list[index] + '\n'
impt2 = """
你是一个精通SQL语句的程序员。
我会给你一个问题,请按照问题描述,仿照以下例子写出正确的SQL代码。
"""
impt2 = impt2 + Examples
impt2 = impt2 + "问题:" + question + '\n'
impt2 = impt2 + "SQL:"
return impt2
SQL_examples_file_dir = "files/ICL_EXP.csv"
SQL_examples_file = pd.read_csv(SQL_examples_file_dir,delimiter = ",",header = 0)
example_employ_list = list()
for cyc in range(len(SQL_examples_file)):
example_employ_list.append(0)
example_question_list = list()
example_table_list = list()
example_sql_list = list()
example_token_list = list()
for cyc in range(len(SQL_examples_file)):
example_question_list.append(SQL_examples_file[cyc:cyc+1]['问题'][cyc])
example_sql_list.append(SQL_examples_file[cyc:cyc+1]['SQL'][cyc])
temp_tokens = tokenizer(SQL_examples_file[cyc:cyc+1]['问题'][cyc])
temp_tokens = temp_tokens['input_ids']
temp_tokens2 = [x for x in temp_tokens if x not in deny_token_list]
example_token_list.append(temp_tokens2)
g = open('intermediate/question_SQL_V6.csv', 'w', newline='', encoding = 'utf-8-sig')
csvwriter = csv.writer(g)
csvwriter.writerow(['问题id','问题','SQL语句','prompt'])
pattern1 = r'\d{8}'
for cyc in range(len(question_csv_file)):
if cyc % 50 == 0:
print(cyc)
response2 = 'N_A'
prompt2 = 'N_A'
if question_csv_file['分类'][cyc] == 'SQL' and cyc not in [174]:
temp_question = question_csv_file[cyc:cyc+1]['问题'][cyc]
date_list = re.findall(pattern1,temp_question)
temp_question2_for_search = temp_question
for t_date in date_list:
temp_question2_for_search.replace(t_date,' ')
temp_tokens = tokenizer(temp_question2_for_search)
temp_tokens = temp_tokens['input_ids']
temp_tokens2 = [x for x in temp_tokens if x not in deny_token_list]
temp_tokens = temp_tokens2
#计算与已有问题的相似度
similarity_list = list()
for cyc2 in range(len(SQL_examples_file)):
similarity_list.append(len(set(temp_tokens) &set(example_token_list[cyc2]))/ (len(set(temp_tokens))+len(set(example_token_list[cyc2])) ))
#求与第X个问题相似的问题
t = copy.deepcopy(similarity_list)
# 求m个最大的数值及其索引
max_number = []
max_index = []
for _ in range(n):
number = max(t)
index = t.index(number)
t[index] = 0
max_number.append(number)
max_index.append(index)
t = []
temp_length_test = ""
short_index_list = list()
for index in max_index:
temp_length_test_1 = temp_length_test
temp_length_test = temp_length_test + example_question_list[index]
temp_length_test = temp_length_test + example_sql_list[index]
if len(temp_length_test) > 2300:
break
short_index_list.append(index)
prompt2 = get_prompt_v33(question_csv_file['问题'][cyc],short_index_list)
print(f"{str(cyc)} prompt2:{prompt2}")
response2, history = model.chat(tokenizer, prompt2, history=None)
print(f"response2 = {response2}, \n history = {history}")
print("---------------------------------------------------------------------------------")
else:
pass
csvwriter.writerow([str(question_csv_file[cyc:(cyc+1)]['问题id'][cyc]),
str(question_csv_file[cyc:(cyc+1)]['问题'][cyc]),
response2,prompt2])
语义检索:
import json
import csv
import pandas as pd
import re
from collections import Counter
import math
from modelscope import AutoTokenizer
from ai_loader import tongyi
def counter_cosine_similarity(c1, c2): # 使用截断的ccs
terms = set(c1).union(c2)
dotprod = sum(c1.get(k, 0) * c2.get(k, 0) for k in terms)
magA = math.sqrt(sum(c1.get(k, 0) ** 2 for k in terms))
magB = math.sqrt(sum(c2.get(k, 0) ** 2 for k in terms))
if magA * magB != 0:
return dotprod / (magA * magB)
else:
return 0
pattern1 = r'截至'
pattern2 = r'\d{1,4}年\d{1,2}月\d{1,2}日'
q_file_dir = 'intermediate/A02_question_classify_entity.csv'
q_file = pd.read_csv(q_file_dir, delimiter=",", header=0)
c00_file = 'intermediate/C00_text_understanding.csv'
g = open(c00_file, 'w', newline='', encoding='utf-8-sig')
text_file_dir = 'tcdata/pdf_txt_file'
csvwriter = csv.writer(g)
csvwriter.writerow(['问题id', '问题', '问题[标准化后]', '对应实体', 'csv文件名', 'FA', 'top_text'])
stopword_list = ['根据', '招股意见书', '招股意向书', '截至', '千元', '万元', '哪里', '哪个',
'知道', "什么", '?', '是',
'的', '想', '元', '。', ',', '怎样', '谁', '以及', '了', '对', '?', ',']
bd_list = ['?', '。', ',', '[', ']']
tongyi_model_path = "/data/nfs/baozhi/models/Qwen-7B-Chat"
tokenizer = AutoTokenizer.from_pretrained(tongyi_model_path, trust_remote_code=True)
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
def text_split(content):
""" 将文本分割为较小的部分 """
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=['\n\n', "\n", "。"],
keep_separator=False)
return text_splitter.split_text(content)
def embedding_2_ver(embedding):
temp_tokens = list()
for word_add in embedding:
temp_tokens.append(word_add)
return temp_tokens
def counter_cosine_similarity(c1, c2): # 使用截断的ccs
terms = set(c1).union(c2)
dotprod = sum(c1.get(k, 0) * c2.get(k, 0) for k in terms)
magA = math.sqrt(sum(c1.get(k, 0) ** 2 for k in terms))
magB = math.sqrt(sum(c2.get(k, 0) ** 2 for k in terms))
if magA * magB != 0:
return dotprod / (magA * magB)
else:
return 0
n = 30
cap = 4
def text_similarity(text, C_temp_q_tokens):
""" 计算文本和问题的相似度 """
temp_s_tokens = tokenizer(text)
temp_s_tokens = temp_s_tokens['input_ids']
C_temp_s_tokens = Counter(temp_s_tokens)
C_temp_s_tokens['220'] = 0
for token in C_temp_s_tokens:
if C_temp_s_tokens[token] >= cap:
C_temp_s_tokens[token] = cap
return counter_cosine_similarity(C_temp_s_tokens, C_temp_q_tokens)
import copy
def process_text_question(question, company, file_path):
""" 处理单个问题 """
try:
temp_q_list = question.split()
temp_q_tokens = list()
for word in temp_q_list:
temp_q_tokens_add = tokenizer(word)
temp_q_tokens_add = temp_q_tokens_add['input_ids']
for word_add in temp_q_tokens_add:
temp_q_tokens.append(word_add)
C_temp_q_tokens = Counter(temp_q_tokens)
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
content = content.replace(' ', '')
text_list = text_split(content)
t = copy.deepcopy(text_list)
sim_list = list()
for text in text_list:
text1 = text
for bd in bd_list:
text1 = text1.replace(bd,' ')
sim = text_similarity(text1, C_temp_q_tokens)
sim_list.append(sim)
sorted_indices = sorted(enumerate(sim_list), key=lambda x: x[1], reverse=True)
top_texts = [t[index] for index, _ in sorted_indices[:5]]
# prompt = ChatPromptTemplate.from_template(
# "你是一个能精准提取文本信息并回答问题的AI。\n"
# "请根据以下资料的所有内容,首先帮我判断能否依据给定材料回答出问题。"
# "如果能根据给定材料回答,则提取出最合理的答案来回答问题,并回答出完整内容,不要输出表格:\n\n"
# "{text}\n\n"
# "请根据以上材料回答:{q}\n\n"
# "请按以下格式输出:\n"
# "能否根据给定材料回答问题:回答能或否\n"
# "答案:").format_messages(q=question, text="\n".join(top_texts))
prompt = ChatPromptTemplate.from_template(
"你是一个能精准提取文本信息并回答问题的AI。\n"
"下面是一段资料,不要计算,不要计算,直接从资料中寻找问题的答案,使用完整的句子回答问题。\n "
"如果资料不包含问题的答案,回答“不知道。”如果从资料无法得出问题的答案,回答“不知道。”如果答案未在资料中说明,回答“不知道。”如果资料与问题无关或者在资料中找不到问题的答案,回答“不知道。”如果资料没有明确说明问题答案,回答“不知道。”资料:\n\n"
"{text}\n\n"
"请根据以上材料回答:{q}\n\n"
"答案:").format_messages(q=question, text="\n".join(top_texts))
response = tongyi(prompt[0].content, temperature=0.01, top_p=0.5)
return (response, top_texts)
except Exception as e:
print(f"Error processing question: {e}")
return None
print('C00_Started')
for cyc in range(1000):
temp_q = q_file[cyc:cyc + 1]['问题'][cyc]
temp_class = q_file[cyc:cyc + 1]['分类'][cyc]
temp_e = q_file[cyc:cyc + 1]['对应实体'][cyc]
print(cyc)
if temp_e == 'N_A':
csvwriter.writerow([q_file[cyc:cyc + 1]['问题id'][cyc],
q_file[cyc:cyc + 1]['问题'][cyc],
'N_A', 'N_A', 'N_A', 'N_A', 'N_A'])
continue
else:
if '\n' in temp_e:
temp_e = temp_e.replace('\n', '')
print(f'问题:{temp_q}')
print(f'分类:{temp_class}')
print(f'对应实体:{temp_e}')
temp_text_name = q_file[cyc:cyc + 1]['csv文件名'][cyc]
print(f'csv文件名:{temp_text_name}')
temp_text_name = temp_text_name.replace('PDF.csv', '')
temp_text_name = temp_text_name + "txt"
temp_csv_dir = text_file_dir + '/' + temp_text_name
print(f'csv文件名[转换后]:{temp_csv_dir}')
temp_q = temp_q.replace(' ', '')
temp_q = temp_q.replace(temp_e, ' ')
#去除截至与日期,使得匹配更有针对性
str1_list = re.findall(pattern1, temp_q)
str2_list = re.findall(pattern2, temp_q)
for word in str1_list:
temp_q = temp_q.replace(word,'')
for word in str2_list:
temp_q = temp_q.replace(word,'')
for word in stopword_list:
temp_q = temp_q.replace(word, ' ')
print(f'问题[标准化后]:{temp_q}')
FA, top_text = process_text_question(temp_q, temp_e, temp_csv_dir)
print(f'答案如下:')
print(FA)
print("-----------------------------------------------------------")
csvwriter.writerow([q_file[cyc:cyc + 1]['问题id'][cyc],
q_file[cyc:cyc + 1]['问题'][cyc],
temp_q, temp_e, temp_text_name, FA, json.dumps(top_text, ensure_ascii=False)])
g.close()