【chatglm3】(3):在AutoDL上,使用4090显卡,部署ChatGLM3API服务,并微调AdvertiseGen数据集,完成微调并测试成功!附视频演示。

在AutoDL上,使用4090显卡,部署ChatGLM3API服务,并微调AdvertiseGen数据集,完成微调并测试成功!

其他chatgpt 和chatglm3 资料:
https://blog.csdn.net/freewebsys/category_12270092.html

视频地址:
https://www.bilibili.com/video/BV1zQ4y1t7x7/?vd_source=4b290247452adda4e56d84b659b0c8a2

在AutoDL上,使用4090显卡,部署ChatGLM3API服务,并微调AdvertiseGen数据集,完成微调并测试成功!

1,显卡市场,租个显卡性价比最高!

https://www.autodl.com/
在这里插入图片描述
创建完成可以使用 juypter 进入:

在这里插入图片描述
也可以监控服务器运行状况:

在这里插入图片描述

2,下载源代码,下载模型,启动服务

下载模型速度超级快 :

apt update && apt install git-lfs -y
git clone https://www.modelscope.cn/ZhipuAI/chatglm3-6b.git chatglm3-6b-models
Cloning into 'chatglm3-6b-models'...
remote: Enumerating objects: 101, done.
remote: Counting objects: 100% (101/101), done.
remote: Compressing objects: 100% (58/58), done.
remote: Total 101 (delta 42), reused 89 (delta 38), pack-reused 0
Receiving objects: 100% (101/101), 40.42 KiB | 1.84 MiB/s, done.
Resolving deltas: 100% (42/42), done.
Filtering content: 100% (8/8), 11.63 GiB | 203.56 MiB/s, done.

再下载github 项目:
https://github.com/THUDM/ChatGLM3/tree/main
或者上传代码

然后安装依赖库:

# 安装完成才可以启动:
pip3 install uvicorn fastapi loguru sse_starlette transformers sentencepiece
cd /root/ChatGLM3-main/openai_api_demo
python3 openai_api.py

启动成功,端口 8000
可以运行命令进行测试:

curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
     "model": "chatglm3-6b",
     "messages": [{"role": "user", "content": "北京景点"}],
     "temperature": 0.7
   }' 

3,使用脚本进行token测试,速度50 tokens/s 速度挺快的

然后使用测试脚本进行 token 测试,修改的 fastcaht的测试脚本:

# coding=utf-8
"""

token测试工具:

python3 test_throughput.py
或者:
python3 test_throughput.py --api-address http://localhost:8000 --n-thread 20


"""
import argparse
import json

import requests
import threading
import time


