Transformer - 注意⼒机制
flyfish
计算过程
flyfish
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math
def attention(query, key, value, mask=None, dropout=None):
# query的最后⼀维的⼤⼩, ⼀般情况下就等同于词嵌⼊维度, 命名为d_k
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
print("scores.shape:",scores.shape)#scores.shape: torch.Size([1, 12, 12])
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1)].requires_grad_(False)
return self.dropout(x)
#在测试attention的时候需要位置编码PositionalEncoding
# 词嵌⼊维度是8维
d_model = 8
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=12
x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)
pe_result = pe(x)
print("pe_result:", pe_result)
query = key = value = pe_result
print("pe_result.shape:",pe_result.shape)
#没有mask的输出情况
#pe_result.shape: torch.Size([1, 12, 8])
attn, p_attn = attention(query, key, value)
print("no mask\n")
print("attn:", attn)
print("p_attn:", p_attn)
#scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 除以math.sqrt(d_k) 表示这个注意力就是 缩放点积注意力,如果没有,那么就是 点积注意力
#当Q=K=V时,又叫⾃注意⼒机制
#有mask的输出情况
print("mask\n")
mask = torch.zeros(1, max_len, max_len)
attn, p_attn = attention(query, key, value, mask=mask)
print("attn:", attn)
print("p_attn:", p_attn)
pe_result: tensor([[[ 0.0000e+00, 1.1111e+00, 0.0000e+00, 1.1111e+00, 0.0000e+00,
1.1111e+00, 0.0000e+00, 1.1111e+00],
[ 9.3497e-01, 6.0034e-01, 1.1093e-01, 1.1056e+00, 1.1111e-02,
1.1111e+00, 1.1111e-03, 1.1111e+00],
[ 1.0103e+00, -4.6239e-01, 2.2074e-01, 1.0890e+00, 2.2221e-02,
0.0000e+00, 2.2222e-03, 1.1111e+00],
[ 1.5680e-01, -1.1000e+00, 0.0000e+00, 1.0615e+00, 3.3328e-02,
0.0000e+00, 3.3333e-03, 1.1111e+00],
[-8.4089e-01, -7.2627e-01, 4.3269e-01, 1.0234e+00, 4.4433e-02,
1.1102e+00, 4.4444e-03, 1.1111e+00],
[-1.0655e+00, 3.1518e-01, 5.3270e-01, 0.0000e+00, 5.5532e-02,
1.1097e+00, 5.5555e-03, 1.1111e+00],
[-3.1046e-01, 1.0669e+00, 6.2738e-01, 9.1704e-01, 0.0000e+00,
1.1091e+00, 6.6666e-03, 0.0000e+00],
[ 7.2999e-01, 8.3767e-01, 7.1580e-01, 8.4982e-01, 7.7714e-02,
1.1084e+00, 7.7777e-03, 1.1111e+00],
[ 1.0993e+00, -1.6167e-01, 7.9706e-01, 7.7412e-01, 8.8794e-02,
1.1076e+00, 8.8888e-03, 1.1111e+00],
[ 4.5791e-01, -0.0000e+00, 8.7036e-01, 6.9068e-01, 9.9865e-02,
1.1066e+00, 9.9999e-03, 1.1111e+00],
[-6.0447e-01, -9.3230e-01, 9.3497e-01, 6.0034e-01, 1.1093e-01,
1.1056e+00, 1.1111e-02, 1.1111e+00],
[-1.1111e+00, 4.9174e-03, 9.9023e-01, 5.0400e-01, 1.2198e-01,
1.1044e+00, 1.2222e-02, 1.1110e+00]]])
pe_result.shape: torch.Size([1, 12, 8])
scores.shape: torch.Size([1, 12, 12])
no mask
attn: tensor([[[ 1.0590e-01, 2.7361e-01, 4.9333e-01, 8.3999e-01, 5.0599e-02,
1.0079e+00, 5.6491e-03, 1.0138e+00],
[ 2.7554e-01, 2.0916e-01, 4.9203e-01, 8.6593e-01, 5.2177e-02,
9.7066e-01, 5.6513e-03, 1.0398e+00],
[ 2.8765e-01, -3.8825e-02, 4.7812e-01, 8.7535e-01, 5.4246e-02,
8.4157e-01, 5.7015e-03, 1.0659e+00],
[ 9.3666e-02, -1.8286e-01, 4.8727e-01, 8.5124e-01, 5.7070e-02,
8.2547e-01, 5.9523e-03, 1.0712e+00],
[-1.6747e-01, -1.0274e-01, 5.6960e-01, 7.7584e-01, 6.3699e-02,
9.6958e-01, 6.7169e-03, 1.0546e+00],
[-2.2646e-01, 6.8462e-02, 5.8668e-01, 7.2227e-01, 6.3119e-02,
1.0233e+00, 6.8004e-03, 1.0310e+00],
[ 8.8945e-04, 2.7654e-01, 5.3750e-01, 8.0958e-01, 5.2289e-02,
1.0259e+00, 6.1360e-03, 9.6094e-01],
[ 2.2231e-01, 2.2832e-01, 5.2263e-01, 8.4111e-01, 5.4828e-02,
9.9655e-01, 5.9765e-03, 1.0298e+00],
[ 2.6388e-01, 7.2239e-02, 5.3800e-01, 8.4070e-01, 5.8958e-02,
9.5033e-01, 6.2306e-03, 1.0564e+00],
[ 1.2822e-01, 7.4518e-02, 5.5305e-01, 8.1381e-01, 6.0125e-02,
9.7442e-01, 6.4089e-03, 1.0462e+00],
[-1.5757e-01, -1.3194e-01, 5.9562e-01, 7.6069e-01, 6.7079e-02,
9.7264e-01, 7.0187e-03, 1.0607e+00],
[-2.3505e-01, 5.6245e-03, 6.0160e-01, 7.3040e-01, 6.5491e-02,
1.0176e+00, 7.0038e-03, 1.0367e+00]]])
p_attn: tensor([[[0.1488, 0.1215, 0.0514, 0.0396, 0.0698, 0.0703, 0.0875, 0.1205,
0.0790, 0.0814, 0.0544, 0.0757],
[0.1170, 0.1434, 0.0757, 0.0489, 0.0590, 0.0460, 0.0642, 0.1304,
0.1161, 0.0943, 0.0527, 0.0524],
[0.0716, 0.1094, 0.1341, 0.1067, 0.0716, 0.0379, 0.0407, 0.0930,
0.1221, 0.0921, 0.0713, 0.0494],
[0.0597, 0.0765, 0.1155, 0.1397, 0.1127, 0.0506, 0.0359, 0.0627,
0.0918, 0.0806, 0.1056, 0.0688],
[0.0692, 0.0607, 0.0509, 0.0740, 0.1475, 0.0846, 0.0509, 0.0607,
0.0692, 0.0788, 0.1342, 0.1194],
[0.0887, 0.0601, 0.0343, 0.0423, 0.1076, 0.1341, 0.0721, 0.0748,
0.0591, 0.0777, 0.1057, 0.1435],
[0.1232, 0.0938, 0.0411, 0.0335, 0.0722, 0.0804, 0.1351, 0.1103,
0.0722, 0.0814, 0.0633, 0.0935],
[0.1124, 0.1263, 0.0623, 0.0388, 0.0571, 0.0553, 0.0731, 0.1388,
0.1134, 0.1001, 0.0571, 0.0652],
[0.0758, 0.1157, 0.0841, 0.0584, 0.0670, 0.0450, 0.0492, 0.1166,
0.1429, 0.1101, 0.0763, 0.0588],
[0.0822, 0.0989, 0.0668, 0.0540, 0.0803, 0.0622, 0.0584, 0.1084,
0.1158, 0.1046, 0.0879, 0.0804],
[0.0548, 0.0551, 0.0515, 0.0705, 0.1364, 0.0845, 0.0454, 0.0617,
0.0801, 0.0877, 0.1499, 0.1224],
[0.0763, 0.0548, 0.0357, 0.0459, 0.1213, 0.1146, 0.0669, 0.0703,
0.0616, 0.0802, 0.1224, 0.1499]]])
mask
scores.shape: torch.Size([1, 12, 12])
attn: tensor([[[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185]]])
p_attn: tensor([[[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833],
[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
0.0833, 0.0833, 0.0833, 0.0833]]])