pytorch使用SVM实现文本分类

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics

# 1. 数据准备(中文文本)
texts = [
    "今天的足球比赛非常激烈,球队表现出色,最终赢得了比赛。",
    "NBA比赛今天开打,球员们的表现非常精彩,球迷们热情高涨。",
    "张艺谋的新电影上映了,票房成绩非常好,观众反响热烈。",
    "娱乐圈最近又出了一些新闻,明星们的私生活成了大家讨论的焦点。",
    "昨晚的篮球赛真是太精彩了,球员们的进攻和防守都非常强硬。",
    "李宇春在最新的音乐会上演出了她的新歌,现场观众反应热烈。",
    "今年的世界杯比赛激烈异常,球队之间的竞争越来越激烈。",
    "最近的综艺节目非常火,明星嘉宾的表现让观众们大笑不已。"
]

# 标签:0表示体育,1表示娱乐
labels = [0, 0, 1, 1, 0, 1, 0, 1]


# 2. 数据预处理:中文分词和 TF-IDF 特征提取
def jieba_cut(text):
    return " ".join(jieba.cut(text))


texts_cut = [jieba_cut(text) for text in texts]

vectorizer = TfidfVectorizer(max_features=10000)
X_tfidf = vectorizer.fit_transform(texts_cut).toarray()
y = np.array(labels)

# 3. 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)


# 4. PyTorch 数据加载
class NewsGroupDataset(torch.utils.data.Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


train_dataset = NewsGroupDataset(X_train, y_train)
test_dataset = NewsGroupDataset(X_test, y_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)


# 5. 定义 SVM 模型(使用线性层)
class SVM(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SVM, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)


# 6. 获取特征数并初始化模型
input_dim = X_tfidf.shape[1]  # 自动获取特征数
model = SVM(input_dim=input_dim, output_dim=2)  # 使用特征数量设置输入维度
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 优化器,调整学习率

# 7. 训练模型
num_epochs = 50  # 增加训练周期

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%")

# 8. 测试模型
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total}%")

# 9. 输出性能指标
y_pred = []
y_true = []

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        y_pred.extend(predicted.numpy())
        y_true.extend(labels.numpy())

print(metrics.classification_report(y_true, y_pred))


# 10. 测试新样本
def predict(text, model, vectorizer):
    # 1. 进行分词
    text_cut = jieba_cut(text)

    # 2. 将文本转为 TF-IDF 特征向量
    text_tfidf = vectorizer.transform([text_cut]).toarray()

    # 3. 转换为 PyTorch 张量
    text_tensor = torch.tensor(text_tfidf, dtype=torch.float32)

    # 4. 模型预测
    model.eval()  # 设置模型为评估模式
    with torch.no_grad():
        output = model(text_tensor)
        _, predicted = torch.max(output.data, 1)

    # 5. 返回预测结果
    return predicted.item()


# 测试一个新的中文文本
new_text = "今天的篮球比赛真是太精彩了,球员们的表现让大家都为之喝彩。"
predicted_label = predict(new_text, model, vectorizer)

# 输出预测结果
if predicted_label == 0:
    print("预测类别: 体育")
else:
    print("预测类别: 娱乐")

1. 数据准备

  • 文本数据:我们定义了一个包含中文文本的列表,每条文本表示一个新闻或评论。
  • 标签:为每条文本分配了一个标签,0 代表“体育”,1 代表“娱乐”。

2. 数据预处理

  • 中文分词:使用 jieba 库对每条文本进行分词,并将分词后的结果连接成字符串。这是处理中文文本时的常见做法。
  • TF-IDF 特征提取:使用 TfidfVectorizer 将文本转化为数值特征。TF-IDF 是一种常见的文本表示方式,能够衡量单词在文档中的重要性。

3. 数据集分割

  • 使用 train_test_split 将数据分为训练集和测试集。80% 的数据用于训练,20% 用于测试。

4. PyTorch 数据加载

  • 定义 Dataset 类:创建了一个自定义的 NewsGroupDataset 类,继承自 torch.utils.data.Dataset,用于将文本特征和标签封装为 PyTorch 可用的数据集格式。
  • DataLoader:使用 DataLoader 将训练集和测试集数据进行批处理和加载。

