Chronos:学习时间序列的大语言模型(代码解析)

前言

  • 《Chronos: Learning the Language of Time Series》原文地址,Github开源代码地址
  • Chronos:学习时间序列的大语言模型(论文解读)CSDN地址
  • GitHub项目地址Some-Paper-CN。本项目是译者在学习长时间序列预测、CV、NLP和机器学习过程中精读的一些论文,并对其进行了中文翻译。还有部分最佳示例教程
  • 如果有帮助到大家,请帮忙点亮Star,也是对译者莫大的鼓励,谢谢啦~
  • 本文代码已同步至项目Some-Paper-CN,后续可能会根据热度发布使用LoRA微调Chronos模型教程,浅浅期待一下吧~

先验知识

  • 建议先阅读Chronos论文解读篇,对大致原理有所了解,阅读代码效果会更好。
  • 在论文解读篇中,我们已经知道了Chronos是基于Google的开源模型T5(Huggingface)。因受篇幅影响,有关T5模型的解析不在本次讨论范围内,感兴趣的小伙伴可以去查询相关资料。
  • 论文基于Transformers框架,在阅读代码前,最好有一定Transformers库的基础知识。
  • 虽然本文模型为时间序列模型,但不管是在模型架构、训练方式还是数据组织上都与大语言模型几乎一致,在阅读代码前,最好有一定大语言模型领域的知识,比如术语tonkentokenizer

代码解析

  • 将开源代码从Github上下载到本地,关键文件在chronos-forecasting/src/chronos下,chronos.py文件。
  • ChronosConfig用于加载模型参数(注意!是参数不是权重),类ChronosTokenizer用于加载模型Tokenizer,类ChronosModel用于根据模型参数构建模型。上述类为Transformers库基础类,这里不多赘述。
  • 论文中的核心在类MeanScaleUniformBins用于数据均值缩放和量化分箱,类ChronosPipeline用于构架数据预测管道。

MeanScaleUniformBins

