0基础学会在亚马逊云科技AWS上利用SageMaker、PEFT和LoRA高效微调AI大语言模型(含具体教程和代码)

项目简介:

小李哥今天将继续介绍亚马逊云科技AWS云计算平台上的前沿前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS上的AI软甲开发最佳实践,并应用到自己的日常工作里。本次介绍的是如何在Amazon SageMaker上微调(Fine-tune)大语言模型dolly-v2-3b,满足日常生活中不同的场景需求,并将介分享如何在SageMaker上优化模型性能并节省计算资源实现成本控制,最后将部署后的大语言模型URL集成到自己云上的软件应用中。

本方案包括通过Amazon Cloudfront和S3托管前端页面,并通过Amazon API Gateway和AWS Lambda将应用程序与AI模型集成,调用大模型实现推理。本方案的解决方案架构图如下:

利用微调模型创建的对话机器人前端UI

利用本方案小李哥用微调后的模型搭建了一个Q&A对话机器人助手,可以生成代码、文字总结、回答问题。

在开始分享案例之前,我们来了解一下本方案的技术背景,帮助大家更好的理解方案架构。

什么是Amazon SageMaker?

Amazon SageMaker 是一个完全托管的机器学习服务(大家可以理解为Serverless的Jupyter Notebook),专为应用开发和数据科学家设计,帮助他们快速构建、训练和部署机器学习模型。使用 SageMaker,您无需担心底层基础设施的管理,可以专注于模型的开发和优化。它提供了一整套工具和功能,包括数据准备、模型训练、超参数调优、模型部署和监控,简化了整个机器学习工作流程。

本方案将介绍以下内容:

1. 使用 SageMaker Jupyter Notebook进行dolly-v2-3b模型开发和微调

2. 在SageMaker部署微调后的大语言模型LLM并基于数据进行推理

3. 使用多场景的测试案例验证推理结果表现,并将部署的模型API节点集成进云端应用

项目搭建具体步骤:

下面跟着小李哥手把手微调一个亚马逊云科技AWS上的生成式AI模型(dolly-v2-3b)的软件应用,并将AI大模型部署与应用集成。

1. 在控制台进入Amazon SageMaker, 点击Notebook

2. 打开Jupyter Notebook

3. 创建一个新的Notebook:“lab-notebook.ipynb”并打开

4. 接下来我们在单元格内一步一步运行代码,检查CUDA的内存状态

!nvidia-smi

5.接下来,我们安装必要依赖并导入

%%capture

!pip3 install -r requirements.txt --quiet
!pip install sagemaker --quiet --upgrade --force-reinstall
%%capture

import os
import numpy as np
import pandas as pd
from typing import Any, Dict, List, Tuple, Union
from datasets import Dataset, load_dataset, disable_caching
disable_caching() ## disable huggingface cache

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import TextDataset

import torch
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer
import accelerate
import bitsandbytes

from IPython.display import Markdown

6. 导入提前准备好的FAQs数据集

sagemaker_faqs_dataset = load_dataset("csv", 
                                      data_files='data/amazon_sagemaker_faqs.csv')['train']
sagemaker_faqs_dataset
sagemaker_faqs_dataset[0]

7. 我们定义用于模型推理的提示词格式

from utils.helpers import INTRO_BLURB, INSTRUCTION_KEY, RESPONSE_KEY, END_KEY, RESPONSE_KEY_NL, DEFAULT_SEED, PROMPT
'''
PROMPT = """{intro}
            {instruction_key}
            {instruction}
            {response_key}
            {response}
            {end_key}"""
'''
Markdown(PROMPT)

8. 下面我们进入重头戏,导入一个提前预训练好的LLM大语言模型“databricks/dolly-v2-3b”。

tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-3b", 
                                          padding_side="left")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": 
                              [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})

model = AutoModelForCausalLM.from_pretrained(
    "databricks/dolly-v2-3b",
    # use_cache=False,
    device_map="auto", #"balanced",
    load_in_8bit=True,
)

