Embedding模型提升效果的方法之一:Whitening和pooling

0. 前言

Embedding模型的主流框架基本上分为三类——基于bert结构的,基于GPT结构的和基于T5结构的,当然这些结构都是Transformer的变形。对于Embedding模型,使用bert结构目前看是最好的。有篇论文论文对基于bert的Embedding模型和基于GPT的Embedding模型做过比较,暂时找不到了,后续找到会附上。另外本人也对基于bert的embedding模型和基于GPT的embedding模型做了比较试验,结果表明基于bert的embedding模型完胜。

要让 embedding 模型性能提升,除了在模型结构和训练数据上做文章之外,还可以使用Whitening方法,特殊的pooling方法和Simcse方法来提升效果。本文先介绍Whitening方法和特殊的pooling方法,下一篇介绍Simcse方法。

1. Whitening

1.1 为什么可以用白化方法提升模型效果

白化操作不仅可以提升模型效果,还可以对句子向量进行降维

白化(whitening)方法之所以能够在embedding模型上产生正向效果是因为我们通常会用两个句子向量的余弦相似度来衡量这两个句子的相似性。但是由于类似由BERT和GPT这样的预训练语言模型得到的句子向量往往是具备各向异性的,表现状态就是向量会不均匀分布,且充斥在一个狭窄的锥形空间下。但是,具备各向异性的向量之间直接计算余弦相似度会出现一些偏差,这就导致 embedding 模型的表现变差。所以,我们只要将各向异性的句子向量转化为一个各向同性的句子向量就可以提升 embedding 模型的效果。此时就用到了白化操作。

在这里插入图片描述

1.2 余弦相似度和各向异性

  • consin
    假设 x x x y y y 两个向量,维度都是 R d R^d Rd。那么,利用cosine的计算方法,他们的相似度为:
    在这里插入图片描述
    上述方程 (1) 仅在坐标基(二维向量)为标准正交基时才成立。余弦角度具有明显的几何意义,但方程(1) 是基于运算的,它取决于所选的坐标基。因此,内积的坐标公式随着坐标基的变化而变化,余弦值的坐标公式也会随之变化。

    (Li et al., 2020) 验证了来自 BERT (Devlin et al., 2019) 的句子嵌入虽然没有得到适当的利用,但已经包含了足够的语义。在这种情况下,如果在操作方程(1) 计算语义相似度的余弦值时句子向量表现不佳,原因可能是句子向量所属的坐标基不是标准正交基。从统计学的角度,我们可以推断,当我们为一组向量选择基时,应该保证每个基向量是独立且一致的。如果这组基是标准正交基,则相应的向量组应该具备各向同性。

    综上所述,上述启发式假设详尽地表明:如果一组向量满足各向同性,我们可以假设从标准正交基推导出来,这也表明我们可以通过方程 (1) 计算余弦相似度。否则,如果它是各向异性的,我们需要对原始向量进行变换以某种方式嵌入句子以强制它是各向同性的,然后使用等式 (1) 计算余弦相似度。

  • 各向异性

    定义:各向异性是指在不同的方向上物理性质(表达含义)不同,各向同性是指不同的方向上物理性质相同。

    BERT和GPT的各项异性是怎么产生的:假设一个句子的向量为 { x i } i = 1 N \{x_i\}_{i=1}^N {xi}i=1N,某2个字的向量分别为 x j = [ x j 1 , x j 2 , x j 3 , … , x j n ] x_j=[x_j^1,x_j^2,x_j^3,\dots,x_j^n] xj=[xj1,xj2,xj3,,xjn] x h = [ x h 1 , x h 2 , x h 3 , … , x h n ] x_h=[x_h^1,x_h^2,x_h^3,\dots,x_h^n] xh=[xh1,xh2,xh3,,xhn],其中 可以理解为参数句子长度sequence_length。由于 BERT 的Token Embedding与Position Embedding的设计结构,导致了生成的句子向量不仅仅包含某单一token的MASK信息,同时还具备了这个token在不同位置所代表的Position信息,这直接为各向异性创作了条件。简单理解,假设 x j 1 x_j^1 xj1 x j 2 x_j^2 xj2 分别代表这个 token 的 mask 信息和 position 信息(实际不是这样简单的),这两个维度就是不同方向有不同的性质。再举一个反例,one-hot词向量就不具备各向异性,而是具备了各向同性的特点。

1.3 白化计算

