昇思25天学习打卡营第25天 | RNN情感分类

内容介绍:

情感分类是自然语言处理中的经典任务,是典型的分类问题。本节使用MindSpore实现一个基于RNN网络的情感分类模型,实现如下的效果:

输入: This film is terrible
正确标签: Negative
预测标签: Negative

输入: This film is great
正确标签: Positive
预测标签: Positive

具体内容:

1. 导包

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path
import re
import six
import string
import tarfile
import mindspore.dataset as ds
import zipfile
import numpy as np
import mindspore as ms
import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniform

2. 数据下载

# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'

def http_get(url: str, temp_file: IO):
    """使用requests库下载数据,并使用tqdm库进行流程可视化"""
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit='B', total=total)
    for chunk in req.iter_content(chunk_size=1024):
        if chunk:
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()

def download(file_name: str, url: str):
    """下载数据并存为指定名称"""
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    cache_path = os.path.join(cache_dir, file_name)
    cache_exist = os.path.exists(cache_path)
    if not cache_exist:
        with tempfile.NamedTemporaryFile() as temp_file:
            http_get(url, temp_file)
            temp_file.flush()
            temp_file.seek(0)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)
    return cache_path
imdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_path

3. 数据集加载

class IMDBData():
    """IMDB数据集加载器

    加载IMDB数据集并处理为一个Python迭代对象。

    """
    label_map = {
        "pos": 1,
        "neg": 0
    }
    def __init__(self, path, mode="train"):
        self.mode = mode
        self.path = path
        self.docs, self.labels = [], []

        self._load("pos")
        self._load("neg")

    def _load(self, label):
        pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))
        # 将数据加载至内存
        with tarfile.open(self.path) as tarf:
            tf = tarf.next()
            while tf is not None:
                if bool(pattern.match(tf.name)):
                    # 对文本进行分词、去除标点和特殊字符、小写处理
                    self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r"))
                                         .translate(None, six.b(string.punctuation)).lower()).split())
                    self.labels.append([self.label_map[label]])
                tf = tarf.next()

    def __getitem__(self, idx):
        return self.docs[idx], self.labels[idx]

    def __len__(self):
        return len(self.docs)
imdb_train = IMDBData(imdb_path, 'train')
len(imdb_train)
def load_imdb(imdb_path):
    imdb_train = ds.GeneratorDataset(IMDBData(imdb_path, "train"), column_names=["text", "label"], shuffle=True, num_samples=10000)
    imdb_test = ds.GeneratorDataset(IMDBData(imdb_path, "test"), column_names=["text", "label"], shuffle=False)
    return imdb_train, imdb_test
imdb_train, imdb_test = load_imdb(imdb_path)
imdb_train

4. 加载预训练词向量

def load_glove(glove_path):
    glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')
    if not os.path.exists(glove_100d_path):
        glove_zip = zipfile.ZipFile(glove_path)
        glove_zip.extractall(cache_dir)

    embeddings = []
    tokens = []
    with open(glove_100d_path, encoding='utf-8') as gf:
        for glove in gf:
            word, embedding = glove.split(maxsplit=1)
            tokens.append(word)
            embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))
    # 添加 <unk>, <pad> 两个特殊占位符对应的embedding
    embeddings.append(np.random.rand(100))
    embeddings.append(np.zeros((100,), np.float32))

    vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)
    embeddings = np.array(embeddings).astype(np.float32)
    return vocab, embeddings

5. 下载Glove

glove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
len(vocab.vocab())
idx = vocab.tokens_to_ids('the')
embedding = embeddings[idx]
idx, embedding

6. 数据预处理

lookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
type_cast_op = ds.transforms.TypeCast(ms.float32)
imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])

imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])
imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
imdb_train = imdb_train.batch(64, drop_remainder=True)
imdb_valid = imdb_valid.batch(64, drop_remainder=True)

6. 模型构建

class RNN(nn.Cell):
    def __init__(self, embeddings, hidden_dim, output_dim, n_layers,
                 bidirectional, pad_idx):
        super().__init__()
        vocab_size, embedding_dim = embeddings.shape
        self.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)
        self.rnn = nn.LSTM(embedding_dim,
                           hidden_dim,
                           num_layers=n_layers,
                           bidirectional=bidirectional,
                           batch_first=True)
        weight_init = HeUniform(math.sqrt(5))
        bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
        self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)

    def construct(self, inputs):
        embedded = self.embedding(inputs)
        _, (hidden, _) = self.rnn(embedded)
        hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
        output = self.fc(hidden)
        return output

7. 损失函数和优化器

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')

model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)

8. 训练过程

def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss

grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)

