Transformers实战03-PEFT库使用LORA方法微调。

文章目录

  • 简介
    • PEFT
    • LORA方法
    • Vision Transformer (ViT)
  • lora方法实战
    • 模型选择
      • google/vit-base-patch16-224-in21k
      • google/vit-base-patch16-224
    • 数据集
    • 模型
    • PEFT configuration and model
    • 训练
    • 预测

简介

PEFT

PEFT(Parameter-Efficient Fine-Tuning)是一个用于高效地将大型预训练模型适配到各种下游应用的库,而无需对模型的所有参数进行微调,因为这在计算上是非常昂贵的。PEFT 方法只微调少量的(额外的)模型参数——显著降低了计算和存储成本——同时其性能与完全微调的模型相当。这使得在消费者硬件上训练和存储大型语言模型(LLMs)变得更加可行。

PEFT 集成了 Transformers、Diffusers 和 Accelerate 库,以提供更快、更简单的方法来加载、训练和使用大型模型进行推理。

LORA方法

一种高效训练大型模型的流行方法是在注意力块中插入较小的可训练矩阵,这些矩阵是微调期间要学习的增量权重矩阵的低秩分解。预训练模型的原始权重矩阵被冻结,仅更新较小的矩阵。这减少了可训练参数的数量,降低了内存使用和训练时间,而这些在大型模型中可能非常昂贵。

有几种不同的方法可以将权重矩阵表示为低秩分解,但最常见的方法是低秩适应(LoRA原理)。PEFT 库支持几种其他 LoRA 变体,例如低秩Hadamard积(LoHa)、低秩Kronecker积(LoKr)和自适应低秩适应(AdaLoRA)。你可以在适配器指南中了解这些方法的概念工作原理。如果你有兴趣将这些方法应用于其他任务和用例,比如语义分割、标记分类,可以看看我们的笔记本集合!

Vision Transformer (ViT)

Vision Transformer(ViT)模型是由Alexey Dosovitskiy,Lucas Beyer,Alexander Kolesnikov,Dirk Weissenborn,Xiaohua Zhai,Thomas Unterthiner,Mostafa Dehghani,Matthias Minderer,Georg Heigold,Sylvain Gelly,Jakob Uszkoreit,Neil Houlsby在《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的。这是第一篇成功在ImageNet上训练Transformer编码器并获得非常好结果的论文。

这篇论文的摘要是:

虽然Transformer架构已经成为自然语言处理任务的事实标准,但它在计算机视觉中的应用仍然有限。在视觉领域,注意力要么与卷积网络一起应用,要么用来替换卷积网络的某些组件,同时保持其总体结构不变。我们展示了在这种对CNN的依赖不是必要的,纯Transformer直接应用于图像块序列可以在图像分类任务上表现得非常好。当在大量数据上进行预训练并转移到多个中等规模或小型图像识别基准数据集(ImageNet,CIFAR-100,VTAB等)时,Vision Transformer(ViT)与最先进的卷积网络相比取得了出色的结果,同时训练所需的计算资源大大减少。

具体关于该模型得结构参考:https://huggingface.co/docs/transformers/model_doc/vit

lora方法实战

本指南将向你展示如何快速训练一个图像分类模型——使用低秩分解方法——来识别图像中显示的食物类别。
案例来自官网:https://huggingface.co/docs/peft/task_guides/lora_based_methods

模型选择

google/vit-base-patch16-224-in21k

是一个基于Transformer编码器的模型(类似于BERT),在监督方式下,即ImageNet-21k上以224x224像素的分辨率预训练了大量图像。

图像被呈现给模型作为固定大小的补丁序列(分辨率为16x16),这些补丁被线性嵌入。在序列的开头还添加了一个[CLS]标记,用于分类任务。在将序列馈送到Transformer编码器的层之前,还会添加绝对位置嵌入。

需要注意的是,这个模型不提供任何经过微调的头部,因为这些头部被Google研究人员清零了。但是,模型包括预训练的汇聚层,可以用于下游任务(如图像分类)。

通过预训练模型,它学习了图像的内部表示,然后可以用于提取对下游任务有用的特征:例如,如果您有一个带标签的图像数据集,可以通过在预训练编码器顶部放置一个线性层来训练标准分类器。通常将线性层放置在[CLS]标记的顶部,因为该标记的最后隐藏状态可以被视为整个图像的表示。

from transformers import ViTImageProcessor, FlaxViTModel
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

inputs = processor(images=image, return_tensors="np")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
print(last_hidden_states.shape)

不包含分类信息,不包含label信息

google/vit-base-patch16-224

是一个在大规模图像数据集上进行监督预训练的转换器编码器模型(类似于BERT),即ImageNet-21k,分辨率为224x224像素。接下来,该模型在ImageNet上进行微调(也称为ILSVRC2012),这是一个包含100万张图像和1000个类别的数据集,分辨率也为224x224。

图像被呈现给模型作为固定大小补丁(分辨率为16x16)的序列,这些补丁被线性嵌入。还在序列开始处添加了一个[CLS]标记,以用于分类任务。在将序列馈送到Transformer编码器的层之前,还会添加绝对位置嵌入。

通过对模型进行预训练,它学习了图像的内部表示,然后可以用于提取对下游任务有用的特征:例如,如果您有一个带标签的图像数据集,您可以在预训练编码器之上放置一个标准分类器的线性层。通常将线性层放置在[CLS]标记的顶部,因为该标记的最后隐藏状态可以被视为整个图像的表示。

from transformers import AutoImageProcessor, ViTForImageClassification
from PIL import Image
import requests,torch

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

inputs = processor(images=image, return_tensors="pt")
print(inputs)
print(inputs["pixel_values"].shape)
outputs = model(**inputs)
with torch.no_grad():
    logits = model(**inputs).logits
    predicted_label = logits.argmax(-1).item()
    print(model.config.id2label[predicted_label])

输出:

{'pixel_values': tensor([[[[ 0.1137,  0.1686,  0.1843,  ..., -0.1922, -0.1843, -0.1843],
          [ 0.1373,  0.1686,  0.1843,  ..., -0.1922, -0.1922, -0.2078],
          [ 0.1137,  0.1529,  0.1608,  ..., -0.2314, -0.2235, -0.2157],
          ...,
          [ 0.8353,  0.7882,  0.7333,  ...,  0.7020,  0.6471,  0.6157],
          [ 0.8275,  0.7961,  0.7725,  ...,  0.5843,  0.4667,  0.3961],
          [ 0.8196,  0.7569,  0.7569,  ...,  0.0745, -0.0510, -0.1922]],

         [[-0.8039, -0.8118, -0.8118,  ..., -0.8902, -0.8902, -0.8980],
          [-0.7882, -0.7882, -0.7882,  ..., -0.8745, -0.8745, -0.8824],
          [-0.8118, -0.8039, -0.7882,  ..., -0.8902, -0.8902, -0.8902],
          ...,
          [-0.2706, -0.3176, -0.3647,  ..., -0.4275, -0.4588, -0.4824],
          [-0.2706, -0.2941, -0.3412,  ..., -0.4824, -0.5451, -0.5765],
          [-0.2784, -0.3412, -0.3490,  ..., -0.7333, -0.7804, -0.8353]],

         [[-0.5451, -0.4667, -0.4824,  ..., -0.7412, -0.6941, -0.7176],
          [-0.5529, -0.5137, -0.4902,  ..., -0.7412, -0.7098, -0.7412],
          [-0.5216, -0.4824, -0.4667,  ..., -0.7490, -0.7490, -0.7647],
          ...,
          [ 0.5686,  0.5529,  0.4510,  ...,  0.4431,  0.3882,  0.3255],
          [ 0.5451,  0.4902,  0.5137,  ...,  0.3020,  0.2078,  0.1294],
          [ 0.5686,  0.5608,  0.5137,  ..., -0.2000, -0.4275, -0.5294]]]])}
