pytorch官方示例Generating Names with a Character-Level RNN的部分理解
- 模型结构
- 功能
- 关键技术
- 模型输入
- 模型输出
- 预测实现
模型结构
功能
输入一个类别名和一个英文字符,就可以自动生成这个类别,且以英文字符开始的姓名
关键技术
- 将字符进行one-hot编码
- 名字最大长度20,就是使用模型预测20次,下一个字符根据上一个字符循环预测,最后将字符串连接在一起输出一个名字
- 如果是分类任务,输入的名字是字符串表示的,因此在循环这个名字长度,每个字符输入模型中,得到的hidden作为下一次字符预测模型的参数,如官方的分类示例:https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html。 因此最后一个名字字符的预测output将作为softmax的参数进行预测
模型输入
1x128是hidden层,1x18是某一类别的one-hot编码,1x59是某一个英文字符的ont-hot编码,如下面就可以实现某一个字符的one-hot
tensor = torch.zeros(1, n_categories)
tensor[0][li] = 1
模型输出
输出是1x59,就是预测的一个字符。
预测实现
- 名字是由多个字符串组合的,根据给的英文字符预测下一个字符,再根据这字符预测下一个字符,一直反复到EOS时停止
for i in range(max_length):
output, hidden = rnn(category_tensor, input[0], hidden)
topv, topi = output.topk(1)
topi = topi[0][0]
if topi == n_letters - 1:
break
else:
letter = all_letters[topi]
output_name += letter
input = inputTensor(letter)