Transformer 代码补充

本文是对Transformer - Attention is all you need 论文阅读-CSDN博客以及【李宏毅机器学习】Transformer 内容补充-CSDN博客的补充,是对相关代码的理解。

先说个题外话,在之前李宏毅老师的课程中提到multi-head attention是把得到的qkv分别乘上不同的矩阵,得到更多的qkv。

实际上,这里采用的方法是直接截取,比如这里有两个头,那么q^i就被分成两部分q^{i,1}和q^{i,2}。在BERT Intro-CSDN博客中有解释,也推荐手推transformer_哔哩哔哩_bilibili

self-attention

本节内容是self-attention这个模块的实现,会先从某一个句子开始,先不在乎怎么组装在一起批量的处理,只是单个拆开看看每一个部件是怎么work的。

现在需要解决的是:

  1. 输入怎么embedding? 
  2. 位置信息怎么保留?
  3. 三个矩阵怎么初始化?

单个句子的attention

输入embedding

sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print("dictionary: {}".format(dc))

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print("sentence, but words have been replaced by index in dictionary: \n{}".format(sentence_int))

torch.manual_seed(123)
# len(sentence.replace(',', '').split()) == 6
# embedded length == 16
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()

print("sentence, but word embedded: \n{}".format(embedded_sentence))
dictionary: {'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}
sentence, but words have been replaced by index in dictionary: 
tensor([0, 4, 5, 2, 1, 3])
sentence, but word embedded: 
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

Embedding — PyTorch 2.1 documentation

位置embedding

我发现这里似乎没有一个固定的名字,有叫position embedding的,有叫position encoding的,还有positional embedding和positional encoding,排列组合orz

### position embedding
def sinusoid_positional_encoding(length, dimensions):
    # odd position
    # cos(position/100000^{2i/d_model})
    # even position
    # sin(position/100000^{2i/d_model})
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2*(i//2)/dimensions) for i in range(dimensions)]
    
    PE = np.array([get_position_angle_vec(i) for i in range(length)])
    PE[:, 0::2] = np.sin(PE[:, 0::2])
    PE[:, 1::2] = np.sin(PE[:, 1::2])
    return PE
embedded_position = torch.tensor(sinusoid_positional_encoding(6, 16))
print("position embedding: \n{}".format(embedded_position))
position embedding: 
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 8.4147e-01,  8.4147e-01,  3.1098e-01,  3.1098e-01,  9.9833e-02,
          9.9833e-02,  3.1618e-02,  3.1618e-02,  9.9998e-03,  9.9998e-03,
          3.1623e-03,  3.1623e-03,  1.0000e-03,  1.0000e-03,  3.1623e-04,
          3.1623e-04],
        [ 9.0930e-01,  9.0930e-01,  5.9113e-01,  5.9113e-01,  1.9867e-01,
          1.9867e-01,  6.3203e-02,  6.3203e-02,  1.9999e-02,  1.9999e-02,
          6.3245e-03,  6.3245e-03,  2.0000e-03,  2.0000e-03,  6.3246e-04,
          6.3246e-04],
        [ 1.4112e-01,  1.4112e-01,  8.1265e-01,  8.1265e-01,  2.9552e-01,
          2.9552e-01,  9.4726e-02,  9.4726e-02,  2.9996e-02,  2.9996e-02,
          9.4867e-03,  9.4867e-03,  3.0000e-03,  3.0000e-03,  9.4868e-04,
          9.4868e-04],
        [-7.5680e-01, -7.5680e-01,  9.5358e-01,  9.5358e-01,  3.8942e-01,
          3.8942e-01,  1.2615e-01,  1.2615e-01,  3.9989e-02,  3.9989e-02,
          1.2649e-02,  1.2649e-02,  4.0000e-03,  4.0000e-03,  1.2649e-03,
          1.2649e-03],
        [-9.5892e-01, -9.5892e-01,  9.9995e-01,  9.9995e-01,  4.7943e-01,
          4.7943e-01,  1.5746e-01,  1.5746e-01,  4.9979e-02,  4.9979e-02,
          1.5811e-02,  1.5811e-02,  5.0000e-03,  5.0000e-03,  1.5811e-03,
          1.5811e-03]], dtype=torch.float64)

