基本理解
nn.Embedding(num_embeddings, embedding_dim)
其中 num_embeddings 是词表的大小,即 len(vocab);embedding_dim 是词向量的维度。
nn.Embedding()产生一个权重矩阵weight,其shape为(num_embeddings, embedding_dim),表示生成num_embeddings个具有embedding_dim大小的嵌入向量;
输入input的形状shape为(batch_size, Seq_len),batch_size表示样本数(NLP句子数),Seq_len表示序列的长度(每个句子单词个数);
nn.Embedding(input)的输出output具有(batch_size,Seq_len,embedding_dim)的形状大小;
示例:
import torch
import torch.nn as nn
torch.manual_seed(0) # 为了复现性
emb = nn.Embedding(4, 3)
print(emb)
print(emb.weight)
输出:
emb = nn.Embedding(4, 3) 即为 生成4
个3
维向量。
注意:emb中生成的值需要用emb.weigh
t进行读取。