Word2Vec:将词汇转化为向量的技术

文章目录

  • Word2Vec
    • 来龙去脉
      • 分层Softmax
      • 负采样


Word2Vec

下面的文章纯属笔记,看完后不会有任何收获,如果想理解这两种优化技术,给大家推荐一篇博客,讲的很好:
详解-----分层Softmax与负采样


来龙去脉

word2vec,即将词语转换为向量。在机器学习或自然语言任务中,我们需要对句子进行翻译或者根据某些词生成另一些词,这些任务现在大多数都可以用神经网络来做。比如在句子翻译任务中,我们给神经网络输入该句子,那么想要的输出就是该句子的翻译版本。但是由于计算机只接受数字形式的输入,所以我们要将词语转化为数字形式。
word2vec就是将词语转换为数字向量的技术,经过该方法训练之后,我们就可以得到每一个词的固定向量表示,使得意思相近的词在向量空间中距离较近,不相关的词在向量空间中距离较远。
word2vec有两种经典的方法来进行训练,从而得到词的向量表示,一种叫做CBOW(连续词袋模型),一种叫skip-gram(跳元模型)。
CBOW的核心想法是通过一个词周围的词,预测该词,类似于完形填空。
skip-gram的核心想法是通过预测该词周围的词,相当于根据一个词造句子。
在用神经网络训练者两种模型的时候,我们输出层的个数就是我们所有词语的个数,需要经过一个softmax才能得到预测的每一个词的概率,这会导致指数运算次数非常多,导致对计算资源的要求很高。
基于这个问题,提出了两种优化方案,一种叫分层softmax,一种叫负采样。本章重点介绍这两种技术。

分层Softmax

在这里插入图片描述

这张图就说明了分层softmax的核心流程,我们以CBOW为例,在得到每一个周围词的词嵌入表示后,对其进行加权平均,就得到了图中的h,然后构建Huffman树(基于每一个词出现的频率构建),得到Huffman树之后,每一个叶子节点就表示了词汇表中的一个词,现在为每一个非叶子节点赋予一个可训练参数,然后将h与每一个非叶子节点的参数相乘,经过一个sigmoid得到一个0-1的值,在计算一个词的概率的时候,将路径上所有非叶子节点得到的值相乘,就得到输出该词的概率值,通过二叉树这种设计,保证了最后得到的所有词的概率的和为1。
损失函数的设计用的是二元交叉熵损失。
Huffman树的构建

import heapq
import numpy as np

# 构建Huffman树
class HuffmanTree:
    def __init__(self, vocab, freq):
        self.vocab = vocab
        self.freq = freq
        self.tree = self.build_huffman_tree()

    def build_huffman_tree(self):
        heap = [[weight, [symbol, ""]] for symbol, weight in zip(self.vocab, self.freq)]
        heapq.heapify(heap)

        while len(heap) > 1:
            lo = heapq.heappop(heap)
            hi = heapq.heappop(heap)
            for pair in lo[1:]:
                pair[1] = '0' + pair[1]
            for pair in hi[1:]:
                pair[1] = '1' + pair[1]
            heapq.heappush(heap, [lo[0] + hi[0]] + lo[1:] + hi[1:])
        
        return heap[0][1:]

    def get_code(self):
        return {symbol: code for symbol, code in self.tree}

分层softmax代码

def preprocess(text):
	text = text.lower()
	text = text.replace('.', ' .')
	text = text.replace(',', ' ,')
	text = text.replace('!', ' !')
	words = text.split(' ')
	
	word_to_id = {}
	id_to_word = {}
	word_count = {}
	for word in words:
		if word not in word_to_id:
			new_id = len(word_to_id)
			word_to_id[word] = new_id
			id_to_word[new_id] = word
			word_count[new_id] = 1
		else:
			word_count[word_to_id[word]] += 1
			corpus = np.array([word_to_id[w] for w in words])
	
	return corpus, word_to_id, id_to_word, word_count

负采样

在这里插入图片描述
负采样的前置部分和前面一样,它的基本思想是从一个概率分布中选择少数几个负样本参与每一次的训练,一般情况下,我们不是只用正的样本吗?我们将h经过一个网络后,会得到所有词的logits值,注意,此时我们还没有将其softmax,也就避免了大量的指数运算。我们根据所有logits值和给的正样本的标签得到此时正样本的logits值,同理,我们从词汇库里选择几个词作为负样本,经过网络传播后,也会得到几个负样本的logits值,接下来,我们对这几个词(正样本和负样本)做softmax,从而得到正样本的概率和这几个负样本的概率,通过最大化的正样本的概率并且最小化负样本的概率,进而训练网络。
那么这里负样本的个数应该选几个呢?在word2vec原文中,当数据量较大时,通常选用的负例个数为5,当数据量较小时,选5-20个。