在手推transformer_哔哩哔哩_bilibili中提到这一方法与傅里叶变换相关(这个细节是我在其他地方没有看到的,记录一下)

初始化权重矩阵

new_embedding =(embedded_position+embedded_sentence).to(torch.float32)
print('add embedding together:\n{}'.format(new_embedding))
torch.manual_seed(123)
​
d = new_embedding.shape[1]
print('embedding dimension:\n{}'.format(d))
​
d_q, d_k, d_v = 24, 24, 28
​
# torch.rand 均匀分布 torch.nn.Parameter 普通的tensor不可训练,转换成可以训练的类型
W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))
print('size of query matrix: {}'.format(W_query.shape))
print('size of key matrix: {}'.format(W_key.shape))
print('size of value matrix: {}'.format(W_value.shape))
add embedding together:
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 1.3561,  1.8352,  0.0523, -0.7716,  0.0555,  1.7234, -2.2913,  1.1194,
          0.6816,  0.7033, -0.9456, -0.0733, -0.1516,  0.1177,  0.4406, -1.4462],
        [ 1.1646,  0.3597,  1.5954,  1.4184, -0.1961,  0.6879, -0.1536, -1.6840,
         -1.5825, -1.0564,  0.9095, -0.7155, -0.5931, -0.7092,  0.6236, -1.3722],
        [-1.1838,  0.3195, -1.3211,  1.8650, -0.0930, -0.6388, -0.4044, -0.9919,
          0.9105,  1.5842,  0.6361, -0.1660,  0.1013, -0.0905,  0.2672, -0.5841],
        [-0.8338, -1.7773,  0.7846,  1.8713,  1.9704,  1.6905,  1.4015, -0.0748,
          0.5365, -1.5323,  0.9792, -1.1355, -1.1549,  0.3295, -0.6302, -2.8387],
        [-0.0821,  0.6632, -0.4780,  2.1331, -0.7409,  1.7933,  1.2108,  0.2963,
          2.2973, -0.7537, -0.2650,  0.7855, -0.6546, -0.7929,  0.1854,  0.2309]],
       dtype=torch.float64)
embedding dimension:
16
size of query matrix: torch.Size([24, 16])
size of key matrix: torch.Size([24, 16])
size of value matrix: torch.Size([28, 16])

Parameter — PyTorch 2.1 documentation

OK,我们在这里先断一下,整理一下:

此时sequence:new_embedding(来源:word embedding+position embedding)

word embedding:6×16(有6个token,每个token用16维向量表示)

position embedding:6×16(和word embedding大小相同,因为要相加)

q:24×16

k:24×16

v:28×16

这样后面再计算query的时候就是每个token(1×16)×q(24×16),反正两个得转置一个

计算qkv

x_1 = embedded_sentence[0]
query_1 = W_query.matmul(x_1)
key_1 = W_key.matmul(x_1)
value_1 = W_value.matmul(x_1)

x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

torch.matmul — PyTorch 2.1 documentation

querys = W_key.matmul(new_embedding.T).T
keys = W_key.matmul(new_embedding.T).T
values = W_value.matmul(new_embedding.T).T

print("querys.shape:", querys.shape)
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
querys.shape: torch.Size([6, 24])
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])

计算attention score

alpha_24 = query_2.dot(keys[4])
print(alpha_24)

比如这里,就是第2个query对第5个key的attention

import torch.nn.functional as F

attention_score = F.softmax(keys.matmul(querys.T) / d_k**0.5, dim=0)
print(attention_score)
tensor([[3.8184e-01, 3.7217e-08, 2.8697e-08, 2.3739e-03, 8.8205e-04, 2.0233e-18],
        [5.1460e-03, 3.1125e-02, 3.3185e-09, 1.9323e-03, 1.3870e-07, 4.5397e-13],
        [1.8988e-04, 1.5880e-10, 9.9998e-01, 6.8005e-04, 8.3661e-03, 2.5704e-26],
        [2.1968e-04, 1.2932e-09, 9.5111e-09, 7.5759e-01, 6.8821e-05, 7.0347e-20],
        [1.5536e-02, 1.7667e-11, 2.2270e-05, 1.3099e-02, 9.9068e-01, 3.1426e-23],
        [5.9707e-01, 9.6887e-01, 1.1463e-12, 2.2433e-01, 5.2653e-07, 1.0000e+00]],
       grad_fn=<SoftmaxBackward0>)

