机器学习第四十四周周报 SAMformer

文章目录

  • week44 SAMformer
  • 摘要
  • Abstract
    • 1. 题目
    • 2. Abstract
    • 3. 网络架构
      • 3.1 问题提出
      • 3.2 微型示例
      • 3.3 SAMformer
    • 4. 文献解读
      • 4.1 Introduction
      • 4.2 创新点
      • 4.3 实验过程
    • 5. 结论
    • 6.代码复现
      • 小结
      • 参考文献

week44 SAMformer

摘要

本周阅读了题为SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention的论文。研究发现,Transformer在小规模线性预测问题中表达能力虽强,但难以收敛至理想水平,其注意力机制导致泛化能力低。为此,该文提出一种轻量级Transformer模型,结合锐度感知优化,成功避免不良局部最小值。实验证明,该模型在多元时间序列数据集上表现优越,超越当前先进方法,且参数显著减少。

Abstract

This week’s weekly newspaper decodes the paper entitled SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention. Research has found that while the Transformer model possesses strong expressive power in small-scale linear prediction problems, it fails to converge to the ideal level due to its attention mechanism, which leads to low generalization ability. Therefore, this article proposes a lightweight Transformer model that, combined with sharpness-aware optimization, successfully avoids undesirable local minima. Experiments have proven that this model performs superbly on multivariate time series datasets, surpassing current advanced methods while significantly reducing parameters.

1. 题目

标题:SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention

作者:Romain Ilbert, Ambroise Odonnat, Vasilii Feofanov, Aladin Virmaux, Giuseppe Paolo, Themis Palpanas, Ievgen Redko

发布:Accepted as an Oral at ICML 2024, Vienna

链接:https://arxiv.org/abs/2402.10198

2. Abstract

首先研究一个小规模线性预测问题,结果表明Transformer尽管具有很高的表达能力,但无法收敛到期望中水平。进一步确定Transformer的注意力是造成这种低泛化能力的原因。基于这一见解,提出了一种浅层轻量级Transformer模型,当通过锐度感知优化进行优化时,该模型成功地避免了不良的局部最小值。凭经验证明,这一结果可以扩展到所有常用的现实世界多元时间序列数据集。特别是,SAMformer 超越了当前最先进的方法,与最大的基础模型 MOIRAI 相当,但参数却少得多。

First, this article investigated a small-scale linear prediction problem and found that although the Transformer model possesses high expressive power, it fails to converge to the desired level. Further analysis revealed that the attention mechanism in the Transformer is the cause of this low generalization ability. Based on this insight, we propose a shallow and lightweight Transformer model that successfully avoids undesirable local minima when optimized using sharpness-aware minimization. Empirical evidence demonstrates that this result can be extended to all commonly used real-world multivariate time series datasets. Specifically, SAMformer outperforms current state-of-the-art methods, comparable to the largest baseline model MOIRAI, but with significantly fewer parameters.

3. 网络架构

3.1 问题提出

考虑多元长期预测框架:给定长度为 L(回溯窗口)的 D 维时间序列,排列在矩阵 X ∈ R D × L X ∈ R^{D×L} XRD×L 中以促进通道关注,目标是预测其下一个 H 值(预测范围),用 Y ∈ R D × H Y ∈ R^{D×H} YRD×H 表示。假设我们可以访问由N个观测值 ( X , Y ) = ( { X ( i ) } i = 0 N , { Y ( i ) } i = 0 N ) (X, Y) = (\{X^{(i)}\}^N_{i=0}, \{Y^{(i)}\}^N_{i=0}) (X,Y)=({X(i)}i=0N,{Y(i)}i=0N) 组成的训练集,并表示为 X d ( i ) ∈ R 1 × L X^{(i)}_d ∈ R^{1×L} Xd(i)R1×L(分别为 Y d ( i ) ∈ R 1 × H Y^{(i)}_d ∈ R^{1×H} Yd(i)R1×H)第 i 个输入(分别为目标)时间序列的第 d 个特征。目标是训练一个由 ω 参数化的预测器 f ω : R D × L → R D × H f_ω:R^{D×L} \rightarrow R^{D×H} fω:RD×LRD×H ,以最小化训练集上的均方误差 (MSE):
L t r a i n ( ω ) = 1 N D ∑ i = 0 N ∣ ∣ Y ( i ) − f ω ( X ( i ) ) ∣ ∣ F 2 (1) L_{train}(ω)=\frac1{ND}\sum^N_{i=0}||Y^{(i)}-f_ω(X^{(i)})||^2_F \tag{1} Ltrain(ω)=ND1i=0N∣∣Y(i)fω(X(i))F2(1)