经过上述的分析,想要让基于 bert 的 embedding 模型生成的句子向量用余弦相似度正确表示两个句子的相似性就得将句子向量转换到标准正交基下面

  • 标准正交基
    我们知道,对于两个向量 A A A B B B来说,如果 A ⋅ B = 0 A \cdot B=0 AB=0,那么,我们称这两个向量正交(零向量与任何向量正交)。 我们知道,在n维的欧式空间中,由n个向量组成的正交向量组称为正交基;由单位向量组成的正交基称为标准正交基。

已知:
在这里插入图片描述

在这里插入图片描述
A A A B B B 都不是0向量的时候,要让 A ⋅ B = 0 A \cdot B=0 AB=0,则 c o s ( A , B ) = 0 cos(A,B)=0 cos(A,B)=0,也就是

∑ i = 1 d a i b i = 0 \sum_{i=1}^da_ib_i=0 i=1daibi=0,而 ∑ i = 1 d a i b i \sum_{i=1}^da_ib_i i=1daibi 表示向量 A A A B B B 的协方差。

此时问题转换为:

已知原句子向量矩阵为 X X X,协方差矩阵为 C C C,目标是将 X X X 转换为协方差为0的向量矩阵 Y Y Y Y Y Y 的协方差矩阵为 D D D,求转换矩阵为 P P P

其中 X X X 可表示为:
在这里插入图片描述
其中 n n n 为sequence length, a a a b b b 代表不同的维度。

然后根据协方差的计算公式,可得:

在这里插入图片描述
我们可以看到这个矩阵对角线上的分别是两个变量的方差,而其它元素是 a 和 b 的协方差。两者被统一到了一个矩阵里。 我们很容易被推广到一般情况:

设我们有 m 个 n 维数据记录,将其排列成矩阵 X m , n X_{m,n} Xm,n ,设 C = 1 m X X T C = \frac{1}{m}XX^T C=m1XXT,则 C C C 是一个对称矩阵,其对角线分别对应各个变量的方差,而第 i 行 j 列和 j 行 i 列元素相同,表示 i 和 j 两个变量的协方差。

由此可知,我们需要将除对角线外的其它元素化为 0,并且在对角线上将元素按大小从上到下排列(变量方差尽可能大),这里就是将协方差转为一个单位矩阵,也就是矩阵的对角化。

推导一下 D D D C C C 的关系:
在这里插入图片描述
现在的目标变成了让 D D D 变成一个对角矩阵,这样的话 Y Y Y 的协方差就都为0了。并且对角元素按从大到小依次排列,那么 P P P 的前 K K K 行就是要寻找的基,用 P P P 的前 K K K 行组成的矩阵乘以 X X X 就使得 X X X N N N 维降到了 K K K 维并满足上述优化条件。

回到 embedding 模型本身的输出,根据上述的协方差矩阵,假设有一组句子向量,也可以写为行向量 { x } i = 1 N \{x\}_{i=1}^N {x}i=1N,在对它做线性变换之后生成一个均值为0、协方差矩阵为单位阵的目标向量 { x ~ } i = 1 N \{\tilde x\}_{i=1}^N {x~}i=1N

在这里插入图片描述
其中:
在这里插入图片描述

下面求解 W W W,将原始数据的协方差记为:
在这里插入图片描述
由上面推导出的 D = P C P T D=PCP^T D=PCPT,可以得到 ∑ ~ = W ∑ W T \tilde \sum=W\sum W^T ~=WWT,而我们的目标是 W ∑ W T = I W\sum W^T=I WWT=I,于是可得:
在这里插入图片描述
我们知道 ∑ \sum 是一个正定对称矩阵,正定对称矩阵都具有如下形式的SVD分解:
在这里插入图片描述
其中 U U U 是一个正交矩阵, ∧ \land 是一个正对角矩阵,则可以让 W − 1 = ∧ U T W^{-1}=\sqrt{\land}U^T W1= UT,则可以得到:

在这里插入图片描述
由于 ∧ \land U U U 均可以由 ∑ \sum 求得,所以 W W W 就被求出来了。

1.4 代码实现

def compute_kernel_bias(vecs, n_components=None):
    """计算kernel和bias
        vecs.shape = [num_samples, embedding_size],
        最后的变换:y = (x + bias).dot(kernel)
    :return kernel, bias
    """
    if isinstance(vecs, list):
        vecs = np.concatenate(vecs, axis=0)
    mu = vecs.mean(axis=0, keepdims=True)
    cov = np.cov(vecs.T)
    u, s, vh = np.linalg.svd(cov)
    W = np.dot(u, np.diag(1 / np.sqrt(s)))
    print(W)
    print(-mu)
    if n_components is not None:
        return W[:, :n_components], -mu
    else:
        return W, -mu

