文章目录
- 1. embeddingbag
- 2. pytorch
1. embeddingbag
词袋embeddingbag 是在embedding词表的基础上演变起来的,nn.embedding的作用是构建一个词表,通过输入index序号来索引词对应的词向量,是可以根据词索引index进行forward计算的,embeddingbag的作用是可以根据offset来按照序号将索引按照一批批来计算,offset相当于菜刀,根据offset索引将引入的词向量求均值或者求和表示出来
- excel 表示 :
2. pytorch
- pytorch源码:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
torch.manual_seed(323424)
if __name__ == "__main__":
run_code = 0
vocab_size = 10
feature_dim = 5
my_embedding_bag = nn.EmbeddingBag(num_embeddings=vocab_size, embedding_dim=feature_dim, mode='mean')
my_weight_total = vocab_size * feature_dim
my_weight = torch.arange(my_weight_total).reshape((vocab_size, feature_dim)).to(torch.float32)
my_input = torch.tensor([0, 1, 2, 2, 3, 3, 5], dtype=torch.long)
my_embedding_bag.weight = nn.Parameter(my_weight)
print(f"my_embedding_bag.weight=\n{my_embedding_bag.weight}")
input_embedding = torch.tensor([0, 2, 5], dtype=torch.long)
output_embedding = my_embedding_bag(my_input, input_embedding)
print(f"my_input={my_input}")
print(f"input_embedding={input_embedding}")
print(f"output_embedding=\n{output_embedding}")
- 结果:
my_embedding_bag.weight=
Parameter containing:
tensor([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.]], requires_grad=True)
my_input=tensor([0, 1, 2, 2, 3, 3, 5])
input_embedding=tensor([0, 2, 5])
output_embedding=
tensor([[ 2.500, 3.500, 4.500, 5.500, 6.500],
[11.667, 12.667, 13.667, 14.667, 15.667],
[20.000, 21.000, 22.000, 23.000, 24.000]],
grad_fn=<EmbeddingBagBackward0>)