def train_step(data, label):
    loss, grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train_one_epoch(model, train_dataset, epoch=0):
    model.set_train()
    total = train_dataset.get_dataset_size()
    loss_total = 0
    step_total = 0
    with tqdm(total=total) as t:
        t.set_description('Epoch %i' % epoch)
        for i in train_dataset.create_tuple_iterator():
            loss = train_step(*i)
            loss_total += loss.asnumpy()
            step_total += 1
            t.set_postfix(loss=loss_total/step_total)
            t.update(1)

9. 评估指标

def binary_accuracy(preds, y):
    """
    计算每个batch的准确率
    """

    # 对预测值进行四舍五入
    rounded_preds = np.around(ops.sigmoid(preds).asnumpy())
    correct = (rounded_preds == y).astype(np.float32)
    acc = correct.sum() / len(correct)
    return acc
def evaluate(model, test_dataset, criterion, epoch=0):
    total = test_dataset.get_dataset_size()
    epoch_loss = 0
    epoch_acc = 0
    step_total = 0
    model.set_train(False)

    with tqdm(total=total) as t:
        t.set_description('Epoch %i' % epoch)
        for i in test_dataset.create_tuple_iterator():
            predictions = model(i[0])
            loss = criterion(predictions, i[1])
            epoch_loss += loss.asnumpy()

            acc = binary_accuracy(predictions, i[1])
            epoch_acc += acc

            step_total += 1
            t.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)
            t.update(1)

    return epoch_loss / total

10. 模型训练与保存

num_epochs = 2
best_valid_loss = float('inf')
ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')

for epoch in range(num_epochs):
    train_one_epoch(model, imdb_train, epoch)
    valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        ms.save_checkpoint(model, ckpt_file_name)

我对RNN的基本原理,如时间步的展开、状态传递、梯度消失与爆炸等问题有了更加深入的理解,还掌握了如何通过门控机制(如LSTM、GRU)来优化这些问题。MindSpore的API设计清晰直观,使得从理论到实践的转换变得顺畅无比,我能够迅速地将理论知识应用于构建具体的模型,如文本生成、情感分析或时间序列预测等任务中。、

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/796689.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

五. TensorRT API的基本使用-MNIST-model-build-infer

目录 前言0. 简述1. 案例运行2. 代码分析2.1 main函数2.2 build接口2.3 infer接口2.4 其他 总结参考 前言 自动驾驶之心推出的 《CUDA与TensorRT部署实战课程》&#xff0c;链接。记录下个人学习笔记&#xff0c;仅供自己参考 本次课程我们来学习课程第五章—TensorRT API 的基…

Mediapipe-姿态估计实例

Mediapipe简介 Mediapipe 是由 Google Research 开发的一款开源框架&#xff0c;旨在帮助开发者轻松地构建、测试和部署复杂的多模态、多任务的机器学习模型。它特别擅长于实时处理和分析音频、视频等多媒体数据。以下是 Mediapipe 的一些关键特点和组件&#xff1a; 关键特点…

Telegram Bot、小程序开发(一)基础入门

文章目录 一、Telegram Bot是什么&#xff1f;二、Telegram Bot应用场景三、机器人是如何工作的&#xff1f;架构getUpdates 和 webhookswebhooks要求自签名证书 四、如何创建和使用Telegram Bot&#xff1f;整体步骤和流程Bot 的申请过程将机器人添加到 Telegram 群组 一、Tel…

函数(实参以及形参)

实际参数&#xff08;实参&#xff09; 实际参数就是在调用函数时传递给函数的具体值。这些值可以是常量、变量、表达式或更复杂的数据结构。实参的值在函数被调用时传递给对应的形参&#xff0c;然后函数内部就可以使用这些值来执行相应的操作。 int main() {int a 0;int b …

嵌入式人工智能应用-篇外-烧写说明

1 外部接线 1.1 前期准备 需要准备的工具 ⚫ 一根 Mini USB 线 ⚫ 嵌入式人工智能教学科研平台 ⚫ 12V DC 电源 ⚫ 一台电脑 1.2 接线 12V DC 电源接入 12V IN&#xff1b;Mini USB 线连接 USB OTG&#xff1b;如果有两条 Mini USB 线&#xff0c;可以接入 UART2 to USB 口…

git安装使用gitlab

第一步&#xff1a;下载git 第二步&#xff1a;安装 第三步&#xff1a;配置sshkey 第四步&#xff1a;处理两台电脑的sshkey问题 第一步下载git 网址&#xff1a;Git点Downloads根据你的操作系统选择对应的版本&#xff0c;我的是Windows&#xff0c;所以我选择了Windows …

动手学深度学习(Pytorch版)代码实践 -注意力机制-Transformer

68Transformer 1. PositionWiseFFN 基于位置的前馈网络 原理&#xff1a;这是一个应用于每个位置的前馈神经网络。它使用相同的多层感知机&#xff08;MLP&#xff09;对序列中的每个位置独立进行变换。作用&#xff1a;对输入序列的每个位置独立地进行非线性变换&#xff0c…