3.2 微型示例

Transformer 的性能与经过训练直接将输入投影到输出的简单线性神经网络相当或更差。考虑以下小型回归问题的生成模型,模仿稍后考虑的时间序列预测设置
Y = X W t o y + ϵ (2) Y=XW_{toy}+\epsilon\tag{2} Y=XWtoy+ϵ(2)
令L = 512,H = 96,D = 7 且 W t o y ε R L × H W^{toy} ε R^{L×H} WtoyεRL×H, $\epsilon \in R^{D×H} $具有随机正态条目,并生成 15000 个输入目标对 (X,Y)(10000 个用于训练,5000 个用于验证)。考虑到这个生成模型,希望开发一种Transformer 架构,可以有效地解决方程(1)中的问题。 (2)没有不必要的复杂性。为了实现这一目标,建议通过将注意力应用于 X 并结合将 X 添加到注意力输出的残差连接来简化常用的 Transformer 编码器。没有在此残差连接之上添加前馈块,而是直接采用线性层进行输出预测。正式地,模型定义如下:
f ( X ) = [ X + A ( X ) X W V W O ] W (3) f(X)=[X+A(X)XW_VW_O]W \tag{3} f(X)=[X+A(X)XWVWO]W(3)
A(X) 是输入序列 X ∈ R D × L X \in R^{D×L} XRD×L 的注意力矩阵,定义为
A ( X ) = softmax ( X W Q W W K T X T d m ) ∈ R D × D (4) A(X)=\text{softmax}(\frac{XW_QWW^T_KX^T}{\sqrt d_m})\in R^{D\times D} \tag{4} A(X)=softmax(d mXWQWWKTXT)RD×D(4)
首先,注意力渠道化,这简化了问题,降低了过度参数化的风险,因为矩阵W与Eq.(2)中的形状相同,并且由于L > d,注意矩阵变得更小。

根据方程生成的数据拟合Transformer 的优化问题。 式2理论上允许无限多个最优分类器W。

image-20240621194727942

如上图,尽管 Transformer 很简单,但它却存在严重的过度拟合问题。随机Transformer中的注意力权重可以提高泛化能力,暗示注意力在防止收敛到最佳局部最小值方面的作用。随机Transformer仅优化W,自注意力权重固定。Transformer泛化能力差的主要原因是注意力模块的可训练性问题。

3.3 SAMformer

为了实现更好的泛化性能和训练稳定性,采用了锐度感知最小框架。将式1迭代为
L t r a i n S A M ( ω ) = max ⁡ ∣ ∣ ϵ ∣ ∣ < ρ L t r a i n ( ω + ϵ ) L^{SAM}_{train}(ω)=\max_{||\epsilon||<\rho}L_{train}(ω+\epsilon) LtrainSAM(ω)=∣∣ϵ∣∣<ρmaxLtrain(ω+ϵ)

提议的 SAMformer 基于式(3),有两个重要的修改。

首先,为其配备了应用于 X 的可逆实例归一化(RevIN),因为该技术被证明可以有效处理时间序列中训练数据和测试数据之间的转换。其次,使用 SAM 优化模型,使其收敛到更平坦的局部最小值。总的来说,这给出了图 4 中带有一个编码器的浅层变压器模型。

image-20240621200234066

c2e3c65c2c108989d27e8748988f310

4. 文献解读

4.1 Introduction

