pytorch基于FastText实现词嵌入

FastText 是 Facebook AI Research 提出的 改进版 Word2Vec,可以: ✅ 利用 n-grams 处理未登录词
比 Word2Vec 更快、更准确
适用于中文等形态丰富的语言

完整的 PyTorch FastText 代码(基于中文语料),包含:

  • 数据预处理(分词 + n-grams)
  • 模型定义
  • 训练
  • 测试
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
import random

# ========== 1. 数据预处理 ==========
corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
]

# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]


# 构建 n-grams
def generate_ngrams(words, n=3):
    ngrams = []
    for word in words:
        ngrams += [word[i:i + n] for i in range(len(word) - n + 1)]
    return ngrams


# 生成 n-grams 词表
all_ngrams = set()
for sentence in tokenized_corpus:
    for word in sentence:
        all_ngrams.update(generate_ngrams(word))

# 构建词汇表
vocab = set(word for sentence in tokenized_corpus for word in sentence) | all_ngrams
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

# 构建训练数据(CBOW 方式)
window_size = 2
data = []

for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
        context = []
        for offset in range(-window_size, window_size + 1):
            context_idx = center_idx + offset
            if 0 <= context_idx < len(indices) and context_idx != center_idx:
                context.append(indices[context_idx])
        if context:
            data.append((context, indices[center_idx]))  # (上下文, 目标词)


