HuggingFace学习笔记--Model的使用

1--Model介绍

        Transformer的 model 一般可以分为:编码器类型(自编码)、解码器类型(自回归)和编码器解码器类型(序列到序列);

        Model Head(任务头)是在base模型的基础上,根据不同任务而设置的模块;base模型只起到一个编码和建模特征的功能;

简单代码:

from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

if __name__ == "__main__":
    # 数据处理
    sen = "弱小的我也有大梦想!"
    tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")
    inputs = tokenizer(sen, return_tensors="pt")
        
    # 不带model head的模型调用
    model = AutoModel.from_pretrained("hfl/rbt3", output_attentions=True)
    output1 = model(**inputs)
    print(output1.last_hidden_state.size()) # [1, 12, 768]
    
    # 带model head的模型调用
    clz_model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3", num_labels=10)
    output2 = clz_model(**inputs)
    print(output2.logits.shape) # [1, 10]

2--AutoModel的使用

官方文档

        AutoModel 用于加载模型;

2-1--简单Demo

测试代码:

from transformers import AutoTokenizer, AutoModel

if __name__ == "__main__":
    checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
    tokenlizer = AutoTokenizer.from_pretrained(checkpoint) 
    
    raw_input = ["I love kobe bryant.", "Me too."]
    inputs = tokenlizer(raw_input, padding = "longest", truncation = True, max_length = 512, return_tensors = "pt")
    
    # 加载指定的模型
    model = AutoModel.from_pretrained(checkpoint)
    print("model: \n", model)
    
    outputs = model(**inputs)
    print("last_hidden_state: \n", outputs.last_hidden_state.shape) # 打印最后一个隐层的输出维度
    # [2 7 768] batch_size为2,7个token,每个token的维度为768

输出结果:

last_hidden_state: 
 torch.Size([2, 7, 768])

# 最后一个隐层的输出
# batchsize为2,表示两个句子
# 7表示token数,每一个句子有7个token
# 768表示特征大小,每一个token的维度为768

测试代码:

from transformers import AutoTokenizer, AutoModelForSequenceClassification

if __name__ == "__main__":
    checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
    tokenlizer = AutoTokenizer.from_pretrained(checkpoint) 
    
    raw_input = ["I love kobe bryant.", "Me too."]
    inputs = tokenlizer(raw_input, padding = "longest", truncation = True, max_length = 512, return_tensors = "pt")

    model2 = AutoModelForSequenceClassification.from_pretrained(checkpoint) # 二分类任务
    print(model2)
    outputs2 = model2(**inputs)
    print(outputs2.logits.shape)

运行结果:

torch.Size([2, 2])
# 两个句子,每个句子二分类的概率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/208302.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

java正则表达式字母开头后面跟12位数字

字母开头后面跟12位数字 ^[A-Za-z]\d{12}$ 验证: 验证工具地址: Java正则表达式测试

学习感悟一己之言

学习感悟一己之言 学习上克服困难实际上是克服心理上或认识上的障碍的过程。所谓的理解,就是化陌生为熟悉。看不懂,一方面是因为接触的材料太陌生,即远离你当前的背景知识;另一方面是材料或讲述者的描述刻画不准确或晦涩不当。有了…

修改sublime配置让其显示文件编码格式

1、下载sublime并安装 2、点击菜单栏Preferences,然后在Preferences里面点击Setting 3、然后在跳出来的窗口添加: "show_enconding":true, 4、随便打开一个文件就可以在底部查看文件编码格式:

openbabel 安装 生成指纹方法

今日踩坑小结: openbabel 安装: 可以装,但是得在 Linux 环境下,win 环境装会报错(安装不会报错,但是生成指纹的时候会) 指纹: 在下面这个链接里,官方给出了命令行调用 o…

一篇博客带你认识泛型

目录 泛型类(Generic Class): 泛型方法(Generic Method): Java 中的泛型是一种编程机制,允许你编写可以与多种数据类型一起工作的代码,同时提供编译时类型检查以确保类型的安全性。泛型的主要目的是提高代…

外贸获客的几种正确打开方式,还不快来GET!

做外贸还在愁没客户?作为外贸人,开发客户是我们的重要工作内容,想要高效地开发客户,首先就要知道外贸获客的方法有哪些,当下最主流的外贸获客渠道分为线下和线上两种方式,今天东哥就介绍几种获客渠道&#…

YOLOv5项目实战(5)— 算法模型优化和服务器部署

前言:Hello大家好,我是小哥谈。近期,作者所负责项目中的算法模型检测存在很多误报情况,为了减少这种误报情况,作者一直在不断优化算法模型。鉴于此,本节课就给大家详细介绍一下实际工作场景中如何去优化算法模型和进行部署,另外为了方便大家进行模型训练,作者在文章中提…

流量内存cpu使用率使用工具