当前方法的局限性:最近将 Transformer 应用于时间序列数据的工作主要集中在:

  1. 降低注意力的二次成本的有效实现
  2. 分解时间序列以更好地捕捉其中的潜在模式

上述研究没有很好的解决现有困境

Transformer 的可训练性

在时间序列预测的情况下,存在如何有效地训练 Transformer 架构而不出现过度拟合的问题。该研究目标是证明,通过消除训练的不稳定性,变压器可以在多元长期预测方面表现出色,这与之前对其局限性的看法相反。

4.2 创新点

该问题提出了SAMformer,主要贡献大致如下:

  1. 研究表明,即使Transformer 架构是为了解决简单的微型线性预测问题而定制的,它的泛化能力仍然很差并且收敛到尖锐的局部最小值。进一步确定注意力是造成这种现象的主要原因;
  2. 提出了一种浅层Transformer 模型,称为 SAMformer,它结合了研究界提出的最佳实践,包括可逆实例归一化和通道注意力最近在计算机视觉社区中引入。结果证明,通过锐度感知最小化(SAM)优化这样一个简单的Transformer 可以收敛到局部最小值,并具有更好的泛化能力;
  3. 凭经验证明了该方法在常见的多元长期预测数据集上的优越性。 SAMformer 超越了当前最先进的方法,与最大的基础模型 MOIRAI 相当,但参数却少得多。

4.3 实验过程

在该部分,提供了实证证明 SAMformer 在通用基准的多元长期时间序列预测中的定量和定性优势。具体来说,证明 SAMformer 比当前最先进的多元 TSMixer高出 14.33%,同时参数减少了约 4 倍。

数据集:在现实世界多元时间序列的 8 个公开数据集上进行了实验,四个电力变压器温度数据集 ETTh1、ETTh2、ETTm1 和 ETTm2、Electricity (UCI, 2015)、Exchange (Lai et al., 2018b)、Traffic (California Department of Transportation, 2021)以及Weather (Max Planck Institute, 2021)。

所有时间序列均以输入长度 L = 512、预测范围 H ∈ {96, 192, 336, 720} 和步幅为 1 进行分段,这意味着每个后续窗口都会移动一步。

基线模型:Transformer 和 TSMixer(Chen 等人,2023)。其中TSMixer 是完全基于 MLP 构建的最先进的多元基线。iTransformer、PatchTST、FEDformer、Informer (Zhou et al., 2022)、Informer (Zhou et al., 2021)和Autoformer。所有报告的结果都是使用 RevIN(Kim 等人,2021b)获得的,以便在 SAMformer 及其竞争对手之间进行更公平的比较。

评估方法:所有模型都经过训练,以式1最大限度地减少方程式中定义的 MSE 损失。 报告测试集上的平均 MSE,以及使用不同种子运行 5 次的标准差。其他详细信息和结果,包括平均绝对误差 (MAE)。除非另有说明,所有的结果都是使用不同种子进行 5 次运行而获得的。

实验结果

image-20240621204746376

在SAMformer的训练中引入SAM,使其损耗比Transformer更平滑。我们在上图a中通过比较
在ETTh1和Exchange上训练后Transformer和SAMformer的值来说明这一点。我们的观察表明,Transformer表现出相当高的清晰度,而SAMformer有一个理想的行为,损失景观清晰度是一个数量级小。

SAMformer演示了针对随机初始化的反业务。图5b - 1给出了SAMformer和Transformer在ETTh1
和Exchange上5种不同种子的试验MSE分布,预测水平为H = 96。SAMformer在不同的种子选择中始终保持性能稳定性,而Transformer表现出显著的差异,因此高度依赖于权重初始化。

image-20240621203607915

在表 1 中,报告了使用不同种子进行多次运行的性能,从而获得更可靠的评估。为了公平比较,还包括了经过训练的 TSMixer 的性能

SAMformer在8个数据集中的7个上明显优于其竞争对手。特别是,它比其最
佳竞争对手TSMixer+SAM提高了5.25%,比独立TSMixer提高了14.33%,比基于变压器的最佳模型FEDformer提高了12.36%。此外,它比Transformer提高了16.96%。对于每个数据集和视界,SAMformer被排名第一或第二。值得注意的是,SAM的集成提高了TSMixer的泛化能力,平均提高了9.58%。

