CENet及多模态情感计算实战


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

一、概述

二、论文地址

三、研究背景

四、主要贡献

五、论文思路

六、主要内容和网络架构

七、数据集介绍

八、性能对比

九、复现过程(重要)

十、演示结果


本文所有资源均可在该地址处获取。

一、概述

本文对 “Cross-Modal Enhancement Network for Multimodal Sentiment Analysis” 论文进行讲解和手把手复现教学,解决当下热门的多模态情感计算问题,并展示在MOSI和MOSEI两个数据集上的效果

二、论文地址

DOI: 10.1109/TMM.2022.3183830

三、研究背景

情感分析在人工智能向情感智能发展中起着重要作用。早期的情感分析研究主要集中在分析单模态数据上,包括文本情感分析、图像情感分析、音频情感分析等。然而,人类的情感是通过人体的多种感官来传达的。因此,单模态情感分析忽略了人类情感的多维性。相比之下,多模态情感分析通过结合文本、视觉和音频等多模态数据来推断一个人的情感状态。与单模态情感分析相比,多模态数据包含多样化的情感信息,具有更高的预测精度。目前,多模态情感分析已被广泛应用于视频理解、人机交互、政治活动等领域。近年来,随着互联网和各种多媒体平台的快速发展,通过互联网表达情感的载体和方式也变得越来越多样化。这导致了多媒体数据的快速增长,为多模态情感分析提供了大量的数据源。下图展示了多模态在情感计算任务中的优势。

四、主要贡献

  • 提出了一种跨模态增强网络,通过融入长范围非文本情感语境来增强预训练语言模型中的文本表示;
  • 提出一种特征转换策略,通过减小文本模态和非文本模态的初始表示之间的分布差异,促进了不同模态的融合;
  • 融合了新的预训练语言模型SentiLARE来提高模型对情感词的提取效率,从而提升对情感计算的准确性。

五、论文思路

作者提出的跨模态增强网络(CENet)模型通过将视觉和声学信息集成到语言模型中来增强文本表示。在基于transformer的预训练语言模型中嵌入跨模态增强(CE)模块,根据非对齐非文本数据中隐含的长程情感线索增强每个单词的表示。此外,针对声学模态和视觉模态,提出了一种特征转换策略,以减少语言模态和非语言模态的初始表示之间的分布差异,从而促进不同模态的融合。

六、主要内容和网络架构

首先我们展示一下CENet的总体网络架构


通过该图,我们可以看出该模型主要有以下几部分组成:1.非文本模态特征转化;2.跨模态增强;3.预训练语言模型输出;接下来将对他们分别进行讲解:
1. 非文本模态转换
针对预训练语言模型,初始文本表示是基于词汇表的单词索引序列,而视觉和听觉的表示则是实值向量序列。为了缩小这些异质模态之间的初始分布差异,进而减少在融合过程中非文本特征和文本特征之间的分布差距,本文提出了一种将非文本向量转换为索引的特征转换策略。这种策略有助于促进文本表征与非文本情感语境的有机融合。

具体而言,特征转换策略利用无监督聚类算法分别构建了“声学词汇表”和“视觉词汇表”。通过查询这些非语言词汇表,可以将原始的非语言特征序列转换为索引序列。下图展示了特征转换过程的具体步骤。考虑到k-means方法具有计算复杂度低和实现简单等优点,作者选择使用k-means算法来学习非语言模态的词汇。

2. 跨模态增强模块
本文提出的CE模块旨在将长程视觉和声学信息集成到预训练语言模型中,以增强文本的表示能力。CE模块的核心组件是跨模态嵌入单元,其结构如下图所示。该单元利用跨模态注意力机制捕捉长程非文本情感信息,并生成基于文本的非语言嵌入。嵌入层的参数可学习,用于将经过特征转换策略处理后得到的非文本索引向量映射到高维空间,然后生成文本模态对非文本模态的注意力权重矩阵。