获得context value

context_vector_2 = attention_score[2].matmul(values)
print(context_vector_2)
tensor([-2.8135, -0.2665, -0.1881,  0.4058,  0.8079, -3.1120,  0.5449, -1.2232,
        -0.1618,  0.3803,  0.6926, -0.4669,  0.2446, -0.3647, -0.0034, -2.2524,
        -2.7228, -1.5109, -0.7725, -1.0958, -2.1254,  0.3064,  0.5129, -0.1340,
         0.7020, -2.2086, -1.9595,  0.4520], grad_fn=<SqueezeBackward4>)
context_vector = attention_score.matmul(values)
print(context_vector)
tensor([[ 2.8488e-01,  6.4077e-01,  1.0665e+00,  5.5947e-01, -3.2868e-01,
          4.2391e-01, -3.2123e-01,  1.0594e-01,  6.5982e-01,  6.1927e-01,
          8.2067e-01,  4.3722e-01,  6.4925e-01,  5.9935e-01,  6.7425e-01,
          3.6706e-01,  5.0318e-01,  9.9682e-02,  1.1377e-01,  1.2804e-01,
          9.1880e-01,  7.6178e-01, -4.2619e-01,  2.5550e-01, -8.1348e-02,
          3.1145e-01,  1.9705e-01,  3.8195e-01],
        [ 3.6250e-02,  3.7593e-02,  8.9476e-02,  9.9750e-02,  9.1430e-02,
          6.2556e-02,  5.8136e-02,  5.5746e-02,  3.5098e-02,  4.1406e-02,
          4.1621e-02,  1.9771e-02,  4.0799e-02, -4.7170e-03,  4.1176e-02,
          4.3792e-02,  6.2029e-02,  5.2132e-02,  7.6929e-03,  5.4507e-02,
          1.4537e-02,  6.9540e-02,  4.1809e-02,  5.8921e-02,  1.2542e-02,
          1.4625e-01,  3.0627e-02,  1.0624e-01],
        [-2.8135e+00, -2.6652e-01, -1.8809e-01,  4.0583e-01,  8.0793e-01,
         -3.1120e+00,  5.4491e-01, -1.2232e+00, -1.6184e-01,  3.8030e-01,
          6.9257e-01, -4.6693e-01,  2.4462e-01, -3.6468e-01, -3.3741e-03,
         -2.2524e+00, -2.7228e+00, -1.5109e+00, -7.7255e-01, -1.0958e+00,
         -2.1254e+00,  3.0638e-01,  5.1293e-01, -1.3400e-01,  7.0203e-01,
         -2.2086e+00, -1.9595e+00,  4.5198e-01],
        [ 1.3995e+00, -5.1583e-02, -7.6128e-01,  6.2276e-01,  1.4197e+00,
         -1.1195e+00,  2.6502e-01,  9.7265e-02, -1.3257e+00,  5.2765e-01,
         -9.0406e-01,  1.0977e+00,  1.0775e+00, -1.1202e+00, -5.3005e-01,
          1.1657e+00,  5.2906e-01, -3.4296e-01, -1.0341e+00, -9.9314e-02,
          2.4160e-01,  1.0506e+00, -2.5196e-01, -1.2585e+00,  7.7441e-01,
         -3.8052e-02,  1.4004e+00,  4.0364e-01],
        [-1.9422e+00, -1.1669e-01,  2.4155e+00, -6.0575e-01,  1.1378e-01,
         -8.1691e-01,  2.8678e-01, -2.6922e+00,  1.9804e+00,  2.7446e+00,
          1.9828e-01, -1.5773e+00, -5.2589e-01,  2.2252e+00, -2.9130e-01,
         -4.2694e+00,  2.4834e+00, -3.3346e+00, -2.5167e-01, -2.8141e+00,
          1.3780e+00, -1.5563e-01, -1.4588e+00,  5.3617e-01, -5.3745e-01,
         -7.6528e-01,  1.2408e+00,  3.5827e+00],
        [ 5.3134e+00,  3.5967e+00,  7.1373e+00,  5.9613e+00,  6.1520e+00,
          5.0065e+00,  4.2107e+00,  5.2589e+00,  9.2143e-01,  6.5614e+00,
          2.7412e+00,  4.6712e+00,  4.9725e+00,  2.2118e+00,  5.2451e+00,
          4.4219e+00,  4.5800e+00,  2.9179e+00,  2.2116e+00,  5.3678e+00,
          5.7133e+00,  7.1016e+00,  3.7317e+00,  5.1325e+00,  4.1306e+00,
          9.4941e+00,  5.6733e+00,  9.7489e+00]], grad_fn=<MmBackward0>)

