使用 Hugging Face 的 Transformers 库加载预训练模型遇到的问题

题意:

Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch

这个错误信息 "Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch" 通常出现在使用 Hugging Face 的 Transformers 库加载预训练模型时,模型的某些参数与预训练模型检查点(checkpoint)中的参数形状不匹配。

问题背景:

I want to finetune an LLM. I am able to successfully finetune LLM. But when reload the model after save, gets error. Below is the code

import argparse
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from trl import DPOTrainer, DPOConfig
def preprocess_data(item):
    return {
        'prompt': 'Instruct: ' + item['prompt'] + '\n',
        'chosen': 'Output: ' + item['chosen'],
        'rejected': 'Output: ' + item['rejected']
    }        

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--beta", type=float, default=0.1)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--lr", type=float, default=1e-6)
    parser.add_argument("--seed", type=int, default=2003)
    parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-14m")
    parser.add_argument("--dataset_name", type=str, default="jondurbin/truthy-dpo-v0.1")
    parser.add_argument("--local_rank", type=int, default=0)

    args = parser.parse_args()

    # Determine device based on local_rank
    device = torch.device("cuda", args.local_rank) if torch.cuda.is_available() else torch.device("cpu")


    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
    ref_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)

    dataset = load_dataset(args.dataset_name, split="train")
    dataset = dataset.map(preprocess_data)

    # Split the dataset into training and validation sets
    dataset = dataset.train_test_split(test_size=0.1, seed=args.seed)
    train_dataset = dataset['train']
    val_dataset = dataset['test']

    training_args = DPOConfig(
        learning_rate=args.lr,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        logging_steps=10,
        remove_unused_columns=False,
        max_length=1024,
        max_prompt_length=512,
        fp16=True        
    )

    

    # Verify and print embedding dimensions before finetuning
    print("Base model embedding dimension:", model.config.hidden_size)

    model.train()
    ref_model.eval()

    dpo_trainer = DPOTrainer(
        model,
        ref_model,
        beta=args.beta,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        args=training_args,
    )

    dpo_trainer.train()
    # Evaluate
    evaluation_results = dpo_trainer.evaluate()
    print("Evaluation Results:", evaluation_results)

    save_model_name = 'finetuned_model'
    model.save_pretrained(save_model_name)

if __name__ == "__main__":
    main()

Error I was getting as below

return model_class.from_pretrained(
    File "/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3838, in from_pretrained
        ) = cls._load_pretrained_model(
    File "/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4349, in _load_pretrained_model
        raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
        RuntimeError: Error(s) in loading state_dict for GPTNeoXForCausalLM:
            size mismatch for gpt_neox.embed_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 128]).
            size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 128]).
            You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

After finetuning, model works perfectly. But after reloading the saved trained model its not working. Any idea why gets this error when reloading the model ?

问题解决:

Instead of

model.save_pretrained(save_model_name)

try this