torch.Size([1, 3, 224, 224])
{0: 'tench, Tinca tinca', 1: 'goldfish, Carassius auratus', 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 3: 'tiger shark, Galeocerdo cuvieri', 4: 'hammerhead, hammerhead shark', 5: 'electric ray, crampfish, numbfish, torpedo', 6: 'stingray', 7: 'cock', 8: 'hen', 9: 'ostrich, Struthio camelus', 10: 'brambling, Fringilla montifringilla', 11: 'goldfinch, Carduelis carduelis', 12: 'house finch, linnet, Carpodacus mexicanus', 13: 'junco, snowbird', 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', 15: 'robin, American robin, Turdus migratorius', 16: 'bulbul', 17: 'jay', 18: 'magpie', 19: 'chickadee', 20: 'water ouzel, dipper', 21: 'kite', 22: 'bald eagle, American eagle, Haliaeetus leucocephalus', 23: 'vulture', 24: 'great grey owl, great gray owl, Strix nebulosa', 25: 'European fire salamander, Salamandra salamandra', 26: 'common newt, Triturus vulgaris', 27: 'eft', 28: 'spotted salamander, Ambystoma maculatum', 29: 'axolotl, mud puppy, Ambystoma mexicanum', 30: 'bullfrog, Rana catesbeiana', 31: 'tree frog, tree-frog', 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', 33: 'loggerhead, loggerhead turtle, Caretta caretta', 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', 35: 'mud turtle', 36: 'terrapin', 37: 'box turtle, box tortoise', 38: 'banded gecko', 39: 'common iguana, iguana, Iguana iguana', 40: 'American chameleon, anole, Anolis carolinensis', 41: 'whiptail, whiptail lizard', 42: 'agama', 43: 'frilled lizard, Chlamydosaurus kingi', 44: 'alligator lizard', 45: 'Gila monster, Heloderma suspectum', 46: 'green lizard, Lacerta viridis', 47: 'African chameleon, Chamaeleo chamaeleon', 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', 49: 'African crocodile, Nile crocodile, Crocodylus niloticus', 50: 'American alligator, Alligator mississipiensis', 51: 'triceratops', 52: 'thunder snake, worm snake, Carphophis amoenus', 53: 'ringneck snake, ring-necked snake, ring snake', 54: 'hognose snake, puff adder, sand viper', 55: 'green snake, grass snake', 56: 'king snake, kingsnake', 57: 'garter snake, grass snake', 58: 'water snake', 59: 'vine snake', 60: 'night snake, Hypsiglena torquata', 61: 'boa constrictor, Constrictor constrictor', 62: 'rock python, rock snake, Python sebae', 63: 'Indian cobra, Naja naja', 64: 'green mamba', 65: 'sea snake', 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus', 68: 'sidewinder, horned rattlesnake, Crotalus cerastes', 69: 'trilobite', 70: 'harvestman, daddy longlegs, Phalangium opilio', 71: 'scorpion', 72: 'black and gold garden spider, Argiope aurantia', 73: 'barn spider, Araneus cavaticus', 74: 'garden spider, Aranea diademata', 75: 'black widow, Latrodectus mactans', 76: 'tarantula', 77: 'wolf spider, hunting spider', 78: 'tick', 79: 'centipede', 80: 'black grouse', 81: 'ptarmigan', 82: 'ruffed grouse, partridge, Bonasa umbellus', 83: 'prairie chicken, prairie grouse, prairie fowl', 84: 'peacock', 85: 'quail', 86: 'partridge', 87: 'African grey, African gray, Psittacus erithacus', 88: 'macaw', 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', 90: 'lorikeet', 91: 'coucal', 92: 'bee eater', 93: 'hornbill', 94: 'hummingbird', 95: 'jacamar', 96: 'toucan', 97: 'drake', 98: 'red-breasted merganser, Mergus serrator', 99: 'goose', 100: 'black swan, Cygnus atratus', 101: 'tusker', 102: 'echidna, spiny anteater, anteater', 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', 104: 'wallaby, brush kangaroo', 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', 106: 'wombat', 107: 'jellyfish', 108: 'sea anemone, anemone', 109: 'brain coral', 110: 'flatworm, platyhelminth', 111: 'nematode, nematode worm, roundworm', 112: 'conch', 113: 'snail', 114: 'slug', 115: 'sea slug, nudibranch', 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore', 117: 'chambered nautilus, pearly nautilus, nautilus', 118: 'Dungeness crab, Cancer magister', 119: 'rock crab, Cancer irroratus', 120: 'fiddler crab', 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus', 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', 124: 'crayfish, crawfish, crawdad, crawdaddy', 125: 'hermit crab', 126: 'isopod', 127: 'white stork, Ciconia ciconia', 128: 'black stork, Ciconia nigra', 129: 'spoonbill', 130: 'flamingo', 131: 'little blue heron, Egretta caerulea', 132: 'American egret, great white heron, Egretta albus', 133: 'bittern', 134: 'crane', 135: 'limpkin, Aramus pictus', 136: 'European gallinule, Porphyrio porphyrio', 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana', 138: 'bustard', 139: 'ruddy turnstone, Arenaria interpres', 140: 'red-backed sandpiper, dunlin, Erolia alpina', 141: 'redshank, Tringa totanus', 142: 'dowitcher', 143: 'oystercatcher, oyster catcher', 144: 'pelican', 145: 'king penguin, Aptenodytes patagonica', 146: 'albatross, mollymawk', 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', 149: 'dugong, Dugong dugon', 150: 'sea lion', 151: 'Chihuahua', 152: 'Japanese spaniel', 153: 'Maltese dog, Maltese terrier, Maltese', 154: 'Pekinese, Pekingese, Peke', 155: 'Shih-Tzu', 156: 'Blenheim spaniel', 157: 'papillon', 158: 'toy terrier', 159: 'Rhodesian ridgeback', 160: 'Afghan hound, Afghan', 161: 'basset, basset hound', 162: 'beagle', 163: 'bloodhound, sleuthhound', 164: 'bluetick', 165: 'black-and-tan coonhound', 166: 'Walker hound, Walker foxhound', 167: 'English foxhound', 168: 'redbone', 169: 'borzoi, Russian wolfhound', 170: 'Irish wolfhound', 171: 'Italian greyhound', 172: 'whippet', 173: 'Ibizan hound, Ibizan Podenco', 174: 'Norwegian elkhound, elkhound', 175: 'otterhound, otter hound', 176: 'Saluki, gazelle hound', 177: 'Scottish deerhound, deerhound', 178: 'Weimaraner', 179: 'Staffordshire bullterrier, Staffordshire bull terrier', 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', 181: 'Bedlington terrier', 182: 'Border terrier', 183: 'Kerry blue terrier', 184: 'Irish terrier', 185: 'Norfolk terrier', 186: 'Norwich terrier', 187: 'Yorkshire terrier', 188: 'wire-haired fox terrier', 189: 'Lakeland terrier', 190: 'Sealyham terrier, Sealyham', 191: 'Airedale, Airedale terrier', 192: 'cairn, cairn terrier', 193: 'Australian terrier', 194: 'Dandie Dinmont, Dandie Dinmont terrier', 195: 'Boston bull, Boston terrier', 196: 'miniature schnauzer', 197: 'giant schnauzer', 198: 'standard schnauzer', 199: 'Scotch terrier, Scottish terrier, Scottie', 200: 'Tibetan terrier, chrysanthemum dog', 201: 'silky terrier, Sydney silky', 202: 'soft-coated wheaten terrier', 203: 'West Highland white terrier', 204: 'Lhasa, Lhasa apso', 205: 'flat-coated retriever', 206: 'curly-coated retriever', 207: 'golden retriever', 208: 'Labrador retriever', 209: 'Chesapeake Bay retriever', 210: 'German short-haired pointer', 211: 'vizsla, Hungarian pointer', 212: 'English setter', 213: 'Irish setter, red setter', 214: 'Gordon setter', 215: 'Brittany spaniel', 216: 'clumber, clumber spaniel', 217: 'English springer, English springer spaniel', 218: 'Welsh springer spaniel', 219: 'cocker spaniel, English cocker spaniel, cocker', 220: 'Sussex spaniel', 221: 'Irish water spaniel', 222: 'kuvasz', 223: 'schipperke', 224: 'groenendael', 225: 'malinois', 226: 'briard', 227: 'kelpie', 228: 'komondor', 229: 'Old English sheepdog, bobtail', 230: 'Shetland sheepdog, Shetland sheep dog, Shetland', 231: 'collie', 232: 'Border collie', 233: 'Bouvier des Flandres, Bouviers des Flandres', 234: 'Rottweiler', 235: 'German shepherd, German shepherd dog, German police dog, alsatian', 236: 'Doberman, Doberman pinscher', 237: 'miniature pinscher', 238: 'Greater Swiss Mountain dog', 239: 'Bernese mountain dog', 240: 'Appenzeller', 241: 'EntleBucher', 242: 'boxer', 243: 'bull mastiff', 244: 'Tibetan mastiff', 245: 'French bulldog', 246: 'Great Dane', 247: 'Saint Bernard, St Bernard', 248: 'Eskimo dog, husky', 249: 'malamute, malemute, Alaskan malamute', 250: 'Siberian husky', 251: 'dalmatian, coach dog, carriage dog', 252: 'affenpinscher, monkey pinscher, monkey dog', 253: 'basenji', 254: 'pug, pug-dog', 255: 'Leonberg', 256: 'Newfoundland, Newfoundland dog', 257: 'Great Pyrenees', 258: 'Samoyed, Samoyede', 259: 'Pomeranian', 260: 'chow, chow chow', 261: 'keeshond', 262: 'Brabancon griffon', 263: 'Pembroke, Pembroke Welsh corgi', 264: 'Cardigan, Cardigan Welsh corgi', 265: 'toy poodle', 266: 'miniature poodle', 267: 'standard poodle', 268: 'Mexican hairless', 269: 'timber wolf, grey wolf, gray wolf, Canis lupus', 270: 'white wolf, Arctic wolf, Canis lupus tundrarum', 271: 'red wolf, maned wolf, Canis rufus, Canis niger', 272: 'coyote, prairie wolf, brush wolf, Canis latrans', 273: 'dingo, warrigal, warragal, Canis dingo', 274: 'dhole, Cuon alpinus', 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', 276: 'hyena, hyaena', 277: 'red fox, Vulpes vulpes', 278: 'kit fox, Vulpes macrotis', 279: 'Arctic fox, white fox, Alopex lagopus', 280: 'grey fox, gray fox, Urocyon cinereoargenteus', 281: 'tabby, tabby cat', 282: 'tiger cat', 283: 'Persian cat', 284: 'Siamese cat, Siamese', 285: 'Egyptian cat', 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', 287: 'lynx, catamount', 288: 'leopard, Panthera pardus', 289: 'snow leopard, ounce, Panthera uncia', 290: 'jaguar, panther, Panthera onca, Felis onca', 291: 'lion, king of beasts, Panthera leo', 292: 'tiger, Panthera tigris', 293: 'cheetah, chetah, Acinonyx jubatus', 294: 'brown bear, bruin, Ursus arctos', 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus', 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', 297: 'sloth bear, Melursus ursinus, Ursus ursinus', 298: 'mongoose', 299: 'meerkat, mierkat', 300: 'tiger beetle', 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', 302: 'ground beetle, carabid beetle', 303: 'long-horned beetle, longicorn, longicorn beetle', 304: 'leaf beetle, chrysomelid', 305: 'dung beetle', 306: 'rhinoceros beetle', 307: 'weevil', 308: 'fly', 309: 'bee', 310: 'ant, emmet, pismire', 311: 'grasshopper, hopper', 312: 'cricket', 313: 'walking stick, walkingstick, stick insect', 314: 'cockroach, roach', 315: 'mantis, mantid', 316: 'cicada, cicala', 317: 'leafhopper', 318: 'lacewing, lacewing fly', 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", 320: 'damselfly', 321: 'admiral', 322: 'ringlet, ringlet butterfly', 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', 324: 'cabbage butterfly', 325: 'sulphur butterfly, sulfur butterfly', 326: 'lycaenid, lycaenid butterfly', 327: 'starfish, sea star', 328: 'sea urchin', 329: 'sea cucumber, holothurian', 330: 'wood rabbit, cottontail, cottontail rabbit', 331: 'hare', 332: 'Angora, Angora rabbit', 333: 'hamster', 334: 'porcupine, hedgehog', 335: 'fox squirrel, eastern fox squirrel, Sciurus niger', 336: 'marmot', 337: 'beaver', 338: 'guinea pig, Cavia cobaya', 339: 'sorrel', 340: 'zebra', 341: 'hog, pig, grunter, squealer, Sus scrofa', 342: 'wild boar, boar, Sus scrofa', 343: 'warthog', 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius', 345: 'ox', 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', 347: 'bison', 348: 'ram, tup', 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', 350: 'ibex, Capra ibex', 351: 'hartebeest', 352: 'impala, Aepyceros melampus', 353: 'gazelle', 354: 'Arabian camel, dromedary, Camelus dromedarius', 355: 'llama', 356: 'weasel', 357: 'mink', 358: 'polecat, fitch, foulmart, foumart, Mustela putorius', 359: 'black-footed ferret, ferret, Mustela nigripes', 360: 'otter', 361: 'skunk, polecat, wood pussy', 362: 'badger', 363: 'armadillo', 364: 'three-toed sloth, ai, Bradypus tridactylus', 365: 'orangutan, orang, orangutang, Pongo pygmaeus', 366: 'gorilla, Gorilla gorilla', 367: 'chimpanzee, chimp, Pan troglodytes', 368: 'gibbon, Hylobates lar', 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus', 370: 'guenon, guenon monkey', 371: 'patas, hussar monkey, Erythrocebus patas', 372: 'baboon', 373: 'macaque', 374: 'langur', 375: 'colobus, colobus monkey', 376: 'proboscis monkey, Nasalis larvatus', 377: 'marmoset', 378: 'capuchin, ringtail, Cebus capucinus', 379: 'howler monkey, howler', 380: 'titi, titi monkey', 381: 'spider monkey, Ateles geoffroyi', 382: 'squirrel monkey, Saimiri sciureus', 383: 'Madagascar cat, ring-tailed lemur, Lemur catta', 384: 'indri, indris, Indri indri, Indri brevicaudatus', 385: 'Indian elephant, Elephas maximus', 386: 'African elephant, Loxodonta africana', 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', 389: 'barracouta, snoek', 390: 'eel', 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', 392: 'rock beauty, Holocanthus tricolor', 393: 'anemone fish', 394: 'sturgeon', 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus', 396: 'lionfish', 397: 'puffer, pufferfish, blowfish, globefish', 398: 'abacus', 399: 'abaya', 400: "academic gown, academic robe, judge's robe", 401: 'accordion, piano accordion, squeeze box', 402: 'acoustic guitar', 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier', 404: 'airliner', 405: 'airship, dirigible', 406: 'altar', 407: 'ambulance', 408: 'amphibian, amphibious vehicle', 409: 'analog clock', 410: 'apiary, bee house', 411: 'apron', 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', 413: 'assault rifle, assault gun', 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack', 415: 'bakery, bakeshop, bakehouse', 416: 'balance beam, beam', 417: 'balloon', 418: 'ballpoint, ballpoint pen, ballpen, Biro', 419: 'Band Aid', 420: 'banjo', 421: 'bannister, banister, balustrade, balusters, handrail', 422: 'barbell', 423: 'barber chair', 424: 'barbershop', 425: 'barn', 426: 'barometer', 427: 'barrel, cask', 428: 'barrow, garden cart, lawn cart, wheelbarrow', 429: 'baseball', 430: 'basketball', 431: 'bassinet', 432: 'bassoon', 433: 'bathing cap, swimming cap', 434: 'bath towel', 435: 'bathtub, bathing tub, bath, tub', 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', 437: 'beacon, lighthouse, beacon light, pharos', 438: 'beaker', 439: 'bearskin, busby, shako', 440: 'beer bottle', 441: 'beer glass', 442: 'bell cote, bell cot', 443: 'bib', 444: 'bicycle-built-for-two, tandem bicycle, tandem', 445: 'bikini, two-piece', 446: 'binder, ring-binder', 447: 'binoculars, field glasses, opera glasses', 448: 'birdhouse', 449: 'boathouse', 450: 'bobsled, bobsleigh, bob', 451: 'bolo tie, bolo, bola tie, bola', 452: 'bonnet, poke bonnet', 453: 'bookcase', 454: 'bookshop, bookstore, bookstall', 455: 'bottlecap', 456: 'bow', 457: 'bow tie, bow-tie, bowtie', 458: 'brass, memorial tablet, plaque', 459: 'brassiere, bra, bandeau', 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', 461: 'breastplate, aegis, egis', 462: 'broom', 463: 'bucket, pail', 464: 'buckle', 465: 'bulletproof vest', 466: 'bullet train, bullet', 467: 'butcher shop, meat market', 468: 'cab, hack, taxi, taxicab', 469: 'caldron, cauldron', 470: 'candle, taper, wax light', 471: 'cannon', 472: 'canoe', 473: 'can opener, tin opener', 474: 'cardigan', 475: 'car mirror', 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig', 477: "carpenter's kit, tool kit", 478: 'carton', 479: 'car wheel', 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', 481: 'cassette', 482: 'cassette player', 483: 'castle', 484: 'catamaran', 485: 'CD player', 486: 'cello, violoncello', 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone', 488: 'chain', 489: 'chainlink fence', 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', 491: 'chain saw, chainsaw', 492: 'chest', 493: 'chiffonier, commode', 494: 'chime, bell, gong', 495: 'china cabinet, china closet', 496: 'Christmas stocking', 497: 'church, church building', 498: 'cinema, movie theater, movie theatre, movie house, picture palace', 499: 'cleaver, meat cleaver, chopper', 500: 'cliff dwelling', 501: 'cloak', 502: 'clog, geta, patten, sabot', 503: 'cocktail shaker', 504: 'coffee mug', 505: 'coffeepot', 506: 'coil, spiral, volute, whorl, helix', 507: 'combination lock', 508: 'computer keyboard, keypad', 509: 'confectionery, confectionary, candy store', 510: 'container ship, containership, container vessel', 511: 'convertible', 512: 'corkscrew, bottle screw', 513: 'cornet, horn, trumpet, trump', 514: 'cowboy boot', 515: 'cowboy hat, ten-gallon hat', 516: 'cradle', 517: 'crane', 518: 'crash helmet', 519: 'crate', 520: 'crib, cot', 521: 'Crock Pot', 522: 'croquet ball', 523: 'crutch', 524: 'cuirass', 525: 'dam, dike, dyke', 526: 'desk', 527: 'desktop computer', 528: 'dial telephone, dial phone', 529: 'diaper, nappy, napkin', 530: 'digital clock', 531: 'digital watch', 532: 'dining table, board', 533: 'dishrag, dishcloth', 534: 'dishwasher, dish washer, dishwashing machine', 535: 'disk brake, disc brake', 536: 'dock, dockage, docking facility', 537: 'dogsled, dog sled, dog sleigh', 538: 'dome', 539: 'doormat, welcome mat', 540: 'drilling platform, offshore rig', 541: 'drum, membranophone, tympan', 542: 'drumstick', 543: 'dumbbell', 544: 'Dutch oven', 545: 'electric fan, blower', 546: 'electric guitar', 547: 'electric locomotive', 548: 'entertainment center', 549: 'envelope', 550: 'espresso maker', 551: 'face powder', 552: 'feather boa, boa', 553: 'file, file cabinet, filing cabinet', 554: 'fireboat', 555: 'fire engine, fire truck', 556: 'fire screen, fireguard', 557: 'flagpole, flagstaff', 558: 'flute, transverse flute', 559: 'folding chair', 560: 'football helmet', 561: 'forklift', 562: 'fountain', 563: 'fountain pen', 564: 'four-poster', 565: 'freight car', 566: 'French horn, horn', 567: 'frying pan, frypan, skillet', 568: 'fur coat', 569: 'garbage truck, dustcart', 570: 'gasmask, respirator, gas helmet', 571: 'gas pump, gasoline pump, petrol pump, island dispenser', 572: 'goblet', 573: 'go-kart', 574: 'golf ball', 575: 'golfcart, golf cart', 576: 'gondola', 577: 'gong, tam-tam', 578: 'gown', 579: 'grand piano, grand', 580: 'greenhouse, nursery, glasshouse', 581: 'grille, radiator grille', 582: 'grocery store, grocery, food market, market', 583: 'guillotine', 584: 'hair slide', 585: 'hair spray', 586: 'half track', 587: 'hammer', 588: 'hamper', 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier', 590: 'hand-held computer, hand-held microcomputer', 591: 'handkerchief, hankie, hanky, hankey', 592: 'hard disc, hard disk, fixed disk', 593: 'harmonica, mouth organ, harp, mouth harp', 594: 'harp', 595: 'harvester, reaper', 596: 'hatchet', 597: 'holster', 598: 'home theater, home theatre', 599: 'honeycomb', 600: 'hook, claw', 601: 'hoopskirt, crinoline', 602: 'horizontal bar, high bar', 603: 'horse cart, horse-cart', 604: 'hourglass', 605: 'iPod', 606: 'iron, smoothing iron', 607: "jack-o'-lantern", 608: 'jean, blue jean, denim', 609: 'jeep, landrover', 610: 'jersey, T-shirt, tee shirt', 611: 'jigsaw puzzle', 612: 'jinrikisha, ricksha, rickshaw', 613: 'joystick', 614: 'kimono', 615: 'knee pad', 616: 'knot', 617: 'lab coat, laboratory coat', 618: 'ladle', 619: 'lampshade, lamp shade', 620: 'laptop, laptop computer', 621: 'lawn mower, mower', 622: 'lens cap, lens cover', 623: 'letter opener, paper knife, paperknife', 624: 'library', 625: 'lifeboat', 626: 'lighter, light, igniter, ignitor', 627: 'limousine, limo', 628: 'liner, ocean liner', 629: 'lipstick, lip rouge', 630: 'Loafer', 631: 'lotion', 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', 633: "loupe, jeweler's loupe", 634: 'lumbermill, sawmill', 635: 'magnetic compass', 636: 'mailbag, postbag', 637: 'mailbox, letter box', 638: 'maillot', 639: 'maillot, tank suit', 640: 'manhole cover', 641: 'maraca', 642: 'marimba, xylophone', 643: 'mask', 644: 'matchstick', 645: 'maypole', 646: 'maze, labyrinth', 647: 'measuring cup', 648: 'medicine chest, medicine cabinet', 649: 'megalith, megalithic structure', 650: 'microphone, mike', 651: 'microwave, microwave oven', 652: 'military uniform', 653: 'milk can', 654: 'minibus', 655: 'miniskirt, mini', 656: 'minivan', 657: 'missile', 658: 'mitten', 659: 'mixing bowl', 660: 'mobile home, manufactured home', 661: 'Model T', 662: 'modem', 663: 'monastery', 664: 'monitor', 665: 'moped', 666: 'mortar', 667: 'mortarboard', 668: 'mosque', 669: 'mosquito net', 670: 'motor scooter, scooter', 671: 'mountain bike, all-terrain bike, off-roader', 672: 'mountain tent', 673: 'mouse, computer mouse', 674: 'mousetrap', 675: 'moving van', 676: 'muzzle', 677: 'nail', 678: 'neck brace', 679: 'necklace', 680: 'nipple', 681: 'notebook, notebook computer', 682: 'obelisk', 683: 'oboe, hautboy, hautbois', 684: 'ocarina, sweet potato', 685: 'odometer, hodometer, mileometer, milometer', 686: 'oil filter', 687: 'organ, pipe organ', 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO', 689: 'overskirt', 690: 'oxcart', 691: 'oxygen mask', 692: 'packet', 693: 'paddle, boat paddle', 694: 'paddlewheel, paddle wheel', 695: 'padlock', 696: 'paintbrush', 697: "pajama, pyjama, pj's, jammies", 698: 'palace', 699: 'panpipe, pandean pipe, syrinx', 700: 'paper towel', 701: 'parachute, chute', 702: 'parallel bars, bars', 703: 'park bench', 704: 'parking meter', 705: 'passenger car, coach, carriage', 706: 'patio, terrace', 707: 'pay-phone, pay-station', 708: 'pedestal, plinth, footstall', 709: 'pencil box, pencil case', 710: 'pencil sharpener', 711: 'perfume, essence', 712: 'Petri dish', 713: 'photocopier', 714: 'pick, plectrum, plectron', 715: 'pickelhaube', 716: 'picket fence, paling', 717: 'pickup, pickup truck', 718: 'pier', 719: 'piggy bank, penny bank', 720: 'pill bottle', 721: 'pillow', 722: 'ping-pong ball', 723: 'pinwheel', 724: 'pirate, pirate ship', 725: 'pitcher, ewer', 726: "plane, carpenter's plane, woodworking plane", 727: 'planetarium', 728: 'plastic bag', 729: 'plate rack', 730: 'plow, plough', 731: "plunger, plumber's helper", 732: 'Polaroid camera, Polaroid Land camera', 733: 'pole', 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', 735: 'poncho', 736: 'pool table, billiard table, snooker table', 737: 'pop bottle, soda bottle', 738: 'pot, flowerpot', 739: "potter's wheel", 740: 'power drill', 741: 'prayer rug, prayer mat', 742: 'printer', 743: 'prison, prison house', 744: 'projectile, missile', 745: 'projector', 746: 'puck, hockey puck', 747: 'punching bag, punch bag, punching ball, punchball', 748: 'purse', 749: 'quill, quill pen', 750: 'quilt, comforter, comfort, puff', 751: 'racer, race car, racing car', 752: 'racket, racquet', 753: 'radiator', 754: 'radio, wireless', 755: 'radio telescope, radio reflector', 756: 'rain barrel', 757: 'recreational vehicle, RV, R.V.', 758: 'reel', 759: 'reflex camera', 760: 'refrigerator, icebox', 761: 'remote control, remote', 762: 'restaurant, eating house, eating place, eatery', 763: 'revolver, six-gun, six-shooter', 764: 'rifle', 765: 'rocking chair, rocker', 766: 'rotisserie', 767: 'rubber eraser, rubber, pencil eraser', 768: 'rugby ball', 769: 'rule, ruler', 770: 'running shoe', 771: 'safe', 772: 'safety pin', 773: 'saltshaker, salt shaker', 774: 'sandal', 775: 'sarong', 776: 'sax, saxophone', 777: 'scabbard', 778: 'scale, weighing machine', 779: 'school bus', 780: 'schooner', 781: 'scoreboard', 782: 'screen, CRT screen', 783: 'screw', 784: 'screwdriver', 785: 'seat belt, seatbelt', 786: 'sewing machine', 787: 'shield, buckler', 788: 'shoe shop, shoe-shop, shoe store', 789: 'shoji', 790: 'shopping basket', 791: 'shopping cart', 792: 'shovel', 793: 'shower cap', 794: 'shower curtain', 795: 'ski', 796: 'ski mask', 797: 'sleeping bag', 798: 'slide rule, slipstick', 799: 'sliding door', 800: 'slot, one-armed bandit', 801: 'snorkel', 802: 'snowmobile', 803: 'snowplow, snowplough', 804: 'soap dispenser', 805: 'soccer ball', 806: 'sock', 807: 'solar dish, solar collector, solar furnace', 808: 'sombrero', 809: 'soup bowl', 810: 'space bar', 811: 'space heater', 812: 'space shuttle', 813: 'spatula', 814: 'speedboat', 815: "spider web, spider's web", 816: 'spindle', 817: 'sports car, sport car', 818: 'spotlight, spot', 819: 'stage', 820: 'steam locomotive', 821: 'steel arch bridge', 822: 'steel drum', 823: 'stethoscope', 824: 'stole', 825: 'stone wall', 826: 'stopwatch, stop watch', 827: 'stove', 828: 'strainer', 829: 'streetcar, tram, tramcar, trolley, trolley car', 830: 'stretcher', 831: 'studio couch, day bed', 832: 'stupa, tope', 833: 'submarine, pigboat, sub, U-boat', 834: 'suit, suit of clothes', 835: 'sundial', 836: 'sunglass', 837: 'sunglasses, dark glasses, shades', 838: 'sunscreen, sunblock, sun blocker', 839: 'suspension bridge', 840: 'swab, swob, mop', 841: 'sweatshirt', 842: 'swimming trunks, bathing trunks', 843: 'swing', 844: 'switch, electric switch, electrical switch', 845: 'syringe', 846: 'table lamp', 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle', 848: 'tape player', 849: 'teapot', 850: 'teddy, teddy bear', 851: 'television, television system', 852: 'tennis ball', 853: 'thatch, thatched roof', 854: 'theater curtain, theatre curtain', 855: 'thimble', 856: 'thresher, thrasher, threshing machine', 857: 'throne', 858: 'tile roof', 859: 'toaster', 860: 'tobacco shop, tobacconist shop, tobacconist', 861: 'toilet seat', 862: 'torch', 863: 'totem pole', 864: 'tow truck, tow car, wrecker', 865: 'toyshop', 866: 'tractor', 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', 868: 'tray', 869: 'trench coat', 870: 'tricycle, trike, velocipede', 871: 'trimaran', 872: 'tripod', 873: 'triumphal arch', 874: 'trolleybus, trolley coach, trackless trolley', 875: 'trombone', 876: 'tub, vat', 877: 'turnstile', 878: 'typewriter keyboard', 879: 'umbrella', 880: 'unicycle, monocycle', 881: 'upright, upright piano', 882: 'vacuum, vacuum cleaner', 883: 'vase', 884: 'vault', 885: 'velvet', 886: 'vending machine', 887: 'vestment', 888: 'viaduct', 889: 'violin, fiddle', 890: 'volleyball', 891: 'waffle iron', 892: 'wall clock', 893: 'wallet, billfold, notecase, pocketbook', 894: 'wardrobe, closet, press', 895: 'warplane, military plane', 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', 897: 'washer, automatic washer, washing machine', 898: 'water bottle', 899: 'water jug', 900: 'water tower', 901: 'whiskey jug', 902: 'whistle', 903: 'wig', 904: 'window screen', 905: 'window shade', 906: 'Windsor tie', 907: 'wine bottle', 908: 'wing', 909: 'wok', 910: 'wooden spoon', 911: 'wool, woolen, woollen', 912: 'worm fence, snake fence, snake-rail fence, Virginia fence', 913: 'wreck', 914: 'yawl', 915: 'yurt', 916: 'web site, website, internet site, site', 917: 'comic book', 918: 'crossword puzzle, crossword', 919: 'street sign', 920: 'traffic light, traffic signal, stoplight', 921: 'book jacket, dust cover, dust jacket, dust wrapper', 922: 'menu', 923: 'plate', 924: 'guacamole', 925: 'consomme', 926: 'hot pot, hotpot', 927: 'trifle', 928: 'ice cream, icecream', 929: 'ice lolly, lolly, lollipop, popsicle', 930: 'French loaf', 931: 'bagel, beigel', 932: 'pretzel', 933: 'cheeseburger', 934: 'hotdog, hot dog, red hot', 935: 'mashed potato', 936: 'head cabbage', 937: 'broccoli', 938: 'cauliflower', 939: 'zucchini, courgette', 940: 'spaghetti squash', 941: 'acorn squash', 942: 'butternut squash', 943: 'cucumber, cuke', 944: 'artichoke, globe artichoke', 945: 'bell pepper', 946: 'cardoon', 947: 'mushroom', 948: 'Granny Smith', 949: 'strawberry', 950: 'orange', 951: 'lemon', 952: 'fig', 953: 'pineapple, ananas', 954: 'banana', 955: 'jackfruit, jak, jack', 956: 'custard apple', 957: 'pomegranate', 958: 'hay', 959: 'carbonara', 960: 'chocolate sauce, chocolate syrup', 961: 'dough', 962: 'meat loaf, meatloaf', 963: 'pizza, pizza pie', 964: 'potpie', 965: 'burrito', 966: 'red wine', 967: 'espresso', 968: 'cup', 969: 'eggnog', 970: 'alp', 971: 'bubble', 972: 'cliff, drop, drop-off', 973: 'coral reef', 974: 'geyser', 975: 'lakeside, lakeshore', 976: 'promontory, headland, head, foreland', 977: 'sandbar, sand bar', 978: 'seashore, coast, seacoast, sea-coast', 979: 'valley, vale', 980: 'volcano', 981: 'ballplayer, baseball player', 982: 'groom, bridegroom', 983: 'scuba diver', 984: 'rapeseed', 985: 'daisy', 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", 987: 'corn', 988: 'acorn', 989: 'hip, rose hip, rosehip', 990: 'buckeye, horse chestnut, conker', 991: 'coral fungus', 992: 'agaric', 993: 'gyromitra', 994: 'stinkhorn, carrion fungus', 995: 'earthstar', 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', 997: 'bolete', 998: 'ear, spike, capitulum', 999: 'toilet tissue, toilet paper, bathroom tissue'}
Egyptian cat

经过转码后得json结构是一个key=pixel_values的像素数组,维度是:[批次,通道数,宽度,高度]。
通过model.config.id2label可以看到总共1000个label。

数据集

food101包含多种食物类别,数据集地址:https://huggingface.co/datasets/ethz/food101。

from datasets import load_dataset
ds = load_dataset("food101")
print("数据集",ds)
#获取训练集数据
ds = load_dataset("food101",split="train")
print("训练集",ds)
print("第一个数据集",ds[0])
#获取所有label
labels = ds.features["label"].names
print(labels)
print(len(labels))

输出

数据集 DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 75750
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 25250
    })
})
训练集 Dataset({
    features: ['image', 'label'],
    num_rows: 75750
})
第一个数据集 {'image': <PIL.Image.Image image mode=RGB size=384x512 at 0x7A1FCE415750>, 'label': 6}
['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']
101