负采样代码:

import random
import numpy as np
from collections import Counter

# 示例语料库
corpus = [
    'cat is on the mat',
    'dog is in the house',
    'cat and dog are friends',
    'dog is playing with the ball'
]

# 1. 构建词汇表并计算词频
def build_vocab(corpus):
    words = []
    for sentence in corpus:
        words.extend(sentence.split())
    word_counts = Counter(words)
    vocab = {word: count for word, count in word_counts.items()}
    return vocab

# 计算词频并构建词汇表
vocab = build_vocab(corpus)
vocab_size = len(vocab)
print("词汇表:", vocab)

# 2. 负样本采样
def get_negative_sample(vocab, num_samples=5):
    # 获取词频的平方根,并进行归一化
    word_freq = np.array([count for count in vocab.values()])
    word_freq = word_freq ** 0.75  # 使用0.75的幂次方分布
    word_freq /= word_freq.sum()  # 归一化,使得总和为1

    # 根据权重选择负样本
    negative_samples = np.random.choice(list(vocab.keys()), size=num_samples, p=word_freq)
    return negative_samples

# 测试负采样
negative_samples = get_negative_sample(vocab, num_samples=5)
print("负样本:", negative_samples)

# 3. 简单的训练步骤(Skip-Gram模型)
def train_step(context, target, negative_samples, learning_rate=0.1):
    # 这里只是简单地输出每个样本的训练步骤,实际情况中会根据模型进行参数更新
    print(f"\n上下文词: {context}")
    print(f"目标词: {target}")
    print(f"负样本: {negative_samples}")
    
    # 计算损失和梯度的代码可以根据实际模型来实现
    # 这里只是模拟训练过程
    
    # 示例的损失计算(假设目标词是正样本,负样本是负样本)
    for word in [target] + list(negative_samples):
        if word == target:
            print(f"目标词 {target} 的损失:正样本,最大化概率")
        else:
            print(f"负样本 {word} 的损失:最小化概率")

# 4. 模拟训练过程
def train_model(corpus, vocab, num_epochs=10, num_negative_samples=5, learning_rate=0.1):
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        
        # 遍历语料库
        for sentence in corpus:
            words = sentence.split()
            
            # 模拟Skip-Gram的训练过程
            for i, target in enumerate(words):
                # 上下文词是目标词附近的其他词
                context = [words[j] for j in range(len(words)) if j != i]
                
                # 从词汇表中选择负样本
                negative_samples = get_negative_sample(vocab, num_samples=num_negative_samples)
                
                # 进行训练步骤
                train_step(context, target, negative_samples, learning_rate)

# 训练模型
train_model(corpus, vocab)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/938434.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

虚幻5描边轮廓材质

很多游戏内都有这种描边效果,挺实用也挺好看的,简单复刻一下 效果演示: Linethickness可以控制轮廓线条的粗细 这样连完,然后放到网格体细节的覆层材质上即可 可以自己更改粗细大小和颜色

websocket_asyncio

WebSocket 和 asyncio 指南 简介 本指南涵盖了使用 Python 中的 websockets 库进行 WebSocket 编程的基础知识,以及 asyncio 在异步非阻塞 I/O 中的作用。它提供了构建高效 WebSocket 服务端和客户端的知识,以及 asyncio 的特性和优势。 1. 什么是 WebS…

序列模型的使用示例

