1、输入数据中,源数据和目标数据的定义
def get_batch(source,i):
'''
用于获取每个批数据合理大小的源数据和目标数据
参数source 是通过batchfy 得到的划分batch个 ,的所有数据,并且转置列表示
i第几个batch
'''
bptt = 15 #超参数,一次输入多少个batch 数据,现在数据矩阵,一行表示一个batch, 一共有n个行,
# len(source) - 1 - i 从大往小变化,知道小到bptt,所以seq_len,大部分时间都是bptt 个=15个,最后几个训练才越来越少
seq_len = min(bptt, len(source) -1-i) #一共是列的元素长度,30个, 行是10个,一共三个batch ,
# 这是转置过的,现在,就变成30个batch,每个batch 长度是3
# 行数错一位,目标数据是原数据向下一位,
data = source[i:i+seq_len]
# 这里最后会越界,使用view(-1) 保证形状正常
target = source[i+1:i+1+seq_len]
return data,target #
文本数据,是每个单词对应的索引,需要对数据进行切分成整块的batch, (n行,batch列), 变成竖着的,
(batch行,n列)
然后,横着一个一个 切分成一个个batch数据,下移一个索引获取目标数据,
(n行,batch列)
【
[A,B,C,D,E,F]
[G,H,I,J,K,L]
[M,N,O,P,Q,R],
……
】
(batch行,n列)
横着看,每一位 AGMS 对应 BHNT, AB, GH, MN, ST, 是相邻的两个字