【HuggingFace Transformer库学习笔记】基础组件学习:pipeline

一、Transformer基础知识

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

pip install transformers datasets evaluate peft accelerate gradio optimum sentencepiece
pip install jupyterlab scikit-learn pandas matplotlib tensorboard nltk rouge

在host文件里添加途中信息,可以避免运行代码下载模型时候报错。

在这里插入图片描述
Transformers测试

#导入gradio
import gradio as gr
#导入transformersi相关包
from transformers import *
#通过Interface)加载pipeline并启动文本分类服务
gr.Interface.from_pipeline(pipeline("text-classification", model="uer/roberta-base-finetuned-dianping-chinese")).launch()

在这里插入图片描述

1、基础组件——pipeline

在这里插入图片描述
在这里插入图片描述
导入包

from transformers.pipelines import SUPPORTED_TASKS

查看pipeline支持的任务类型

# 查看SUPPORTED_TASK所有可支持的任务
print(SUPPORTED_TASKS.items())

dict_items([('audio-classification', {'impl': <class 'transformers.pipelines.audio_classification.AudioClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForAudioClassification'>,), 'default': {'model': {'pt': ('superb/wav2vec2-base-superb-ks', '372e048')}}, 'type': 'audio'}), ('automatic-speech-recognition', {'impl': <class 'transformers.pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForCTC'>, <class 'transformers.models.auto.modeling_auto.AutoModelForSpeechSeq2Seq'>), 'default': {'model': {'pt': ('facebook/wav2vec2-base-960h', '55bb623')}}, 'type': 'multimodal'}), ('text-to-audio', {'impl': <class 'transformers.pipelines.text_to_audio.TextToAudioPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForTextToWaveform'>, <class 'transformers.models.auto.modeling_auto.AutoModelForTextToSpectrogram'>), 'default': {'model': {'pt': ('suno/bark-small', '645cfba')}}, 'type': 'text'}), ('feature-extraction', {'impl': <class 'transformers.pipelines.feature_extraction.FeatureExtractionPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModel'>,), 'default': {'model': {'pt': ('distilbert-base-cased', '935ac13'), 'tf': ('distilbert-base-cased', '935ac13')}}, 'type': 'multimodal'}), ('text-classification', {'impl': <class 'transformers.pipelines.text_classification.TextClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSequenceClassification'>,), 'default': {'model': {'pt': ('distilbert-base-uncased-finetuned-sst-2-english', 'af0f99b'), 'tf': ('distilbert-base-uncased-finetuned-sst-2-english', 'af0f99b')}}, 'type': 'text'}), ('token-classification', {'impl': <class 'transformers.pipelines.token_classification.TokenClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForTokenClassification'>,), 'default': {'model': {'pt': ('dbmdz/bert-large-cased-finetuned-conll03-english', 'f2482bf'), 'tf': ('dbmdz/bert-large-cased-finetuned-conll03-english', 'f2482bf')}}, 'type': 'text'}), ('question-answering', {'impl': <class 'transformers.pipelines.question_answering.QuestionAnsweringPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForQuestionAnswering'>,), 'default': {'model': {'pt': ('distilbert-base-cased-distilled-squad', '626af31'), 'tf': ('distilbert-base-cased-distilled-squad', '626af31')}}, 'type': 'text'}), ('table-question-answering', {'impl': <class 'transformers.pipelines.table_question_answering.TableQuestionAnsweringPipeline'>, 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForTableQuestionAnswering'>,), 'tf': (), 'default': {'model': {'pt': ('google/tapas-base-finetuned-wtq', '69ceee2'), 'tf': ('google/tapas-base-finetuned-wtq', '69ceee2')}}, 'type': 'text'}), ('visual-question-answering', {'impl': <class 'transformers.pipelines.visual_question_answering.VisualQuestionAnsweringPipeline'>, 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForVisualQuestionAnswering'>,), 'tf': (), 'default': {'model': {'pt': ('dandelin/vilt-b32-finetuned-vqa', '4355f59')}}, 'type': 'multimodal'}), ('document-question-answering', {'impl': <class 'transformers.pipelines.document_question_answering.DocumentQuestionAnsweringPipeline'>, 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForDocumentQuestionAnswering'>,), 'tf': (), 'default': {'model': {'pt': ('impira/layoutlm-document-qa', '52e01b3')}}, 'type': 'multimodal'}), ('fill-mask', {'impl': <class 'transformers.pipelines.fill_mask.FillMaskPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForMaskedLM'>,), 'default': {'model': {'pt': ('distilroberta-base', 'ec58a5b'), 'tf': ('distilroberta-base', 'ec58a5b')}}, 'type': 'text'}), ('summarization', {'impl': <class 'transformers.pipelines.text2text_generation.SummarizationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSeq2SeqLM'>,), 'default': {'model': {'pt': ('sshleifer/distilbart-cnn-12-6', 'a4f8f3e'), 'tf': ('t5-small', 'd769bba')}}, 'type': 'text'}), ('translation', {'impl': <class 'transformers.pipelines.text2text_generation.TranslationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSeq2SeqLM'>,), 'default': {('en', 'fr'): {'model': {'pt': ('t5-base', '686f1db'), 'tf': ('t5-base', '686f1db')}}, ('en', 'de'): {'model': {'pt': ('t5-base', '686f1db'), 'tf': ('t5-base', '686f1db')}}, ('en', 'ro'): {'model': {'pt': ('t5-base', '686f1db'), 'tf': ('t5-base', '686f1db')}}}, 'type': 'text'}), ('text2text-generation', {'impl': <class 'transformers.pipelines.text2text_generation.Text2TextGenerationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSeq2SeqLM'>,), 'default': {'model': {'pt': ('t5-base', '686f1db'), 'tf': ('t5-base', '686f1db')}}, 'type': 'text'}), ('text-generation', {'impl': <class 'transformers.pipelines.text_generation.TextGenerationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>,), 'default': {'model': {'pt': ('gpt2', '6c0e608'), 'tf': ('gpt2', '6c0e608')}}, 'type': 'text'}), ('zero-shot-classification', {'impl': <class 'transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSequenceClassification'>,), 'default': {'model': {'pt': ('facebook/bart-large-mnli', 'c626438'), 'tf': ('roberta-large-mnli', '130fb28')}, 'config': {'pt': ('facebook/bart-large-mnli', 'c626438'), 'tf': ('roberta-large-mnli', '130fb28')}}, 'type': 'text'}), ('zero-shot-image-classification', {'impl': <class 'transformers.pipelines.zero_shot_image_classification.ZeroShotImageClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForZeroShotImageClassification'>,), 'default': {'model': {'pt': ('openai/clip-vit-base-patch32', 'f4881ba'), 'tf': ('openai/clip-vit-base-patch32', 'f4881ba')}}, 'type': 'multimodal'}), ('zero-shot-audio-classification', {'impl': <class 'transformers.pipelines.zero_shot_audio_classification.ZeroShotAudioClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModel'>,), 'default': {'model': {'pt': ('laion/clap-htsat-fused', '973b6e5')}}, 'type': 'multimodal'}), ('conversational', {'impl': <class 'transformers.pipelines.conversational.ConversationalPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSeq2SeqLM'>, <class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>), 'default': {'model': {'pt': ('microsoft/DialoGPT-medium', '8bada3b'), 'tf': ('microsoft/DialoGPT-medium', '8bada3b')}}, 'type': 'text'}), ('image-classification', {'impl': <class 'transformers.pipelines.image_classification.ImageClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForImageClassification'>,), 'default': {'model': {'pt': ('google/vit-base-patch16-224', '5dca96d'), 'tf': ('google/vit-base-patch16-224', '5dca96d')}}, 'type': 'image'}), ('image-segmentation', {'impl': <class 'transformers.pipelines.image_segmentation.ImageSegmentationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForImageSegmentation'>, <class 'transformers.models.auto.modeling_auto.AutoModelForSemanticSegmentation'>), 'default': {'model': {'pt': ('facebook/detr-resnet-50-panoptic', 'fc15262')}}, 'type': 'multimodal'}), ('image-to-text', {'impl': <class 'transformers.pipelines.image_to_text.ImageToTextPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForVision2Seq'>,), 'default': {'model': {'pt': ('ydshieh/vit-gpt2-coco-en', '65636df'), 'tf': ('ydshieh/vit-gpt2-coco-en', '65636df')}}, 'type': 'multimodal'}), ('object-detection', {'impl': <class 'transformers.pipelines.object_detection.ObjectDetectionPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForObjectDetection'>,), 'default': {'model': {'pt': ('facebook/detr-resnet-50', '2729413')}}, 'type': 'multimodal'}), ('zero-shot-object-detection', {'impl': <class 'transformers.pipelines.zero_shot_object_detection.ZeroShotObjectDetectionPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForZeroShotObjectDetection'>,), 'default': {'model': {'pt': ('google/owlvit-base-patch32', '17740e1')}}, 'type': 'multimodal'}), ('depth-estimation', {'impl': <class 'transformers.pipelines.depth_estimation.DepthEstimationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForDepthEstimation'>,), 'default': {'model': {'pt': ('Intel/dpt-large', 'e93beec')}}, 'type': 'image'}), ('video-classification', {'impl': <class 'transformers.pipelines.video_classification.VideoClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForVideoClassification'>,), 'default': {'model': {'pt': ('MCG-NJU/videomae-base-finetuned-kinetics', '4800870')}}, 'type': 'video'}), ('mask-generation', {'impl': <class 'transformers.pipelines.mask_generation.MaskGenerationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForMaskGeneration'>,), 'default': {'model': {'pt': ('facebook/sam-vit-huge', '997b15')}}, 'type': 'multimodal'}), ('image-to-image', {'impl': <class 'transformers.pipelines.image_to_image.ImageToImagePipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForImageToImage'>,), 'default': {'model': {'pt': ('caidas/swin2SR-classical-sr-x2-64', '4aaedcb')}}, 'type': 'image'})])