参考链接

Positional Encoding: Everything You Need to Know - inovex GmbH
Build your own Transformer from scratch using Pytorch | by Arjun Sarkar | Towards Data Science
Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch (sebastianraschka.com)
2021-03-18-Transformers - Multihead Self Attention Explanation & Implementation in Pytorch.ipynb - Colaboratory (google.com)
通俗易懂的理解傅里叶变换(一)[收藏] - 知乎 (zhihu.com)
Linear Relationships in the Transformer’s Positional Encoding - Timo Denk's Blog
Transformer 中的 positional embedding - 知乎 (zhihu.com)
transformer中使用的position embedding为什么是加法? - 知乎 (zhihu.com)

multi-head self-attention

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        # 这里d_k是每个key和query的size,同时在后面归一化也需要使用
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # 计算attention score,Q和K反正得转置一个,看怎么定义
        # 比如现在的attn_scores的第(i,j)位置表示:
        # 第i个query对第k个key的attention(相关性高低)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

这样就大概看懂了orz。

view

维度变化

randmat = torch.rand((3, 2, 5))
print("view(2, 3, 5): \n{}".format(randmat.view(2,3,5)))
print("view(2, 5, 3): \n{}".format(randmat.view(2,3,5)))
view(2, 3, 5): 
tensor([[[0.8058, 0.3869, 0.7523, 0.1501, 0.1501],
         [0.3409, 0.5355, 0.3474, 0.8371, 0.6785],
         [0.6564, 0.8204, 0.0539, 0.7422, 0.2216]],

        [[0.9450, 0.7839, 0.7118, 0.8868, 0.4249],
         [0.1633, 0.5220, 0.7583, 0.7841, 0.0838],
         [0.4304, 0.5082, 0.3141, 0.1689, 0.0869]]])
view(2, 5, 3): 
tensor([[[0.8058, 0.3869, 0.7523, 0.1501, 0.1501],
         [0.3409, 0.5355, 0.3474, 0.8371, 0.6785],
         [0.6564, 0.8204, 0.0539, 0.7422, 0.2216]],

        [[0.9450, 0.7839, 0.7118, 0.8868, 0.4249],
         [0.1633, 0.5220, 0.7583, 0.7841, 0.0838],
         [0.4304, 0.5082, 0.3141, 0.1689, 0.0869]]])
view(5, 2, 3): 
tensor([[[0.8058, 0.3869, 0.7523],
         [0.1501, 0.1501, 0.3409],
         [0.5355, 0.3474, 0.8371],
         [0.6785, 0.6564, 0.8204],
         [0.0539, 0.7422, 0.2216]],

        [[0.9450, 0.7839, 0.7118],
         [0.8868, 0.4249, 0.1633],
         [0.5220, 0.7583, 0.7841],
         [0.0838, 0.4304, 0.5082],
         [0.3141, 0.1689, 0.0869]]])

transpose