class MeanScaleUniformBins(ChronosTokenizer):
    def __init__(
        self, low_limit: float, high_limit: float, config: ChronosConfig
    ) -> None:
        self.config = config
        # 线性平分向量torch.linspace(start, end, steps)
        self.centers = torch.linspace(
            low_limit,
            high_limit,
            config.n_tokens - config.n_special_tokens - 1,
        )
        # 首尾元素分别为-1e20、1e20
        # self.centers[1:]除第1个元素外的所有元素
        # self.centers[:-1]除最后1个元素外的所有元素
        # (self.centers[1:] + self.centers[:-1]) / 2表示相邻元素平均值
        self.boundaries = torch.concat(
            (
                torch.tensor([-1e20], device=self.centers.device),
                (self.centers[1:] + self.centers[:-1]) / 2,
                torch.tensor([1e20], device=self.centers.device),
            )
        )

    def input_transform(
        self, context: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size, length = context.shape

        if length > self.config.context_length:
            # 保留最后context_length个元素
            context = context[..., -self.config.context_length :]

        # 空值的反向布尔值
        attention_mask = ~torch.isnan(context)
        # context绝对值和attention_mask的点积,除以attention_mask的和
        scale = torch.nansum(
            torch.abs(context) * attention_mask, dim=-1
        ) / torch.nansum(attention_mask, dim=-1)
        # scale是0或空值设为1.0
        scale[~(scale > 0)] = 1.0
        # 将context按scale缩放
        scaled_context = context / scale.unsqueeze(dim=-1)
        # torch.bucketize根据边界值将输入映射到相应bucket(桶)中
        token_ids = (
            torch.bucketize(
                input=scaled_context,
                boundaries=self.boundaries,
                right=True,
            )
            + self.config.n_special_tokens
        )
        # 不需要关注的地方使用padding
        token_ids[~attention_mask] = self.config.pad_token_id

        # 如果需要在末尾添加eos符
        if self.config.use_eos_token:
            eos_tokens = torch.full(
                (batch_size, 1), fill_value=self.config.eos_token_id
            )
            token_ids = torch.concat((token_ids, eos_tokens), dim=1)
            # mask置为true
            eos_mask = torch.full((batch_size, 1), fill_value=True)
            attention_mask = torch.concat((attention_mask, eos_mask), dim=1)

        return token_ids, attention_mask, scale

    def output_transform(
        self, samples: torch.Tensor, scale: torch.Tensor
    ) -> torch.Tensor:
        # 将scale扩展两个维度
        scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
        # 将值限制在0和centers长度间,确保索引值不超出centers
        indices = torch.clamp(
            samples - self.config.n_special_tokens,
            min=0,
            max=len(self.centers) - 1,
        )
        # 返回在原始context缩放级别下分桶值
        return self.centers[indices] * scale_unsqueezed
  • low_limithigh_limit包含在模型参数中,根据论文分别为-1515

  • input_transform函数中scale = torch.nansum(torch.abs(context) * attention_mask, dim=-1) / torch.nansum(attention_mask, dim=-1)看上去非常复杂,实际上在没有空值的情况下,相当于对序列求平均值。

  • input_transform函数中分箱函数torch.bucketize的使用可以参考官方文档。

  • input_transform函数中空值使用padding填充,并使用mask进行遮掩,是大语言模型训练的常用操作。

  • 在论文中,作者表示为了保持与大语言模型训练方式保持一致,会在序列结束后放置eos标识符,所以模型参数use_eos_token是为True的。

  • output_transform函数是input_transform函数的反操作,需要注意的是torch.clamp函数,确保token_id在词表中,否则就无法反归一化得到正常的值了。

ChronosPipeline

  • from_pretrained函数用于加载模型预训练权重,这里不在过多赘述,关键在于predict函数。
    def predict(
        self,
        context: Union[torch.Tensor, List[torch.Tensor]],
        prediction_length: Optional[int] = None,
        num_samples: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        limit_prediction_length: bool = True,
    ) -> torch.Tensor:
        """
        Get forecasts for the given time series.

        Parameters
        ----------
        context
            Input series. This is either a 1D tensor, or a list
            of 1D tensors, or a 2D tensor whose first dimension
            is batch. In the latter case, use left-padding with
            ``torch.nan`` to align series of different lengths.
        prediction_length
            Time steps to predict. Defaults to what specified
            in ``self.model.config``.
        num_samples
            Number of sample paths to predict. Defaults to what
            specified in ``self.model.config``.
        temperature
            Temperature to use for generating sample tokens.
            Defaults to what specified in ``self.model.config``.
        top_k
            Top-k parameter to use for generating sample tokens.
            Defaults to what specified in ``self.model.config``.
        top_p
            Top-p parameter to use for generating sample tokens.
            Defaults to what specified in ``self.model.config``.
        limit_prediction_length
            Force prediction length smaller or equal than the
            built-in prediction length from the model. True by
            default. When true, fail loudly if longer predictions
            are requested, otherwise longer predictions are allowed.

        Returns
        -------
        samples
            Tensor of sample forecasts, of shape
            (batch_size, num_samples, prediction_length).
        """
        context_tensor = self._prepare_and_validate_context(context=context)

        if prediction_length is None:
            prediction_length = self.model.config.prediction_length

        if prediction_length > self.model.config.prediction_length:
            msg = (
                f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
                "The quality of longer predictions may degrade since the model is not optimized for it. "
            )
            if limit_prediction_length:
                msg += "You can turn off this check by setting `limit_prediction_length=False`."
                raise ValueError(msg)
            warnings.warn(msg)

        predictions = []
        remaining = prediction_length

        while remaining > 0:
            # 根据MeanScaleUniformBins类对数据进行缩放和分箱
            token_ids, attention_mask, scale = self.tokenizer.input_transform(
                context_tensor
            )
            # 输入模型得到结果
            samples = self.model(
                token_ids.to(self.model.device),
                attention_mask.to(self.model.device),
                min(remaining, self.model.config.prediction_length),
                num_samples,
                temperature,
                top_k,
                top_p,
            )
            prediction = self.tokenizer.output_transform(
                samples.to(scale.device), scale
            )

            predictions.append(prediction)
            remaining -= prediction.shape[-1]
			
            # 判断是否预测完
            if remaining <= 0:
                break
			# 拼接操作
            context_tensor = torch.cat(
                [context_tensor, prediction.median(dim=1).values], dim=-1
            )

        return torch.cat(predictions, dim=-1)
  • 作者建议将prediction length保持在64以下,因为模型没有针对较长的预测长度进行优化,因此预测质量可能会下降。
  • 预测过程为:根据MeanScaleUniformBins类中input_transform函数对数据进行缩放和分箱,得到token_id、掩码矩阵attention_mask, 均值scale;将token_id和掩码矩阵attention_mask输入模型,得到输出samples。根据MeanScaleUniformBins类中output_transform函数和均值scale将输出samples反归一化得到实际值。
  • remaining变量用于检验prediction length是否全部预测完。

left_pad_and_stack_1D

  • 上述代码中函数predict调用了_prepare_and_validate_context函数,本质是left_pad_and_stack_1D函数。
def left_pad_and_stack_1D(tensors: List[torch.Tensor]):
    # tensors中最长元素的长度
    max_len = max(len(c) for c in tensors)
    padded = []
    # 遍历tensors中元素
    for c in tensors:
        assert isinstance(c, torch.Tensor)
        # c为一维张量
        assert c.ndim == 1
        # 填充torch.nan
        padding = torch.full(
            size=(max_len - len(c),), fill_value=torch.nan, device=c.device
        )
        # 拼接(c长度被扩展为max_len),并添加到列表padded中
        padded.append(torch.concat((padding, c), dim=-1))
    # 将padded列表中的所有元素沿着新维度折叠,形成二维张量
    return torch.stack(padded)
  • 该函数是大语言模型训练过程中为了补齐长度做的操作,如果不理解也没事,只要明白在干什么就行。

测试Demo

  • 如果想要进一步了解代码,还是希望大家用一个轻量的测试Demo从头到尾Debug一下。
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-tiny",
    device_map="cpu",
    torch_dtype=torch.float16,
)