9. 对模型训练进行预准备, 处理数据集、优化模型训练(PEFT)效率

model.resize_token_embeddings(len(tokenizer))

from functools import partial
from utils.helpers import mlu_preprocess_batch

MAX_LENGTH = 256
_preprocessing_function = partial(mlu_preprocess_batch, max_length=MAX_LENGTH, tokenizer=tokenizer)

encoded_sagemaker_faqs_dataset = sagemaker_faqs_dataset.map(
        _preprocessing_function,
        batched=True,
        remove_columns=["instruction", "response", "text"],
)

processed_dataset = encoded_sagemaker_faqs_dataset.filter(lambda rec: len(rec["input_ids"]) < MAX_LENGTH)

split_dataset = processed_dataset.train_test_split(test_size=14, seed=0)
split_dataset

10. 同时我们使用LoRA(Low-Rank Adaptation)模型加速我们的模型微调

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType

MICRO_BATCH_SIZE = 8  
BATCH_SIZE = 64
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LORA_R = 256 # 512
LORA_ALPHA = 512 # 1024
LORA_DROPOUT = 0.05

# Define LoRA Config
lora_config = LoraConfig(
                 r=LORA_R,
                 lora_alpha=LORA_ALPHA,
                 lora_dropout=LORA_DROPOUT,
                 bias="none",
                 task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

from utils.helpers import MLUDataCollatorForCompletionOnlyLM

data_collator = MLUDataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)

11. 接下来我们定义模型训练参数并开始训练。其中Batch=1,Step=20000,epoch为10.

EPOCHS = 10
LEARNING_RATE = 1e-4  
MODEL_SAVE_FOLDER_NAME = "dolly-3b-lora"

training_args = TrainingArguments(
                    output_dir=MODEL_SAVE_FOLDER_NAME,
                    fp16=True,
                    per_device_train_batch_size=1,
                    per_device_eval_batch_size=1,
                    learning_rate=LEARNING_RATE,
                    num_train_epochs=EPOCHS,
                    logging_strategy="steps",
                    logging_steps=100,
                    evaluation_strategy="steps",
                    eval_steps=100, 
                    save_strategy="steps",
                    save_steps=20000,
                    save_total_limit=10,
)

trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=split_dataset['train'],
        eval_dataset=split_dataset["test"],
        data_collator=data_collator,
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

12. 接下来我们将微调后的模型保存在本地

trainer.model.save_pretrained(MODEL_SAVE_FOLDER_NAME)

trainer.save_model()

trainer.model.config.save_pretrained(MODEL_SAVE_FOLDER_NAME)

tokenizer.save_pretrained(MODEL_SAVE_FOLDER_NAME)

13. 接下来,我们将保存到本地的模型进行部署,生成公开访问的API节点Endpoint

对部署所需要的参数进行定义和初始化

import boto3
import json
import sagemaker.djl_inference
from sagemaker.session import Session
from sagemaker import image_uris
from sagemaker import Model

sagemaker_session = Session()
print("sagemaker_session: ", sagemaker_session)

aws_role = sagemaker_session.get_caller_identity_arn()
print("aws_role: ", aws_role)

aws_region = boto3.Session().region_name
print("aws_region: ", aws_region)

image_uri = image_uris.retrieve(framework="djl-deepspeed",
                                version="0.22.1",
                                region=sagemaker_session._region_name)
print("image_uri: ", image_uri)

进行模型部署

model_data="s3://{}/lora_model.tar.gz".format(mybucket)

model = Model(image_uri=image_uri,
              model_data=model_data,
              predictor_cls=sagemaker.djl_inference.DJLPredictor,
              role=aws_role)

14.最后我们写入提示词,对大语言模型进行测试, 得到推理

outputs = predictor.predict({"inputs": "What solutions come pre-built with Amazon SageMaker JumpStart?"})

from IPython.display import Markdown
Markdown(outputs)

