【Python】科研代码学习:三 PreTrainedModel, PretrainedConfig

【Python】科研代码学习:三 PreTrainedModel, PretrainedConfig

  • 前言
  • Models : PreTrainedModel
    • PreTrainedModel 中重要的方法
  • tensorflow & pytorch 简单对比
  • Configuration : PretrainedConfig
    • PretrainedConfig 中重要的方法

前言

  • HF 官网API
    本文主要从官网API与源代码中学习调用HF的关键模组

Models : PreTrainedModel

  • HF 提供的基础模型类有 PreTrainedModel, TFPreTrainedModel, and FlaxPreTrainedModel
  • 这三者有什么区别呢
    PreTrainedModel 指的是用 torch 的框架
    在这里插入图片描述
    TFPreTrainedModel 指的是用 tensorflow 框架
    在这里插入图片描述
    FlaxPreTrainedModel 指的是用 flax 框架,是用 jax 做的
    在这里插入图片描述
    (哈哈,搜了好久都没搜到,去看源码导包瞬间明白了,也可能是我比较笨)
  • Transformers的大部分模型都会继承PretrainedModel基类。PretrainedModel主要负责管理模型的配置,模型的参数加载、下载和保存。
  • PretrainedModel继承自 nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin
    在初始化时需要提供给它一个 config: PretrainedConfig
  • 所以,我们可以视为它是所有模型的基类
    可以看到很多其他代码在判断模型类型时,一般写 model: Union[PreTrainedModel, nn.Module]

PreTrainedModel 中重要的方法

  • push_to_hub:将模型传到HF hub
from transformers import AutoModel

model = AutoModel.from_pretrained("google-bert/bert-base-cased")

# Push the model to your namespace with the name "my-finetuned-bert".
model.push_to_hub("my-finetuned-bert")

# Push the model to an organization with the name "my-finetuned-bert".
model.push_to_hub("huggingface/my-finetuned-bert")
  • from_pretrained:根据config实例化预训练pytorch模型(Instantiate a pretrained pytorch model from a pre-trained model configuration.)
    默认使用评估模式 .eval()
    可以打开训练模式 .train()

    看下面的例子,可以从官方加载,也可以从本地模型参数加载。如果本地参数是tf的,转pytorch需要设置 from_tf=True,并且会慢些;本地参数是flax的话类似同理。
from transformers import BertConfig, BertModel

# Download model and configuration from huggingface.co and cache.
model = BertModel.from_pretrained("google-bert/bert-base-uncased")
# Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
model = BertModel.from_pretrained("./test/saved_model/")
# Update configuration during loading.
model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
# Loading from a Flax checkpoint file instead of a PyTorch model (slower)
model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)

可以给 torch_dtype 设置数据类型。若不给,则默认为 torch.float16。也可以给 torch_dtype="auto"

  • get_input_embeddings:获得输入的词嵌入在这里插入图片描述
    对应还有 get_output_embeddings
  • init_weights:设置参数初始化
    如果需要自己调整参数初始化的,在 _init_weights_initialize_weights 中设置
  • save_pretrained:把模型和配置参数保存在文件夹中
    保存完后,便可以通过 from_pretrained 再次加载模型了
    在这里插入图片描述

tensorflow & pytorch 简单对比

  • 知乎:Tensorflow 到底比 Pytorch 好在哪里?
    下面截取了比较重要的图
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
  • 里面还提到了一个内容叫做 Keras

Keras是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化

Configuration : PretrainedConfig

  • 刚才看了,对于 PretrainedModel 初始化提供的参数是 PretrainedConfig 类型的参数。
    它主要为不同的任务,提供了不同的重要参数
    HF官网:PretrainedConfig
  • 列一下对于NLP中比较重要的参数吧,所有的就看官方文档吧