randmat = torch.rand((3, 2, 5))
print(randmat)
print("tanspose(-2,-1): \n{}".format(randmat.transpose(-2,-1)))
print("transpose(1,2): \n{}".format(randmat.transpose(1,2)))
tensor([[[0.3440, 0.9779, 0.9154, 0.6843, 0.9358],
         [0.5081, 0.7446, 0.0274, 0.6329, 0.6427]],

        [[0.6770, 0.6826, 0.2888, 0.8483, 0.9896],
         [0.1457, 0.3154, 0.6381, 0.6555, 0.2204]],

        [[0.4549, 0.0385, 0.1135, 0.8426, 0.8534],
         [0.7915, 0.4030, 0.8209, 0.3390, 0.6290]]])
tanspose(-2,-1): 
tensor([[[0.3440, 0.5081],
         [0.9779, 0.7446],
         [0.9154, 0.0274],
         [0.6843, 0.6329],
         [0.9358, 0.6427]],

        [[0.6770, 0.1457],
         [0.6826, 0.3154],
         [0.2888, 0.6381],
         [0.8483, 0.6555],
         [0.9896, 0.2204]],

        [[0.4549, 0.7915],
         [0.0385, 0.4030],
         [0.1135, 0.8209],
         [0.8426, 0.3390],
         [0.8534, 0.6290]]])
transpose(1,2): 
tensor([[[0.3440, 0.5081],
         [0.9779, 0.7446],
         [0.9154, 0.0274],
         [0.6843, 0.6329],
         [0.9358, 0.6427]],

        [[0.6770, 0.1457],
         [0.6826, 0.3154],
         [0.2888, 0.6381],
         [0.8483, 0.6555],
         [0.9896, 0.2204]],

        [[0.4549, 0.7915],
         [0.0385, 0.4030],
         [0.1135, 0.8209],
         [0.8426, 0.3390],
         [0.8534, 0.6290]]])

参考链接

Build your own Transformer from scratch using Pytorch | by Arjun Sarkar | Towards Data Science
Python numpy.transpose 详解 - 我的明天不是梦 - 博客园 (cnblogs.com)

附录

"""Self-attention module

1. Read the code and explain the following:
    - The nature of the dataset
    - The data flow
    - The shapes of the tensors
    - Why can the attention module be used for this dataset?
2. Create a training loop and evaluate the model according to the instructions
"""
import copy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.auto import tqdm


class SampleDataset(Dataset):
    def __init__(
        self,
        size: int = 1024,
        emb_dim: int = 32,
        sequence_length: int = 8,
        n_classes: int = 3,
    ):
        self.embeddings = torch.randn(size, emb_dim)
        self.sequence_length = sequence_length
        self.n_classes = n_classes

    def __len__(self):
        return len(self.embeddings) - self.sequence_length + 1

    def __getitem__(self, idx):
        indices = np.random.choice(
            np.arange(0, len(self.embeddings)), self.sequence_length
        )
        # np.random.shuffle(indices)
        return (
            self.embeddings[indices],  # sequence_length x emb_dim
            torch.tensor(np.max(indices) % self.n_classes),
        )


def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    # The length of the key and value sequences need to be the same
    d_k = query.size(-1)
    # N *
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super().__init__()
        assert d_model % heads == 0
        # We assume d_v always equals d_k
        self.d = d_model // heads # d_model: 32 heads: 4
        self.h = heads # h: 4
        self.linears = clones(nn.Linear(d_model, d_model), 4) # 4 identical layers (input: 32, output: 32)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            l(x).view(nbatches, -1, self.h, self.d).transpose(1, 2)
            for l, x in zip(self.linears, (query, key, value))
        ] # 4 x 

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d)
        return self.linears[-1](x)


class SequenceClassifier(nn.Module):
    def __init__(self, heads: int = 4, d_model: int = 32, n_classes: int = 3):
        super().__init__()
        self.attention = MultiHeadAttention(heads, d_model)
        self.linear = nn.Linear(d_model, n_classes)

    def forward(self, x):
        # x: N x sequence_length x emb_dim
        x = self.attention(x, x, x)
        x = self.linear(x[:, 0])
        return x


