神经网络语言模型(NNLM)

神经网络语言模型【NNLM】

  • 1 为什么使用神经网络模型?
  • 2 什么是神经网络模型?
  • 3. 代码实现
    • 3.1 语料库预处理代码
    • 3.2 词向量创建
    • 3.3 NNLM模型类
    • 3.4 完整代码

1 为什么使用神经网络模型?

  • 解决独热编码无法解决词之间相似性问题
    • 使用神经网络语言模型中出现的词向量 C w i C_{wi} Cwi代替
    • C w i C_{wi} Cwi就是单词对应的 Word Embedding【词向量】
  • 解决独热编码占用内存较大的问题

2 什么是神经网络模型?

在这里插入图片描述

  • Q矩阵相关参数

    • Q矩阵:从one-hot编码生成新的词向量
    • Q矩阵是参数,需要学习训练,刚开始用随机值初始化 Q矩阵,当这个网络训练好之后,Q矩阵的内容被正确赋值,每一行代表一个单词对应的 Word embedding
      参数含义
      Q Q Q V ∗ m V*m Vm的矩阵,模型参数
      V V V词典大小,词的个数,有几个词就有几行
      m m m新词向量的大小

    在这里插入图片描述

  • 神经网络相关参数

    参数含义
    W W Wword缩写,表示单词
    t t ttarget缩写,表示目标词【标签词】
    n n n窗口大小,上下文的大小(即周围的单词有多少个)称为窗口大小
    C C C就是Q 矩阵,需要学习的参数函数

在这里插入图片描述

  • 举例:
    • 文本You say goodbye and I say hello

      单词index
      You0
      say 1
      goodbye2
      and 3
      I4
      hello5
    • 窗口大小n=2,t从2开始,下标从0开始

      contextstargetcontexts indextarget index
      [You ,say][goodbye][0,1][2]
      [say ,goodbye] [and ][1,2] [3]
      [goodbye, and][I][2,3][4]
      [and, I][say ][3,4][1]
      [I ,say][hello][4,1][5]

3. 代码实现

3.1 语料库预处理代码

def preprocess(sentences_list,lis=[]):
   """
      语料库预处理

      :param text_list:句子列表
      :return:
           word_list 是单词列表
           word_dict:是单词到单词 ID 的字典
           number_dict 是单词 ID 到单词的字典
           n_class 单词数
   """
   for i in sentences_list:
       text = i.split(' ')  # 按照空格分词,统计 sentences的分词的个数
       word_list = list({}.fromkeys(text).keys())  # 去重 统计词典个数
       lis=lis+word_list
   word_list=list({}.fromkeys(lis).keys())
   word_dict = {w: i for i, w in enumerate(word_list)}
   number_dict = {i: w for i, w in enumerate(word_list)}
   n_class = len(word_dict)  # 词典的个数,也是softmax 最终分类的个数
   return word_list, word_dict, number_dict,n_class

在这里插入图片描述

3.2 词向量创建

def make_batch(sentences_list, word_dict, windows_size):
    """
    词向量编码函数

    :param sentences_list:句子列表
    :param word_dict: 字典{'You': 0,,,} key:单词,value:索引
    :param windows_size: 窗口大小
    :return:
        input_batch:数据集向量
        target_batch:标签值
    """
    input_batch, target_batch = [], []
    for sen in sentences_list:
        word_repeat_list = sen.split(' ')  # 按照空格分词
        for i in range(windows_size, len(word_repeat_list)):  # 目标词索引迭代
            target = word_repeat_list[i]  # 获取目标词
            input_index = [word_dict[word_repeat_list[j]] for j in range((i - windows_size), i)]  # 获取目标词相关输入数据集
            target_index = word_dict[target]  # 目标词索引
            input_batch.append(input_index)
            target_batch.append(target_index)

    return input_batch, target_batch

在这里插入图片描述

3.3 NNLM模型类

