NLP_知识图谱_图谱问答实战

文章目录

  • 图谱问答
    • NER
    • ac自动机
    • 实体链接
    • 实体消歧
  • 多跳问答
    • neo4j_graph执行流程
    • 结构图![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/1577c1d9c9e342b3acbf79824aae980f.png)
      • company_data![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/20f567d877c743b49546e50caad92fba.png)
  • 代码与数据
    • 先启动neo4j图数据库
    • import_data
    • create_question_data
    • data_process
    • ac_automaton
    • torch_utils
    • text_cnn
    • train
    • main
  • 图谱问答实战小结


图谱问答

图谱问答有很多种情况,例如根据实体和关系查询尾实体,或者根据实体查询关系,甚至还会出现多跳的情况,不同的情况采用的方法略有不同,我们先来看最简单的情况,根据头实体和关系查询尾实体。

1、找到实体与关系,可以采用BIO的形式做NER,也可以直接使用分类的方法
2、实体链接,如果遇到相同名字的实体,需要做一个消歧

NER

目前的NER的方式很多,基本的结构都是encoder+crf层

ac自动机

1、构建前缀树
2、给前缀树加上fail指针
节点i的fail指针,如果在第一层,则指向root节点,其它情况指向其父节点的fail指针指向的节点的相同节点
在这里插入图片描述

有如下的几个模式串:she he say shr her
匹配串:yasherhs
在这里插入图片描述

实体链接

实体链接包括两个步骤:
Candidate Entity Generation、Entity Disambiguation

找到候选实体后,下一步就是实体消歧

实体消歧

实体消歧,这里我们使用的是匹配的方法:
1、使用孪生网络,计算相似度
2、对问题和候选集做embedding,计算余弦相似度

多跳问答

在这里插入图片描述

neo4j_graph执行流程

1、先执行import_data.py脚本,把company_data下面的数据导入到neo4j

2、执行gnn/saint.py脚本进行节点分类

3、company.csv文件是每个节点的属性

结构图在这里插入图片描述

company_data在这里插入图片描述

截图举几个例子(ps:数据为虚假,作为学习使用):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码与数据

先启动neo4j图数据库

操作流程:WIN+R,cmd,neo4j.bat concole
在这里插入图片描述
在这里插入图片描述

import_data

import os
from py2neo import Node, Subgraph, Graph, Relationship, NodeMatcher
from tqdm import tqdm
import pandas as pd
import numpy as np

#graph = Graph("http://127.0.0.1:7474", auth=("neo4j", "qwer"))

#graph = Graph("http://127.0.0.1:7474", auth=("neo4j", "qwer"))

#uri = 'bolt://localhost:7687'
#graph = Graph(uri, auth=("neo4j", "password"), port= 7687, secure=True)

#uri = uri = 'http://localhost:7687'
#graph = Graph(uri, auth=("neo4j", "qwer"), port= 7687, secure=True, name= "StellarGraph")

import py2neo
default_host = os.environ.get("STELLARGRAPH_NEO4J_HOST")

# Create the Neo4j Graph database object; the arguments can be edited to specify location and authentication

graph = py2neo.Graph(host=default_host, port=7687, user='neo4j', password='qwer')