def main(
    n_epochs: int = 1000,
    size: int = 256,
    emb_dim: int = 128,
    sequence_length: int = 8,
    n_classes: int = 3,
):
    dataset = SampleDataset(
        size=size, emb_dim=emb_dim, sequence_length=sequence_length, n_classes=n_classes
    )
    # TODO: create a training loop

    # TODO: Evaluate with the same dataset

    # TODO: Evaluate with a different sequence length (12)


if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/368606.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

030-安全开发-JS应用NodeJS指南原型链污染Express框架功能实现审计

030-安全开发-JS应用&NodeJS指南&原型链污染&Express框架&功能实现&审计 #知识点&#xff1a; 1、NodeJS-开发环境&功能实现 2、NodeJS-安全漏洞&案例分析 3、NodeJS-开发指南&特有漏洞 演示案例&#xff1a; ➢环境搭建-NodeJS-解析安装&…

常用换源总结

1.Ubuntu16.04更换国内源 在Ubuntu系统上使用apt-get install进行软件安装或更新的时候&#xff0c;由于使用的是国外源&#xff0c;导致下载速度很慢或者连接超时&#xff0c;需要更换下载源。 1.将系统原始的源文件进行备份 sudo cp /etc/apt/sources.list /etc/apt/source…

c语言--二进制和其他进制之间的转换

目录 一、前言二、二进制、十进制、十六进制、八进制的组成2.1二进制的组成2.2十进制的组成2.3八进制的组成2.4十六进制的组成 三、二进制转换为十进制3.1 二进制转换为十进制3.2十进制转换为二进制 四、二进制转八进制和十六进制4.1二进制转八进制4.2二进制转换为十六进制 五、…

【安装指南】maven下载、安装与配置详细教程

&#x1f33c;一、概述 maven功能与python的pip类似。 Apache Maven是一个用于软件项目管理和构建的强大工具。它是基于项目对象模型的&#xff0c;用于描述项目的构建配置和依赖关系。以下是一些关键的 Maven 特性和概念&#xff1a; POM&#xff08;Project Object Model&…

Mybatis基础教程及使用细节

本篇主要对Mybatis基础使用进行总结&#xff0c;包括Mybatis的基础操作&#xff0c;使用注解进行增删改查的练习&#xff1b;详细介绍xml映射文件配置过程并且使用xml映射文件进行动态sql语句进行条件查询&#xff1b;为了简化java开发提高效率&#xff0c;介绍一下依赖&#x…

安科瑞电气火灾监控系统在海尔(合肥)创新产业园一期厂房改扩建项目的设计与应用

摘要&#xff1a;介绍海尔&#xff08;合肥&#xff09;创新产业园一期厂房改扩建项目采用安科瑞剩余电流式电气火灾探测器&#xff0c;就地组网方式&#xff0c;通过现场总线通讯远传至后台&#xff0c;从而实现剩余电流式电气火灾监控系统的搭建&#xff0c;完成对现场配电回…

万户 ezOFFICE wpsservlet SQL注入漏洞

免责声明&#xff1a;文章来源互联网收集整理&#xff0c;请勿利用文章内的相关技术从事非法测试&#xff0c;由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;所产生的一切不良后果与文章作者无关。该…

微信小程序应用商店源码系统 带完整的安装代码包以及搭建教程

随着微信小程序的普及&#xff0c;越来越多的企业和开发者开始关注小程序的开发与运营。为了满足市场需求&#xff0c;小编给大家分享一款微信小程序应用商店源码系统。该系统集成了完整的安装代码包&#xff0c;方便用户快速搭建自己的小程序应用商店。 以下是部分代码示例&a…

typedef

typedef typedef &#xff0c;type表示类型&#xff0c; def就是define&#xff0c; 定义的意思。所以&#xff0c;根据名字我们就可以知道typedef就是类型定义的意思。可以对一个类型进行重新定义。 一般对一个类型重新定义都是这种形式: typedef 类型 重定义 如&#xff…

web学习笔记(十九)

目录 1.作用域 1.1作用域的概念 1.2作用域的分类 1.2.1全局作用域 1.2.2局部作用域 1.2.3块级作用域&#xff08;ES6新增 &#xff09; 2.变量作用域 2.1全局变量 2.2局部变量 3.作用域链 3.1作用域链的定义 4.垃圾回收机制 4.1定义 4.2如何避免内存泄漏 5.预…