# ========== 2. 定义 FastText 模型 ==========
class FastText(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(FastText, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, context):
        context_vec = self.embeddings(context).mean(dim=1)  # 平均上下文向量
        output = self.linear(context_vec)
        return output


# 初始化模型
embedding_dim = 10
model = FastText(len(vocab), embedding_dim)

# ========== 3. 训练 FastText ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100

for epoch in range(num_epochs):
    total_loss = 0
    random.shuffle(data)

    for context, target in data:
        context = torch.tensor([context], dtype=torch.long)
        target = torch.tensor([target], dtype=torch.long)

        optimizer.zero_grad()
        output = model(context)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

# ========== 4. 获取词向量 ==========
word_vectors = model.embeddings.weight.data.numpy()


# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):
    if word not in word2idx:
        return "单词不在词汇表中"

    word_vec = word_vectors[word2idx[word]].reshape(1, -1)
    similarities = np.dot(word_vectors, word_vec.T).squeeze()
    similar_idx = similarities.argsort()[::-1][1:top_n + 1]
    return [(idx2word[idx], similarities[idx]) for idx in similar_idx]


# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

1. 生成 n-grams

  • FastText 处理单词的 子词单元(n-grams)
  • 例如 "学习" 会生成 ["学习", "习学", "学"]
  • 这样即使遇到未登录词也能拆分为 n-grams 计算

2. 训练数据

  • 使用 CBOW(上下文预测中心词)
  • 窗口大小 = 2,即:
    句子: ["深度", "学习", "是", "人工智能"]
    示例: (["深度", "是"], "学习")
    

3. FastText 模型

  • 词向量是 n-grams 词向量的平均值
  • 计算公式: 
  • 这样,即使单词没见过,也能用它的 n-grams 计算词向量!

 4. 计算相似度

  • cosine similarity 找出最相似的单词
  • FastText 比 Word2Vec 更准确,因为它能利用 n-grams 捕捉词的语义信息
特性FastTextWord2VecGloVe
原理预测中心词 + n-grams预测中心词或上下文统计词共现信息
未登录词处理可处理无法处理无法处理
训练速度 快
适合领域中文、罕见词传统 NLP大规模数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963078.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【hot100】刷题记录(8)-矩阵置零

题目描述&#xff1a; 给定一个 m x n 的矩阵&#xff0c;如果一个元素为 0 &#xff0c;则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,1,1],[1,0,1],[1,1,1]] 输出&#xff1a;[[1,0,1],[0,0,0],[1,0,1]]示例 2…

PyTorch框架——基于深度学习YOLOv8神经网络学生课堂行为检测识别系统

基于YOLOv8深度学习的学生课堂行为检测识别系统&#xff0c;其能识别三种学生课堂行为&#xff1a;names: [举手, 读书, 写字] 具体图片见如下&#xff1a; 第一步&#xff1a;YOLOv8介绍 YOLOv8 是 ultralytics 公司在 2023 年 1月 10 号开源的 YOLOv5 的下一个重大更新版本…

Doki Doki Mods Maker小指南

-*- 做都做了&#xff0c;那就做到底吧。 -*- 前言&#xff1a; 项目的话&#xff0c;在莫盘里&#xff0c;在贴吧原帖下我有发具体地址。 这里是Doki Doki Mods Maker&#xff0c;是用来做DDLC Mods的小工具。 说是“Mods”&#xff0c;实则不然&#xff0c;这个是我从零仿…

nodejs:express + js-mdict 网页查询英汉词典

向 DeepSeek R1 提问&#xff1a; 我想写一个Web 前端网页&#xff0c;后台用 nodejs js-mdict, 实现在线查询英语单词 1. 项目结构 首先&#xff0c;创建一个项目目录&#xff0c;结构如下&#xff1a; mydict-app/ ├── public/ │ ├── index.html │ ├── st…

LabVIEW纤维集合体微电流测试仪

LabVIEW开发纤维集合体微电流测试仪。该设备精确测量纤维材料在特定电压下的电流变化&#xff0c;以分析纤维的结构、老化及回潮率等属性&#xff0c;对于纤维材料的科学研究及质量控制具有重要意义。 ​ 项目背景 在纤维材料的研究与应用中&#xff0c;电学性能是评估其性能…

dfs枚举问题

碎碎念&#xff1a;要开始刷算法题备战蓝桥杯了&#xff0c;一切的开头一定是dfs 定义 枚举问题就是咱数学上学到的&#xff0c;从n个数里面选m个数&#xff0c;有三种题型(来自Acwing) 从 1∼n 这 n个整数中随机选取任意多个&#xff0c;输出所有可能的选择方案。 把 1∼n这…

SOME/IP--协议英文原文讲解3

前言 SOME/IP协议越来越多的用于汽车电子行业中&#xff0c;关于协议详细完全的中文资料却没有&#xff0c;所以我将结合工作经验并对照英文原版协议做一系列的文章。基本分三大块&#xff1a; 1. SOME/IP协议讲解 2. SOME/IP-SD协议讲解 3. python/C举例调试讲解 Note: Thi…

leetcode——二叉树的中序遍历(java)

给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。 示例 1&#xff1a; 输入&#xff1a;root [1,null,2,3] 输出&#xff1a;[1,3,2] 示例 2&#xff1a; 输入&#xff1a;root [] 输出&#xff1a;[] 示例 3&#xff1a; 输入&#xff1a;root [1] 输出…

91,【7】 攻防世界 web fileclude

进入靶场 <?php // 包含 flag.php 文件 include("flag.php");// 以高亮语法显示当前文件&#xff08;即包含这段代码的 PHP 文件&#xff09;的内容 // 方便查看当前代码结构和逻辑&#xff0c;常用于调试或给解题者提示代码信息 highlight_file(__FILE__);// 检…

Microsoft Power BI:融合 AI 的文本分析

Microsoft Power BI 是微软推出的一款功能强大的商业智能工具&#xff0c;旨在帮助用户从各种数据源中提取、分析和可视化数据&#xff0c;以支持业务决策和洞察。以下是关于 Power BI 的深度介绍&#xff1a; 1. 核心功能与特点 Power BI 提供了全面的数据分析和可视化功能&…

海外问卷调查,最常用到的渠道查有什么特殊之处

市场调研&#xff0c;包含市场调查和市场研究两个步骤&#xff0c;是企业和机构根据经营方向而做出的决策问题&#xff0c;最终通过海外问卷调查中的渠道查&#xff0c;来系统地设计、收集、记录、整理、分析、研究市场反馈的工作流程。 市场调研的工作流程包括&#xff1a;确…

《苍穹外卖》项目学习记录-Day10来单提醒

type&#xff1a;用来标识消息的类型&#xff0c;比如说type1表示来单提醒&#xff0c;type2表示客户催单。 orderId&#xff1a;表示订单id&#xff0c;因为不管是来单提醒还是客户催单&#xff0c;这一次提醒都对应一个订单。是用户下了某个单或者催促某个订单&#xff0c;这…

【全栈】SprintBoot+vue3迷你商城(10)

【全栈】SprintBootvue3迷你商城&#xff08;10&#xff09; 往期的文章都在这里啦&#xff0c;大家有兴趣可以看一下 后端部分&#xff1a; 【全栈】SprintBootvue3迷你商城&#xff08;1&#xff09; 【全栈】SprintBootvue3迷你商城&#xff08;2&#xff09; 【全栈】Sp…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.27 线性代数王国:矩阵分解实战指南

1.27 线性代数王国&#xff1a;矩阵分解实战指南 #mermaid-svg-JWrp2JAP9qkdS2A7 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JWrp2JAP9qkdS2A7 .error-icon{fill:#552222;}#mermaid-svg-JWrp2JAP9qkdS2A7 .erro…

【愚公系列】《循序渐进Vue.js 3.x前端开发实践》030-自定义组件的插槽Mixin

标题详情作者简介愚公搬代码头衔华为云特约编辑&#xff0c;华为云云享专家&#xff0c;华为开发者专家&#xff0c;华为产品云测专家&#xff0c;CSDN博客专家&#xff0c;CSDN商业化专家&#xff0c;阿里云专家博主&#xff0c;阿里云签约作者&#xff0c;腾讯云优秀博主&…

langchain 实现多智能体多轮对话

这里写目录标题 工具定义模型选择graph节点函数定义graph 运行 工具定义 import random from typing import Annotated, Literalfrom langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId from langgraph.prebuilt import InjectedSt…

pytorch生成对抗网络

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 生成对抗网络&#xff08;GAN&#xff0c;Generative Adversarial Network&#xff09;是一种深度学习模型&#xff0c;由两个神经网络组成&#xff1a;生成器&#xff08;Generator&#xff09;和判别器&#xff0…

ui-automator定位官网文档下载及使用

一、ui-automator定位官网文档简介及下载 AndroidUiAutomator&#xff1a;移动端特有的定位方式&#xff0c;uiautomator是java实现的&#xff0c;定位类型必须写成java类型 官方地址&#xff1a;https://developer.android.com/training/testing/ui-automator.html#ui-autom…

RabbitMQ持久化队列配置修改问题

文章目录 1.问题产生2.问题解决1.询问gpt2.独立思考 1.问题产生 我在给一个普通队列去绑定死信交换机和死信队列的时候&#xff0c;发现总是报错x-dead-letter-exchange的属性为none ERROR [PFTID:] [Module:defaultModule] org.springframework.amqp.rabbit.connection.Cach…

MySQL常用数据类型和表的操作

文章目录 (一)常用数据类型1.数值类2.字符串类型3.二进制类型4.日期类型 (二)表的操作1查看指定库中所有表2.创建表3.查看表结构和查看表的创建语句4.修改表5.删除表 (三)总代码 (一)常用数据类型 1.数值类 BIT([M]) 大小:bit M表示每个数的位数&#xff0c;取值范围为1~64,若…