DFSMN
SAN-M
python实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class SelfAttention(nn.Module):
def __init__(self, in_features, out_features, dropout=0.1):
super(SelfAttention, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.w_qs = nn.Linear(in_features, out_features, bias=False)
self.w_ks = nn.Linear(in_features, out_features, bias=False)
self.w_vs = nn.Linear(in_features, out_features, bias=False)
self.fc_out = nn.Linear(out_features, out_features, bias=False)
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, k, v, mask=None):
n_heads = self.w_qs.weight.size(0)
d_k = self.w_qs.weight.size(1) // n_heads
q = self.w_qs(q).view(q.size(0), q.size(1), n_heads, d_k)
k = self.w_ks(k).view(k.size(0), k.size(1), n_heads, d_k)
v = self.w_vs(v).view(v.size(0), v.size(1), n_heads, d_k)
scores = torch.matmul(q.transpose(1, 2), k.transpose(1, 3)) / d_k ** 0.5
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = self.softmax(scores)
output = torch.matmul(attn, v).transpose(1, 2).contiguous()
output = output.view(output.size(0), output.size(1), -1)
output = self.fc_out(output)
return output, attn
class SANMEncoderLayer(nn.Module):
def __init__(self, size, self_attn, feed_forward, dropout=0.1):
super(SANMEncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size)
self.norm2 = nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
residual = x
x = self.norm1(x)
x, _ = self.self_attn(x, x, x, mask)
x = F.relu(x)
x = self.dropout(x)
x = residual + x
x = self.norm2(x)
residual = x
x = self.feed_forward(x)
x = self.dropout(x)
x = residual + x
return x
class SANMEncoder(nn.Module):
def __init__(self, input_dim, num_layers, size, num_heads, ff_size, dropout=0.1):
super(SANMEncoder, self).__init__()
self.embedding = PositionalEncoding(size)
self.layers = nn.ModuleList([
SANMEncoderLayer(size, SelfAttention(size, size),
nn.Linear(size, ff_size), dropout)
for _ in range(num_layers)
])
def forward(self, x, mask):
x = self.embedding(x)
for layer in self.layers:
x = layer(x, mask)
return x