查看pipeline都支持哪些任务和实现

for k, v in SUPPORTED_TASKS.items():
    print(k, v)     # k:任务名称,v:任务的实现。tf:tensorflow模型,pt:pytorch模型

audio-classification {'impl': <class 'transformers.pipelines.audio_classification.AudioClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForAudioClassification'>,), 'default': {'model': {'pt': ('superb/wav2vec2-base-superb-ks', '372e048')}}, 'type': 'audio'}
automatic-speech-recognition {'impl': <class 'transformers.pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForCTC'>, <class 'transformers.models.auto.modeling_auto.AutoModelForSpeechSeq2Seq'>), 'default': {'model': {'pt': ('facebook/wav2vec2-base-960h', '55bb623')}}, 'type': 'multimodal'}
text-to-audio {'impl': <class 'transformers.pipelines.text_to_audio.TextToAudioPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForTextToWaveform'>, <class 'transformers.models.auto.modeling_auto.AutoModelForTextToSpectrogram'>), 'default': {'model': {'pt': ('suno/bark-small', '645cfba')}}, 'type': 'text'}
feature-extraction {'impl': <class 'transformers.pipelines.feature_extraction.FeatureExtractionPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModel'>,), 'default': {'model': {'pt': ('distilbert-base-cased', '935ac13'), 'tf': ('distilbert-base-cased', '935ac13')}}, 'type': 'multimodal'}
text-classification {'impl': <class 'transformers.pipelines.text_classification.TextClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSequenceClassification'>,), 'default': {'model': {'pt': ('distilbert-base-uncased-finetuned-sst-2-english', 'af0f99b'), 'tf': ('distilbert-base-uncased-finetuned-sst-2-english', 'af0f99b')}}, 'type': 'text'}
token-classification {'impl': <class 'transformers.pipelines.token_classification.TokenClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForTokenClassification'>,), 'default': {'model': {'pt': ('dbmdz/bert-large-cased-finetuned-conll03-english', 'f2482bf'), 'tf': ('dbmdz/bert-large-cased-finetuned-conll03-english', 'f2482bf')}}, 'type': 'text'}
question-answering {'impl': <class 'transformers.pipelines.question_answering.QuestionAnsweringPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForQuestionAnswering'>,), 'default': {'model': {'pt': ('distilbert-base-cased-distilled-squad', '626af31'), 'tf': ('distilbert-base-cased-distilled-squad', '626af31')}}, 'type': 'text'}
table-question-answering {'impl': <class 'transformers.pipelines.table_question_answering.TableQuestionAnsweringPipeline'>, 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForTableQuestionAnswering'>,), 'tf': (), 'default': {'model': {'pt': ('google/tapas-base-finetuned-wtq', '69ceee2'), 'tf': ('google/tapas-base-finetuned-wtq', '69ceee2')}}, 'type': 'text'}
visual-question-answering {'impl': <class 'transformers.pipelines.visual_question_answering.VisualQuestionAnsweringPipeline'>, 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForVisualQuestionAnswering'>,), 'tf': (), 'default': {'model': {'pt': ('dandelin/vilt-b32-finetuned-vqa', '4355f59')}}, 'type': 'multimodal'}
document-question-answering {'impl': <class 'transformers.pipelines.document_question_answering.DocumentQuestionAnsweringPipeline'>, 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForDocumentQuestionAnswering'>,), 'tf': (), 'default': {'model': {'pt': ('impira/layoutlm-document-qa', '52e01b3')}}, 'type': 'multimodal'}
fill-mask {'impl': <class 'transformers.pipelines.fill_mask.FillMaskPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForMaskedLM'>,), 'default': {'model': {'pt': ('distilroberta-base', 'ec58a5b'), 'tf': ('distilroberta-base', 'ec58a5b')}}, 'type': 'text'}
summarization {'impl': <class 'transformers.pipelines.text2text_generation.SummarizationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSeq2SeqLM'>,), 'default': {'model': {'pt': ('sshleifer/distilbart-cnn-12-6', 'a4f8f3e'), 'tf': ('t5-small', 'd769bba')}}, 'type': 'text'}
translation {'impl': <class 'transformers.pipelines.text2text_generation.TranslationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSeq2SeqLM'>,), 'default': {('en', 'fr'): {'model': {'pt': ('t5-base', '686f1db'), 'tf': ('t5-base', '686f1db')}}, ('en', 'de'): {'model': {'pt': ('t5-base', '686f1db'), 'tf': ('t5-base', '686f1db')}}, ('en', 'ro'): {'model': {'pt': ('t5-base', '686f1db'), 'tf': ('t5-base', '686f1db')}}}, 'type': 'text'}
text2text-generation {'impl': <class 'transformers.pipelines.text2text_generation.Text2TextGenerationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSeq2SeqLM'>,), 'default': {'model': {'pt': ('t5-base', '686f1db'), 'tf': ('t5-base', '686f1db')}}, 'type': 'text'}
text-generation {'impl': <class 'transformers.pipelines.text_generation.TextGenerationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>,), 'default': {'model': {'pt': ('gpt2', '6c0e608'), 'tf': ('gpt2', '6c0e608')}}, 'type': 'text'}
zero-shot-classification {'impl': <class 'transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSequenceClassification'>,), 'default': {'model': {'pt': ('facebook/bart-large-mnli', 'c626438'), 'tf': ('roberta-large-mnli', '130fb28')}, 'config': {'pt': ('facebook/bart-large-mnli', 'c626438'), 'tf': ('roberta-large-mnli', '130fb28')}}, 'type': 'text'}
zero-shot-image-classification {'impl': <class 'transformers.pipelines.zero_shot_image_classification.ZeroShotImageClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForZeroShotImageClassification'>,), 'default': {'model': {'pt': ('openai/clip-vit-base-patch32', 'f4881ba'), 'tf': ('openai/clip-vit-base-patch32', 'f4881ba')}}, 'type': 'multimodal'}
zero-shot-audio-classification {'impl': <class 'transformers.pipelines.zero_shot_audio_classification.ZeroShotAudioClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModel'>,), 'default': {'model': {'pt': ('laion/clap-htsat-fused', '973b6e5')}}, 'type': 'multimodal'}
conversational {'impl': <class 'transformers.pipelines.conversational.ConversationalPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForSeq2SeqLM'>, <class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>), 'default': {'model': {'pt': ('microsoft/DialoGPT-medium', '8bada3b'), 'tf': ('microsoft/DialoGPT-medium', '8bada3b')}}, 'type': 'text'}
image-classification {'impl': <class 'transformers.pipelines.image_classification.ImageClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForImageClassification'>,), 'default': {'model': {'pt': ('google/vit-base-patch16-224', '5dca96d'), 'tf': ('google/vit-base-patch16-224', '5dca96d')}}, 'type': 'image'}
image-segmentation {'impl': <class 'transformers.pipelines.image_segmentation.ImageSegmentationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForImageSegmentation'>, <class 'transformers.models.auto.modeling_auto.AutoModelForSemanticSegmentation'>), 'default': {'model': {'pt': ('facebook/detr-resnet-50-panoptic', 'fc15262')}}, 'type': 'multimodal'}
image-to-text {'impl': <class 'transformers.pipelines.image_to_text.ImageToTextPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForVision2Seq'>,), 'default': {'model': {'pt': ('ydshieh/vit-gpt2-coco-en', '65636df'), 'tf': ('ydshieh/vit-gpt2-coco-en', '65636df')}}, 'type': 'multimodal'}
object-detection {'impl': <class 'transformers.pipelines.object_detection.ObjectDetectionPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForObjectDetection'>,), 'default': {'model': {'pt': ('facebook/detr-resnet-50', '2729413')}}, 'type': 'multimodal'}
zero-shot-object-detection {'impl': <class 'transformers.pipelines.zero_shot_object_detection.ZeroShotObjectDetectionPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForZeroShotObjectDetection'>,), 'default': {'model': {'pt': ('google/owlvit-base-patch32', '17740e1')}}, 'type': 'multimodal'}
depth-estimation {'impl': <class 'transformers.pipelines.depth_estimation.DepthEstimationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForDepthEstimation'>,), 'default': {'model': {'pt': ('Intel/dpt-large', 'e93beec')}}, 'type': 'image'}
video-classification {'impl': <class 'transformers.pipelines.video_classification.VideoClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForVideoClassification'>,), 'default': {'model': {'pt': ('MCG-NJU/videomae-base-finetuned-kinetics', '4800870')}}, 'type': 'video'}
mask-generation {'impl': <class 'transformers.pipelines.mask_generation.MaskGenerationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForMaskGeneration'>,), 'default': {'model': {'pt': ('facebook/sam-vit-huge', '997b15')}}, 'type': 'multimodal'}
image-to-image {'impl': <class 'transformers.pipelines.image_to_image.ImageToImagePipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForImageToImage'>,), 'default': {'model': {'pt': ('caidas/swin2SR-classical-sr-x2-64', '4aaedcb')}}, 'type': 'image'}    