def main():

    headers = {"User-Agent": "openai client", "Content-Type": "application/json"}
    ploads = {
        "model": args.model_name,
        "messages": [{"role": "user", "content": "生成一个50字的故事,内容随即生成。"}],
        "temperature": 1,
    }
    thread_api_addr = args.api_address

    def send_request(results, i):
        print(f"thread {i} goes to {thread_api_addr}")
        response = requests.post(
            thread_api_addr + "/v1/chat/completions",
            headers=headers,
            json=ploads,
            stream=False,
        )
        print(response.text)
        response_new_words = json.loads(response.text)["usage"]["completion_tokens"]
        print(f"=== Thread {i} ===, words: {response_new_words} ")
        results[i] = response_new_words

    # use N threads to prompt the backend
    tik = time.time()
    threads = []
    results = [None] * args.n_thread
    for i in range(args.n_thread):
        t = threading.Thread(target=send_request, args=(results, i))
        t.start()
        # time.sleep(0.5)
        threads.append(t)

    for t in threads:
        t.join()

    print(f"Time (POST): {time.time() - tik} s")
    n_words = sum(results)
    time_seconds = time.time() - tik
    print(
        f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
        f"throughput: {n_words / time_seconds} words/s."
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--api-address", type=str, default="http://localhost:8000")
    parser.add_argument("--model-name", type=str, default="chatglm3-6b")
    parser.add_argument("--n-thread", type=int, default=10)
    args = parser.parse_args()

    main()

测下下服务:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:42:00.0 Off |                  Off |
| 30%   39C    P2              56W / 450W |  12429MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

速度特别快:

Time (POST): 22.13719415664673 s
Time (Completion): 22.137234687805176, n threads: 10, throughput: 51.22591037193507 words/s.

完全可以满足内部使用了。

3,下载微调数据,并进行模型训练

https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1

AdvertiseGen以商品网页的标签与文案的信息对应关系为基础构造

载处理好的 AdvertiseGen 数据集,将解压后的 AdvertiseGen 目录放到本目录下。

./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"

来下载和将数据集处理成上述格式。

微调模型

# 安装依赖库
pip install transformers==4.30.2 accelerate sentencepiece astunparse deepspeed

./scripts/finetune_pt.sh  # P-Tuning v2 微调

为了验证演示,调整参数,快速训练:

#! /usr/bin/env bash

set -ex

PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1
MAX_SOURCE_LEN=1024
MAX_TARGET_LEN=128
DEV_BATCH_SIZE=1
GRAD_ACCUMULARION_STEPS=8
MAX_STEP=10
SAVE_INTERVAL=10

DATESTR=`date +%Y%m%d-%H%M%S`
RUN_NAME=advertise_gen_pt

BASE_MODEL_PATH=/root/chatglm3-6b-models
DATASET_PATH=formatted_data/advertise_gen.jsonl
OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${PRE_SEQ_LEN}-${LR}

mkdir -p $OUTPUT_DIR

torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
    --train_format input-output \
    --train_file $DATASET_PATH \
    --preprocessing_num_workers 1 \
    --model_name_or_path $BASE_MODEL_PATH \
    --output_dir $OUTPUT_DIR \
    --max_source_length $MAX_SOURCE_LEN \
    --max_target_length $MAX_TARGET_LEN \
    --per_device_train_batch_size $DEV_BATCH_SIZE \
    --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
    --max_steps $MAX_STEP \
    --logging_steps 1 \
    --save_steps $SAVE_INTERVAL \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN 2>&1 | tee ${OUTPUT_DIR}/train.log

4,推理验证,使用命令行的方式

对于输入输出格式的微调,可使用 inference.py 进行基本的推理验证。

python inference.py \
    --model /root/chatglm3-6b-models \
    --pt-checkpoint "output/advertise_gen_pt-20231113-222811-128-2e-2" 
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.32it/s]
Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at /root/chatglm3-6b-models and are newly initialized: ['transformer.prefix_encoder.embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Prompt:类型#裙*材质#网纱*颜色#粉红色*裙袖长#短袖*裙领型#圆领
Response: *裙下摆流苏设计,轻轻松松演绎甜美可爱风。这条裙子真的太仙了,粉红色网纱,在阳光的照耀下,真的太仙了,仿佛置身于童话故事中。短袖的设计,既不会过于露肤,也能展示出修长的身材线条。裙摆处流苏的设计,让整个裙子的层次感更加明显,给人一种飘逸的感觉。

5,总结

在 4090 上面运行 chatgm3 速度还是挺快的。
然后找到官方的 AdvertiseGen 数据集,就是对商品的标签和文案的匹配数据。
然后根据内容进行训练,然后再输入相关类似的标签,就可以自动生成广告文案了。
这个是AIGC的挺好的落地场景。

可以在 4090 上完成训练,并验证成功了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/140902.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++入门篇】保姆级教程篇【下】

目录 一、运算符重载 1&#xff09;比较、赋值运算符重载 2&#xff09; 流插入留提取运算符重载 二、剩下的默认成员函数 1&#xff09;赋值运算符重载 2&#xff09;const成员函数 3&#xff09;取地址及const取地址操作符重载 三、再谈构造函数 1&#xff09;初始化列表 …

SparkSQL之Analyzed LogicalPlan生成过程

经过AstBuilder的处理&#xff0c;得到了Unresolved LogicalPlan。该逻辑算子树中未被解析的有UnresolvedRelation和UnresolvedAttribute两种对象。Analyzer所起到的主要作用就是将这两种节点或表达式解析成有类型的&#xff08;Typed&#xff09;对象。在此过程中&#xff0c;…

链表相关部分OJ题

&#x1f493;作者简介&#x1f44f;&#xff1a;在校大二迷茫大学生 &#x1f496;个人主页&#x1f389;&#xff1a;小李很执着 &#x1f497;系列专栏&#xff1a;Leetcode经典题 每日分享&#xff1a;人总是在离开一个地方后开始原谅它❣️❣️❣️———————————…

“第六十七天”

各位&#xff0c;昨天查找子串的方法想起来了&#xff0c;就是那个KMP算法......自己理解都有点困难&#xff0c;还看看能不能想一下&#xff0c;确实很困难啊。 不要忘了toupper函数和tolower函数不是直接改变字符的大小写&#xff0c;而是返回对应的大小写的值&#xff0c;需…

pytest-bdd快速示例和问题解决

BDD 与 pytest-bdd BDD 即 Behavior-driven development&#xff0c;行为驱动开发。BDD行为驱动是一种敏捷开发模式, 重点在于消除开发/测试对需求了解的歧义及用户场景的验证。 pytest-bdd 是一个BDD测试框架&#xff0c;类似于behave, cucumber。它可以统一单元测试和功能测…

【Git】第四篇:基本操作(理解工作区、暂存区、版本库)

Git 工作区、暂存区和版本库 工作区&#xff1a;就是我们创建的本地仓库所在的目录暂存区&#xff1a; stage或index&#xff0c;一般放在.git(可隐藏文件)目录下的index文件&#xff08;.git/index&#xff09;中&#xff0c;所以我们把暂存区有时候也叫做索引&#xff08;in…

飞书开发学习笔记(五)-Python快速开发网页应用

飞书开发学习笔记(五)-Python快速开发网页应用 一.下载示例代码 首先进入飞书开放平台: https://open.feishu.cn/app 凭证与基础信息 页面&#xff0c;在 应用凭证 中获取 App ID 和 App Secret 值。 教程和示例代码位置:https://open.feishu.cn/document/home/integrating-…

C语言 每日一题 牛客网 11.13 Day17

找零 Z国的货币系统包含面值1元、4元、16元、64元共计4种硬币&#xff0c;以及面值1024元的纸币。 现在小Y使用1024元的纸币购买了一件价值为N(0 < N≤1024)的商品&#xff0c;请问最少他会收到多少硬币&#xff1f; 思路 运用if语句进行判断分类 代码实现 int main() {…

基于php+thinkphp的网上书店购物商城系统

运行环境 开发语言&#xff1a;PHP 数据库:MYSQL数据库 应用服务:apache服务器 使用框架:ThinkPHPvue 开发工具:VScode/Dreamweaver/PhpStorm等均可 项目简介 系统主要分为管理员和用户二部分&#xff0c;管理员主要功能包括&#xff1a;首页、个人中心、用户管理、图书分类…

jupyter lab常用插件集合

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…

毕业设计项目:基于java+springboot的共享单车信息网站

运行环境 开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09; 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Ma…

【Java 进阶篇】JQuery DOM操作:通用属性操作的绝妙魔法

在前端的舞台上&#xff0c;JQuery犹如一位魔法师&#xff0c;为我们展现了操纵HTML元素的奇妙技巧。而在这个技巧的精妙组成中&#xff0c;通用属性操作是一门绝妙的魔法。在本篇博客中&#xff0c;我们将深入研究JQuery DOM操作中的通用属性操作&#xff0c;揭示这段魔法的神…

Linux进程间通信之命名管道及SystemV共享内存

命名管道及SystemV共享内存 命名管道1. 什么是命名管道2. 用命名管道实现server&client通信Log.hppcomm.hppserver.cppclient.cppclient.cppMakefile编译 system V共享内存1. 共享内存示意图2. 共享内存数据结构3. 共享内存函数3.1 shmget函数3.2 shmat函数3.3 shmdt函数3.…

一招验收测试自动化天下知

今天下午给同事就自动化验收测试做了一个简单的介绍&#xff0c;引起了大家的阵阵讨论。同时还有其他Team的人来分享各自的经验&#xff0c;他们也都做得相当不错。 测试包括很多种&#xff0c;单元测试、集成测试、功能测试、验收测试、数据库测试等等。撇开大家都熟悉的单元测…

基于JavaWeb+SSM+基于微信小程序的“生鲜食品”团购平台生鲜商城系统的设计和实现

基于JavaWebSSM基于微信小程序的“生鲜食品”团购平台生鲜商城系统的设计和实现 源码获取入口前言主要技术系统设计功能截图Lun文目录订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 源码获取入口 前言 生鲜食品是我们生活中的必需品&#xff0c;为了确保食品的新鲜…

leetCode 25.K 个一组翻转链表

给你链表的头节点 head &#xff0c;每 k 个节点一组进行翻转&#xff0c;请你返回修改后的链表。k 是一个正整数&#xff0c;它的值小于 或 等于链表的长度。如果节点总数不是 k 的整数倍&#xff0c;那么请将最后剩余的节点保持原有顺序。你不能只是单纯的改变节点内部的值&a…

设计模式—结构型模式之代理模式

设计模式—结构型模式之代理模式 代理模式(Proxy Pattern) ,给某一个对象提供一个代理&#xff0c;并由代理对象控制对原对象的引用,对象结构型模式。 静态代理 比如我们有一个直播平台&#xff0c;提供了直播功能&#xff0c;但是如果不进行美颜&#xff0c;可能就比较冷清…

基于springboot+vue的学生毕业离校信息网站

项目介绍 该学生毕业离校系统包括管理员、学生和教师。其主要功能包括管理员&#xff1a;首页、个人中心、学生管理、教师管理、离校信息管理、费用结算管理、论文审核管理、管理员管理、留言板管理、系统管理等&#xff0c;前台首页&#xff1b;首页、离校信息、网站公告、留…

博弈论入门

目录 什么是博弈&#xff1f; 博弈论的发展历史&#xff1f; 博弈的要素有哪些&#xff1f; 博弈的分类&#xff1f; 博弈论的应用 收益矩阵 纳什均衡的定义 博弈论的例子 1、田忌赛马 2、穷途困境 2.1优化反应函数法 2.2Nashpy库 2.3顶点枚举算法 3、Nash游戏 …

WebUI 自动化测试框架搭建详解

&#x1f4e2;专注于分享软件测试干货内容&#xff0c;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01;&#x1f4e2;交流讨论&#xff1a;欢迎加入我们一起学习&#xff01;&#x1f4e2;资源分享&#xff1a;耗时200小时精选的「软件测试」资…