df = pd.read_csv("AirPassengers.csv")

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(df["#Passengers"])
prediction_length = 12
forecast = pipeline.predict(
    context,
    prediction_length,
    num_samples=20,
    temperature=1.0,
    top_k=50,
    top_p=1.0,
) # forecast shape: [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
plt.legend()
plt.grid()
plt.show()
  • 预测结果效果图

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/612048.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【漫画版】指挥官的排序战术:快速排序算法解密

作者介绍&#xff1a;10年大厂数据\经营分析经验&#xff0c;现任字节跳动数据部门负责人。 会一些的技术&#xff1a;数据分析、算法、SQL、大数据相关、python&#xff0c;欢迎探讨交流 欢迎加入社区&#xff1a;码上找工作 作者专栏每日更新&#xff1a; LeetCode解锁1000题…

使用Python在PowerPoint演示文稿之间复制样式(复制幻灯片母版)

在专业演示文稿设计与制作领域&#xff0c;多场演示间保持一致性至关重要。在PowerPoint演示文稿之间复制幻灯片母版成为了一项关键技巧&#xff0c;用以维持统一的视觉风格&#xff0c;确保品牌形象的一致性&#xff0c;并提升观众的参与度。这一做法不仅能节省宝贵的时间&…

OC foudation框架(下)的学习

OCfoudation框架&#xff08;下&#xff09; 前面学习了有关OCfoudation框架的部分内容&#xff0c;我们现在对于后面的内容继续学习。 文章目录 OCfoudation框架&#xff08;下&#xff09;数组&#xff08;NSArray和NSMutableArray&#xff09;对集合元素整体调用方法排序使用…

SwinIR: Image Restoration Using Swin Transformer

ICCV2021 workshophttps://github.com/JingyunLiang/SwinIR 问题引入 将swim transformer使用到图像恢复任务当中&#xff0c;因为卷积存在不能建模长距离依赖以及使用相同的卷积核来恢复不同的图像区域&#xff1b;并不是首个将transformer引入图像恢复中的方法&#xff0c;…

简单的Python HTML 输出

1、问题背景 一名初学者在尝试将 Python 脚本输出到网页上时遇到了一些问题。他当前使用 Python 和 HTML 进行开发&#xff0c;并且遇到了以下问题&#xff1a; 担心自己的代码过于复杂&#xff0c;尤其是 WebOutput() 函数。希望通过 JavaScript 使用 HTML 模板文件更新数据。…

Java多线程与并发编程

1.多线程基础 1.1 线程相关概念 程序(program)&#xff1a;是为完成特定任务、用某种语言编写的一组指令的集合。简单的说:就是我们写的代码 进程: 1. 进程是指运行中的程序&#xff0c;比如我们使用QQ&#xff0c;就启动了一个进程&#xff0c;操作系统就会为该进程…

常见扩频系统的基础概念和模型

扩频系统是一种通信技术&#xff0c;它通过将信号的频谱扩展到一定程度来实现传输&#xff0c;这种系统的设计和实现涉及到多种不同的方法和技术。 扩频系统的主要特点和好处包括&#xff1a; 抗干扰能力强&#xff1a;由于信号被扩展到较宽的频带上&#xff0c;单位带宽内的功…

数据收集-分化轨迹推断

数据收集-分化轨迹推断 1参考内容 2参考内容 3参考内容 4参考内容 5参考内容 6&#xff1a;methods and datasets review参考内容 1 参考 Ranek, J.S., Stanley, N. & Purvis, J.E. Integrating temporal single-cell gene expression modalities for trajectory inferen…