总共101个food总类。
显示第二张图片和label

import matplotlib.pyplot as plt
plt.imshow(ds[1]["image"])
plt.axis('off')  # 关闭坐标轴
plt.show()
print(labels[ds[1]["label"]])

显示
在这里插入图片描述
这里我们看到第二个数据集的label=6,也就是beignets。

我们需要生成label和id关系的字典。

labels = ds["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

print(id2label)
print(label2id)

输出:

{0: 'apple_pie', 1: 'baby_back_ribs', 2: 'baklava', 3: 'beef_carpaccio', 4: 'beef_tartare', 5: 'beet_salad', 6: 'beignets', 7: 'bibimbap', 8: 'bread_pudding', 9: 'breakfast_burrito', 10: 'bruschetta', 11: 'caesar_salad', 12: 'cannoli', 13: 'caprese_salad', 14: 'carrot_cake', 15: 'ceviche', 16: 'cheesecake', 17: 'cheese_plate', 18: 'chicken_curry', 19: 'chicken_quesadilla', 20: 'chicken_wings', 21: 'chocolate_cake', 22: 'chocolate_mousse', 23: 'churros', 24: 'clam_chowder', 25: 'club_sandwich', 26: 'crab_cakes', 27: 'creme_brulee', 28: 'croque_madame', 29: 'cup_cakes', 30: 'deviled_eggs', 31: 'donuts', 32: 'dumplings', 33: 'edamame', 34: 'eggs_benedict', 35: 'escargots', 36: 'falafel', 37: 'filet_mignon', 38: 'fish_and_chips', 39: 'foie_gras', 40: 'french_fries', 41: 'french_onion_soup', 42: 'french_toast', 43: 'fried_calamari', 44: 'fried_rice', 45: 'frozen_yogurt', 46: 'garlic_bread', 47: 'gnocchi', 48: 'greek_salad', 49: 'grilled_cheese_sandwich', 50: 'grilled_salmon', 51: 'guacamole', 52: 'gyoza', 53: 'hamburger', 54: 'hot_and_sour_soup', 55: 'hot_dog', 56: 'huevos_rancheros', 57: 'hummus', 58: 'ice_cream', 59: 'lasagna', 60: 'lobster_bisque', 61: 'lobster_roll_sandwich', 62: 'macaroni_and_cheese', 63: 'macarons', 64: 'miso_soup', 65: 'mussels', 66: 'nachos', 67: 'omelette', 68: 'onion_rings', 69: 'oysters', 70: 'pad_thai', 71: 'paella', 72: 'pancakes', 73: 'panna_cotta', 74: 'peking_duck', 75: 'pho', 76: 'pizza', 77: 'pork_chop', 78: 'poutine', 79: 'prime_rib', 80: 'pulled_pork_sandwich', 81: 'ramen', 82: 'ravioli', 83: 'red_velvet_cake', 84: 'risotto', 85: 'samosa', 86: 'sashimi', 87: 'scallops', 88: 'seaweed_salad', 89: 'shrimp_and_grits', 90: 'spaghetti_bolognese', 91: 'spaghetti_carbonara', 92: 'spring_rolls', 93: 'steak', 94: 'strawberry_shortcake', 95: 'sushi', 96: 'tacos', 97: 'takoyaki', 98: 'tiramisu', 99: 'tuna_tartare', 100: 'waffles'}
{'apple_pie': 0, 'baby_back_ribs': 1, 'baklava': 2, 'beef_carpaccio': 3, 'beef_tartare': 4, 'beet_salad': 5, 'beignets': 6, 'bibimbap': 7, 'bread_pudding': 8, 'breakfast_burrito': 9, 'bruschetta': 10, 'caesar_salad': 11, 'cannoli': 12, 'caprese_salad': 13, 'carrot_cake': 14, 'ceviche': 15, 'cheesecake': 16, 'cheese_plate': 17, 'chicken_curry': 18, 'chicken_quesadilla': 19, 'chicken_wings': 20, 'chocolate_cake': 21, 'chocolate_mousse': 22, 'churros': 23, 'clam_chowder': 24, 'club_sandwich': 25, 'crab_cakes': 26, 'creme_brulee': 27, 'croque_madame': 28, 'cup_cakes': 29, 'deviled_eggs': 30, 'donuts': 31, 'dumplings': 32, 'edamame': 33, 'eggs_benedict': 34, 'escargots': 35, 'falafel': 36, 'filet_mignon': 37, 'fish_and_chips': 38, 'foie_gras': 39, 'french_fries': 40, 'french_onion_soup': 41, 'french_toast': 42, 'fried_calamari': 43, 'fried_rice': 44, 'frozen_yogurt': 45, 'garlic_bread': 46, 'gnocchi': 47, 'greek_salad': 48, 'grilled_cheese_sandwich': 49, 'grilled_salmon': 50, 'guacamole': 51, 'gyoza': 52, 'hamburger': 53, 'hot_and_sour_soup': 54, 'hot_dog': 55, 'huevos_rancheros': 56, 'hummus': 57, 'ice_cream': 58, 'lasagna': 59, 'lobster_bisque': 60, 'lobster_roll_sandwich': 61, 'macaroni_and_cheese': 62, 'macarons': 63, 'miso_soup': 64, 'mussels': 65, 'nachos': 66, 'omelette': 67, 'onion_rings': 68, 'oysters': 69, 'pad_thai': 70, 'paella': 71, 'pancakes': 72, 'panna_cotta': 73, 'peking_duck': 74, 'pho': 75, 'pizza': 76, 'pork_chop': 77, 'poutine': 78, 'prime_rib': 79, 'pulled_pork_sandwich': 80, 'ramen': 81, 'ravioli': 82, 'red_velvet_cake': 83, 'risotto': 84, 'samosa': 85, 'sashimi': 86, 'scallops': 87, 'seaweed_salad': 88, 'shrimp_and_grits': 89, 'spaghetti_bolognese': 90, 'spaghetti_carbonara': 91, 'spring_rolls': 92, 'steak': 93, 'strawberry_shortcake': 94, 'sushi': 95, 'tacos': 96, 'takoyaki': 97, 'tiramisu': 98, 'tuna_tartare': 99, 'waffles': 100}

加载一个图像处理器,以正确调整大小并对训练和评估图像的像素值进行归一化。

from transformers import AutoImageProcessor
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

您还可以使用图像处理器来准备一些转换函数,用于数据增强和像素缩放。

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
    [
        RandomResizedCrop(image_processor.size["height"]),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

val_transforms = Compose(
    [
        Resize(image_processor.size["height"]),
        CenterCrop(image_processor.size["height"]),
        ToTensor(),
        normalize,
    ]
)

def preprocess_train(example_batch):
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

定义训练和验证数据集,并使用set_transform函数在运行时应用转换。

train_ds = ds["train"]
val_ds = ds["validation"]

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

最后,您需要一个数据整理器来创建训练和评估数据的批次,并将标签转换为torch.tensor对象。

import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

模型

现在让我们加载一个预训练模型作为基础模型。本指南使用了google/vit-base-patch16-224-in21k模型,但您可以使用任何您想要的图像分类模型。将label2id和id2label字典传递给模型,以便它知道如何将整数标签映射到它们的类标签,并且如果您正在微调已经微调过的检查点,可以选择传递ignore_mismatched_sizes=True参数。

from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

PEFT configuration and model

每个 PEFT 方法都需要一个配置,其中包含了指定 PEFT 方法应该如何应用的所有参数。一旦配置设置好了,就将其传递给 get_peft_model() 函数,同时还要传递基础模型,以创建一个可训练的 PeftModel。

调用 print_trainable_parameters() 方法来比较 PeftModel 的参数数量与基础模型的参数数量!

LoRA将权重更新矩阵分解为两个较小的矩阵。这些低秩矩阵的大小由其秩或r确定。更高的秩意味着模型有更多的参数需要训练,但也意味着模型有更大的学习能力。您还需要指定 target_modules,确定较小矩阵插入的位置。对于本指南,您将针对注意力块的查询和值矩阵进行目标指定。设置的其他重要参数包括 lora_alpha(缩放因子)、bias(是否应该训练none、all或只有 LoRA 偏置参数)、modules_to_save(除了 LoRA 层之外需要训练和保存的模块)。所有这些参数 - 以及更多 - 都可以在 LoraConfig 中找到。

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

输出:“trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.7712775047664294”

在LoRA中,为了简化和精简,可能只针对查询和值矩阵进行权重分解,而不对键矩阵进行处理。这样可以在一定程度上减少计算量和参数数量,同时仍然提高模型的学习能力和灵活性。

训练

对于训练,让我们使用Transformers中的Trainer类。Trainer类包含一个PyTorch训练循环,当您准备好时,调用train开始训练。要自定义训练运行,请在TrainingArguments类中配置训练超参数。对于类似LoRA的方法,您可以承受更高的批量大小和学习率。

batch_size = 128

args = TrainingArguments(
    #peft_model_id,
    output_dir="/kaggle/working",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-3,
    report_to="none",
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    num_train_epochs=5,
    logging_steps=10,
    load_best_model_at_end=True,
    label_names=["labels"],
)

这里是对TrainingArguments中参数的解释:

  • output_dir: 指定训练过程中输出模型和日志的目录。
  • remove_unused_columns: 控制是否在训练过程中删除未使用的列。
  • evaluation_strategy: 指定评估策略,这里设置为“epoch”表示在每个epoch结束时进行评估。
  • save_strategy: 指定模型保存策略,这里设置为“epoch”表示在每个epoch结束时保存模型。
  • learning_rate: 学习率设置为5e-3,即0.005。
  • report_to: 控制训练过程中的报告输出,这里设置为“none”表示不输出任何报告。
  • per_device_train_batch_size: 每个设备上的训练批量大小。
  • gradient_accumulation_steps: 梯度累积步数。
  • per_device_eval_batch_size: 每个设备上的评估批量大小。
  • fp16: 控制是否使用混合精度训练。
  • num_train_epochs: 训练的总epoch数。
  • logging_steps: 控制日志输出的步数。
  • load_best_model_at_end: 在训练结束时是否加载最佳模型。
  • label_names: 标签的名称列表。

这些参数是用来配置训练过程的,例如指定训练和评估的批量大小、学习率、训练时长等等。
开始训练

trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    data_collator=collate_fn,
)
trainer.train()

使用kaggle的免费gpu T4*2(双倍时间消耗累计),gpu基本100%,gpu是一周30hhrs免费时间的,我为了节省时间,用2epoch,batch_size=128,因为kaggle的session有效期在12hours内,越快越好,否则session断开就白训练了,简单看下效果,大概1个小时左右,也可以save session让他在后台跑。
在这里插入图片描述
看下速度
在这里插入图片描述
第一次epoch完成
在这里插入图片描述
查看输出
在这里插入图片描述
点击后面的三个点下载所有的文件,然后将模型下载下来,点击输入的上传-new model
在这里插入图片描述
输入model名称,选择私有,点击create model
在这里插入图片描述
输入平台:transformer,点击addnewvariation
在这里插入图片描述
定义附件名称,选择协议,点击右下侧的create
在这里插入图片描述
关闭后点击return to notebook,就可以看到输入的模型了,点击右侧的复制路径即可。
在这里插入图片描述

这里input的是持久的不会丢失,output数据再页面关闭session关闭后就丢失,所以尽快保存下来,或者上传到huggingface。

预测

切换到p100(按分钟算,省钱)验证
在这里插入代码片
代码

model_name="/kaggle/input/image-classifity/transformers/version1/1/checkpoint-74" #复制输入的路径
from peft import PeftConfig, PeftModel
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests,torch
from datasets import load_dataset

ds = load_dataset("food101")
labels = ds["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label
    
config = PeftConfig.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(
    config.base_model_name_or_path,#google/vit-base-patch16-224-in21k
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)
model = PeftModel.from_pretrained(model, model_name)

url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/beignets.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
print(image)
image_processor = AutoImageProcessor.from_pretrained(config.base_model_name_or_path)
encoding = image_processor(image.convert("RGB"), return_tensors="pt")
with torch.no_grad():
    outputs = model(**encoding)
    logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

输出:beignets
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/677754.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

开窗函数!

开窗函数&#xff08;Window Function&#xff09;是SQL中的一种高级功能&#xff0c;允许你在一组相关行&#xff08;一个“窗口”&#xff09;上执行聚合操作&#xff0c;而不像传统聚合函数&#xff08;如SUM(), AVG(), COUNT()&#xff09;那样将所有匹配行合并成单个汇总行…

这就是英伟达 CEO 黄仁勋所说的人工智能“下一波浪潮”|TodayAI

在台湾一年一度的科技展 COMPUTEX 开幕前的周日&#xff0c;英伟达&#xff08;Nvidia&#xff09;首席执行官黄仁勋&#xff08;Jensen Huang&#xff09;表示&#xff0c;机器人和“理解物理定律的 AI”将成为下一波技术浪潮。他指出&#xff0c;英伟达目前正在推动生成式人工…

[Redis]Zset类型

Zset有序集合相对于字符串、列表、哈希、集合来说会有一些陌生。 它保留了集合不能有重复成员的特点&#xff0c;但与集合不同的是&#xff0c;有序集合中的每个元素都有一个唯一的浮点类型的分数&#xff08;score&#xff09;与之关联&#xff0c;着使得有序集合中的元素是可…

c语言:自定义类型(枚举、联合体)

前言&#xff1a; c语言中中自定义类型不仅有结构体&#xff0c;还有枚举、联合体等类型&#xff0c;上一期我们详细讲解了结构体的初始化&#xff0c;使用&#xff0c;传参和内存对齐等知识&#xff0c;这一期我们来介绍c语言中的其他自定义类型枚举和联合体的知识。 1.位段 …

jupyter notebook anaconda环境下查看|加载|更换内核

文章目录 1 问题复现2 查看内核位置3 调整python解释器位置 1 问题复现 在conda虚拟环境中使用pip安装相应package&#xff0c; 但是在jupyter notebook中加载该package时报错 [ERROR]ModuleNotFoundError: No module named shap 此时&#xff0c;除去包安装出现问题以外&am…

ARM学习(28)NXP 双coreMCU IMX1160学习

笔者最近接触到一块IMXRT1160的双core板子&#xff0c;特依次来记录学习一下 1、IMXRT1160 板子介绍 介绍一下NXP的Demo板子&#xff0c;是一个双core的板子&#xff0c;Cortex-M7和Cortex-M4&#xff0c;总计1MB的RAM空间&#xff0c;256KB的ROM空间&#xff0c;提供了丰富的…

哥斯拉、冰蝎、中国蚁剑在护网中流量特征分析,收藏起来当资料吧,24年护网用得上

护网哥斯拉、冰蝎、中国蚁剑流量分析 【点击免费领取】CSDN大礼包&#xff1a;《黑客&网络安全入门&进阶学习资源包》&#x1f517;包含了应急响应工具、入侵排查、日志分析、权限维持、Windows应急实战、Linux应急实战、Web应急实战。 护网中最担心的是木马已经到了服…

CV每日论文--2024.6.4

1、Mixed Diffusion for 3D Indoor Scene Synthesis 中文 标题&#xff1a;用于 3D 室内场景合成的混合扩散 简介&#xff1a;这篇论文提出了一种名为MiDiffusion的混合离散-连续扩散模型,用于从给定的房间类型、平面图和可能存在的物体中合成逼真的3D室内场景。 作者指出,该…

芯片验证分享5 —— 激励开发3

大家好&#xff0c;我是谷公子&#xff0c;上节课跟大家分享了黑盒技术中的等价类分析和边界值分析方法。我们这次来分享下黑盒设计中的其它技术。 边界值分析和等价类划分的一个弱点是没有对输入条件的组合进行分析。对输入组合进行验证并不是简单的事情&#xff0c;因为即使…

Linux 36.3 + JetPack v6.0@jetson-inference之语义分割

Linux 36.3 JetPack v6.0jetson-inference之语义分割 1. 源由2. segNet2.1 命令选项2.2 下载模型2.2.1 Cityscapes2.2.2 DeepScene2.2.3 MHP2.2.4 VOC2.2.5 SUN 2.3 操作示例2.3.1 单张照片2.3.2 多张照片2.3.3 视频 3. 代码3.1 Python3.2 C 4. 参考资料 1. 源由 分类和目标识…

指针的认识(野指针、规避野指针、assert宏断言)

目录 a.野指针成因 1.指针未初始化 2.指针越界访问 3.指针指向的空间释放 b.规避野指针 1.指针初始化 2.小心指针越界 3.指针变量不再使用时&#xff0c;及时置NULL&#xff0c;指针使用之前检查有效性 4.避免返回局部变量的地址 c.assert宏断言的使用 概念&#xff1…

【Kubernetes】k8s的调度约束(亲和与反亲和)

一、调度约束 list-watch 组件 Kubernetes 是通过 List-Watch 的机制进行每个组件的协作&#xff0c;保持数据同步的&#xff0c;每个组件之间的设计实现了解耦。 用户是通过 kubectl 根据配置文件&#xff0c;向 APIServer 发送命令&#xff0c;在 Node 节点上面建立 Pod 和…

Qt/C++音视频开发76-获取本地有哪些摄像头名称/ffmpeg内置函数方式

一、前言 上一篇文章是写的用Qt的内置函数方式获取本地摄像头名称集合&#xff0c;但是有几个缺点&#xff0c;比如要求Qt5&#xff0c;或者至少要求安装了多媒体组件multimedia&#xff0c;如果没有安装呢&#xff0c;或者安装的是个空的呢&#xff0c;比如很多嵌入式板子&am…

js

<!DOCTYPE html> <html><head><meta charset"utf-8"><title>Document</title></head><body><input type"button" values"点击" onclick"alert(hahaha)"><script>alert(&…

【开源三方库】Fuse.js:强大、轻巧、零依赖的模糊搜索库

1.简介 Fuse.js是一款功能强大且轻量级的JavaScript模糊搜索库&#xff0c;支持OpenAtom OpenHarmony&#xff08;以下简称“OpenHarmony”&#xff09;操作系统&#xff0c;它具备模糊搜索和排序等功能。该库高性能、易于使用、高度可配置&#xff0c;支持多种数据类型和多语…

界面控件DevExpress WinForms的流程图组件 - 可完美复制Visio功能(二)

DevExpress WinForms的Diagram&#xff08;流程图&#xff09;组件允许您复制Microsoft Visio中的许多功能&#xff0c;并能在下一个Windows Forms项目中引入信息丰富的图表、流程图和组织图。 P.S&#xff1a;DevExpress WinForms拥有180组件和UI库&#xff0c;能为Windows F…

Solidworks 提取模型中的零件,并组合成一个新的零件,放入特征库

对方发来一个STP文件&#xff0c;其中有模型的部分零件想为我所用。 Shift键鼠标左键 选取需要的零件 在选好零件上右键&#xff0c;选择“孤立” 左边找到部件&#xff0c;ctrl左键选中&#xff0c;选择“插入到新零件” 点 绿色 勾 就选择保存类型&#xff0c;完成 。 打开这…

【技术】工业机器人机械臂安装高速电主轴打磨去毛刺

随着现代工业的发展&#xff0c;机械加工在制造业中扮演着至关重要的角色。然而&#xff0c;机械加工后的零件普遍存在着毛刺问题。这些毛刺不仅影响了零件的外观&#xff0c;更对工序的定位、产品的装配以及性能产生了不良影响&#xff0c;甚至可能导致机械设备损坏等严重事故…

计算机专业本科就业还是考研?考研有哪些热门方向?

考研并不是一个逃避就业的避难所&#xff0c;也不是一个简单的提升待遇的手段。考研是提升自我的途径&#xff0c;特别是对于那些对特定技术领域有浓厚兴趣并愿意深入研究的人来说 一个本科生能够认真学三年&#xff0c;那么他们所掌握的技能和知识不应该逊色于那些通过短期培…

作为表达式调用时,无法解析类修饰器的签名。vue3+ts+vite,使用装饰器时报错

作为表达式调用时&#xff0c;无法解析类修饰器的签名。 The runtime will invoke the decorator with 2 arguments, but the decorator expects 1.ts(1238) 页面也无法打开 解决方案&#xff1a; {"extends": "vue/tsconfig/tsconfig.dom.json","in…