1895_分离进程的能力

1895_分离进程的能力 全部学习汇总&#xff1a; g_unix: UNIX系统学习笔记 (gitee.com) 有些理念可能在控制类的嵌入式系统中不好实施&#xff0c;尤其是没有unix这样的系统搭载的情况下。如果是考虑在RTOS的基础上看是否有一些理念可以做尝试&#xff0c;我觉得还是可以有一定…

Android: 深入理解 ‘companion object {}‘

Android: 深入理解 ‘companion object {}’ Kotlin是一种现代的、静态类型的编程语言&#xff0c;它在设计时充分考虑了开发者的生产力和代码的可读性。其中一个独特的特性就是companion object。在本篇博客中&#xff0c;我们将深入探讨这个特性&#xff0c;理解它的工作原理…

一款轻量级、高性能、功能强大的内网穿透代理服务器

简介 nps是一款轻量级、高性能、功能强大的内网穿透代理服务器。目前支持tcp、udp流量转发&#xff0c;可支持任何tcp、udp上层协议&#xff08;访问内网网站、本地支付接口调试、ssh访问、远程桌面&#xff0c;内网dns解析等等……&#xff09;&#xff0c;此外还支持内网htt…

Node需要了解的知识

Node能执行javascript的原因。 浏览器之所以能执行Javascript代码&#xff0c;因为内部含有v8引擎。Node.js基于v8引擎封装&#xff0c;因此可以执行javascript代码。Node.js环境没有DOM和BOM。DOM能访问HTML所有的节点对象&#xff0c;BOM是浏览器对象。但是node中提供了cons…

网络攻防模拟与城市安全演练 | 图扑数字孪生

在数字化浪潮的推动下&#xff0c;网络攻防模拟和城市安全演练成为维护社会稳定的不可或缺的环节。基于数字孪生技术我们能够在虚拟环境中进行高度真实的网络攻防模拟&#xff0c;为安全专业人员提供实战经验&#xff0c;从而提升应对网络威胁的能力。同时&#xff0c;在城市安…

【Android】二级分类双列表联动Demo

先上图&#xff1a; Demo解释 demo使用的是双列表展示&#xff08;准确的说是三个&#xff0c;二级分类那里嵌套了一个&#xff09;&#xff0c;点击左边的条目&#xff0c;右边的列表会跳转相应的条目&#xff0c;滑动右边的列表&#xff0c;左边的列表也会相应的滑动。 代…

2024 年 10 款最佳免费无限的数据恢复软件工具

十大无限的数据恢复软件工具 数据丢失可能是一场噩梦&#xff0c;无论是由于意外删除、系统崩溃还是硬件故障。值得庆幸的是&#xff0c;有多种数据恢复软件工具可以帮助您检索珍贵的文件和文档。在本文中&#xff0c;我们将探讨可以拯救世界的十大最佳免费无限数据恢复软件工…

第二百零五回

文章目录 1. 概念介绍2. 实现方法2.1 文字信息2.2 红色边框 3. 示例代码4. 内容总结 我们在上一章回中介绍了"如何实现密码输入框"相关的内容&#xff0c;本章回中将介绍如何在在输入框中提示错误.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介绍 我们…

RT-Thread线程管理(使用篇)

layout: post title: “RT-Thread线程管理” date: 2024-1-26 15:39:08 0800 tags: RT-Thread 线程管理(使用篇) 之后会做源码分析 线程是任务的载体&#xff0c;是RTT中最基本的调度单位。 线程执行时的运行环境称为上下文&#xff0c;具体来说就是各个变量和数据&#xff0c…

stable diffusion学习笔记——高清修复

ai画图中通常存在以下痛点&#xff1a; 受限于本地设备的性能&#xff08;主要是显卡显存&#xff09;&#xff0c;无法跑出分辨率较高的图片。生图的时候分辨率一调大就爆显存。即便显存足够。目前主流的模型大多基于SD1.0和SD1.5&#xff0c;这些模型在训练的时候通常使用小…