在初始训练阶段,由于语言表征和非语言表征处于不同的特征空间,它们之间的相关性通常较低。因此,注意力权重矩阵中的元素可能较小。为了更有效地学习模型参数,研究者在应用softmax之前使用超参数来缩放这些注意力权重矩阵。

基于注意力权重矩阵,可以生成基于文本的非语言向量。将基于文本的声学嵌入和基于文本的视觉嵌入结合起来,形成非语言增强嵌入。最后,通过整合非语言增强嵌入来更新文本的表示。因此,CE模块的提出旨在为文本提供非语言上下文信息,通过增加非语言增强嵌入来调整文本表示,从而使其在语义上更加准确和丰富。

3. 预训练语言模型输出
作者采用SentiLARE作为语言模型,其利用包括词性和单词情感极性在内的语言知识来学习情感感知的语言表示。CE模块被集成到预训练语言模型的第i层中。值得注意的是,任何基于Transformer的预训练语言模型都可以与CE模块集成。下面是作者根据SentiLARE的设置进行的步骤:

  1. 给定一个单词序列,首先通过Stanford Log-Linear词性(POS)标记器学习其词性序列,并通过SentiwordNet学习单词级情感极性序列。

  2. 然后,使用预训练语言模型的分词器获取词标索引序列。这个序列作为输入,产生一个初步的增强语言知识表示。

  3. 更新后的文本表示将作为第(i+1)层的输入,并通过SentiLARE中的剩余层进行处理。

  4. 每一层的输出将是具有视觉和听觉信息的文本主导的高级情感表示。

  5. 最后,将这些文本表示输入到分类头中,以获取情感强度。

因此,CE模块通过将非语言增强嵌入集成到预训练语言模型中,有助于生成更富有情感感知的语言表示。这种方法能够在文本表示中有效地整合视觉和听觉信息,从而提升情感分析等任务的性能。

七、数据集介绍

1. CMU-MOSI: 它是一个多模态数据集,包括文本、视觉和声学模态。它来自Youtube上的93个电影评论视频。这些视频被剪辑成2199个片段。每个片段都标注了[-3,3]范围内的情感强度。该数据集分为三个部分,训练集(1,284段)、验证集(229段)和测试集(686段)。
2. CMU-MOSEI: 它类似于CMU-MOSI,但规模更大。它包含了来自在线视频网站的23,453个注释视频片段,涵盖了250个不同的主题和1000个不同的演讲者。CMU-MOSEI中的样本被标记为[-3,3]范围内的情感强度和6种基本情绪。因此,CMU-MOSEI可用于情感分析和情感识别任务。

八、性能对比

有下图可以观察到,该论文提出的CENet与其他SOTA模型对比性能有明显提升:

九、复现过程(重要)

1. 数据集准备
下载MOSI和MOSEI数据集已提取好的特征文件(.pkl)。把它放在"./dataset”目录。

2. 下载预训练语言模型
下载SentiLARE语言模型文件,然后将它们放入"/pretrained-model / sentilare_model”目录。

3. 下载需要的包

pip install -r requirements.txt

4. 搭建CENet模块
利用pytorch框架对CENet模块进行搭建:

class CE(nn.Module):
    def __init__(self, beta_shift_a=0.5, beta_shift_v=0.5, dropout_prob=0.2):
        super(CE, self).__init__()
        self.visual_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size)
        self.acoustic_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size)
        self.hv = SelfAttention(TEXT_DIM)
        self.ha = SelfAttention(TEXT_DIM)
        self.cat_connect = nn.Linear(2 * TEXT_DIM, TEXT_DIM)
        

    def forward(self, text_embedding, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None):
        visual_ = self.visual_embedding(visual_ids)
        acoustic_ = self.acoustic_embedding(acoustic_ids)
        visual_ = self.hv(text_embedding, visual_)
        acoustic_ = self.ha(text_embedding, acoustic_) 
        visual_acoustic = torch.cat((visual_, acoustic_), dim=-1)
        shift = self.cat_connect(visual_acoustic)
        embedding_shift = shift + text_embedding
    
        return embedding_shift

