自制RAG工具:docx文档读取工具
- 1. 介绍
- 2. 源码
- 2.1 chunk
- 2.2 DocReader
- 3. 使用方法
- 3.1 文档格式设置
- 3.2 代码使用方法
1. 介绍
在RAG相关的工作中,经常会用到读取docx文档的功能,为了更好地管理文档中的各个分块,以提供更高质量的prompt给LLM,我在去年实现了一个轻量好用的docx管理工具。
主要应用到python模块docx。安装依赖:
python-docx 1.0.1
2. 源码
代码结构非常简单,仅有两个类构成。以及需要引用的部分:
import docx
from uuid import uuid4
from typing import *
ZH2NUM = {'一': 1, '二': 2, '三': 3, '四': 4, '五': 5, '六': 6, '七': 7, '八': 8, '九': 9, '十': 10}
2.1 chunk
chunk类指的是文档中一个分块,考虑到LLM的长度限制问题和成本问题,通常需要对文档进行分块处理,尤其是对于篇幅很长的文档,需要在文档内部再做一次召回。
class Chunk:
"""
文本块
---------------
ver: 2023-11-02
by: changhongyu
"""
def __init__(self, id_: str, level: int, content: str, children: List = None, max_depth: int = 99):
"""
:param id_: 此文本块的唯一id
:param level: 此文本块的层级
:param content: 此文本块的内容
:param children: 此文本块的所有下一级文本块
:param max_depth: 允许存在的文本块最大层级数,由DocReader控制
"""
self.id = id_
self.level = level
self.content = content
self.children = children
self.max_depth = max_depth
if not self.children:
self.children = []
self.path_to_this_chunk = None
self.title_path = None
def __len__(self):
return len(self.content)
def __str__(self):
msg = ''
msg += f'[{self.level}]'
if self.level >= 99:
msg += ' ' * self.max_depth
else:
msg += ' ' * (self.level - 1)
if len(self.content) < 20:
msg += self.content
else:
msg += self.content[:20]
msg += '...'
msg += '\n'
for child in self.children:
msg += str(child)
return msg
2.2 DocReader
文档读取器,直接传入一篇文档的地址以实例化。
class DocxReader:
"""
读取一篇docx文档,并转化为结构化格式
---------------
ver: 2023-11-02
by: changhongyu
"""
def __init__(self, doc_path: str, doc_name: str = None, filter_none_text: bool = True):
"""
:param doc_path: 文件路径
:param doc_name: 文档名称,如空则用路径为名
:param filter_none_text: 过滤掉非文本的内容
"""
self.doc = docx.Document(doc_path)
self.doc_name = doc_name if doc_name else doc_path
max_depth = max(self.style2level(para.style.name) for para in self.doc.paragraphs)
self.chunks = [self.para2chunk(para, max_depth=max_depth) for para in self.doc.paragraphs if para.text]
if filter_none_text:
self.chunks = [chunk for chunk in self.chunks if chunk.level <= 10]
# 合并level==10的内容部分
self.chunks = self.combine_chunks()
self.id2chunks = {chunk.id: chunk for chunk in self.chunks}
self.doc_tree = Chunk(self.doc_name, 0, self.doc_name, self.build_tree())
for chunk in self.chunks:
_ = self.get_path_to(chunk.id)
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
return self.id2chunks[idx]
@classmethod
def style2level(cls, style_name):
if '级' in style_name:
num = style_name.split('级')[0]
if num in ZH2NUM:
return ZH2NUM[num]
else:
return int(num)
if '标题' in style_name:
return int(style_name.strip().split('标题')[-1])
elif '正文' in style_name or 'Normal' in style_name:
return 10
elif 'Heading' in style_name:
return int(style_name.split('Heading')[-1])
elif '图' in style_name or 'Caption' in style_name:
return 11
else:
return 12
def para2chunk(self, para, max_depth):
return Chunk(self.doc_name + '**' + str(uuid4()), self.style2level(para.style.name), para.text, [], max_depth)
def combine_chunks(self):
"""
将同一级目录下的数据合并成大一点的分块
只对内容块进行操作,不对标题块进行操作
"""
combined_chunks = []
prev_level = -1
cur_chunk = None
for chunk in self.chunks:
if chunk.level < 10:
if cur_chunk is not None:
# 如果当前存在分块,则将其添加
combined_chunks.append(cur_chunk)
cur_chunk = None
# 如果是标题,则无需对这个分块进行处理
combined_chunks.append(chunk)
elif chunk.level == 10:
# 如果是内容,则判断上一个分块是否也是内容
if prev_level == 10:
# 如果是连续的内容,则合并
cur_chunk.content += f'\n{chunk.content}'
else:
# 如果是第一次遇到内容,则创建要给新的
cur_chunk = Chunk(self.doc_name + '**' + str(uuid4()), 10, chunk.content, [], 10)
else:
continue
prev_level = chunk.level
return combined_chunks
def build_tree(self):
"""
建立树状结构
"""
def build_subtree(chunks_):
if not chunks_:
return []
root = chunks_[0]
remaining_nodes = chunks_[1:]
# 找到当前节点的子节点
child_nodes = []
while remaining_nodes and remaining_nodes[0].level > root.level:
child_nodes.append(remaining_nodes.pop(0))
# 递归构建子树
root.children = build_subtree(child_nodes)
return [root] + build_subtree(remaining_nodes)
return build_subtree(self.chunks)
def get_path_to(self, id_: str):
"""
给定一个文本块的id,获取从根节点到该块的路径
"""
def trace_back(chunk: Chunk, cur_path: List, target_id: str):
if chunk is None:
return
# 如果找到了直接返回
cur_path.append(chunk)
if chunk.id == target_id:
return cur_path
# dfs子树
for child in chunk.children:
path_ = trace_back(child, cur_path, target_id)
if path_ is not None:
return path_
# 如果没有找到,回溯
cur_path.pop()
return
# 先检查这个路径有没有已经计算过
if self[id_].path_to_this_chunk is not None:
return self[id_].path_to_this_chunk
path_chunks = trace_back(self.doc_tree, [], id_)
for i in range(1, len(path_chunks)):
# 跟新路径中的所有chunk的路径
if path_chunks[i].path_to_this_chunk is None:
path_chunks[i].path_to_this_chunk = [chunk.id for chunk in path_chunks[: i]]
path_id = [chunk.id for chunk in path_chunks]
title_path = '_'.join(chunk.content for chunk in path_chunks if chunk.level < 10)
if not title_path:
title_path = self.doc_name
self[id_].path_to_this_chunk = path_id
self[id_].title_path = title_path
return path_id
3. 使用方法
3.1 文档格式设置
这个工具并不是任何文档都能使用的,由于是基于docx模块来确定每个段落的层级的,所以需要在文档中,把文档格式设置成正确的格式。这个过程目前没有想到特别好的自动化实现的方法。最好是在编辑文档的时候就注意一下格式规范。
以WPS为例,需要在这里选中对应的格式:
3.2 代码使用方法
本工具的功能主要包括如下:
(1)创建
doc_item = DocxReader('./测试.docx', doc_name='测试')
(2)打印结构
print(doc_item.doc_tree)
# [0]测试
# [1]标题1
# [2] 标题1-1
# [3] 标题1-1-1
# [10] 111内容内容内容内容
# [2] 标题1-2
# [3] 标题1-2-1
# [10] 121内容内容内容内容
# [3] 标题1-2-2
(3)id索引
doc_item.id2chunks
# {'测试**a7f1dba1-89fc-4861-b9b4-2035c1d136f7': <__main__.Chunk at 0x7f235145d750>,
# '测试**5dc92bc5-e856-47e7-a67e-8fff6b496614': <__main__.Chunk at 0x7f235145d6f0>,
# '测试**7dec18a1-b50d-47d8-808d-250f16f240e9': <__main__.Chunk at 0x7f235145c4c0>,
# '测试**2ab5291e-b2a1-4d04-a19b-37e9fceb56f9': <__main__.Chunk at 0x7f235145cfa0>,
# '测试**e3eb2029-6493-40df-803a-5c80f2ce7040': <__main__.Chunk at 0x7f235145cee0>,
# '测试**4d081a9e-b689-47aa-a75a-6786f90a111e': <__main__.Chunk at 0x7f235145c460>,
# '测试**07dc5c3f-7fd0-4cce-8204-1fe7da847e68': <__main__.Chunk at 0x7f235145c0d0>,
# '测试**f0bcc14a-7bdd-4899-8691-fca284772a9b': <__main__.Chunk at 0x7f235145c160>}
doc_item.id2chunks['测试**a7f1dba1-89fc-4861-b9b4-2035c1d136f7'].children
# [<__main__.Chunk at 0x7f235145d6f0>, <__main__.Chunk at 0x7f235145cee0>]
doc_item.id2chunks['测试**a7f1dba1-89fc-4861-b9b4-2035c1d136f7'].children[0].content
# '标题1-1'
(4)按照所有父级标题拼接当前标题
这个功能是为了给当前文本内容(正文部分,level=10),生成一个综合了之前所有标题信息的当前标题,用于向量化检索。
doc_item.chunks[4].title_path
# '测试_标题1_标题1-2'
(5)获取路径
类似于4中拼接标题,也可以获取从根节点到当前chunk的路径,返回结果是路径中所有id构成的列表:
doc_item.chunks[4].path_to_this_chunk
# ['测试',
# '测试**a7f1dba1-89fc-4861-b9b4-2035c1d136f7',
# '测试**e3eb2029-6493-40df-803a-5c80f2ce7040']
(6)更多自定义扩展应用
使用者可以根据自己的实际需求进一步开发该工具的使用,例如获取某个层级的所有分块:
level_2 = [chunk for chunk in doc_item.chunks if chunk.level == 2]
level_2
# [<__main__.Chunk at 0x7f235145d6f0>, <__main__.Chunk at 0x7f235145cee0>]
以上就是本文的全部内容,如果对你有所帮助,记得点一个免费的赞。