返回信息
output_hidden_states (bool, optional, defaults to False) — Whether or not the model should return all hidden-states.
output_attentions (bool, optional, defaults to False) — Whether or not the model should returns all attentions.
return_dict (bool, optional, defaults to True) — Whether or not the model should return a ModelOutput instead of a plain tuple.
output_scores (bool, optional, defaults to False) — Whether the model should return the logits when used for generation.
return_dict_in_generate (bool, optional, defaults to False) — Whether the model should return a ModelOutput instead of a torch.LongTensor.

序列生成
max_length (int, optional, defaults to 20) — Maximum length that will be used by default in the generate method of the model.
min_length (int, optional, defaults to 0) — Minimum length that will be used by default in the generate method of the model.
do_sample (bool, optional, defaults to False) — Flag that will be used by default in the generate method of the model. Whether or not to use sampling ; use greedy decoding otherwise.
num_beams (int, optional, defaults to 1) — Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
diversity_penalty (float, optional, defaults to 0.0) — Value to control diversity for group beam search. that will be used by default in the generate method of the model. 0 means no diversity penalty. The higher the penalty, the more diverse are the outputs.
temperature (float, optional, defaults to 1.0) — The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive.
top_k (int, optional, defaults to 50) — Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model.
top_p (float, optional, defaults to 1) — Value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
epetition_penalty (float, optional, defaults to 1) — Parameter for repetition penalty that will be used by default in the generate method of the model. 1.0 means no penalty.
length_penalty (float, optional, defaults to 1) — Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
bad_words_ids (List[int], optional) — List of token ids that are not allowed to be generated that will be used by default in the generate method of the model. In order to get the tokens of the words that should not appear in the generated text, use tokenizer.encode(bad_word, add_prefix_space=True).

tokenizer相关
bos_token_id (int, optional) — The id of the beginning-of-stream token.
pad_token_id (int, optional) — The id of the padding token.
eos_token_id (int, optional) — The id of the end-of-stream token.