def import_company():
    df = pd.read_csv('company_data/公司.csv')
    eid = df['eid'].values
    name = df['companyname'].values

    nodes = []
    data = list(zip(eid, name))
    for eid, name in tqdm(data):
        profit = np.random.randint(100000, 100000000, 1)[0]
        node = Node('company', name=name, profit=int(profit), eid=eid)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_person():
    df = pd.read_csv('company_data/人物.csv')
    pid = df['personcode'].values
    name = df['personname'].values

    nodes = []
    data = list(zip(pid, name))
    for eid, name in tqdm(data):
        age = np.random.randint(20, 70, 1)[0]
        node = Node('person', name=name, age=int(age), pid=str(eid))
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_industry():
    df = pd.read_csv('company_data/行业.csv')
    names = df['orgtype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('industry', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_assign():
    df = pd.read_csv('company_data/分红.csv')
    names = df['schemetype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('assign', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_violations():
    df = pd.read_csv('company_data/违规类型.csv')
    names = df['gooltype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('violations', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


def import_bond():
    df = pd.read_csv('company_data/债券类型.csv')
    names = df['securitytype'].values

    nodes = []
    for name in tqdm(names):
        node = Node('bond', name=name)
        nodes.append(node)

    graph.create(Subgraph(nodes))


# def import_dishonesty():
#     node = Node('dishonesty', name='失信')
#     graph.create(node)


def import_relation():
    df = pd.read_csv('company_data/公司-人物.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    pid = df['pid'].values
    post = df['post'].values
    relations = []
    data = list(zip(eid, pid, post))
    for e, p, po in tqdm(data):
        company = matcher.match('company', eid=e).first()
        person = matcher.match('person', pid=str(p)).first()
        if company is not None and person is not None:
            relations.append(Relationship(company, po, person))

    graph.create(Subgraph(relationships=relations))
    print('import company-person relation succeeded')

    df = pd.read_csv('company_data/公司-行业.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['industry'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        industry = matcher.match('industry', name=str(n)).first()
        if company is not None and industry is not None:
            relations.append(Relationship(company, '行业类型', industry))

    graph.create(Subgraph(relationships=relations))
    print('import company-industry relation succeeded')

    df = pd.read_csv('company_data/公司-分红.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['assign'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        assign = matcher.match('assign', name=str(n)).first()
        if company is not None and assign is not None:
            relations.append(Relationship(company, '分红方式', assign))

    graph.create(Subgraph(relationships=relations))
    print('import company-assign relation succeeded')

    df = pd.read_csv('company_data/公司-违规.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['violations'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        violations = matcher.match('violations', name=str(n)).first()
        if company is not None and violations is not None:
            relations.append(Relationship(company, '违规类型', violations))

    graph.create(Subgraph(relationships=relations))
    print('import company-violations relation succeeded')

    df = pd.read_csv('company_data/公司-债券.csv')
    matcher = NodeMatcher(graph)
    eid = df['eid'].values
    name = df['bond'].values
    relations = []
    data = list(zip(eid, name))
    for e, n in tqdm(data):
        company = matcher.match('company', eid=e).first()
        bond = matcher.match('bond', name=str(n)).first()
        if company is not None and bond is not None:
            relations.append(Relationship(company, '债券类型', bond))

    graph.create(Subgraph(relationships=relations))
    print('import company-bond relation succeeded')

    # df = pd.read_csv('company_data/公司-失信.csv')
    # matcher = NodeMatcher(graph)
    # eid = df['eid'].values
    # rel = df['dishonesty'].values
    # relations = []
    # data = list(zip(eid, rel))
    # for e, r in tqdm(data):
    #     company = matcher.match('company', eid=e).first()
    #     dishonesty = matcher.match('dishonesty', name='失信').first()
    #     if company is not None and dishonesty is not None:
    #         if pd.notna(r):
    #             if int(r) == 0:
    #                 relations.append(Relationship(company, '无', dishonesty))
    #             elif int(r) == 1:
    #                 relations.append(Relationship(company, '有', dishonesty))
    #
    # graph.create(Subgraph(relationships=relations))
    # print('import company-dishonesty relation succeeded')


def import_company_relation():
    df = pd.read_csv('company_data/公司-供应商.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2) and e1 != e2:
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '供应商', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-supplier relation succeeded')

    df = pd.read_csv('company_data/公司-担保.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2) and e1 != e2:
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '担保', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-guarantee relation succeeded')

    df = pd.read_csv('company_data/公司-客户.csv')
    matcher = NodeMatcher(graph)
    eid1 = df['eid1'].values
    eid2 = df['eid2'].values
    relations = []
    data = list(zip(eid1, eid2))
    for e1, e2 in tqdm(data):
        if pd.notna(e1) and pd.notna(e2):
            company1 = matcher.match('company', eid=e1).first()
            company2 = matcher.match('company', eid=e2).first()

            if company1 is not None and company2 is not None:
                relations.append(Relationship(company1, '客户', company2))

    graph.create(Subgraph(relationships=relations))
    print('import company-customer relation succeeded')


def delete_relation():
    cypher = 'match ()-[r]-() delete r'
    graph.run(cypher)


def delete_node():
    cypher = 'match (n) delete n'
    graph.run(cypher)


def import_data():
    import_company()
    import_company_relation()

    import_person()
    import_industry()
    import_assign()
    import_violations()
    import_bond()
    # import_dishonesty()

    import_relation()


def delete_data():
    delete_relation()
    delete_node()
    print('delete data succeeded')


if __name__ == '__main__':
    profit = np.random.randint(100000, 100000000, 10).tolist()

    delete_data()
    import_data()

在这里插入图片描述

create_question_data

from py2neo import Graph
import numpy as np
import pandas as pd

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))

# import os
# import py2neo
# default_host = os.environ.get("STELLARGRAPH_NEO4J_HOST")
# graph = py2neo.Graph(host=default_host, port=7687, user='neo4j', password='qwer')

def create_attribute_question():
    company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
    person = graph.run('MATCH (n:person) RETURN n.name as name').to_ndarray()

    questions = []

    for c in company:
        c = c[0].strip()
        question = f"{c}的收益"
        questions.append(question)
        question = f"{c}的收入"
        questions.append(question)

    for p in person:
        p = p[0].strip()
        question = f"{p}的年龄是几岁"
        questions.append(question)
        question = f"{p}多大"
        questions.append(question)
        question = f"{p}几岁"
        questions.append(question)

    return questions


def create_entity_question():
    questions = []

    for _ in range(250):
        for op in ['大于', '等于', '小于', '是', '有']:
            profit = np.random.randint(10000, 10000000, 1)[0]
            question = f"收益{op}{profit}的公司有哪些"
            questions.append(question)
            profit = np.random.randint(10000, 10000000, 1)[0]
            question = f"哪些公司收益{op}{profit}"
            questions.append(question)

    for _ in range(250):
        for op in ['大于', '等于', '小于', '是', '有']:
            profit = np.random.randint(20, 60, 1)[0]
            question = f"年龄{op}{profit}的人有哪些"
            questions.append(question)
            profit = np.random.randint(20, 60, 1)[0]
            question = f"哪些人年龄{op}{profit}"
            questions.append(question)

    return questions


def create_relation_question():
    relation = graph.run('MATCH (n)-[r]->(m) RETURN n.name as name, type(r) as r').to_ndarray()

    questions = []

    for r in relation:
        if str(r[1]) in ['董事', '监事']:
            question = f"{r[0]}{r[1]}是谁"
            questions.append(question)
        else:
            question = f"{r[0]}{r[1]}"
            questions.append(question)
            question = f"{r[0]}{r[1]}是啥"
            questions.append(question)
            question = f"{r[0]}{r[1]}什么"
            questions.append(question)

    return questions


q1 = create_entity_question()
q2 = create_attribute_question()
q3 = create_relation_question()

df = pd.DataFrame()
df['question'] = q1 + q2 + q3
df['label'] = [0] * len(q1) + [1] * len(q2) + [2] * len(q3)

df.to_csv('question_classification.csv', encoding='utf_8_sig', index=False)

data_process

import pandas as pd
import jieba
from collections import defaultdict
import numpy as np
import os

__file__ = 'kbqa'
path = os.path.dirname(__file__)


def tokenize(text, use_jieba=True):
    if use_jieba:
        res = list(jieba.cut(text, cut_all=False))
    else:
        res = list(text)
    return res


# 构建词典
def build_vocab(del_word_frequency=0):
    data = pd.read_csv('question_classification.csv')
    segment = data['question'].apply(tokenize)

    word_frequency = defaultdict(int)
    for row in segment:
        for i in row:
            word_frequency[i] += 1

    word_sort = sorted(word_frequency.items(), key=lambda x: x[1], reverse=True)  # 根据词频降序排序

    f = open('vocab.txt', 'w', encoding='utf-8')
    f.write('[PAD]' + "\n" + '[UNK]' + "\n")
    for d in word_sort:
        if d[1] > del_word_frequency:
            f.write(d[0] + "\n")
    f.close()


# 划分训练集和测试集
def split_data(df, split=0.7):
    df = df.sample(frac=1)
    length = len(df)
    train_data = df[0:length - 2000]
    eval_data = df[length - 2000:]

    return train_data, eval_data


vocab = {}
if os.path.exists(path + '/vocab.txt'):
    with open(path + '/vocab.txt', encoding='utf-8')as file:
        for line in file.readlines():
            vocab[line.strip()] = len(vocab)


# 把数据转换成index
def seq2index(seq):
    seg = tokenize(seq)
    seg_index = []
    for s in seg:
        seg_index.append(vocab.get(s, 1))
    return seg_index


# 统一长度
def padding_seq(X, max_len=10):
    return np.array([
        np.concatenate([x, [0] * (max_len - len(x))]) if len(x) < max_len else x[:max_len] for x in X
    ])


if __name__ == '__main__':
    build_vocab(5)

在这里插入图片描述

ac_automaton

import ahocorasick
from py2neo import Graph

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))
company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
relation = graph.run('MATCH ()-[r]-() RETURN distinct type(r)').to_ndarray()

ac_company = ahocorasick.Automaton()
ac_relation = ahocorasick.Automaton()

for key in enumerate(company):
    ac_company.add_word(key[1][0], key[1][0])
for key in enumerate(relation):
    ac_relation.add_word(key[1][0], key[1][0])

ac_company.make_automaton()
ac_relation.make_automaton()

# haystack = '浙江东阳东欣房地产开发有限公司的客户的供应商'
haystack = '衡水中南锦衡房地产有限公司的债券类型'
# haystack = '临沂金丰公社农业服务有限公司的分红方式'
print('question:', haystack)

subject = ''
predicate = []

for end_index, original_value in ac_company.iter(haystack):
    start_index = end_index - len(original_value) + 1
    print('公司实体:', (start_index, end_index, original_value))
    assert haystack[start_index:start_index + len(original_value)] == original_value
    subject = original_value

for end_index, original_value in ac_relation.iter(haystack):
    start_index = end_index - len(original_value) + 1
    print('关系:', (start_index, end_index, original_value))
    assert haystack[start_index:start_index + len(original_value)] == original_value
    predicate.append(original_value)

for p in predicate:
    cypher = f'''match (s:company)-[p:`{p}`]-(o) where s.name='{subject}' return o.name'''
    print(cypher)
    res = graph.run(cypher).to_ndarray()
    # print(res)
    subject = res[0][0]
print('answer:', res[0][0])

在这里插入图片描述

torch_utils

import torch
import time
import numpy as np
import six


class TrainHandler:

    def __init__(self,
                 train_loader,
                 valid_loader,
                 model,
                 criterion,
                 optimizer,
                 model_path,
                 batch_size=32,
                 epochs=5,
                 scheduler=None,
                 gpu_num=0):
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.model_path = model_path
        self.batch_size = batch_size
        self.epochs = epochs
        self.scheduler = scheduler
        if torch.cuda.is_available():
            self.device = torch.device(f'cuda:{gpu_num}')
            print('Training device is gpu:{gpu_num}')
        else:
            self.device = torch.device('cpu')
            print('Training device is cpu')
        self.model = model.to(self.device)

    def _train_func(self):
        train_loss = 0
        train_correct = 0
        for i, (x, y) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            x, y = x.to(self.device).long(), y.to(self.device)
            output = self.model(x)
            loss = self.criterion(output, y)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
            train_correct += (output.argmax(1) == y).sum().item()

        if self.scheduler is not None:
            self.scheduler.step()

        return train_loss / len(self.train_loader), train_correct / len(self.train_loader.dataset)

    def _test_func(self):
        valid_loss = 0
        valid_correct = 0
        for x, y in self.valid_loader:
            x, y = x.to(self.device).long(), y.to(self.device)
            with torch.no_grad():
                output = self.model(x)
                loss = self.criterion(output, y)
                valid_loss += loss.item()
                valid_correct += (output.argmax(1) == y).sum().item()

        return valid_loss / len(self.valid_loader), valid_correct / len(self.valid_loader.dataset)

    def train(self):
        min_valid_loss = float('inf')

        for epoch in range(self.epochs):
            start_time = time.time()
            train_loss, train_acc = self._train_func()
            valid_loss, valid_acc = self._test_func()

            if min_valid_loss > valid_loss:
                min_valid_loss = valid_loss
                torch.save(self.model, self.model_path)
                print(f'\tSave model done valid loss: {valid_loss:.4f}')

            secs = int(time.time() - start_time)
            mins = secs / 60
            secs = secs % 60

            print('Epoch: %d' % (epoch + 1), " | time in %d minutes, %d seconds" % (mins, secs))
            print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
            print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')


def torch_text_process():
    from torchtext import data

    def tokenizer(text):
        import jieba
        return list(jieba.cut(text))

    TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True, fix_length=20)
    LABEL = data.Field(sequential=False, use_vocab=False)
    all_dataset = data.TabularDataset.splits(path='',
                                             train='LCQMC.csv',
                                             format='csv',
                                             fields=[('sentence1', TEXT), ('sentence2', TEXT), ('label', LABEL)])[0]
    TEXT.build_vocab(all_dataset)
    train, valid = all_dataset.split(0.1)
    (train_iter, valid_iter) = data.BucketIterator.splits(datasets=(train, valid),
                                                          batch_sizes=(64, 128),
                                                          sort_key=lambda x: len(x.sentence1))
    return train_iter, valid_iter


def pad_sequences(sequences, maxlen=None, dtype='int32',
                  padding='post', truncating='pre', value=0.):
    """Pads sequences to the same length.

    This function transforms a list of
    `num_samples` sequences (lists of integers)
    into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
    `num_timesteps` is either the `maxlen` argument if provided,
    or the length of the longest sequence otherwise.

    Sequences that are shorter than `num_timesteps`
    are padded with `value` at the end.

    Sequences longer than `num_timesteps` are truncated
    so that they fit the desired length.
    The position where padding or truncation happens is determined by
    the arguments `padding` and `truncating`, respectively.

    Pre-padding is the default.

    # Arguments
        sequences: List of lists, where each element is a sequence.
        maxlen: Int, maximum length of all sequences.
        dtype: Type of the output sequences.
            To pad sequences with variable length strings, you can use `object`.
        padding: String, 'pre' or 'post':
            pad either before or after each sequence.
        truncating: String, 'pre' or 'post':
            remove values from sequences larger than
            `maxlen`, either at the beginning or at the end of the sequences.
        value: Float or String, padding value.

    # Returns
        x: Numpy array with shape `(len(sequences), maxlen)`

    # Raises
        ValueError: In case of invalid values for `truncating` or `padding`,
            or in case of invalid shape for a `sequences` entry.
    """
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    num_samples = len(sequences)

    lengths = []
    for x in sequences:
        try:
            lengths.append(len(x))
        except TypeError:
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))

    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break

    is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
    if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
        raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
                         "You should set `dtype=object` for variable length strings."
                         .format(dtype, type(value)))

    x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" '
                             'not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s '
                             'is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x


if __name__ == '__main__':
    torch_text_process()

text_cnn

import torch
from torch import nn


class TextCNN(nn.Module):
    def __init__(self, vocab_len, embedding_size, n_class):
        super().__init__()

        self.embedding = nn.Embedding(vocab_len, embedding_size)

        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[3, embedding_size])
        self.cnn2 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[4, embedding_size])
        self.cnn3 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=[5, embedding_size])

        self.max_pool1 = nn.MaxPool1d(kernel_size=8)
        self.max_pool2 = nn.MaxPool1d(kernel_size=7)
        self.max_pool3 = nn.MaxPool1d(kernel_size=6)

        self.drop_out = nn.Dropout(0.2)
        self.full_connect = nn.Linear(300, n_class)

    def forward(self, x):
        embedding = self.embedding(x)
        embedding = embedding.unsqueeze(1)

        cnn1_out = self.cnn1(embedding).squeeze(-1)
        cnn2_out = self.cnn2(embedding).squeeze(-1)
        cnn3_out = self.cnn3(embedding).squeeze(-1)

        out1 = self.max_pool1(cnn1_out)
        out2 = self.max_pool2(cnn2_out)
        out3 = self.max_pool3(cnn3_out)

        out = torch.cat([out1, out2, out3], dim=1).squeeze(-1)

        out = self.drop_out(out)
        out = self.full_connect(out)
        # out = torch.softmax(out, dim=-1).squeeze(dim=-1)
        return out