类似360工具球的工具 我提供了夸克下载喜欢的朋友可以直接下载使用 我用夸克网盘分享了「TrafficMonitor」,点击链接即可保存。打开「夸克APP」,无需下载在线播放视频,畅享原画5倍速,支持电视投屏。 链接:https://pan…

wpf devexpress 使用IDataErrorInfo实现input验证

此处下载源码 当form初始化显示,Register按钮应该启动和没有输入错误应该显示。如果用户点击注册按钮在特定的输入无效数据,form将显示输入错误和禁用的注册按钮。实现逻辑在标准的IDataErrorInfo接口。请查阅IDataErrorInfo接口(System.Com…

微机原理——定时器学习1

目录 定时类型 8253内部结构框图 8253命令字 六种工作方式及输出波形 计数初值的计算与装入 8253的初始化 定时类型 可编程定时器8253:(内部采用的是16位 减法计数器) 8253内部结构框图 8253命令字 8253有三个命令字:方式命…

【广州华锐视点】VR云端看车:一键穿越!VR技术让你在家就能试驾各种豪车!

随着科技的不断发展,虚拟现实(VR)技术已经逐渐走进我们的生活。在汽车行业,VR线上看车已经成为了一种全新的购车体验。通过这种创新的方式,消费者可以在不出门的情况下,全方位地了解汽车的外观、内饰和性能…

PGSQL(PostgreSQL)数据库安装教程

安装包下载 下载地址 下载后点击exe安装包 设置的data存储路径 设置密码 设置端口 安装完毕,配置PGSQL的ip远程连接,pg_hba.conf,postgresql.conf,需要更改这两个文件 pg_hba.conf 最后增加一行 host all all …

1+x网络系统建设与运维(中级)-练习题

一.给设备重命名 同理可得&#xff0c;所有交换机和路由器都用一下命令配置 <Huawei>sys [Huawei]sysn LSW1 二.配置VLAN LSW1&#xff1a; [LSW1]vlan batch 10 20 [LSW1]int e0/0/1 [LSW1-Ethernet0/0/1]port link-type access [LSW1-Ethernet0/0/1]port default vlan…

P1012 [NOIP1998 提高组] 拼数( 字典序 )

字典序&#xff1a; 在字典中&#xff0c;单词是按照首字母在字母表中的顺序进行排列的 比如 alpha 在 beta 之前。 1.而第一个字母相同时&#xff0c;会 去比较两个单词的第二个字母在字母表中的顺序&#xff0c;比如 account 在 advanced 之前&#xff0c;以此类推。 2. 若…

2023年中国金融科技研究报告

第一章 行业概况 1.1 定义 金融科技&#xff08;FinTech, Financial Technology&#xff09;代表了金融和技术的交汇。这一领域虽然处于发展的初期阶段&#xff0c;但已经展现出深远的影响力。金融科技的业务模式多样&#xff0c;涵盖了从传统金融服务的数字化转型到新兴技术…

使用Xshell启动远程服务器上的tensorboard:本地浏览器打开

在远程服务器上启动的tensorboard产生的localhost网址用本地浏览器一般不能直接打开&#xff0c;我们需要建立本地PC与远程服务器的通信&#xff0c;将tensorboard的映射端口与本地端口连接起来&#xff08;参考解决方案&#xff09;。 一、连接远程服务器设置 二、添加SSH隧道…

MySQL实现(免密登录)

简介: MySQL免密登录是一种允许用户在没有输入密码的情况下直接登录到MySQL服务器的配置。这通常是通过在登录时跳过密码验证来实现的。 1、修改MySQL的配置文件 使用vi /etc/my.cnf&#xff0c;添加到【mysqld】后面 skip-grant-tables #配置项告诉mysql跳过权限验证&#…

运算放大器和常见运放电路

关于运算放大器 运算放大器(Operational Amplifier), 简称运放, 是一种直流耦合, 差模输入, 单端输出(Differential-in, single-ended output)的高增益电压放大器件. 运放能产生一个比输入端电势差大数十万倍的输出电势. 因为刚发明时主要用于加减法等运算电路中, 因而得名运算…

flutter 自定义TabBar 【top 0 级别】

flutter 自定义TabBar 【top 0 级别】 前言一、基础widget二、tab 标签三、barView总结 前言 在日常开发中&#xff0c;tab 标签选项&#xff0c;是一个我们特别常用的一个组件了&#xff0c;往往我们在一个项目中&#xff0c;有很多地方会使用到它&#xff0c;每次单独去写&am…

ESP32-Web-Server编程- 使用表格(Table)实时显示设备信息

ESP32-Web-Server编程- 使用表格&#xff08;Table&#xff09;实时显示设备信息 概述 上节讲述了通过 Server-Sent Events&#xff08;以下简称 SSE&#xff09; 实现在网页实时更新 ESP32 Web 服务器的传感器数据。 本节书接上会&#xff0c;继续使用 SSE 机制在网页实时显…