【p7】正规式转正规文法

需要注意的是&#xff0c;有时候需要自己构造一个非终结符&#xff0c;非终结符推导到空&#xff0c;然后套用上面的公式即可

十大排序算法(java实现)

注&#xff1a;本篇仅用来自己学习&#xff0c;大量内容来自菜鸟教程&#xff08;地址&#xff1a;1.0 十大经典排序算法 | 菜鸟教程&#xff09; 排序算法可以分为内部排序和外部排序&#xff0c;内部排序是数据记录在内存中进行排序&#xff0c;而外部排序是因排序的数据很大…

SpringCloud面试题

SpringCloud常见组件有哪些 注册中心组件&#xff1a;Eureka、Nacos 负载均衡组件&#xff1a;Ribbon 远程调用组件&#xff1a;OpenFeign 网关组件&#xff1a;Zuul、Gateway 服务保护组件&#xff1a;Hystrix、Sentinel 服务配置管理组件&#xff1a;SpringCloudConfig、Nac…

OpenCompass大模型评估

作业链接&#xff1a; Tutorial/opencompass/homework.md at camp2 InternLM/Tutorial GitHub 项目链接&#xff1a; GitHub - open-compass/opencompass: OpenCompass is an LLM evaluation platform, supporting a wide range of models (Llama3, Mistral, InternLM2,GPT-…

Docker快速搭建NAS服务——FileBrowser

Docker快速搭建NAS服务——FileBrowser 文章目录 前言FileBrowser的搭建docker-compose文件编写运行及访问 总结 前言 本文主要讲解如何使用docker在本地快速搭建NAS服务&#xff0c;这里主要写如下两种&#xff1a; FileBrowser1&#xff1a;是一个开源的Web文件管理器&…

【吊打面试官系列】Java高并发篇 - 为什么 wait(), notify()和 notifyAll ()必须在同步方法或者同步块中被调用?

大家好&#xff0c;我是锋哥。今天分享关于 【为什么 wait(), notify()和 notifyAll ()必须在同步方法或者同步块中被调用&#xff1f;】面试题&#xff0c;希望对大家有帮助&#xff1b; 为什么 wait(), notify()和 notifyAll ()必须在同步方法或者同步块中被调用&#xff1f;…

这3种深拷贝实现,你都知道吗?

目录&#xff1a; 1、JSON.parse 2、structuredClone 3、cloneDeep

【竞技宝jjb.lol】MSI:换线战术或将成为BLG命门

北京时间2024年5月10日,英雄联盟2024MSI季中赛继续进行,昨日迎来胜败分组赛首轮BLG对阵PSG。本以为这场比赛没有任何悬念,BLG将会非常轻松地击败PSG,没想到最终PSG两度扳平比分,BLG决胜局抗住压力才艰难取胜。虽然赢下了比赛,但BLG低迷的状态还是在比赛结束后遭到网友们的热议。…

超全MySQL锁机制介绍

前言 MySQL作为关系型数据库管理系统中的佼佼者&#xff0c;为了保证数据的一致性和完整性&#xff0c;在并发控制方面采用了锁机制。锁机制是数据库管理系统用于控制对共享资源的访问&#xff0c;避免多个事务同时修改同一数据造成的数据不一致问题。了解MySQL的锁机制对于数…

【组合博弈】介绍

本文为学习笔记&#xff0c;详细内容参考"Lessons in Play,Michael H. Albert Richard J. Nowakowski David Wolfe" 文章目录 组合博弈介绍(Combinatorial Games)DOMINEERING游戏组合游戏选手介绍Options博弈树&#xff08;game tree&#xff09; 组合博弈介绍(Combi…

*****水上飞机:继承,虚函数,虚继承

一题目 请设计以下航行器、飞机、船、水上飞机等 4 个类。 CRAFT 为航行器类&#xff0c;是公共基类&#xff0c;提供航行器的基本特性。包括&#xff1a; 一个保护数据成员&#xff1a;speed(速度)。 三个公有成员函数&#xff1a;构造函数(初始化速度)、析构函数和 Show 函数…

ASP.NET学生成绩管理系统

摘要 本系统依据开发要求主要应用于教育系统&#xff0c;完成对日常的教育工作中学生成绩档案的数字化管理。开发本系统可使学院教职员工减轻工作压力&#xff0c;比较系统地对教务、教学上的各项服务和信息进行管理&#xff0c;同时&#xff0c;可以减少劳动力的使用&#xf…