train

import torch
from torch.utils.data import TensorDataset, DataLoader
from kbqa.torch_utils import TrainHandler
from kbqa.data_process import *
from kbqa.text_cnn import TextCNN

# df = pd.read_csv('question_classification.csv')
# print(df['label'].value_counts())


def load_data(batch_size=32):
    df = pd.read_csv('kbqa/question_classification.csv')
    train_df, eval_df = split_data(df)
    train_x = df['question']
    train_y = df['label']
    valid_x = eval_df['question']
    valid_y = eval_df['label']

    train_x = padding_seq(train_x.apply(seq2index))
    train_y = np.array(train_y)
    valid_x = padding_seq(valid_x.apply(seq2index))
    valid_y = np.array(valid_y)

    train_data_set = TensorDataset(torch.from_numpy(train_x),
                                   torch.from_numpy(train_y))
    valid_data_set = TensorDataset(torch.from_numpy(valid_x),
                                   torch.from_numpy(valid_y))
    train_data_loader = DataLoader(dataset=train_data_set, batch_size=batch_size, shuffle=True)
    valid_data_loader = DataLoader(dataset=valid_data_set, batch_size=batch_size, shuffle=True)

    return train_data_loader, valid_data_loader


train_loader, valid_loader = load_data(batch_size=64)