5. 将CE模块与预训练语言模型融合

class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.CE = CE()

    def forward(self, hidden_states, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, attention_mask=None, head_mask=None):
        all_hidden_states = ()
        all_attentions = ()
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if i == ROBERTA_INJECTION_INDEX:
                hidden_states = self.CE(hidden_states, visual=visual, acoustic=acoustic, visual_ids=visual_ids, acoustic_ids=acoustic_ids)

6. 训练代码编写

  • 定义一个整体的训练过程 train()函数,它负责训练模型多个 epoch,并在每个 epoch 结束后评估模型在验证集和测试集上的性能,并记录相关的指标和损失;并在训练最后一轮输出所有测试集id,true label 和 predicted label。
def train(
    args,
    model,
    train_dataloader,
    validation_dataloader,
    test_data_loader,
    optimizer,
    scheduler,
):
    valid_losses = []
    test_accuracies = []
    for epoch_i in range(int(args.n_epochs)):
        train_loss, train_pre, train_label = train_epoch(args, model, train_dataloader, optimizer, scheduler)
        valid_loss, valid_pre, valid_label = evaluate_epoch(args, model, validation_dataloader)

        test_loss, test_pre, test_label = evaluate_epoch(args, model, test_data_loader)
        train_acc, train_mae, train_corr, train_f_score = score_model(train_pre, train_label)
        test_acc, test_mae, test_corr, test_f_score = score_model(test_pre, test_label)
        non0_test_acc, _, _, non0_test_f_score = score_model(test_pre, test_label, use_zero=True)
        valid_acc, valid_mae, valid_corr, valid_f_score = score_model(valid_pre, valid_label)

        print(
            "epoch:{}, train_loss:{}, train_acc:{}, valid_loss:{}, valid_acc:{}, test_loss:{}, test_acc:{}".format(
                epoch_i, train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc
            )
        )
        valid_losses.append(valid_loss)
        test_accuracies.append(test_acc)
        wandb.log(
            (
                {
                    "train_loss": train_loss,
                    "valid_loss": valid_loss,
                    "train_acc": train_acc,
                    "train_corr": train_corr,
                    "valid_acc": valid_acc,
                    "valid_corr": valid_corr,
                    "test_loss": test_loss,
                    "test_acc": test_acc,
                    "test_mae": test_mae,
                    "test_corr": test_corr,
                    "test_f_score": test_f_score,
                    "non0_test_acc": non0_test_acc,
                    "non0_test_f_score": non0_test_f_score,
                    "best_valid_loss": min(valid_losses),
                    "best_test_acc": max(test_accuracies),
                }
            )
        )

    # 输出测试集的 id、真实标签和预测标签
    with torch.no_grad():
        for step, batch in enumerate(test_data_loader):
            batch = tuple(t.to(DEVICE) for t in batch)
            input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch
            visual = torch.squeeze(visual, 1)
            outputs = model(
                input_ids,
                visual,
                acoustic,
                visual_ids,
                acoustic_ids,
                pos_ids, senti_ids, polarity_ids,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                labels=None,
            )
            logits = outputs[0]
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.detach().cpu().numpy()
            logits = np.squeeze(logits).tolist()
            label_ids = np.squeeze(label_ids).tolist()
            
            # 假设您从 label_ids 中获取 ids
            ids = [f"sample_{idx}" for idx in range(len(label_ids))]  # 这里是示例,您可以根据实际情况生成合适的 ids
            
            # 输出所有测试样本的 id、真实标签和预测标签值
            for i in range(len(ids)):
                print(f"id: {ids[i]}, true label: {label_ids[i]}, predicted label: {logits[i]}")

6. 开始训练+测试

python train.py