5. 结论

SAMformer 通过锐度感知最小化进行了优化,与现有的预测基线(包括当前最大的基础模型 MOIRAI)相比,带来了显着的性能增益,并受益于跨数据集和预测范围的高通用性和鲁棒性。最后,我们还表明,时间序列预测中的通道注意力在计算和性能方面比以前常用的时间注意力更加有效。我们相信,这一令人惊讶的发现可能会刺激在我们简单的架构之上进行许多进一步的工作,以进一步改进它。

6.代码复现

该代码可从 https://github.com/romilbert/samformer 获取。

attention实现

这段代码定义了一个使用PyTorch库的函数,实现了缩放点积注意力机制,用于模型中的注意力计算。它接受查询(query)、键(key)和值(value)张量,并可选地接受注意力掩码和丢失概率参数。函数首先计算查询和键的点积,并根据查询的维度大小进行缩放。如果指定了因果关系或提供了注意力掩码,它会修改注意力权重以避免未来信息的泄露或应用额外的掩码。最后,它使用Softmax函数规范化注意力权重,并通过点积操作与值张量结合,输出最终的注意力加权结果。

import torch
 
import numpy as np
 
 
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    """
    A copy-paste from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
    """
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)
 
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

dataset处理

这段代码定义了一个Python类 LabeledDataset,它是一个继承自 torch.utils.data.Dataset 的自定义数据集类,用于处理带有标签的数据。该类的主要作用是将 NumPy 数组格式的数据和标签转换为 PyTorch 张量格式,并提供了一些基本的数据处理方法,使其可以用于 PyTorch 的数据加载和预处理流程中。

import torch
 
from torch.utils.data import Dataset
 
 
class LabeledDataset(Dataset):
    def __init__(self, x, y):
        """
        Converts numpy data to a torch dataset
        Args:
            x (np.array): data matrix
            y (np.array): class labels
        """
        self.x = torch.FloatTensor(x)
        self.y = torch.FloatTensor(y)
 
    def transform(self, x):
        return torch.FloatTensor(x)
 
    def __len__(self):
        return self.y.shape[0]
 
    def __getitem__(self, idx):
        examples = self.x[idx]
        labels = self.y[idx]
        return examples, labels

RevIN

这段代码定义了一个名为 RevIN(Reversible Instance Normalization)的 Python 类,它继承自 PyTorch 的 nn.Module。这个类实现了可逆的实例归一化,主要用于神经网络中,可以在正向传播时进行标准化,并在需要时进行反向去标准化。

import torch
import torch.nn as nn
 
 
class RevIN(nn.Module):
    """
    Reversible Instance Normalization (RevIN) https://openreview.net/pdf?id=cGDAkQo1C0p
    https://github.com/ts-kim/RevIN
    """
    def __init__(self, num_features: int, eps=1e-5, affine=True):
        """
        :param num_features: the number of features or channels
        :param eps: a value added for numerical stability
        :param affine: if True, RevIN has learnable affine parameters
        """
        super(RevIN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        if self.affine:
            self._init_params()
 
    def forward(self, x, mode:str):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        else: raise NotImplementedError
        return x
 
    def _init_params(self):
        # initialize RevIN params: (C,)
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
 
    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim-1))
        self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
 
    def _normalize(self, x):
        x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias
        return x
 
    def _denormalize(self, x):
        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps*self.eps)
        x = x * self.stdev
        x = x + self.mean
        return x

SAM

这段代码定义了一个名为 SAM 的 Python 类,它是一个用于优化神经网络训练的自定义优化器,继承自 PyTorch 的 Optimizer。SAM 代表 Sharpness-Aware Minimization,这是一种用于改进模型泛化能力的优化技术,通过最小化损失函数的锐度来实现。

import torch

from torch.optim import Optimizer