model = TextCNN(1289, 256, 3)# 原model = TextCNN(1141, 256, 3),1289根据vocat.txt行数
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
model_path = 'text_cnn.p'
handler = TrainHandler(train_loader,valid_loader,model,
                       criterion,
                       optimizer,
                       model_path,
                       batch_size=32,
                       epochs=5,
                       scheduler=None,
                       gpu_num=0)
handler.train()

在这里插入图片描述

main

import torch
from kbqa.data_process import *
import ahocorasick
from py2neo import Graph
import re
import traceback


model = torch.load('kbqa/text_cnn.p', map_location=torch.device('cpu'))
model.eval()

graph = Graph("http://localhost:7474", auth=("neo4j", "qwer"))
company = graph.run('MATCH (n:company) RETURN n.name as name').to_ndarray()
person = graph.run('MATCH (n:person) RETURN n.name as name').to_ndarray()
relation = graph.run('MATCH ()-[r]-() RETURN distinct type(r)').to_ndarray()

ac_company = ahocorasick.Automaton()
ac_person = ahocorasick.Automaton()
ac_relation = ahocorasick.Automaton()

for key in enumerate(company):
    ac_company.add_word(key[1][0], key[1][0])
for key in enumerate(person):
    ac_person.add_word(key[1][0], key[1][0])