7. 输出结果
此外,为了方便直观地查看模型性能,我在最后一层训练结束后将所有测试集视频的clip id、真实标签和预测标签依次进行输出;并且结合wandb库自动保存结果可视化;结果在后续章节展示。

十、演示结果

  • 以下是我们的训练过程

  • 下面是模型性能结果

  • 接下来是我们自己补充的每个测试集的真实标签和预测标签

  • 可视化

​​

希望对你有帮助!加油!

若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/926813.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于深度学习和卷积神经网络的乳腺癌影像自动化诊断系统(PyQt5界面+数据集+训练代码)

乳腺癌是全球女性中最常见的恶性肿瘤之一,早期准确诊断对于提高生存率具有至关重要的意义。传统的乳腺癌诊断方法依赖于放射科医生的经验,然而,由于影像分析的复杂性和人类判断的局限性,准确率和一致性仍存在挑战。近年来&#xf…

【热门主题】000074 深度学习模型:探索与应用

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 【热…

MacOS使用VSCode编写C++程序如何配置clang编译环境

前言 这段时间在练习写C和Python,用vscode这个开发工具,调试的时候遇到一些麻烦,浪费了很多时间,因此整理了这个文档。将详细的细节描述清楚,避免与我遇到同样问题的人踩坑。 1.开发环境的配置 vscode的开发环境配置…

Scala关于成绩的常规操作

score.txt中的数据: 姓名,语文,数学,英语 张伟,87,92,88 李娜,90,85,95 王强,78,90,82 赵敏,92,8…

【实战】在Koa.js中实现文件上传的接口 (本地存储)

目录 环境准备 使用 koa-body 中间件获取上传的文件 使用 Postman 测试 使用 koa-static 中间件生成图片链接 编写前端页面上传文件 文件上传是一个基本的功能,每个系统几乎都会有,比如上传图片、上传Excel等。那么在Node Koa应用中如何实现一个支持…

使用html语言完成拼多多移动端导航栏的设计-大连东软信息学院计算机科学与技术专业高级网页设计基础课题

目录 前言 一、效果图 二、图标的使用 三、代码的编写 四、运行效果 五、文档编写 前言 1.本文所讲内容来自辽宁大连东软信息学院计算机与技术专业高级网页设计(专升本)课程期中四级项目课题之一,题目要求是自主选择相应的APP移动端&…

从语法、功能、社区和使用场景来比较 Sass 和 LESS