15. 我们下面进入SageMaker Endpoint页面,得到刚部署的模型API端点的URL,通过这种方式我们就可以在应用中调用我们的微调后的大语言模型了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/793742.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【漏洞复现】Splunk Enterprise for Windows 任意文件读取漏洞 CVE-2024-36991

声明&#xff1a;本文档或演示材料仅用于教育和教学目的。如果任何个人或组织利用本文档中的信息进行非法活动&#xff0c;将与本文档的作者或发布者无关。 一、漏洞描述 Splunk Enterprise 是一款强大的机器数据管理和分析平台&#xff0c;广泛应用于企业中&#xff0c;用于实…

应用最优化方法及MATLAB实现——第3章代码实现

一、概述 在阅读最优方法及MATLAB实现后&#xff0c;想着将书中提供的代码自己手敲一遍&#xff0c;来提高自己对书中内容理解程度&#xff0c;巩固一下。 这部分内容主要针对第3章的内容&#xff0c;将其所有代码实现均手敲一遍&#xff0c;中间部分代码自己根据其公式有些许的…

百度安全大模型智能体实践入选信通院“安全守卫者计划”优秀案例

7月3日&#xff0c;由全球数字经济大会组委会主办&#xff0c;中国信息通信研究院&#xff08;以下简称中国信通院&#xff09;与中国通信标准化协会联合承办的2024全球数字经济大会“云和软件安全论坛暨第二届SecGo云和软件安全大会”在北京召开。本届论坛聚焦云和软件安全最新…

从基础到进阶:无线局域网技术解析

在局域网刚刚问世后的一段时间内&#xff0c;无线局域网的发展比较缓慢&#xff0c;其原因是价格贵、数据传输速率低、安全性较差。但自20世纪80年代末以来&#xff0c;由于人们工作和生活节奏的加快&#xff0c;以及移动通信技术的飞速发展&#xff0c;无线局域网逐步进入市场…

今年2024,而那一年是1984

那一年&#xff0c;是1984 对于经历了改革开放洪流的国人来说&#xff0c;1984年似乎没有什么特别。 可是这一年&#xff0c;又确确实实非同寻常&#xff0c;许多后来的巨大变迁&#xff0c;在这一年埋下了伏笔…… 文学创作: 余华、莫言等作家在这一年迎来了自己的创作高峰…

学习通er图和项目思路

ER图 项目构思&#xff1a; 用户功能&#xff1a; 主要功能逻辑&#xff1a;

Web3知识图谱,一篇读完

这张图展示了区块链生态系统的架构和主要组件。以下是对图中内容的概括总结&#xff1a; 基础层&#xff1a; 底层基础设施&#xff1a;包括光纤网络、P2P网络、非对称加密、哈希算法、默克尔树和随机数生成。共识机制&#xff1a; PoW&#xff08;工作量证明&#xff09;: 比特…

Elasticsearch:介绍 retrievers - 搜索一切事物

作者&#xff1a;来自 Elastic Jeff Vestal, Jack Conradson 在 8.14 中&#xff0c;Elastic 在 Elasticsearch 中引入了一项名为 “retrievers - 检索器” 的新搜索功能。继续阅读以了解它们的简单性和效率&#xff0c;以及它们如何增强你的搜索操作。 检索器是 Elasticsearc…

MyBatis框架学习笔记(三):MyBatis重要文件详解:配置文件与映射文件

1 mybatis-config.xml-配置文件详解 1.1 说明 &#xff08;1&#xff09;mybatis 的核心配置文件(mybatis-config.xml)&#xff0c;比如配置 jdbc 连接信息&#xff0c;注册 mapper 等等都是在这个文件中进行配置,我们需要对这个配置文件有详细的了解 &#xff08;2&#x…

如何做好漏洞扫描工作提高网络安全

在数字化浪潮席卷全球的今天&#xff0c;企业数字化转型已成为提升竞争力、实现可持续发展的关键路径。然而&#xff0c;这一转型过程并非坦途&#xff0c;其中网络安全问题如同暗礁般潜伏&#xff0c;稍有不慎便可能引发数据泄露、服务中断乃至品牌信誉受损等严重后果。因此&a…