class SAM(Optimizer):
    """
    SAM: Sharpness-Aware Minimization for Efficiently Improving Generalization https://arxiv.org/abs/2010.01412
    https://github.com/davda54/sam
    """
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
    assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
 
    defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
    super(SAM, self).__init__(params, defaults)
 
    self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
    self.param_groups = self.base_optimizer.param_groups
 
@torch.no_grad()
def first_step(self, zero_grad=False):
    grad_norm = self._grad_norm()
    for group in self.param_groups:
        scale = group["rho"] / (grad_norm + 1e-12)
 
        for p in group["params"]:
            if p.grad is None:
                continue
            e_w = (
                (torch.pow(p, 2) if group["adaptive"] else 1.0)
                * p.grad
                * scale.to(p)
            )
            p.add_(e_w)  # climb to the local maximum "w + e(w)"
            self.state[p]["e_w"] = e_w
 
    if zero_grad:
        self.zero_grad()
 
@torch.no_grad()
def second_step(self, zero_grad=False):
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
 
    self.base_optimizer.step()  # do the actual "sharpness-aware" update
 
    if zero_grad:
        self.zero_grad()
 
@torch.no_grad()
def step(self, closure=None):
    assert (
        closure is not None
    ), "Sharpness Aware Minimization requires closure, but it was not provided"
    closure = torch.enable_grad()(
        closure
    )  # the closure should do a full forward-backward pass
 
    self.first_step(zero_grad=True)
    closure()
    self.second_step()
 
def _grad_norm(self):
    shared_device = self.param_groups[0]["params"][
        0
    ].device  # put everything on the same device, in case of model parallelism
    norm = torch.norm(
        torch.stack(
            [
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad)
                .norm(p=2)
                .to(shared_device)
                for group in self.param_groups
                for p in group["params"]
                if p.grad is not None
            ]
        ),
        p=2,
    )
    return norm

samformer

这段代码定义了两个主要的 Python 类,SAMFormerArchitecture 和 SAMFormer,它们基于 PyTorch 框架用于构建和训练一个深度学习模型,具体是用于时间序列预测任务。这些类利用了一些先进的技术如可逆实例归一化(RevIN)、注意力机制和锐度感知最小化(SAM)来改善模型的预测性能和泛化能力。

import torch
import random
import numpy as np

from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader

from .utils.attention import scaled_dot_product_attention
from .utils.dataset import LabeledDataset
from .utils.revin import RevIN
from .utils.sam import SAM


class SAMFormerArchitecture(nn.Module):
    def __init__(self, num_channels, seq_len, hid_dim, pred_horizon, use_revin=True):
        super().__init__()
        self.revin = RevIN(num_features=num_channels)
        self.compute_keys = nn.Linear(seq_len, hid_dim)
        self.compute_queries = nn.Linear(seq_len, hid_dim)
        self.compute_values = nn.Linear(seq_len, seq_len)
        self.linear_forecaster = nn.Linear(seq_len, pred_horizon)
        self.use_revin = use_revin
        
def forward(self, x):
    # RevIN Normalization
    if self.use_revin:
        x_norm = self.revin(x.transpose(1, 2), mode='norm').transpose(1, 2) # (n, D, L)
    else:
        x_norm = x
    # Channel-Wise Attention
    queries = self.compute_queries(x_norm) # (n, D, hid_dim)
    keys = self.compute_keys(x_norm) # (n, D, hid_dim)
    values = self.compute_values(x_norm) # (n, D, L)
    if hasattr(nn.functional, 'scaled_dot_product_attention'):
        att_score = nn.functional.scaled_dot_product_attention(queries, keys, values) # (n, D, L)
    else:
        att_score = scaled_dot_product_attention(queries, keys, values) # (n, D, L)
    out = x_norm + att_score # (n, D, L)
    # Linear Forecasting
    out = self.linear_forecaster(out) # (n, D, H)
    # RevIN Denormalization
    if self.use_revin:
        out = self.revin(out.transpose(1, 2), mode='denorm').transpose(1, 2) # (n, D, H)
    return out.reshape([out.shape[0], out.shape[1]*out.shape[2]])