for key in enumerate(relation):
    ac_relation.add_word(key[1][0], key[1][0])
ac_relation.add_word('年龄', '年龄')
ac_relation.add_word('年纪', '年纪')
ac_relation.add_word('收入', '收入')
ac_relation.add_word('收益', '收益')

ac_company.make_automaton()
ac_person.make_automaton()
ac_relation.make_automaton()


def classification_predict(s):
    s = seq2index(s)
    s = torch.from_numpy(padding_seq([s])).long() #.cuda().long()
    out = model(s)
    out = out.cpu().data.numpy()
    print(out)
    return out.argmax(1)[0]


def entity_link(text):
    subject = []
    subject_type = None
    for end_index, original_value in ac_company.iter(text):
        start_index = end_index - len(original_value) + 1
        print('实体:', (start_index, end_index, original_value))
        assert text[start_index:start_index + len(original_value)] == original_value
        subject.append(original_value)
        subject_type = 'company'
    for end_index, original_value in ac_person.iter(text):
        start_index = end_index - len(original_value) + 1
        print('实体:', (start_index, end_index, original_value))
        assert text[start_index:start_index + len(original_value)] == original_value
        subject.append(original_value)
        subject_type = 'person'

    return subject[0], subject_type


def get_op(text):
    pattern = re.compile(r'\d+')
    num = pattern.findall(text)
    op = None
    if '大于' in text:
        op = '>'
    elif '小于' in text:
        op = '<'
    elif '等于' in text or '是' in text:
        op = '='
    return op, float(num[0])