序列模型的使用示例 1 RNN原理1.1 序列模型的输入输出1.2 循环神经网络(RNN)1.3 RNN的公式表示2 数据的尺寸 3 PyTorch中查看RNN的参数4 PyTorch中实现RNN(1)RNN实例化(2)forward函数(3&#xf…

Hadoop学习笔记(包括hadoop3.4.0集群安装)(黑马)

Hadoop学习笔记 0-前置章节-环境准备 0.1 环境介绍 配置环境:hadoop-3.4.0,jdk-8u171-linux-x64 0.2 VMware准备Linux虚拟机 0.2.1主机名、IP、SSH免密登录 1.配置固定IP地址(root权限) 开启master,修改主机名为…

【计算机网络】Layer4-Transport layer

目录 传输层协议How demultiplexing works in transport layer(传输层如何进行分用)分用(Demultiplexing)的定义:TCP/UDP段格式: UDPUDP的特点:UDP Format端口号Trivial File Transfer Protocol…

Android Studio创建新项目并引入第三方so外部aar库驱动NFC读写器读写IC卡

本示例使用设备:https://item.taobao.com/item.htm?spma21dvs.23580594.0.0.52de2c1bbW3AUC&ftt&id615391857885 一、打开Android Studio,点击 File> New>New project 菜单,选择 要创建的项目模版,点击 Next 二、输入项目名称…

【Linux】—简单实现一个shell(myshell)

大家好呀,我是残念,希望在你看完之后,能对你有所帮助,有什么不足请指正!共同学习交流哦! 本文由:残念ing原创CSDN首发,如需要转载请通知 个人主页:残念ing-CSDN博客&…

【Python爬虫系列】_032.Scrapy_全站爬取

课 程 推 荐我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈虚 拟 环 境 搭 建 :👉👉 Python项目虚拟环境(超详细讲解) 👈👈PyQt5 系 列 教 程:👉👉 Python GUI(PyQt5)教程合集 👈👈

Android通过okhttp下载文件(本文案例 下载mp4到本地,并更新到相册)

使用步骤分为两步 第一步导入 okhttp3 依赖 第二步调用本文提供的 utils 第一步这里不做说明了,直接提供第二步复制即用 DownloadUtil 中 download 为下载文件 参数说明 这里主要看你把 destFileName 下载文件名称定义为什么后缀,比如我定义为 .mp4 下…

win10配置子系统Ubuntu子系统(无需通过Windows应用市场)实际操作记录

win10配置子系统Ubuntu子系统(无需通过Windows应用市场)实际操作记录 参考教程 : win10配置子系统Ubuntu子系统(无需通过Windows应用市场) - 一佳一 - 博客园 开启虚拟机服务的 以管理员方式运行PowerShell运行命令。 &#xf…

Showrunner AI技术浅析(四):多智能体模拟

多智能体模拟技术涉及多个智能体(Agents)在虚拟环境中的行为和互动,每个智能体都有自己的属性、目标和行为规则。 1. 多智能体模拟概述 多智能体模拟技术通过模拟多个智能体在虚拟环境中的互动来生成复杂的剧情和场景。每个智能体都有其独特…

创新性融合丨卡尔曼滤波+目标检测 新突破!

2024深度学习发论文&模型涨点之——卡尔曼滤波目标检测 卡尔曼滤波是一种递归算法,用于估计线性动态系统的状态。它通过预测和更新两个步骤,结合系统模型和观测数据,来估计系统状态,并最小化估计的不确定性。 在目标检测中&am…

USB模块布局布线

1、USB接口定义 2、USB模块常规分类介绍 3、USB常用管脚定义图示 4、USB模块布局布线分析 USB3.0高速线因为速度比较高,建议走圆弧线不能走钝角 5、总结 1、CTRL鼠标中间滑轮按下可以看线的长度 2、不懂差分类和规则的设置,可以看本人写的AD基础操作…

SpringCloud系列之分布式配置中心极速入门与实践

[toc] 1、分布式配置中心简介 在实际的项目开发中,配置文件是使用比较多的,很多项目有测试环境(TEST)、开发环境(DEV)、规范的项目还有集成环境(UAT)、生产环境(PROD),每个环境就一个配置文件。 CSDN链接:SpringCloud系列之分布式…

【Vue3学习】setup语法糖中的ref,reactive,toRef,toRefs

在 Vue 3 的组合式 API(Composition API)中,ref、reactive、toRef 和 toRefs 是四个非常重要的工具函数,用于创建和管理响应式数据。 一、ref 用ref()包裹数据,返回的响应式引用对象,包含一个 .value 属性&#xff0…

解决 Git Permission denied 问题

前言 push项目时出现gitgithub.com: Permission denied (publickey). fatal: Could not read from remote repository.Please make sure you have the correct access rights and the repository exists.出现这个问题表示你在尝试将本地代码推送到GitHub时,没有提供…

React的状态管理库-Redux

核心思想:单一数据源、状态是只读的、以及使用纯函数更新状态。 组成部分 Store(存储) 应用的唯一状态容器,存储整个应用的状态树,使用 createStore() 创建。 getState():获取当前状态。dispatch(action)&#xff…

蓝卓总裁谭彰:AI+工业互联网推动制造业数字化转型

近日,新一代工业操作系统supOS6.0在2024中国5G工业互联网大会上重磅发布。 大会期间,工信部新闻宣传中心《人民邮电报》对蓝卓总裁谭彰就“工业互联网人工智能技术融合的思考”“supOS6.0的探索与实践”“未来工业互联网平台的发展方向”展开专题访谈&am…

RabbitMQ消息队列的笔记

Rabbit与Java相结合 引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-amqp</artifactId> </dependency> 在配置文件中编写关于rabbitmq的配置 rabbitmq:host: 192.168.190.132 /…

数据结构:贪吃蛇详解

目录 一.地图的设计 1.字符与坐标&#xff1a; 2.本地化&#xff08;头文件&#xff09;: 3.类项&#xff1a; 4.setlocale函数&#xff1a; &#xff08;1&#xff09;函数原型&#xff1a; &#xff08;2&#xff09;使用&#xff1a; 5.宽字符的打印&#xff1a; &a…