dpo_trainer.save_model(save_model_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/788427.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[高频 SQL 50 题(基础版)]第一千七百五十七题,可回收且低脂产品

题目: 表:Products ---------------------- | Column Name | Type | ---------------------- | product_id | int | | low_fats | enum | | recyclable | enum | ---------------------- product_id 是该表的主键(具有唯…

解决树形表格 第一列中文字没有对齐

二级分类与一级分类的文字没有对齐 <el-table:data"templateStore.hangyeList"style"width: 100%"row-key"id":tree-props"{ children: subData, hasChildren: hasChildren }" ><el-table-column prop"industryCode&quo…

Kettle常用参数配置

目录 一、时区二、时间戳三、tinyint类型转换 一、时区 Kettle链接mysql出现报错&#xff1a;Connection failed. Verify all connection parameters and confirm that the appropriate driver is installed. The server time zone value is unrecognized or represents more…

无人机之穿越机注意事项篇

一、检查设备 每次飞行前都要仔细检查穿越机的每个部件&#xff0c;确保所有功能正常&#xff0c;特别是电池和电机。 二、遵守法律 了解并遵循你所在地区关于无人机的飞行规定&#xff0c;避免非法飞行。 三、评估环境 在飞行前检查周围环境&#xff0c;确保没有障碍物和…

[激光原理与应用-102]:南京科耐激光-激光焊接-焊中检测-智能制程监测系统IPM介绍 - 6 - 激光焊接系统的组成

目录 一、激光焊接系统的组成概述 1.1、核心部件 1.2、焊接执行部件 1.3、辅助系统 1.4、控制系统 1.5、其他辅助设备 二、激光器 2.1 按出光类型分 1. 脉冲激光器 2. 连续激光器 3. 准连续激光器&#xff08;QCW&#xff09; 4. 其他常见激光器 5. 应用领域 2.2…

HTTP入门

目录 1. 原理介绍 2. HTTP协议简介 2.1简介 ​编辑 2.2基本工作原理 2.3HTTP三个要点 3. chrome浏览器和开发者工具 4. HTTP的消息结构 4.1主要流程和概念 4.2请求和响应例子 5. 完整的网页请求过程 6. 请求 6.1 请求行 6.2请求方法 6.3请求参数 GET请求的参数…

【问题记录】Windows中Node的express无法直接识别

问题描述 在使用express_generator的时候windows平台中出现无法识别express命令的问题&#xff0c;另外就算添加了全局环境变量也没用。 问题解决 查看官方文档发现在node版本8之前的时候使用的是express&#xff0c;但是之后的版本使用npx&#xff0c;这个工具的出现主要想…

头歌资源库(23)资源分配

一、 问题描述 某工业生产部门根据国家计划的安排&#xff0c;拟将某种高效率的5台机器&#xff0c;分配给所属的3个工厂A,B,C&#xff0c;各工厂在获得这种机器后&#xff0c;可以为国家盈利的情况如表1所示。问&#xff1a;这5台机器如何分配给各工厂&#xff0c;才能使国家盈…

Flutter 开启混淆打包apk,并反编译apk确认源码是否被混淆

第一步&#xff1a;开启混淆并打包apk flutter build apk --obfuscate --split-debug-info./out/android/app.android-arm64.symbols 第二步&#xff1a;从dex2jar download | SourceForge.net 官网下载dex2jar 下载完终端进入该文件夹&#xff0c;然后运行以下命令就会在该…

C语言程序题(一)

一.三个整数从大到小输出 首先做这个题目需要知道理清排序的思路&#xff0c;通过比较三个整数的值&#xff0c;使之从大到小输出。解这道题有很多方法我就总结了两种方法&#xff1a;一是通过中间变量比较和交换&#xff0c;二是可以用冒泡排序法&#xff08;虽然三个数字排序…

Go语言入门之基础语法

Go语言入门之基础语法 1.简单语法概述 行分隔符&#xff1a; 一行代表一个语句结束&#xff0c;无需写分号。将多个语句写在一行可以用分号分隔&#xff0c;但是不推荐 注释&#xff1a; // 或者/* */ 标识符&#xff1a; 用来命名变量、类型等程序实体。 支持大小写字母、数字…

【区块链+跨境服务】湾区金融科技人才链 | FISCO BCOS应用案例

湾区金融科技人才链于 2020 年 8 月正式发布&#xff0c;是全国首创的金融科技人才创新举措&#xff0c;对推动金融科技人才机制和认证标准建立&#xff0c;促进金融科技人才要素自由流通&#xff0c;推进产业 链、技术链、人才链深度融合具有重大意义。以深港澳金 融科技师专才…

部署前端项目

常见部署方式有&#xff1a;静态托管服务、服务器部署 1. 静态托管服务 使用平台部署代码&#xff0c;比如 GitHub。 | 创建一个仓库&#xff0c;仓库名一般是 yourGithubName.github.io。 | 将打包后的静态文件文件上传到仓库。 | 在“Settings”&#xff08;选项&#xff0…

Effective C++笔记之二十一:One Definition Rule(ODR)

ODR细节有点复杂&#xff0c;跨越各种情况。基本内容如下&#xff1a; ●普通&#xff08;非模板&#xff09;的noninline函数和成员函数、noninline全局变量、静态数据成员在整个程序中都应当只定义一次。 ●class类型&#xff08;包括structs和unions&#xff09;、模板&…

leetcode--验证二叉搜索树

leetcode地址&#xff1a;验证二叉搜索树 给你一个二叉树的根节点 root &#xff0c;判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下&#xff1a; 节点的左 子树 只包含 小于 当前节点的数。 节点的右子树只包含 大于 当前节点的数。 所有左子树和右子树自身必…

基于与STM32的加湿器之雾化片驱动

基于与STM32的加湿器之雾化片驱动 加湿器是一种由电力驱动&#xff0c;用于增加环境湿度的家用电器。加湿器通过特定的方式&#xff08;如蒸发、超声波振动或加热&#xff09;将水转化为水蒸气&#xff0c;并将这些水蒸气释放到空气中&#xff0c;从而增加空气中的湿度。主要功…

java通过poi-tl导出word实战详细步骤

文章目录 与其他模版引擎对比1.引入maven依赖包2.新建Word文档exportWprd.docx模版3.编写导出word接口代码4.导出成果 poi-tl是一个基于Apache POI的Word模板引擎&#xff0c;也是一个免费开源的Java类库&#xff0c;你可以非常方便的加入到你的项目中&#xff0c;并且拥有着让…

自动化数据集成的BI工具,为你提供决策洞察力

传统的商业智能&#xff08;BI&#xff09;报表系统采用的是“业务提报表需求&#xff0c;IT进行开发”的模式。决策管理者和业务人员提出用报表等来展示经营管理数据的需求&#xff1b;接着IT响应需求&#xff0c;进行需求沟通、数据处理加工、报表开发等主体工作&#xff1b;…

企业化运维(7)_Zabbix企业级监控平台

官网&#xff1a;Zabbix :: The Enterprise-Class Open Source Network Monitoring Solution ###1.Zabbix部署### &#xff08;1&#xff09;zabbix安装 安装源 修改安装路径为清华镜像 [rootserver1 zabbix]# cd /etc/yum.repos.d/ [rootserver1 yum.repos.d]# vim zabbix.r…

PageDTO<T>,PageQuery,BeanUtils,CollUtils的封装

一、PageDTO<T> import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.fasterxml.jackson.annotation.JsonIgnore; import com.tianji.common.utils.BeanUtils; import com.tianji.common.utils.CollUtils; import com.tianji.common.utils.…