class SAMFormer:
    """
    SAMFormer pytorch trainer implemented in the sklearn fashion
    """
    def __init__(self, device='cuda:0', num_epochs=100, batch_size=256, base_optimizer=torch.optim.Adam,
                 learning_rate=1e-3, weight_decay=1e-5, rho=0.5, use_revin=True, random_state=None):
        self.network = None
        self.criterion = nn.MSELoss()
        self.device = device
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.base_optimizer = base_optimizer
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.rho = rho
        self.use_revin = use_revin
        self.random_state = random_state
        
def fit(self, x, y):
    if self.random_state is not None:
        torch.manual_seed(self.random_state)
        random.seed(self.random_state)
        np.random.seed(self.random_state)
        torch.cuda.manual_seed_all(self.random_state)
 
    self.network = SAMFormerArchitecture(num_channels=x.shape[1], seq_len=x.shape[2], hid_dim=16,
                                         pred_horizon=y.shape[1] // x.shape[1], use_revin=self.use_revin)
    self.criterion = self.criterion.to(self.device)
    self.network = self.network.to(self.device)
    self.network.train()
 
    optimizer = SAM(self.network.parameters(), base_optimizer=self.base_optimizer, rho=self.rho,
                    lr=self.learning_rate, weight_decay=self.weight_decay)
 
    train_dataset = LabeledDataset(x, y)
    data_loader_train = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
 
    progress_bar = tqdm(range(self.num_epochs))
    for epoch in progress_bar:
        loss_list = []
        for (x_batch, y_batch) in data_loader_train:
            x_batch = x_batch.to(self.device)
            y_batch = y_batch.to(self.device)
            # =============== forward ===============
            out_batch = self.network(x_batch)
            loss = self.criterion(out_batch, y_batch)
            # =============== backward ===============
            if optimizer.__class__.__name__ == 'SAM':
                loss.backward()
                optimizer.first_step(zero_grad=True)
 
                out_batch = self.network(x_batch)
                loss = self.criterion(out_batch, y_batch)
 
                loss.backward()
                optimizer.second_step(zero_grad=True)
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            loss_list.append(loss.item())
        # =============== save model / update log ===============
        train_loss = np.mean(loss_list)
        self.network.train()
        progress_bar.set_description("Epoch {:d}: Train Loss {:.4f}".format(epoch, train_loss), refresh=True)
    return
 
def forecast(self, x, batch_size=256):
    self.network.eval()
    dataset = torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.float))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    outs = []
    for _, batch in enumerate(dataloader):
        x = batch[0].to(self.device)
        with torch.no_grad():
            out = self.network(x)
        outs.append(out.cpu())
    outs = torch.cat(outs)
    return outs.cpu().numpy()
 
def predict(self, x, batch_size=256):
    return self.forecast(x, batch_size=batch_size)

小结

本文讨论了在时间序列预测中运用转换器模型的挑战与创新。传统的基于Transformer的模型虽然在自然语言处理和计算机视觉领域表现出色,但在多变量长期预测任务上,它们的性能却不及简单的线性模型。研究中指出,这些模型在基本的线性预测场景中难以实现最佳解决方案,主要问题在于其注意力机制的泛化能力较差。为应对这一问题,研究提出了一种新型的浅层、轻量级Transformer模型,即SAMformer。该模型采用锐度感知优化(SAM),有效克服了不良的局部最小值,显著提高了模型在多变量时间序列数据集上的性能,性能提升幅度达14.33%,且模型的参数数量大约减少了四倍。此外,SAMformer展示了优越的泛化能力和鲁棒性,其在多个数据集上的表现均优于当前先进的多变量模型TSMixer。研究结果表明,采用channel-wise注意力机制的SAMformer在计算和性能方面都比传统的temporal attention更为有效,为时间序列预测领域提供了新的视角和方法。

参考文献

[1] Romain Ilbert, Ambroise Odonnat, Vasilii Feofanov, Aladin Virmaux, Giuseppe Paolo, Themis Palpanas, Ievgen Redko “SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention” [C], ICML 2024

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/730485.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