Open-TeleVision——通过VR沉浸式感受人形机器人视野:兼备远程控制和深度感知能力

前言 7.3日&#xff0c;我司七月在线(集AI大模型职教、应用开发、机器人解决方案为一体的科技公司)的「大模型机器人(具身智能)线下营」群里的一学员发了《Open-TeleVision: Teleoperation with Immersive Active Visual Feedback》这篇论文的链接&#xff0c;我当时快速看了一…

【网络文明】关注网络安全

在这个数字化时代&#xff0c;互联网已成为我们生活中不可或缺的一部分&#xff0c;它极大地便利了我们的学习、工作、娱乐乃至日常生活。然而&#xff0c;随着网络空间的日益扩大&#xff0c;网络安全问题也日益凸显&#xff0c;成为了一个不可忽视的全球性挑战。认识到网络安…

C双指针滑动窗口算法

这也许是双指针技巧的最⾼境界了&#xff0c;如果掌握了此算法&#xff0c;可以解决⼀⼤类⼦字符串匹配的问题 原理 1、我们在字符串 S 中使⽤双指针中的左右指针技巧&#xff0c;初始化 left right 0&#xff0c;把索引闭区间 [left, right] 称为⼀个「窗⼝」。 2、我们先…

开发个人Ollama-Chat--6 OpenUI

开发个人Ollama-Chat–6 OpenUI Open-webui Open WebUI 是一种可扩展、功能丰富且用户友好的自托管 WebUI&#xff0c;旨在完全离线运行。它支持各种 LLM 运行器&#xff0c;包括 Ollama 和 OpenAI 兼容的 API。 功能 由于总所周知的原由&#xff0c;OpenAI 的接口需要密钥才…

总结单例模式的写法

一、单例模式的概念 1.1 单例模式的概念 单例模式&#xff08;Singleton Pattern&#xff09;是 Java 中最简单的设计模式之一。这种类型的设计模式属于创建型模式&#xff0c;它提供了一种创建对象的最佳方式。就是当前进程确保一个类全局只有一个实例。 1.2 单例模式的优…

首次跑通Arduino IDE + ESP8266

感触 网络&#xff0c;网络&#xff0c;还是网络。 中文开发者往往具备更多的网络知识和能力。 文档&#xff0c;文档&#xff0c;更是文档。 在计算机领域&#xff0c;遇到的问题&#xff0c;99.999%前人已经解决过。

Lesson 50 He likes ... But he doesn‘t like ...

Lesson 50 He likes … But he doesn’t like … 词汇 tomato n. 西红柿 复数&#xff1a;tomatoes 同义词&#xff1a;ketch-up 番茄酱     dead horse 例句&#xff1a;冰箱里有很多西红柿。    There are some tomatoes in the fridge. potato n. 土豆 复数&#…

【公益案例展】华为云X《无尽攀登》——攀登不停,向上而行

‍ 华为云公益案例 本项目案例由华为云投递并参与数据猿与上海大数据联盟联合推出的 #榜样的力量# 《2024中国数据智能产业最具社会责任感企业》榜单/奖项”评选。 大数据产业创新服务媒体 ——聚焦数据 改变商业 夏伯渝&#xff0c;中国无腿登珠峰第一人&#xff0c;一生43年…

Redis持久化RDB,AOF

目 录 CONFIG动态修改配置 慢查询 持久化 在上一篇主要对redis的了解入门&#xff0c;安装&#xff0c;以及基础配置&#xff0c;多实例的实现&#xff1a;redis的安装看我上一篇&#xff1a; Redis安装部署与使用,多实例 redis是挡在MySQL前面的&#xff0c;运行在内存…

Java基础-I/O流

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 字节流 定义 说明 InputStream与OutputStream示意图 说明 InputStream的常用方法 说明 OutputStrea…

虚函数__

10 文章目录 虚函数虚函数表override(不允许后续函数继承)虚析构纯虚函数 虚函数 虚函数表 override(不允许后续函数继承) 虚析构 纯虚函数

C++的deque(双端队列),priority_queue(优先级队列)

deque deque是一个容器,是双端队列,从功能上来讲,deque是一个vector和list的结合体 顺序表和链表 deque的结构和优缺点 开辟buff小数组,空间不够了,不扩容,而是开辟一个新的小数组 开辟中控数组(指针数组)指向buff小数组 将已存在的数组指针存在中控数组中间,可以使用下标访…

【ARM】CCI集成指导整理

目录 1.CCI集成流程 2.CCI功能集成指导 2.1CCI结构框图解释 Request concentrator Transaction tracker Read-data Network Write-data Network B-response Network 2.2 接口注意项 记录一下CCI500的ACE slave interface不支持的功能&#xff1a; 对于ACE-Lite slav…