def kbqa(text):
    print('*' * 100)
    cls = classification_predict(text)
    print('question type:', cls)
    res = ''

    if cls == 0:
        op, num = get_op(text)
        subject_type = ''
        attribute = ''
        for w in ['年龄', '年纪']:
            if w in text:
                subject_type = 'person'
                attribute = 'age'
                break
        for w in ['收入', '收益']:
            if w in text:
                subject_type = 'company'
                attribute = 'profit'
                break
        cypher = f'match (n:{subject_type}) where n.{attribute}{op}{num} return n.name'
        print(cypher)
        res = graph.run(cypher).to_ndarray()
    elif cls == 1:
        # 查询属性
        subject, subject_type = entity_link(text)
        predicate = ''
        for w in ['年龄', '年纪']:
            if w in text and subject_type == 'person':
                predicate = 'age'
                break
        for w in ['收入', '收益']:
            if w in text and subject_type == 'company':
                predicate = 'profit'
                break
        cypher = f'''match (n:{subject_type}) where n.name='{subject}' return n.{predicate}'''
        print(cypher)
        res = graph.run(cypher).to_ndarray()
    elif cls == 2:
        subject = ''
        for end_index, original_value in ac_company.iter(text):
            start_index = end_index - len(original_value) + 1
            print('公司实体:', (start_index, end_index, original_value))
            assert text[start_index:start_index + len(original_value)] == original_value
            subject = original_value
        predicate = []
        for end_index, original_value in ac_relation.iter(text):
            start_index = end_index - len(original_value) + 1
            print('关系:', (start_index, end_index, original_value))
            assert text[start_index:start_index + len(original_value)] == original_value
            predicate.append(original_value)
        for i, p in enumerate(predicate):
            cypher = f'''match (s:company)-[p:`{p}`]->(o) where s.name='{subject}' return o.name'''
            print(cypher)
            res = graph.run(cypher).to_ndarray()
            subject = res[0][0]
            if i == len(predicate) - 1:
                break
            new_index = text.index(p) + len(p)
            new_question = subject + str(text[new_index:])
            print('new question:', new_question)
            res = kbqa(new_question)
            break
    return res