【论文解读】BERT Whitening
whitening计算详解

2. pooling

模型生成文本的 embedding 时会用 pooling 的方法进行降维,但是在降维操作的时候常见的mean pooling 和 last token pooling都有一定的局限性,比如都会忽略位置信息,如果使用position weighted mean pooling将位置信息加进来会有更好的效果。

hidden_state的形状为(batch_size, sequence_length, hidden_size),mean pooling 就是在sequence_length上进行平均,生成形状为(batch_size, hidden_size) 的embedding 矩阵。

而last token pooling是直接取 “[CLS]” 字短的embedding 作为整个文本的embedding。

position weighted mean pooling是在mean pooling 里面加上了位置信息。

  • mean pooling
def mean_pooling(hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
    if attention_mask is None:
        return torch.mean(hidden_state, dim=1)
    attention_mask = attention_mask.float()
    return torch.sum(hidden_state * attention_mask.unsqueeze(-1), dim=1) / torch.sum(attention_mask, dim=-1, keepdim=True)
  • position weighted mean pooling
def position_weighted_mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
    weights = (
        torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
            .unsqueeze(0)
            .unsqueeze(-1)
            .expand(last_hidden_state.size())
            .float().to(last_hidden_state.device)
    )

    input_mask_expanded = (
        attention_mask
            .unsqueeze(-1)
            .expand(last_hidden_state.size())
            .float()
    )

    # Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
    sum_mask = torch.sum(input_mask_expanded * weights, dim=1)

    embeddings = sum_embeddings / sum_mask

    return embeddings

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/491682.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

How to convert .py to .ipynb in Ubuntu 22.04

How to convert .py to .ipynb in Ubuntu 22.04 jupyter nbconvertp2j 最近看到大家在用jupyter notebook,我也试了一下,感觉还不错,不过,也遇到了一些问题,比方说,我有堆的.py文件,如果要一个一…

STM32 使用gcc编译介绍

文章目录 前言1. keil5下的默认编译工具链用的是哪个2. Arm编译工具链和GCC编译工具链有什么区别吗?3. Gcc交叉编译工具链的命名规范4. 怎么下载gcc-arm编译工具链参考资料 前言 我们在STM32上进行开发时,一般都是基于Keil5进行编译下载,Kei…

C++STL学习之unordered_map与unordered_set(底层Hash)

前言:我们前面已经学习论map和set,现在又冒出来一个unordered_map和unordered_set,这两个有啥差别吗?前面我们已经说过,map和set的底层是红黑树,那unordered_map和unordered_set的底层是什么呢?…

2024 解决 Failed to launch process [ElasticSearch]

操作系统:centos 7 (x86) sonarQube不能使⽤root账号进⾏启动,所以需要创建普通⽤户及其⽤户组 一、问题描述:使用root启动时,一直反馈 SonarQube is not running 问题原因:不能够使用root用户进行启动 解决方案…

Python(Socket) +Unreal(HTTP)

Python(Socket) Unreal(HTTP) python(Socket):UE:Post请求并发送本机IP 上班咯,好久没记笔记了。。。 局域网 UE的apk,请求Python的Socket 跑起Socket ,UE发 …

【Python机器学习系列】sklearn机器学习模型的保存---joblib法

这是我的第247篇原创文章。 一、引言 joblib包是由scikit-learn外带的,是一个用于将Python对象序列化为磁盘文件的库,专门用于大型数组,常用于保存机器学习模型。它可以高效地处理大型数据集和模型。对于大数据和大型机器学习模型&#xff0…

JavaScript高级(一)--V8引擎上

浏览器渲染的原理 主流浏览器及其内核 内核浏览器css前缀备注TridentIE4-IE11-ms最新的Edge已转向BlinkGecko火狐浏览器-mozWebkitsafari、旧版谷歌-webkitBlinkGoogle Chrome-webkitPrestoopera-o现在的opera转向了Blink 我们常说的浏览器内核指的就是浏览器的排版引擎&…

【No.20】蓝桥杯简单数论下|寻找整数|素数的判断|笨小猴|最大最小公倍数|素数筛|埃氏筛|欧氏线性筛|质数|分解质因子(C++)

寻找整数 【题目描述】 有一个不超过 1 0 1 7 10^17 1017的正整数n,知道这个数除以2至49后的余数如下表所示,求这个正整数最小是多少 解法一:模拟 暴力法:一个个检验 1 … 1 0 17 1\dots 10^{17} 1…1017的每个数 由于这个数n…

证券公司数据摆渡,如何兼顾安全性、可控性和效率?

根据国家和金融行业的法律法规要求,我国的证券公司不少采用网络隔离的方式将内部网络隔离为操作内网和操作外网,但网络隔离后,证券公司的操作内外网间仍需要进行数据交换,如提数、与第三方合作机构的数据外发和收取等业务需求&…

【AI绘画/作图】风景背景类关键词模板参考

因为ds官网被墙,所以翻了IDE的源码整理了下stablestudio里的官方模板,顺便每个模板生成了一份…不知道怎么写关键词的可以参考 Stunning sunset over a futuristic city, with towering skyscrapers and flying vehicles, golden hour lighting and dramatic cloud…

MySQL数据库的高级SQL语句与高级操作(1)

目录 以下例子都是基于该数据表 1、查询不重复记录(distinct) 2、and 、or:根据多条件查询 3、IN ----显示已知的值的数据记录 4、BETWEEN ----显示两个值范围内的数据记录 5、 like通配符:模糊查询 6、order by&#xff1a…

vlan、三层交换机、网关、DNS、子网掩码、MAC地址详解

vlan、三层交换机、网关、DNS、子网掩码、MAC地址详解 一、 什么是VLAN? VLAN中文是“虚拟局域网”。 ​ LAN可以是由少数几台家用计算机构成的网络,也可以是数以百计的计算机构成的企业网络。 ​ VLAN所指的LAN特指使用路由器分割的网络——也就是广…

【数字图像处理】改变图像灰度级别

改变图像灰度级别 首先,对原始图像 O O O进行灰度级量化: q int ⁡ ( O 2 i ) 2 i , q\operatorname{int}\left(\frac{O}{2^{i}}\right) \times 2^{i}, qint(2iO​)2i, 灰度级别256,128,64,32,16,8&…

FastAPI+React全栈开发08 安装MongoDB

Chapter02 Setting Up the Document Store with MongoDB 08 Installing MongoDB and friends FastAPIReact全栈开发08 安装MongoDB The MongoDB ecosystem is composed of different pieces of software, and I remember that when I was starting to play with it, there w…

QT_day5:使用定时器实现闹钟

1、 程序代码&#xff1a; widget.h&#xff1a; #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTime>//时间类 #include <QTimer>//时间事件类 #include <QTextToSpeech>//文本转语音类 QT_BEGIN_NAMESPACE namespace Ui { cla…

深度学习十大算法之图神经网络(GNN)

一、图神经网络的基础 图的基本概念 图是数学中的一个基本概念&#xff0c;用于表示事物间复杂的关系。在图论中&#xff0c;图通常被定义为一组节点&#xff08;或称为顶点&#xff09;以及连接这些节点的边。每个边可以有方向&#xff0c;称为有向边&#xff0c;或者没有方向…

C#学习笔记4:PC串口发送数据

今日继续我的C#学习之路&#xff0c;今日学习制作PC串口发送数据的窗口程序 串口是单片机上位机开发的重点&#xff0c;本文围绕做一个通过PC端串口发送数据的程序进行实践学习&#xff0c; 文章提供源码与解释、整体工程文件 目录 1、控件的选择与摆放&#xff1a; 2、程序设…

46 div 下面包含 el-radio, 导致点击一次 div, div 的 click 事件执行多次

前言 这是一个最近碰到的一个很奇怪的问题 情况如下一个 div 下面有一个 el-radio, 然后 div 上面配置了 click 的回调为 handleClick 然后 但是点击 div 的时候, handleClick 触发了两次 然后 这里 来模拟一下, 并解决一下 这个问题 这里的知识主要是 设计到 label 和 …

pytorch反向传播算法

目录 1. 链式法则复习2. 多输出感知机3. 多层感知机4. 多层感知机梯度推导5. 反向传播的总结 1. 链式法则复习 2. 多输出感知机 3. 多层感知机 如图&#xff1a; 4. 多层感知机梯度推导 简化式子把( O k O_k Ok​ - t k t_k tk​) O k O_k Ok​(1 - O k O_k Ok​)起个别名…

09-LearnTheArchitecture-MemoryManagement

快速链接: 【精选】ARMv8/ARMv9架构入门到精通-[目录] &#x1f448;&#x1f448;&#x1f448; 1 Overview 本文介绍了 Armv8-A 中的内存转换&#xff0c;这是内存管理的关键。 它解释了虚拟地址如何转换为物理地址、转换表格式以及软件如何管理Translation Lookaside Buffe…