一:可以从语法、功能、社区和使用场景来比较 Sass 和 LESS: 1:语法 原始的 Sass 采用的是缩进而不是大括号,后续的 Sass 版本与 LESS 一样使用与 CSS 类似的语法: address {.fa.fa-mobile-phone {margin: 0 3px 0 2…

7. 现代卷积神经网络

文章目录 7.1. 深度卷积神经网络(AlexNet)7.2. 使用块的网络(VGG)7.3. 网络中的网络(NiN)7.4. 含并行连结的网络(GoogLeNet)7.5. 批量规范化7.5.1. 训练深层网络7.5.2. 批量规范化层…

sqlmap详细使用

SQLmap使用详解 SQLmap(常规)使用步骤 1、查询注入点 python sqlmap.py -u http://127.0.0.1/sqli-labs/Less-1/?id12、查询所有数据库 python sqlmap.py -u http://127.0.0.1/sqli-labs/Less-1/?id1 --dbs3、查询当前数据库 python sqlmap.py -u htt…

React+TS+css in js 练习

今天分享的内容是动态规划的经典问题--0-1 背包问题 0-1背包问题的描述如下:给定一组物品,每种物品都有自己的重量和价值,背包的总容量是固定的。我们需要从这些物品中挑选一部分,使得背包内物品的总价值最大,同时不超过背包的总容量。 举个例子:假设这组物品的质量…

刷题日常(找到字符串中所有字母异位词,​ 和为 K 的子数组​,​ 滑动窗口最大值​,全排列)

找到字符串中所有字母异位词 给定两个字符串 s 和 p,找到 s 中所有 p 的 异位词的子串,返回这些子串的起始索引。不考虑答案输出的顺序。 题目分析: 1.将p里面的字符先丢进一个hash1中,只需要在S字符里面找到多少个和他相同的has…

《C++ Primer Plus》学习笔记|第8章 函数探幽 (24-11-30更新)

文章目录 8.1 内联函数8.2 引用变量8.2.1 创建引用变量8.2.2 将引用用作函数参数8.2.3 引用的属性和特别之处特点1:在计算过程中,传入的形参的值也被改变了。特点2:使用引用的函数参数只接受变量,而不接受变量与数值的运算左值引用…

[2024年1月28日]第15届蓝桥杯青少组stema选拔赛C++中高级(第二子卷、编程题(1))

参考程序&#xff1a; #include <iostream> #include <algorithm> // 用于 std::sortusing namespace std;int main() {int a, b, c;cin >> a >> b >> c;// 将三个数放入一个数组中int arr[3] {a, b, c};// 对数组进行排序sort(arr, arr 3);…

基于hexo框架的博客搭建流程

这篇博文讲一讲hexo博客的搭建及文章管理&#xff0c;也算是我对于暑假的一个交代 &#xff01;&#xff01;&#xff01;注意&#xff1a;下面的操作是基于你已经安装了node.js和git的前提下进行的&#xff0c;并且拥有github账号 创建一个blog目录 在磁盘任意位置创建一个…

基于Java Springboot传统戏曲推广微信小程序

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术&#xff1a;Html、Css、Js、Vue、Element-ui 数据库&#xff1a;MySQL 后端技术&#xff1a;Java、Spring Boot、MyBatis 三、运行环境 开发工具&#xff1a;IDEA/eclipse 微信…

数据结构--树二叉树顺序结构存储的二叉树(堆)

前言 前面我们学习了顺序表、链表、栈和队列&#xff0c;这些都是线性的数据结构。今天我们要来学习一种非线性的数据结构——树。 树的概念及结构 树的概念 树是一种非线性的数据结构&#xff0c;是由n&#xff08;n≥0&#xff09;个有效结点组成的一个具有层次关系的集合…

网络安全运行与维护 加固练习题

1. 提交用户密码的最小长度要求。 输入代码: cat /etc/pam.d/common-password 提交答案: flag{20} 2.提交iptables配置以允许10.0.0.0/24网段访问22端口的命令。 输入代码: iptables -A INPUT -p tcp -s 10.0.0.0/24 --dport 22 -j ACCEPT 提交答案: flag{iptables -A I…

【汇编语言】call 和 ret 指令(三) —— 深度解析汇编语言中的批量数据传递与寄存器冲突

文章目录 前言1. 批量数据的传递1.1 存在的问题1.2 如何解决这个问题1.3 示例演示1.3.1 问题说明1.3.2 程序实现 2. 寄存器冲突问题的引入2.1 问题引入2.2 分析与解决问题2.2.1 字符串定义方式2.2.2 分析子程序功能2.2.3 得到子程序代码 2.3 子程序的应用2.3.1 示例12.3.2 示例…

Java 泛型详细解析

泛型的定义 泛型类的定义 下面定义了一个泛型类 Pair&#xff0c;它有一个泛型参数 T。 public class Pair<T> {private T start;private T end; }实际使用的时候就可以给这个 T 指定任何实际的类型&#xff0c;比如下面所示&#xff0c;就指定了实际类型为 LocalDate…

Python语法基础(四)

&#x1f308;个人主页&#xff1a;羽晨同学 &#x1f4ab;个人格言:“成为自己未来的主人~” 高阶函数之map 高阶函数就是说&#xff0c;A函数作为B函数的参数&#xff0c;B函数就是高阶函数 map&#xff1a;映射 map(func,iterable) 这个是map的基本语法&#xff0c;…