【Linux】磁盘性能压测-FIO工具

一、FIO工具介绍 fio&#xff08;Flexible I/O Tester&#xff09;是一个用于评估计算机系统中 I/O 性能的强大工具。 官网&#xff1a;fio - fio - Flexible IO Tester 注意事项&#xff01; 1、不要指定文件系统名称&#xff08;如/dev/mapper/centos-root)&#xff0c;避…

socket编程(2) -- TCP通信

TCP通信 2. 使用 Socket 进行TCP通信2.1 socket相关函数介绍socket()bind()listen()accept()connect()2.2 TCP协议 C/S 模型基础通信代码 最后 2. 使用 Socket 进行TCP通信 Socket通信流程图如下&#xff1a; 这里服务器段listen是监听socket套接字的监听文件描述符。如果客户…

Excel第30享:基于辅助列的条件求和

1、需求描述 如下图所示&#xff0c;现要统计2022年YTD&#xff08;Year To Date&#xff1a;年初至今日&#xff09;各个人员的“上班工时&#xff08;a2&#xff09;”。 下图为系统直接导出的工时数据明细样例。 2、解决思路 Step1&#xff1a;确定逻辑。“从日期中提取出…

[spring] Spring MVC - security(上)

[spring] Spring MVC - security&#xff08;上&#xff09; 这部分的内容基本上和 [spring] rest api security 是重合的&#xff0c;主要就是添加 验证&#xff08;authentication&#xff09;和授权&#xff08;authorization&#xff09;这两个功能 即&#xff1a; 用户…

构造函数的初始化列表,static成员,友元,内部类【类和对象(下)】

P. S.&#xff1a;以下代码均在VS2022环境下测试&#xff0c;不代表所有编译器均可通过。 P. S.&#xff1a;测试代码均未展示头文件stdio.h的声明&#xff0c;使用时请自行添加。 博主主页&#xff1a;LiUEEEEE                        …

2-31 基于matlab的微表情识别

基于matlab的微表情识别。通过gabor小波提取表情特征&#xff0c;pca进行降维&#xff0c;ELM分类器训练&#xff0c;然后选择待识别的微表情&#xff0c;提取特征后输入训练好的模型进行分类&#xff0c;识别结果由MATLAB的GUI输出。程序已调通&#xff0c;可直接运行。 2-31 …

Tomcat多实例

一、Tomcat多实例 Tomcat多实例是指在同一台服务器上运行多个独立的tomcat实例&#xff0c;每个tomcat实例都具有独立的配置文件、日志文件、应用程序和端口&#xff0c;通过配置不同的端口和文件目录&#xff0c;可以实现同时运行多个独立的Tomcat服务器&#xff0c;每个服务…

Fastjson2使用JSONOObject或者mao转换为JSON字符串时丢失Null值字段

最近在工作中发现问题fastJson转换为JSONString时丢失值为null的问题特此解决。 public class test001 {public static void main(String[] args) {JSONObject jsonObject new JSONObject();jsonObject.put("foo1", "bar");jsonObject.put("foo2&quo…

19. 地址转换

地址转换 题目描述 Excel 是最常用的办公软件。每个单元格都有唯一的地址表示。比如&#xff1a;第 12 行第 4 列表示为&#xff1a;"D12"&#xff0c;第 5 行第 255 列表示为"IU5"。 事实上&#xff0c;Excel 提供了两种地址表示方法&#xff0c;还有一…

代码随想录第50天|单调栈

739. 每日温度 参考 思路1: 暴力解法 思路2: 单调栈 使用场合: 寻找任一个元素的右边或者左边第一个比自己大或者小的元素位置, 存放的是遍历过的元素 记忆: 单调栈是对遍历过的元素做记录, 一般是对栈顶的元素 nums[mystack.top()] 做赋值操作的 如果想找到右边的元素大于左…