5. 模型定义

  • 定义了一个简单的线性 SVM 模型。实际上,使用了一个线性层 (nn.Linear) 来进行分类,输入是文本的 TF-IDF 特征,输出是两个类别(体育或娱乐)。
  • 使用了 CrossEntropyLoss 作为损失函数,因为这是分类任务中常用的损失函数。
  • 优化器使用了随机梯度下降(SGD),并设置了学习率为 0.01。

6. 训练过程

  • 训练模型的过程包括:前向传播(计算输出),计算损失,反向传播(更新参数),并在每个 epoch 后输出损失和准确率。
  • 每个 batch 的训练过程中,模型会通过计算损失并进行优化,逐步提升准确率。

7. 测试和评估

  • 在测试过程中,将模型设置为评估模式 (model.eval()),并计算测试集上的准确率。通过比较预测标签与真实标签,计算正确的预测数量并输出准确率。
  • 使用 classification_report 来输出精确度、召回率和 F1 分数等更多的评估指标。

8. 预测新样本

  • 定义了一个 predict() 函数,用于预测新的文本样本的分类。
  • 预测过程包括:对文本进行分词,转化为 TF-IDF 特征,传入模型进行前向传播,最后返回模型预测的标签。

9. 输出

  • 代码的最后,会输出模型对新文本的预测结果,标明是属于体育类别还是娱乐类别。

关键技术点:

  • 中文分词:使用 jieba 对中文文本进行分词处理,这对于中文文本的处理至关重要。
  • TF-IDF:将文本转换为数值特征,便于模型处理。TF-IDF 是基于单词在文档中的出现频率及其在整个语料中的稀有度进行加权的。
  • 模型训练与评估:通过多轮训练提升模型准确度,使用测试集来评估模型的泛化能力。
  • PyTorch DataLoader:通过 DataLoader 高效地处理训练集和测试集,进行批处理和自动化管理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962640.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Git进阶之旅:Git Hub注册创建仓库

介绍: GitHub 是一个面向开源及私有软件项目的托管平台,因为只支持 git 作为唯一的版本库格式进行托管,故名 GitHub 仓库注册: GitHub官网:https://github.com/ 修改本地仓库用户名: git config --local…

【B站保姆级视频教程:Jetson配置YOLOv11环境(五)Miniconda安装与配置】

Jetson配置YOLOv11环境(5)Miniconda安装与配置 文章目录 0. Anaconda vs Miniconda in Jetson1. 下载Miniconda32. 安装Miniconda33. 换源3.1 conda 换源3.2 pip 换源 4. 创建环境5. 设置默认启动环境 0. Anaconda vs Miniconda in Jetson Jetson 设备资…

Cesium ArcGisMapServerImageryProvider API 介绍

作为一名GIS研究生,WebGIS 技术无疑是我们必学的核心之一。说到WebGIS,要提的就是 Cesium —— 这个让3D地球可视化变得简单又强大的工具。为了帮助大家更好地理解和使用 Cesium,我决定把我自己在学习 Cesium 文档过程中的一些心得和收获分享…

STM32-时钟树

STM32-时钟树 时钟 时钟

【Unity3D】实现横版2D游戏——单向平台(简易版)

目录 问题 项目Demo直接使用免费资源:Hero Knight - Pixel Art (Asset Store搜索) 打开Demo场景,进行如下修改,注意Tag是自定义标签SingleDirCollider using System.Collections; using System.Collections.Generic;…

基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF。 2.测试软件版本以及运行结果展示 MATLAB2022A版本运行 3.核心程序 .…

学习笔记 ---- 平衡树 总结

文章目录 平衡树的含义二叉搜索树 t r e a p treap treap S p l a y Splay Splay(延伸树)优化思想 S p l a y Splay Splay 的定义核心操作文艺平衡树(序列操作)练习题平衡树维护序列平衡树维护数集 F H Q t r e a p FHQ \ treap F…

Windows基础