智谱AI GLM-4V-9B视觉大模型环境搭建推理

引子 最近在关注多模态大模型&#xff0c;之前4月份的时候关注过CogVLM&#xff08;CogVLM/CogAgent环境搭建&推理测试-CSDN博客&#xff09;。模型整体表现还不错&#xff0c;不过不支持中文。智谱AI刚刚开源了GLM-4大模型&#xff0c;套餐里面包含了GLM-4V-9B大模型&…

HTTP 状态码详解及使用场景

目录 1xx 信息性状态码2xx 成功状态码3xx 重定向状态码4xx 客户端错误状态码5xx 服务器错误状态码 HTTP思维导图连接&#xff1a;https://note.youdao.com/s/A7QHimm0 1xx 信息性状态码 100 Continue&#xff1a;表示客户端应继续发送请求的其余部分。 使用场景&#xff1a;客…

昇思25天学习打卡营第3天|数据集Dataset

一、简介&#xff1a; 数据是深度学习的基础&#xff0c;高质量的数据输入将在整个深度神经网络中起到积极作用。有一种说法是模型最终训练的结果&#xff0c;10%受到算法影响&#xff0c;剩下的90%都是由训练的数据质量决定。&#xff08;doge&#xff09; MindSpore提供基于…

公司怎么管理文档外发泄密?强化企业文档安全用迅软加密软件就行了!

一、文档加密软件原理 迅软DSE加密软件对各类需要加密的文件&#xff08;如&#xff1a;技术资料、商业数据、红头文件、会议纪要、机要文件、图纸、财务报表等&#xff09;进行加密。 使用加密算法对文件自动加密&#xff0c;只有拥有正确的解密密钥或密码的人才能打开文件&…

【uni-app学习手札】

uni-app&#xff08;vue3&#xff09;编写微信小程序 编写uni-app不必拘泥于HBuilder-X编辑器&#xff0c;可用vscode进行编写&#xff0c;在《微信开发者工具》中进行热加载预览&#xff0c; 主要记录使用uni-app过程中自我备忘一些api跟语法&#xff0c;方便以后编写查找使用…

OrangePi连接Wi-Fi步骤

下面介绍的是用终端命令行的方式配置WIFI&#xff1a; 首先输入以下命令用于扫描并查看周围的WiFi热点。也可以直接连接。 nmcli dev wifi之后会在终端打出周围所有可以连接的WiFi&#xff0c;按方向键上下可以查看显示更多&#xff0c;按q键退出。 然后同样使用nmcli命令连接…

如何修改外接移动硬盘的区号

- 问题介绍 当电脑自身内存不够使用的时候&#xff0c;使用外接硬盘扩展内存是一个不错的选择。但是当使用的外接硬盘数量过多的时候&#xff0c;会出现分配硬盘的区号变动的情况&#xff0c;这种情况下会极大的影响使用的体验情况。可以通过以下步骤手动调整恢复 - 配置 版本…

【CT】LeetCode手撕—143. 重排链表

目录 题目1- 思路2- 实现⭐143. 重排链表——题解思路 3- ACM 实现 题目 原题连接&#xff1a;143. 重排链表 1- 思路 模式识别&#xff1a;重排链表 ——> 逆向 ——> ① 找到中间节点 ——> ②逆置 mid.next 链表——> ③遍历 2- 实现 ⭐143. 重排链表——题解…

ELK Kibana搜索框模糊搜索包含不包含

默认是KQL,点击切换Lucene搜索&#xff0c;搜索日志中包含Exception关键字&#xff0c;不包含BizException、IllegalArgumentException、DATA_SYNC_EXCEPTION关键字的日志&#xff0c;如下&#xff1a; message: *Exception AND !(message : *BizException OR message : *Ille…

现代易货交易:重塑物品交换的新纪元