在这里插入图片描述
导入包

from transformers import pipeline

根据任务类型直接创建pipeline,默认都是英文模型
加载模型

pipe = pipeline("text-classification", model="./model/distilbert-base-uncased-finetuned-sst-2-english")

测试分类效果

pipe(["very good!", "vary bad!", "not bad", "just so so", "oh, damn!"])

[{'label': 'POSITIVE', 'score': 0.9998525381088257},
 {'label': 'NEGATIVE', 'score': 0.9991207718849182},
 {'label': 'POSITIVE', 'score': 0.9995881915092468},
 {'label': 'POSITIVE', 'score': 0.9887603521347046},
 {'label': 'NEGATIVE', 'score': 0.5632225871086121}]

在这里插入图片描述
推理测试

from transformers import *

# 这种方式,必须同时指定model和tokenizer
model = AutoModelForSequenceClassification.from_pretrained("uer/roberta-base-finetuned-dianping-chinese")
tokenizer = AutoTokenizer.from_pretrained("uer/roberta-base-finetuned-dianping-chinese")
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device_map="auto")		# GPU自动分配

Model config DistilBertConfig {
  "_name_or_path": "model/roberta-base-finetuned-dianping-chinese",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "finetuning_task": "sst-2",
  "hidden_dim": 3072,
  "id2label": {
    "0": "NEGATIVE",
    "1": "POSITIVE"
  },
  "initializer_range": 0.02,
  "label2id": {
    "NEGATIVE": 0,
    "POSITIVE": 1
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
...
  "transformers_version": "4.35.2",
  "vocab_size": 30522
}

测试效果

pipe(["我觉得不太行!", "一般般", "还凑合吧", "太强了!"])

[{'label': 'NEGATIVE', 'score': 0.5539911389350891},
 {'label': 'POSITIVE', 'score': 0.5317790508270264},
 {'label': 'NEGATIVE', 'score': 0.5028885006904602},
 {'label': 'POSITIVE', 'score': 0.8547790050506592}]

速度测试

import torch
import time
times = []
for i in range(100):
    torch.cuda.synchronize()
    start = time.time()
    pipe("我觉得不太行!")
    torch.cuda.synchronize()
    end = time.time()
    times.append(end - start)
print(sum(times) / 100)

0.01336388111114502

当想用知道怎么使用某个库时候,可以先实例化这个库,然后再查看对应信息去查找。
例如

qa_pipe = pipeline("question-answering", model="model/robert-base-chinese-extractive-qa")

输入qa_pipe查看pipline

qa_pipe

<transformers.pipelines.question_answering.QuestionAnsweringPipeline at 0x7f40c65a75e0>

再在代码界面上输入QuestionAnsweringPipeline,按住Ctrl进去查看示例,查看__call___方法

QuestionAnsweringPipeline

测试

# question是问题,context是让模型根据context内容抽取可以回答问题的答案
qa_pipe(question="中国的首都是哪里?", context="北京是中国的政治和文化中心,上海是中国的经济中心")

{'score': 0.00011347973486408591, 'start': 0, 'end': 2, 'answer': '北京'}

设置输出答案字长度

qa_pipe(question="中国的首都是哪里?", context="中国的首都是北京", max_answer_len=1)

{'score': 0.0022874099668115377, 'start': 6, 'end': 7, 'answer': '北'}

解析pipline背后的实现过程

先初始化tokenizer和model

from transformers import *
import torch

tokenizer = AutoTokenizer.from_pretrained("model/roberta-base-finetuned-dianping-chinese")
model = AutoModelForSequenceClassification.from_pretrained("model/roberta-base-finetuned-dianping-chinese")

Model config DistilBertConfig {
  "_name_or_path": "model/roberta-base-finetuned-dianping-chinese",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "finetuning_task": "sst-2",
  "hidden_dim": 3072,
  "id2label": {
    "0": "NEGATIVE",
    "1": "POSITIVE"
  },
  "initializer_range": 0.02,
  "label2id": {
    "NEGATIVE": 0,
    "POSITIVE": 1
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
...
All model checkpoint weights were used when initializing DistilBertForSequenceClassification.

输入文本并进行token化

input_text = "我觉得不太行!"
inputs = tokenizer(input_text, return_tensors="pt")
inputs

{'input_ids': tensor([[ 101, 1855,  100,  100, 1744, 1812, 1945, 1986,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}

将inputs输入model

res = model(**inputs)
res

SequenceClassifierOutput(loss=None, logits=tensor([[2.1696e-01, 1.5108e-04]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

模型训练后,对最终全连接层的输出(logits)的最后一个维度进行归一化

logits = res.logits
logits = torch.softmax(logits, dim=-1)      # 对最后一个维度进行归一化
logits

tensor([[0.5540, 0.4460]], grad_fn=<SoftmaxBackward0>)

根据最后一层的输出结果,找到概率最大的类别作为最终输出

pred = torch.argmax(logits).item()      # 通过取概率最大值对应类的下表,取对应的类别
pred

0

查看一下0索引对应的类别

model.config.id2label       # model config里的id2label有的对应的类别信息

{0: 'NEGATIVE', 1: 'POSITIVE'}

输出最终结果

result = model.config.id2label.get(pred)
result

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/185174.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue中下载文件后无法打开的坑

今天在项目开发的时候临时要添加个导出功能我就写了一份请求加导出得代码&#xff0c; 代码&#xff1a; //导出按钮放开exportDutySummarizing (dataRangeInfo) {const params {departmentName: dataRangeInfo.name,departmentQode: dataRangeInfo.qode}//拼接所需得urlcons…

Tomcat注册为服务后,如何配置Tomcat内存大小

前提条件&#xff1a;tomcat已经注册为服务。 1.winR,输入regedit打开注册表 2.找到Tomcat注册表路径&#xff1a; HKEY_LOCAL_MACHINE\SOFTWARE\Wow6432Node\Apache Software Foundation\Procrun 2.0\Tomcat80603.找到jvm内存配置路径&#xff1a; HKEY_LOCAL_MACHINE\SOFTW…

Redis高可用之主从复制及哨兵模式

一、Redis的主从复制 1.1 Redis主从复制定义 主从复制是redis实现高可用的基础&#xff0c;哨兵模式和集群都是在主从复制的基础之上实现高可用&#xff1b; 主从复制实现数据的多级备份&#xff0c;以及读写分离(主服务器负责写&#xff0c;从服务器只能读) 1.2 主从复制流…

Windows从源码构建tensorflow

由一开始的在线编译&#xff0c;到后面的离线编译&#xff0c;一路踩坑无数。在此记录一下参考过的文章&#xff0c;有时间整理一下踩坑记录。 一、环境配置 在tensorflow官网上有版本对应关系 win10 bazel 3.1.0 msys2 tensorflow2.3.0 python3.5-3.8 MSVC2019 protobuf3.9.…

羊大师:强健身体是成功的关键

健康是一项无价的财富&#xff0c;拥有强健的身体是实现人生目标的关键。而如何保持健康并拥有一个强健的身体呢&#xff1f;下面就为大家分享一些有效的健身方法和建议&#xff0c;帮助您达到健美身材的目标。 良好的饮食习惯是形成强健身体的基石。我们要摄入足够的营养物质…

Java二级医院区域HIS信息管理系统源码(SaaS服务)

一个好的HIS系统&#xff0c;要具有开放性&#xff0c;便于扩展升级&#xff0c;增加新的功能模块&#xff0c;支撑好医院的业务的拓展&#xff0c;而且可以反过来给医院赋能&#xff0c;最终向更多的患者提供更好的服务。 系统采用前后端分离架构&#xff0c;前端由Angular、J…

首先啊骚年们我们必须先了解网络安全这个行业究竟是干啥的。

导 读 近年来&#xff0c;人工智能、5G、量子信息技术、工业互联网、大数据、云计算、物联网、虚拟现实、区块链等具有颠覆性的战略性新技术突飞猛进&#xff0c;但伴随着互联网技术的发展&#xff0c;网络安全问题也日趋多样化&#xff0c;甚至严重威胁到国家、企业&#xff…

【正点原子STM32连载】 第六十章 串口IAP实验(Julia分形)实验 摘自【正点原子】APM32F407最小系统板使用指南

1&#xff09;实验平台&#xff1a;正点原子APM32F407最小系统板 2&#xff09;平台购买地址&#xff1a;https://detail.tmall.com/item.htm?id609294757420 3&#xff09;全套实验源码手册视频下载地址&#xff1a; http://www.openedv.com/thread-340252-1-1.html## 第六十…

AD9528学习笔记

前言 AD9528是ADI的一款时钟芯片&#xff0c;由2-stage PLL组成&#xff0c;并且集成JESD204B/JESD204C SYSREF信号发生器&#xff0c;SYSREF发生器输出单次、N次或连续信号&#xff0c;并与PLL1和PLL2输出同步&#xff0c;从而可以实现多器件之间的同步。 AD9528总共有14路输…

<精学社>LCD1602移屏操作

本文为博主 日月同辉&#xff0c;与我共生&#xff0c;csdn原创首发。希望看完后能对你有所帮助&#xff0c;不足之处请指正&#xff01;一起交流学习&#xff0c;共同进步&#xff01; > 发布人&#xff1a;日月同辉,与我共生_单片机-CSDN博客 > 欢迎你为独创博主日月同…

Linux:配置Ubuntu系统的镜像软件下载地址

一、原理介绍 好处&#xff1a;从国内服务器下载APT软件&#xff0c;速度快。 二、配置 我这里配置的是清华大学的镜像服务器地址 https://mirrors.tuna.tsinghua.edu.cn/ 1、备份文件 sudo cp /etc/apt/sources.list /etc/apt/sources.list.bak2、清空sources.list ec…

CS5511规格书|CS5511方案应用说明|DP转双路LVDS/eDP芯片方案

概述&#xff1a;CS5511是一个将DP/eDP输入转换为LVDS信号的桥接芯片&#xff0c;此外&#xff0c;CS5511可以用作在DP/eDP输入到DP/eDP输出场景中桥接芯片。CS5511的高级接收器支持VEDA DisplayPort&#xff08;DP&#xff09;1.3和嵌入式DisplayPort&#xff08;eDP&#xf…

这是一个最简单的爱国主义为主题的网页首页

代码&#xff1a; <!DOCTYPE html> <html lang"zh-CN"> <head> <meta charset"UTF-8"> <title>爱国主题网页</title> <style> body { font-family: Arial, sans-serif; …

pytest系列——pytest-instafail插件之命令行实时输出错误信息

前言 1、pytest 运行全部用例的时候&#xff0c;在控制台会先显示用例的运行结果(.或F)&#xff1b;等待用例全部运行完成后最后把报错信息全部一起抛出到控制台。 2、这样我们每次都需要等用例运行结束&#xff0c;才知道为什么报错&#xff0c;不方便实时查看报错信息。 3…

数据挖掘之PCA-主成分分析

PCA的用处&#xff1a;找出反应数据中最大变差的投影&#xff08;就是拉的最开&#xff09;。 在减少需要分析的指标同时&#xff0c;尽量减少原指标包含信息的损失&#xff0c;以达到对所收集数据进行全面分析的目的 但是什么时候信息保留的最多呢&#xff1f;具体一点&#…

Day31| Leetcode 455. 分发饼干 Leetcode 376. 摆动序列 Leetcode 53. 最大子数组和

进入贪心了&#xff0c;我觉得本专题是最烧脑的专题 Leetcode 455. 分发饼干 题目链接 455 分发饼干 让大的饼干去满足需求量大的孩子即是本题的思路&#xff1a; class Solution { public:int findContentChildren(vector<int>& g, vector<int>& s) {…

joplin笔记同步 到腾讯云S3

创建存储桶 打开腾讯云的存储桶列表&#xff0c;点击“创建存储桶”&#xff0c;输入名称&#xff0c;选择地域&#xff08;建议选择离自己较近的地域以降低访问时延&#xff09;和访问权限&#xff08;建议选择“私有读写”&#xff09;。 s3 存储桶&#xff1a; 存储桶的名称…

windows环境下使用arthas不启动服务替换文件

场景&#xff1a; windows环境&#xff0c;如果现场环境客户正在使用项目&#xff0c;如果要替换项目中的一个class文件&#xff0c;但是又不能重启服务改怎么处理&#xff0c;今天介绍使用arthas中的retransform命令动态替换及使用注意的事项。 环境&#xff1a; windows,arth…

葡萄酒怎么按照饮用时间分类?

不同的葡萄酒搭配不同的餐食&#xff0c;会让饮酒人有不一样的感受和体会&#xff0c;所以&#xff0c;葡萄酒是分场合并且有饮用时间的。云仓酒庄的品牌雷盛红酒分享一般按照饮用时间分类可以把葡萄酒分为三大类&#xff0c;分别是餐前酒、佐餐酒和餐后酒。 餐前酒&#xff1…

java编程:使用递归 循环和位运算实现将10进制转为2进制

1 递归 /*** 递归&#xff1a;十进制转二进制* param decimal 待转换的十进制数* param binary 转换后的二进制数*/public static void decimalToBinaryByRecursion(int decimal,StringBuilder binary){if(decimal < 0){return;}decimalToBinaryByRecursion(decimal/2,bina…