if __name__ == '__main__':

    while 1:
        try:
            text = input('text:')
            res = kbqa(text)
            print(res)
        except:
            print(traceback.format_exc())

在这里插入图片描述

图谱问答实战小结

模型的整体结构:ac自动机+找实体+多跳问答

ps:这里实体没有多个,用不到实体消歧,这里我们使用的是匹配的方法


学习的参考资料:
七月在线NLP高级班

代码参考:
https://github.com/terrifyzhao/neo4j_graph

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/543952.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python爬虫-----Selenium (第二十二天)

&#x1f388;&#x1f388;作者主页&#xff1a; 喔的嘛呀&#x1f388;&#x1f388; &#x1f388;&#x1f388;所属专栏&#xff1a;python爬虫学习&#x1f388;&#x1f388; ✨✨谢谢大家捧场&#xff0c;祝屏幕前的小伙伴们每天都有好运相伴左右&#xff0c;一定要天天…

(一)基于IDEA的JAVA基础15

还是先来说一下: Arrays工具类 Arrays是java.util包提供的工具类 提供了操作数组的方法&#xff0c;如排序,查询等。 如排序(升序)使用sort方法 语法: Arrays.sort(数组名)&#xff1b; 还是直接写来看看: public class Test01 { public static void main(String[] args)…

攻防世界12-baby_web

12-baby_web 题目说想想初始页面是哪个&#xff0c;一般都是index.php&#xff0c;然后如题分析即可。 我们在链接后面拼接上/index.php&#xff0c;返回后发现界面又回到了1.php&#xff0c;有可能是重定向。 我们点击检查-网络&#xff0c;发现没有index的请求&#xff0c;…

系统架构最佳实践 -- 供应链系统架构

供应链系统是现代企业管理中不可或缺的一部分&#xff0c;它涉及到从原材料采购到产品销售的整个生产流程。一个高效的供应链系统可以帮助企业实现成本控制、库存优化和客户满意度提升等目标。在本文中&#xff0c;我们将讨论供应链系统的设计与实践。 一、供应链系统设计 业务…

数字乡村创新实践探索农业现代化与乡村振兴新路径:科技赋能农村全面振兴与农民幸福新篇章

随着信息技术的飞速发展&#xff0c;数字乡村成为推动农业现代化与乡村振兴的重要战略举措。科技赋能下的数字乡村创新实践&#xff0c;不仅提升了农业生产的智能化水平&#xff0c;也为乡村治理和农民生活带来了翻天覆地的变化。本文旨在探讨数字乡村创新实践在农业现代化与乡…

数据库数据恢复—Sql Server数据库文件丢失如何恢复数据?

服务器数据恢复环境&#xff1a; 一台安装windows server操作系统的服务器。一组由8块硬盘组建的RAID5&#xff0c;划分LUN供这台服务器使用。 在windows服务器内装有SqlServer数据库。存储空间LUN划分了两个逻辑分区。 服务器故障&初检&#xff1a; 由于未知原因&#xf…

Spring框架中的单例bean是线程安全的吗?

无状态bean&#xff1a; 无状态的Bean的行为不受其内部状态的影响&#xff0c;每次调用都是基于传入的参数进行计算&#xff0c;而不依赖于任何之前的状态。 (例如上面例子&#xff1a;userService是不能修改的&#xff0c;是无状态的bean) 因此&#xff1a; Spring框架中的…

基于51单片机的无线病床呼叫系统设计—LCD1602显示

基于51单片机的无线病床呼叫系统 &#xff08;仿真&#xff0b;程序&#xff0b;原理图&#xff0b;设计报告&#xff09; 功能介绍 具体功能&#xff1a; 1.病人按下按键&#xff0c;LCD1602显示对应的床位号&#xff1b; 2.多人同时呼叫&#xff0c;显示屏同时显示&#xf…

Vitis HLS 学习笔记--优化循环启动间隔(II)

目录 1. 概述 2. 常规矩阵乘法 3. 数据依赖性和内存访问模式 4. 优化循环 5. 总结 1. 概述 Initiation Interval&#xff08;II&#xff09;定义为启动连续操作之间的时间间隔&#xff0c;以时钟周期为单位。低的II是高性能和高资源利用率的关键。 较高的II意味着在单位…