在数字时代的浪潮中&#xff0c;交易模式正在经历一场革命。其中&#xff0c;现代易货交易模式以其独特的魅力&#xff0c;逐渐在市场中崭露头角。这种交易模式不仅是对古老“以物换物”的复兴&#xff0c;更是对物品价值和交换方式的全新定义。 现代易货&#xff1a;物品交换的…

机器人系统工具箱的 Gazebo 模拟

Gazebo 联合仿真模块 机器人系统工具箱> Gazebo联合仿真模块库包含与仿真环境相关的 Simulink 模块。要查看该库&#xff0c;在 MATLAB 命令提示符下输入robotgazebolib。

张量 Tensor学习总结

张量 Tensor 张量是一种多线性函数&#xff0c;用于表示矢量、标量和其他张量之间的线性关系&#xff0c;其在n维空间内有n^r个分量&#xff0c;每个分量都是坐标的函数。张量在坐标变换时也会按照某些规则作线性变换&#xff0c;是一种特殊的数据结构&#xff0c;在MindSpore…

IDEA中SpringMVC的运行环境问题

文章目录 一、IEAD 清理缓存二、用阿里云和spring创建 SpringMVC 项目中 pom.xml 文件的区别 一、IEAD 清理缓存 springMVC 运行时存在一些之前运行过的缓存导致项目不能运行&#xff0c;可以试试清理缓存 二、用阿里云和spring创建 SpringMVC 项目中 pom.xml 文件的区别 以下…

容器之工具栏构件演示

代码; #include <gtk-2.0/gtk/gtk.h> #include <glib-2.0/glib.h> #include <gtk-2.0/gdk/gdkkeysyms.h> #include <stdio.h>int main(int argc, char *argv[]) {gtk_init(&argc, &argv);GtkWidget *window;window gtk_window_new(GTK_WINDO…

Meta-Llama-3-8B 部署

Meta-Llama-3-8B 模型文件地址 LLaMA-Factory 仓库地址 Download Ollama 环境准备 操作系统&#xff1a;Ubuntu 22.04.5 LTSAnaconda3&#xff1a;Miniconda3-latest-Linux-x86_64GPU&#xff1a; NVIDIA G…

第二十六篇——极简通信史:从1G到5G通信,到底经历了什么?

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么&#xff1f; 四、总结五、升华 一、背景介绍 对于网络&#xff0c;1G到5G&#xff0c;我们都在享受它带来的进步成果&a…

3.3 Ubuntu24使用kubeadm部署高可用K8S集群

Ubuntu24使用kubeadm部署高可用K8S集群 使用kubeadm部署一个k8s集群&#xff0c;3个master1个worker节点。 1. 环境信息 操作系统&#xff1a;ubuntu24.04内存: 2GBCPU: 2网络: 能够互访&#xff0c;能够访问互联网 hostnameip备注k8s-master1192.168.0.51master1k8s-maste…

聚类算法(1)---最大最小距离、C-均值算法

本篇文章是博主在人工智能等领域学习时&#xff0c;用于个人学习、研究或者欣赏使用&#xff0c;并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记&#xff0c;若有不当和侵权之处&#xff0c;指出后将会立即改正&#xff0c;还望谅解。文章分类在AI学习笔记&#…

[Qt]Qt框架解析:从入门到精通,探索平台开发的无限可能

一、Qt的概述 Qt是一个跨平台的C图形用户界面应用程序框架&#xff08;GUI&#xff09;。它为应用程序开发者提供建立艺术级图形界面所需的所有功能。它是完全面向对象的&#xff0c;很容易扩展&#xff0c;并且允许真正的组件编程。开发环境为Qt creator5.8.0&#xff0c;下载…

小红书 2024 大模型论文分享会来啦,与多位顶会作者在线畅聊!

大模型正引领新一轮的研究热潮&#xff0c;业界和学术界都涌现出了众多的创新成果。 小红书技术团队也在这一浪潮中不断探索&#xff0c;多篇论文研究成果在 ICLR、ACL、CVPR、AAAI、SIGIR、WWW 等国际顶会上频频亮相。 在大模型与自然语言处理的交汇处&#xff0c;我们发现了…