# Model
class NNLM(nn.Module):
    def __init__(self):
        super(NNLM, self).__init__()
        self.C = nn.Embedding(n_class, m)  # 矩阵Q  (V x m)  V 表示word的字典大小, m 表示词向量的维度
        self.H = nn.Linear(windows_size * m, n_hidden, bias=False)  #
        self.d = nn.Parameter(torch.ones(n_hidden))
        self.U = nn.Linear(n_hidden, n_class, bias=False)
        self.W = nn.Linear(windows_size * m, n_class, bias=False)
        self.b = nn.Parameter(torch.ones(n_class))

    def forward(self, X):
        X = self.C(X)  # X : [batch_size, n_step, m]
        X = X.view(-1, windows_size * m)  # [batch_size, n_step * m]
        tanh = torch.tanh(self.d + self.H(X))  # [batch_size, n_hidden]
        output = self.b + self.W(X) + self.U(tanh)  # [batch_size, n_class]
        return output

3.4 完整代码

import torch
import torch.nn as nn
import torch.optim as optim


def preprocess(sentences_list,lis=[]):
    """
       语料库预处理

       :param text_list:句子列表
       :return:
            word_list 是单词列表
            word_dict:是单词到单词 ID 的字典
            number_dict 是单词 ID 到单词的字典
            n_class 单词数
    """
    for i in sentences_list:
        text = i.split(' ')  # 按照空格分词,统计 sentences的分词的个数
        word_list = list({}.fromkeys(text).keys())  # 去重 统计词典个数
        lis=lis+word_list
    word_list=list({}.fromkeys(lis).keys())
    word_dict = {w: i for i, w in enumerate(word_list)}
    number_dict = {i: w for i, w in enumerate(word_list)}
    n_class = len(word_dict)  # 词典的个数,也是softmax 最终分类的个数
    return word_list, word_dict, number_dict,n_class

def make_batch(sentences_list, word_dict, windows_size):
    """
    词向量编码函数

    :param sentences_list:句子列表
    :param word_dict: 字典{'You': 0,,,} key:单词,value:索引
    :param windows_size: 窗口大小
    :return:
        input_batch:数据集向量
        target_batch:标签值
    """
    input_batch, target_batch = [], []
    for sen in sentences_list:
        word_repeat_list = sen.split(' ')  # 按照空格分词
        for i in range(windows_size, len(word_repeat_list)):  # 目标词索引迭代
            target = word_repeat_list[i]  # 获取目标词
            input_index = [word_dict[word_repeat_list[j]] for j in range((i - windows_size), i)]  # 获取目标词相关输入数据集
            target_index = word_dict[target]  # 目标词索引
            input_batch.append(input_index)
            target_batch.append(target_index)

    return input_batch, target_batch

# Model
class NNLM(nn.Module):
    def __init__(self):
        super(NNLM, self).__init__()
        self.C = nn.Embedding(n_class, m)  # 矩阵Q  (V x m)  V 表示word的字典大小, m 表示词向量的维度
        self.H = nn.Linear(windows_size * m, n_hidden, bias=False)  #
        self.d = nn.Parameter(torch.ones(n_hidden))
        self.U = nn.Linear(n_hidden, n_class, bias=False)
        self.W = nn.Linear(windows_size * m, n_class, bias=False)
        self.b = nn.Parameter(torch.ones(n_class))

    def forward(self, X):
        X = self.C(X)  # X : [batch_size, n_step, m]
        X = X.view(-1, windows_size * m)  # [batch_size, n_step * m]
        tanh = torch.tanh(self.d + self.H(X))  # [batch_size, n_hidden]
        output = self.b + self.W(X) + self.U(tanh)  # [batch_size, n_class]
        return output