《手把手教你》系列基础篇(八十六)-java+ selenium自动化测试-框架设计基础-Log4j实现日志输出(详解教程)

1.简介 自动化测试中如何输出日志文件。任何软件&#xff0c;都会涉及到日志输出。所以&#xff0c;在测试人员报bug&#xff0c;特别是崩溃的bug&#xff0c;一般都要提供软件产品的日志文件。开发通过看日志文件&#xff0c;知道这个崩溃产生的原因&#xff0c;至少知道触发崩…

java:特殊文件(properties,xml)和日志

特殊文件 txt(文本文件) txt文件是一种纯文本文件,用于存储文本信息 优缺点:txt文件简单易用,可以使用任何文本编辑器打开和编辑,但不支持数据类型和结构,所有信息均用纯文本形式保存 适合简单的配置信息存储 properties文件 properties文件是一种键值对文件,用于存储配置…

2024最新在线工具箱网站系统源码

2024最新在线工具箱网站系统源码 下载地址: 2024最新在线工具箱网站系统源码-JXASP源码网https://www.jxasp.com/think-php/12489.html

数据库世界信息速递-- TIDB 怎么走向世界如何保证稳定性和可靠性(译)

开头还是介绍一下群&#xff0c;如果感兴趣PolarDB ,MongoDB ,MySQL ,PostgreSQL ,Redis, Oceanbase, Sql Server等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请联系 liuaustin3 &#xff0c;&#xff08;…

LeetCode刷题记(三):61~90题

61. 旋转链表 给你一个链表的头节点 head &#xff0c;旋转链表&#xff0c;将链表每个节点向右移动 k 个位置。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], k 2 输出&#xff1a;[4,5,1,2,3]示例 2&#xff1a; 输入&#xff1a;head [0,1,2], k 4 输出&…

C#.net手术麻醉信息系统源码,集成HIS、EMR、LIS、PACS系统

手术麻醉信息系统可以实现手术室监护仪、麻醉机、呼吸机、输液泵等设备输出数据的自动采集&#xff0c;采集的数据能据如实准确地反映患者生命体征参数的变化&#xff0c;并实现信息高度共享&#xff0c;根据采集结果&#xff0c;综合其他患者数据&#xff0c;自动生成手术麻醉…

JavaScript教程(十四)--- 类型化数组

JavaScript 类型化数组 JavaScript 类型化数组是一种类似数组的对象&#xff0c;并提供了一种用于在内存缓冲中访问原始二进制数据的机制。 引入类型化数组并非是为了取代 JavaScript 中数组的任何一种功能。相反&#xff0c;它为开发者提供了一个操作二进制数据的接口。这在操…

SAP软件如何批量修改物料主数据

在SAP/ERP系统日常运维中经常会遇到批量修改物料主数据的业务需求&#xff0c; 遇到这种业务需求可以使用SAP提供的标准的事务代码MM17进行处理。 下面按业务场景介绍下具体的操作步骤 业务场景1 需要将一批物料主数据的采购组字段全部修改为002。 具体操作步骤如下&#…

[漏洞复现]D-Link未授权RCE漏洞复现(CVE-2024-3273)

声明&#xff1a;亲爱的读者&#xff0c;我们诚挚地提醒您&#xff0c;Aniya网络安全的技术文章仅供个人研究学习参考。任何因传播或利用本实验室提供的信息而造成的直接或间接后果及损失&#xff0c;均由使用者自行承担责任。Aniya网络安全及作者对此概不负责。如有侵权&#…

Spring+SpringMVC的知识总结

一:技术体系架构二:SpringFramework介绍三:Spring loC容器和核心概念3.1 组件和组件管理的概念3.1.1什么是组件:3.1.2:我们的期待3.1.3Spring充当组件管理角色(IOC)3.1.4 Spring优势3.2 Spring Ioc容器和容器实现3.2.1普通和复杂容器3.2.2 SpringIOC的容器介绍3.2.3 Spring IOC…

L1-027 出租

下面是新浪微博上曾经很火的一张图&#xff1a; 一时间网上一片求救声&#xff0c;急问这个怎么破。其实这段代码很简单&#xff0c;index数组就是arr数组的下标&#xff0c;index[0]2 对应 arr[2]1&#xff0c;index[1]0 对应 arr[0]8&#xff0c;index[2]3 对应 arr[3]0&…