一. Windows防火墙与Defender 介绍:Windows防火墙与Defender是Windows操作系统中两大重要的安全组件,它们共同工作以保护计算机免受各种网络威胁和病毒攻击。 Windows防火墙:Windows防火墙是一种软件防火墙,旨在监控和控制进出计…

(1)Linux高级命令简介

Linux高级命令简介 在安装好linux环境以后第一件事情就是去学习一些linux的基本指令,我在这里用的是CentOS7作演示。 首先在VirtualBox上装好Linux以后,启动我们的linux,输入账号密码以后学习第一个指令 简介 Linux高级命令简介ip addrtou…

人工智能|基本概念|人工智能相关重要概念---AI定义以及模型相关知识

一、 前言: 最近deepseek(深度求索)公司的开源自然语言处理模型非常火爆。 本人很早就对人工智能比较感兴趣,但由于种种原因没有过多的深入此领域,仅仅是做了一点初步的了解,借着这个deepseek&#xff0…

【疑海破局】一个注解引发的线上事故

【疑海破局】一个注解引发的线上事故 1、问题背景 在不久前一个阳光明媚的上午,我的思绪正在代码中游走、双手正在键盘上飞舞。突然,公司内部通讯工具上,我被拉进了一个临时工作群,只见群中产品、运营、运维、测试等关键人员全部严阵以待,我就知道大的可能要来了。果不其…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.1 NumPy高级索引:布尔型与花式索引的底层原理

2.1 NumPy高级索引:布尔型与花式索引的底层原理 目录 #mermaid-svg-NpcC75NxxU2mkB3V {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-NpcC75NxxU2mkB3V .error-icon{fill:#552222;}#mermaid-svg-NpcC75…

如何在 ACP 中建模复合罐

概括 本篇博文介绍了 ANSYS Composite PrepPost (ACP) 缠绕向导。此工具允许仅使用几个条目自动定义高压罐中常见的悬垂复合结构。 ACP 绕线向导 将必要的信息输入到绕组向导中。重要的是要注意“参考半径”,它代表圆柱截面的半径,以及“轴向”&#x…

【Linux】使用管道实现一个简易版本的进程池

文章目录 使用管道实现一个简易版本的进程池流程图代码makefileTask.hppProcessPool.cc 程序流程: 使用管道实现一个简易版本的进程池 流程图 代码 makefile ProcessPool:ProcessPool.ccg -o $ $^ -g -stdc11 .PHONY:clean clean:rm -f ProcessPoolTask.hpp #pr…

【算法-位运算】求数字的补数

文章目录 1. 题目2. 思路3. 代码4. 小结 1. 题目 476. 数字的补数 对整数的二进制表示取反(0 变 1 ,1 变 0)后,再转换为十进制表示,可以得到这个整数的补数。 例如,整数 5 的二进制表示是 “101” &…

DeepSeek能下围棋吗?(续)

休息了一下,接着琢磨围棋,其实前面一篇里的规则有个漏洞的,就是邻居关系定义有问题,先回顾一下游戏规则: 游戏规则 定义: 1.数字对,是指两个1到9之间的整数组成的有序集合。可与记为(m,n)&…

[Collection与数据结构] B树与B+树

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

origin如何在已经画好的图上修改数据且不改变原图像的画风和格式

例如我现在的.opju文件长这样 现在我换了数据集,我想修改这两个图表里对应的算法里的数据,但是我还想保留这图像现在的形式,可以尝试像下面这样做: 右击第一个图,出现下面,选择Book[sheet1] 选择工作簿 出…

Workbench 中的热源仿真

探索使用自定义工具对移动热源进行建模及其在不同行业中的应用。 了解热源动力学 对移动热源进行建模为各种工业过程和应用提供了有价值的见解。激光加热和材料加工使用许多激光束来加热、焊接或切割材料。尽管在某些情况下,热源 (q) 不是通…

Midjourney中的强变化、弱变化、局部重绘的本质区别以及其有多逆天的功能

开篇 Midjourney中有3个图片“微调”,它们分别为: 强变化;弱变化;局部重绘; 在Discord里分别都是用命令唤出的,但如今随着AI技术的发达在类似AI可人一类的纯图形化界面中,我们发觉这样的逆天…