if __name__ == '__main__':
    m = 2  # 词向量的维度
    n_hidden = 2  # 隐层个数
    windows_size = 2
    sentences_list = ['You say goodbye and I say hello']  # 训练数据
    word_list, word_dict, number_dict,n_class=preprocess(sentences_list)
    print('word_list为: ',word_list)
    print('word_dict为:',word_dict)
    print('number_dict为:',number_dict)
    print('n_class为:',n_class)
    #
    input_batch, target_batch = make_batch(sentences_list, word_dict, windows_size)  # 构建输入数据和 target label
    # # 转为 tensor 形式
    input_batch,target_batch = torch.LongTensor(input_batch), torch.LongTensor(target_batch)
    print(input_batch)
    print(target_batch)
    model = NNLM()
    # 损失函数定义
    criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数

    # 采用 Adam 优化算法 学习率定义为   0.001
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    # Training 迭代 2000次
    for epoch in range(2000):
        optimizer.zero_grad()  # 梯度归零
        output = model(input_batch)

        # output : [batch_size, n_class], target_batch : [batch_size]
        loss = criterion(output, target_batch)
        if (epoch + 1) % 250 == 0:
            print('Epoch:', '%d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
        loss.backward()  # 反向传播计算 每个参数的梯度值
        optimizer.step()  # 每一个参数的梯度值更新

    # Predict
    predict = model(input_batch).data.max(1, keepdim=True)[1]
    print(list(predict))
    # Test
    print([number_dict[n.item()] for n in predict.squeeze()])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/23588.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Blazor实战——Known框架增删改查导

本章介绍学习增、删、改、查、导功能如何实现,下面以商品资料作为示例,该业务栏位如下: 类型、编码、名称、规格、单位、库存下限、库存上限、备注 1. 前后端共用 1.1. 创建实体类 在KIMS项目Entities文件夹下创建KmGoods实体类该类继承Ent…

【C++】类和对象的应用案例 2 - 点和圆的关系

欢迎来到博主 Apeiron 的博客,祝您旅程愉快 !时止则止,时行则行。动静不失其时,其道光明。 目录 1、缘起 2、分析 3、示例代码 1 4、代码优化 4.1、point.h 4.2、point.c 4.3、circle.h 4.4、circle.c 4.4、main.c …

Netty 源码分析系列(十八)一行简单的writeAndFlush都做了哪些事?

文章目录 前言源码分析ctx.writeAndFlush 的逻辑writeAndFlush 源码ChannelOutBoundBuff 类addMessage 方法addFlush 方法AbstractNioByteChannel 类 小结 前言 对于使用netty的小伙伴来说,我们想通过服务端往客户端发送数据,通常我们会调用ctx.writeAn…

实时聊天组合功能,你了解吗?

你有兴趣安装实时聊天组合功能吗?如果您选择了SaleSmartly(ss客服),您的实时聊天插件可以不仅仅只是聊天通道,还可以有各种各样的功能,你不需要包含每一个功能,正所谓「宁缺勿滥」,功…

再获认可!腾讯连续三年被Gartner列为CWPP供应商之一

随着云的快速发展,企业的工作负载已经从服务器发展到虚拟机、容器、serverless等,部署的模式也日益复杂,包括公有云、混合云和多云等。在此背景下,传统的主机安全防护已无法满足需求,CWPP(云工作负载保护平…

C#,码海拾贝(23)——求解“复系数线性方程组“的“全选主元高斯消去法“之C#源代码,《C#数值计算算法编程》源代码升级改进版

using System; namespace Zhou.CSharp.Algorithm { /// <summary> /// 求解线性方程组的类 LEquations /// 原作 周长发 /// 改编 深度混淆 /// </summary> public static partial class LEquations { /// <summary&g…

day20 - 绘制物体的运动轨迹

在我们平常做目标检测或者目标追踪时&#xff0c;经常要画出目标的轨迹图。绘制轨迹图的一种方法就是利用光流估计来进行绘制。 本期我们主要来介绍视频中光流估计的使用和效果&#xff0c;利用光流估计来绘制运动轨迹。 完成本期内容&#xff0c;你可以&#xff1a; 掌握视…

网站部署与上线(1)虚拟机

文章目录 .1 虚拟机简介2 虚拟机的安装 本章将搭建实例的生产环境&#xff0c;将所有的代码搭建在一台Linux服务器中&#xff0c;并且测试其能否正常运行。 使用远程服务器进行连接&#xff1b; 基本的Linux命令&#xff1b; 使用Nginx搭建Node.js服务器&#xff1b; 在服务器端…

一、预约挂号详情

文章目录 一、预约挂号详情1、需求分析 2、api接口2.1 添加service接口2.2 添加service接口实现2.2.1 在ScheduleServiceImpl类实现接口2.2.2 在获取科室信息 2.3 添加controller方法 3、前端3.1封装api请求3.2 页面展示 二、预约确认1、api接口1.1 添加service接口1.2 添加con…

通过python采集整站lazada商品列表数据,支持多站点

要采集整站lazada商品列表数据&#xff0c;需要先了解lazada网站的结构和数据源。Lazada是东南亚最大的电商平台之一&#xff0c;提供各种商品和服务。Lazada的数据源主要分为两种&#xff1a;HTML和API。 方法1&#xff1a;采集HTML数据 步骤1&#xff1a;确定采集目标 首先…

一、CNNs网络架构-基础网络架构(LeNet、AlexNet、ZFNet)

目录 1.LeNet 2.AlexNet 2.1 激活函数&#xff1a;ReLU 2.2 随机失活&#xff1a;Droupout 2.3 数据扩充&#xff1a;Data augmentation 2.4 局部响应归一化&#xff1a;LRN 2.5 多GPU训练 2.6 论文 3.ZFNet 3.1 网络架构 3.2 反卷积 3.3 卷积可视化 3.4 ZFNet改…

Java的Arrays类的sort()方法(41)

目录 sort&#xff08;&#xff09;方法 1.sort&#xff08;&#xff09;方法的格式 2.使用sort&#xff08;&#xff09;方法时要导入的类 3.作用 4.作用的对象 5.注意 6.代码及结果 &#xff08;1&#xff09;代码 &#xff08;2&#xff09;结果 sort&#xff08;&…

【Netty】字节缓冲区 ByteBuf (六)(上)

文章目录 前言一、ByteBuf类二、ByteBuffer 实现原理2.1 ByteBuffer 写入模式2.2 ByteBuffer 读取模式2.3 ByteBuffer 写入模式切换为读取模式2.4 clear() 与 compact() 方法2.5 ByteBuffer 使用案例 总结 前言 回顾Netty系列文章&#xff1a; Netty 概述&#xff08;一&…

亏损?盈利?禾赛科技Q1财报背后的激光雷达赛道「现实」

随着禾赛科技在去年登陆美股&#xff0c;作为全球为数不多已经开始前装量产交付的激光雷达上市公司&#xff0c;财务数据的变化&#xff0c;也在一定程度上反映了行业的真实状况。 根据禾赛科技最新发布的今年一季度财报显示&#xff0c;公司季度净营收为4.3亿元&#xff08;人…

day13 - 对指纹图片进行噪声消除

在指纹识别的过程中&#xff0c;指纹图片通常都是现场采集的&#xff0c;受环境的影响会有产生很多的噪声点&#xff0c;如果直接使用&#xff0c;会对指纹的识别产生很大的影响&#xff0c;而指纹识别的应用场景又都是一些比较严肃不容有错的场合&#xff0c;所以去除噪声又不…

python+vue空巢老人网上药店购药系统9h2k5

本空巢老人购药系统主要包括三大功能模块&#xff0c;即用户功能模块、家属功能模块和管理员功能模块。 &#xff08;1&#xff09;管理员模块&#xff1a;系统中的核心用户是管理员&#xff0c;管理员登录后&#xff0c;通过管理员功能来管理后台系统。主要功能有&#xff1a;…

【实验】SegViT: Semantic Segmentation with Plain Vision Transformers

想要借鉴SegViT官方模型源码部署到本地自己代码文件中 1. 环境配置 官网要求安装mmcv-full1.4.4和mmsegmentation0.24.0 在这之前记得把mmcv和mmsegmentation原来版本卸载 pip uninstall mmcv pip uninstall mmcv-full pip uninstall mmsegmentation安装mmcv 其中&#xff…

旋翼无人机常用仿真工具

四旋翼常用仿真工具 rviz&#xff1a; 简单的质点&#xff08;也可以加上动力学姿态&#xff09;&#xff0c;用urdf模型在rviz中显示无人机和飞行轨迹、地图等。配合ROS代码使用&#xff0c;轻量化适合多机。典型的比如浙大ego-planner的仿真&#xff1a; https://github.c…

Java面试知识点(全)-分布式算法- ZAB算法

Java面试知识点(全) 导航&#xff1a; https://nanxiang.blog.csdn.net/article/details/130640392 注&#xff1a;随时更新 研究zookeeper时&#xff0c;必须要了解zk的选举和集群间个副本间的数据一致性。 什么是 ZAB 协议&#xff1f; ZAB 协议介绍 ZAB 协议全称&#xf…

树和二叉树

树 逻辑表示方法 树形表示法 文氏图表示法 凹入表示法 括号表示法 性质 树的结点数等于所有结点的度加一 度为m的树中第i层最多有m的(i-1)次方个结点 高度为h的m次树最多的节点数&#xff08;等比数列公式求和&am…