PyTorch相关
torch_dtype (str, optional) — The dtype of the weights. This attribute can be used to initialize the model to a non-default dtype (which is normally float32) and thus allow for optimal storage allocation. For example, if the saved model is float16, ideally we want to load it back using the minimal amount of memory needed to load float16 weights. Since the config object is stored in plain text, this attribute contains just the floating type string without the torch. prefix. For example, for torch.float16 `torch_dtype is the "float16" string.

常见参数
vocab_size (int) — The number of tokens in the vocabulary, which is also the first dimension of the embeddings matrix (this attribute may be missing for models that don’t have a text modality like ViT).
hidden_size (int) — The hidden size of the model.
num_attention_heads (int) — The number of attention heads used in the multi-head attention layers of the model.
num_hidden_layers (int) — The number of blocks in the model.

PretrainedConfig 中重要的方法

  • push_to_hub:依然是上传到 HF hub
  • from_dict:把一个 dict 类型转到 PretrainedConfig 类型
  • from_json_file:把一个 json 文件转到 PretrainedConfig 类型,传入的是文件路径
  • to_dict:转成 dict 类型
  • to_json_file:保存到 json 文件
  • to_json_string:转成 json 字符串
  • from_pretrained:从预训练模型配置文件中直接获取配置
    可以是HF模型,也可以是本地模型,见下方例子
# We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
# derived class: BertConfig
config = BertConfig.from_pretrained(
    "google-bert/bert-base-uncased"
)  # Download configuration from huggingface.co and cache.
config = BertConfig.from_pretrained(
    "./test/saved_model/"
)  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
assert config.output_attentions == True
config, unused_kwargs = BertConfig.from_pretrained(
    "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
)
assert config.output_attentions == True
assert unused_kwargs == {"foo": False}
  • save_pretrained:把配置文件保存到文件夹中,方便下次 from_pretrained 直接读取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/440644.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

乐高EV3硬件编程

文章目录&#xff1a; 一&#xff1a;软件 1.软件下载安装 2.软件的使用 二&#xff1a;乐高EV3电子元器件介绍 1.针对不同的版本 2.组合起来看 3.元器件栏 绿色部分&#xff1a;动作 橙色部分&#xff1a;流程控制 黄色部分&#xff1a;传感器 红色部分&#xff1…

【PyTorch】进阶学习:探索BCEWithLogitsLoss的正确使用---二元分类问题中的logits与标签形状问题

【PyTorch】进阶学习&#xff1a;探索BCEWithLogitsLoss的正确使用—二元分类问题中的logits与标签形状问题 &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、Py…

Python爬虫——scrapy-4

免责声明 本文章仅用于学习交流&#xff0c;无任何商业用途 部分图片来自尚硅谷 meta简介 在Scrapy框架中&#xff0c;可以使用meta属性来传递额外的信息。meta属性可以在不同的组件之间传递数据&#xff0c;包括爬虫、中间件和管道等。 在爬虫中&#xff0c;可以使用meta属…

储能系统---交流充电桩(三)

一、充电模式及其功能要求 关注公众号 --- 小Q下午茶 新国标在标准 GB/T 18487.1-2015《电动汽车传导充电系统 第1部分&#xff1a;通用要求》中规定了 4 种充电模式&#xff0c;下面将对这 4 种充电模式及其功能要求进行介绍。 1.1 、模式 1 模式 1 是指在充电系统中应使用…

一次电脑感染Synaptics Pointing Device Driver病毒的经历,分享下经验

没想到作为使用电脑多年的老司机也会电脑中病毒&#xff0c;周末玩电脑的时候突然电脑很卡&#xff0c;然后自动重启&#xff0c;奇怪&#xff0c;之前没出现这个情况。 重启后电脑开机等了几十秒&#xff0c;打开任务管理器查看开机进程&#xff0c;果然发现有个Synaptics Po…

LeetCode 2482.行和列中一和零的差值

给你一个下标从 0 开始的 m x n 二进制矩阵 grid 。 我们按照如下过程&#xff0c;定义一个下标从 0 开始的 m x n 差值矩阵 diff &#xff1a; 令第 i 行一的数目为 onesRowi 。 令第 j 列一的数目为 onesColj 。 令第 i 行零的数目为 zerosRowi 。 令第 j 列零的数目为 zer…

用于审核、优化和跟踪的 18 种顶级 SEO 工具

DIY SEO工具 需要自己动手 &#xff08;DIY&#xff09; SEO 工具吗&#xff1f;以下是帮助您自己实现 SEO 目标的最佳工具&#xff1a; SEO Checker&#xff1a; 最适合评估和提高 SEO 性能。Google Analytics 4&#xff1a;最适合跟踪 SEO 结果。Moz Pro&#xff1a;最适合…

清华大学1748页CTF竞赛入门指南,完整版开放下载!

CTF是一种针对信息安全领域的经济性挑战&#xff0c;旨在通过解决一系列的难题来寻找隐藏的“flag”。CTF比赛战队一般是以高校、科研单位、企业、信息安全从业者或社会团体组成。对于网安爱好者及从业者来说&#xff0c;拥有“CTF参赛经验”也是求职中的加分项。 前几天分享的…

C++ Qt开发:QFileSystemWatcher文件监视组件

Qt 是一个跨平台C图形界面开发库&#xff0c;利用Qt可以快速开发跨平台窗体应用程序&#xff0c;在Qt中我们可以通过拖拽的方式将不同组件放到指定的位置&#xff0c;实现图形化开发极大的方便了开发效率&#xff0c;本章将重点介绍如何运用QFileSystemWatcher组件实现对文件或…

分享axios+MQTT简单封装示例

MQTT&#xff08;Message Queuing Telemetry Transport&#xff0c;消息队列遥测传输协议&#xff09;&#xff0c;是一种基于发布/订阅&#xff08;publish/subscribe&#xff09;模式的"轻量级"通讯协议&#xff0c;该协议构建于TCP/IP协议上&#xff0c;由IBM在19…

git fatal: detected dubious ownership in repository at ‘xxx‘ 彻底解决方法

前言 在 windows 重置后&#xff0c; git 仓库无法正常使用 git 的所有 命令&#xff0c;运行任何 git 命令&#xff0c;都会提示如下&#xff1a; $ git log fatal: detected dubious ownership in repository at D:/rk/rk3568/nanopi/uboot-rockchip D:/rk/rk3568/nanopi/u…

深入理解并发编程:解锁现代软件性能的关键

在当今快速发展的软件开发世界中&#xff0c;并发编程已经成为一种无法回避的重要议题。它涉及到如何在同一时间内处理多个任务&#xff0c;以此来提升应用程序的性能和响应速度。互联网服务的高并发需求以及多核处理器的普及使得并发编程成为了现代软件工程的一个核心组成部分…

瑞芯微RV系列-超级编码

参加开发者大会,逛了相应的workshop,对超级编码技术很感兴趣,RK做了很多事情,挺好的!!!

Type-C接口小家电使用PD诱骗芯片获取充电器的5V9V12V20V供电

随着Type-C接口的逐渐普及&#xff0c;小家电设备慢慢开始采用Type-C&#xff0c;淘汰了以往的DC接口&#xff0c;Type-C接口在小家电设备中的应用也越来越广泛。Type-C接口支持大电流宽电压范围&#xff0c;如何确保设备能够正确识别并使用各种电压&#xff08;例如5V、9V、12…

it-tools工具箱

it-tools 是一个在线工具集合&#xff0c;包含各种实用的开发工具、网络工具、图片视频工具、数学工具等 github地址&#xff1a;https://github.com/CorentinTh/it-tools 部署 docker run -d --name it-tools --restart unless-stopped -p 8080:80 corentinth/it-tools:lat…

【前端Vue】社交信息头条项目完整笔记第1篇:一、项目初始化【附代码文档】

社交媒体-信息头条项目完整开发笔记完整教程&#xff08;附代码资料&#xff09;主要内容讲述&#xff1a;一、项目初始化使用 Vue CLI 创建项目,加入 Git 版本管理,调整初始目录结构,导入图标素材。二、登录注册准备,实现基本登录功能,登录状态提示,表单验证。三、个人中心&am…

浏览器一键重新发起请求

一、需求场景 在前端开发过程中&#xff0c;经常会需要重新请求后台进行代码调试&#xff0c;之前的常规方法是刷新浏览器页面或者点击页面进行交互&#xff0c;这样对多个请求的场景就很方便&#xff0c;但是往往很多时候我们只是单纯的想重新发起一个请求&#xff08;多个请求…

找出单身狗1,2

目录 1. 单身狗12. 单身狗2 1. 单身狗1 题目如下&#xff1a; 思路&#xff1a;一部分人可能会使用对数组排序&#xff0c;遍历数组的方式去找出只出现一次的数字&#xff0c;但这种方法的时间复杂度过高&#xff0c;有时候可能会不满足要求。 有一种十分简便的方法是使用异或…

VITS 模型详解与公式推导:基于条件变分自编码器和对抗学习的端到端语音合成模型

参考文献&#xff1a; [1] Kim J, Kong J, Son J. Conditional variational autoencoder with adversarial learning for end-to-end text-to-speech[C]//International Conference on Machine Learning. PMLR, 2021: 5530-5540. [2] Su J, Wu G. f-VAEs: Improve VAEs with co…

Day25:安全开发-PHP应用文件管理模块包含上传遍历写入删除下载安全

目录 PHP文件操作安全 文件包含 文件删除 文件编辑 文件下载 云产品OSS存储对象去存储文件(泄漏安全) 思维导图 PHP知识点 功能&#xff1a;新闻列表&#xff0c;会员中心&#xff0c;资源下载&#xff0c;留言版&#xff0c;后台模块&#xff0c;模版引用&#xff0c;框…