文章目录
文章目录
- 00 写在前面
- 01 基于Python版本的滑动窗口代码
- 02 算法效果
00 写在前面
写这个算法原因是:训练了一个时序网络,该网络模型的时序维度为32,而测试数据的时序维度为90。因此需要采用滑动窗口的方法,生成一系列32维度的窗口,用于测试。
该算法中用到了一个python的关键字yield,其用于定义生成器函数。生成器函数与普通函数不同,它可以在执行过程中暂停,并在以后继续从暂停的地方恢复执行。每次调用生成器函数时,都会返回一个生成器对象,而不是直接返回一个值。在你的代码中,yield 用于产生一个滑动窗口。
01 基于Python版本的滑动窗口代码
def window(seq, size=3, stride=2):
"""
返回一个滑动窗口(宽度为'size')在数据序列'seq'上,具有指定的'stride'。
例如,seq -> (s0, s1, ..., s[size-1]), (s[stride], s[stride+1], ..., s[stride+size-1]), ...
"""
it = iter(seq) # 从输入序列创建一个迭代器
result = [] # 初始化一个空列表来存储当前窗口
# 遍历迭代器中的每个元素
for elem in it:
result.append(elem) # 将当前元素添加到窗口中
if len(result) == size: # 如果窗口达到所需大小
yield result # 生成当前窗口
result = result[stride:] # 根据步幅长度滑动窗口
# 如果主循环后结果列表中还有剩余元素
if result:
i = 0 # 初始化一个计数器来填充剩余窗口
while len(result) < size: # 当窗口小于所需大小时
result.append(seq[i % len(seq)]) # 从序列开始添加元素
i += 1 # 增加计数器
yield result # 生成最终窗口
02 算法效果
# 示例使用
seq = [1, 2, 3, 4, 5, 6, 7, 8]
for windowed in window(seq, size=3, stride=2):
print(windowed) # 打印每个滑动窗口
初始状态:result = []
添加元素:result = [1, 2, 3] 生成窗口 [1, 2, 3],重置 result = [3]
添加元素:result = [3, 4, 5] 生成窗口 [3, 4, 5],重置 result = [5]
添加元素:result = [5, 6, 7] 生成窗口 [5, 6, 7],重置 result = [7]
添加元素:result = [7, 8]
填充元素:result = [7, 8, 1],生成最后一个窗口 [7, 8, 1]