前言:
迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。
这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程
一 简介
迁移学习可以通过以下几种方式实现:
1.1 基于预训练模型的迁移:
将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。
1.2 网络结构迁移:
将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。
1.3 特征迁移:
将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。
word2vec
1.4 参数迁移:
将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。
本文主要例子用的是 参数迁移
二 Flatten
作用:
输入的向量x [batch, c, w, h]=>[batch, c*w*h]
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023
@author: chengxf2
"""
import torch
from torch import optim,nn
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self, x):
a = torch.tensor(x.shape[1:])
#dim 中 input 张量的每一行的乘积。
shape = torch.prod(a).item()
#print("\n ---new shape--- ",shape)
return x.view(-1,shape)
三 迁移学习
torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。
from torchvision.models import resnet152
from torchvision.models import resnet18
注意:
现有分类器分类的类型 > = 新分类器类型,再做transfer.
才能取得好的效果.
分类器 | 分类类型 |
已有分类器 | [猫,狗,鸡,鸭】 |
新分类器 | [猫,狗] |
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023
@author: chengxf2
"""
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023
@author: chengxf2
"""
import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18
from util import Flatten
batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)
root ='pokemon'
resize =224
csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)
train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()
def evalute(model, loader):
total =len(loader.dataset)
correct =0
for x,y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
acc = correct/total
return acc
def main():
trained_model = resnet152(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
Flatten(),
nn.Linear(in_features=2048, out_features=5))
optimizer = optim.Adam(model.parameters(),lr =lr)
criteon = nn.CrossEntropyLoss()
best_epoch=0,
best_acc=0
viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))
viz.line([0],[-1],win='val_loss', opts =dict(title='val_acc'))
global_step =0
for epoch in range(epochs):
print("\n --main---: ",epoch)
for step, (x,y) in enumerate(train_loader):
#x:[b,3,224,224] y:[b]
x = x.to(device)
y = y.to(device)
#print("\n --x---: ",x.shape)
logits =model(x)
loss = criteon(logits, y)
#print("\n --loss---: ",loss.shape)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
global_step +=1
if epoch %2 ==0:
val_acc = evalute(model, val_loader)
if val_acc>best_acc:
best_acc = val_acc
best_epoch =epoch
torch.save(model.state_dict(),'best.mdl')
print("\n val_acc ",val_acc)
viz.line([val_acc],[global_step],win='val_loss',update='append')
print('\n best acc',best_acc, "best_epoch: ",best_epoch)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt')
test_acc = evalute(model, test_loader)
print('\n test acc',test_acc)
if __name__ == "__main__":
main()
参考:
https://blog.csdn.net/qq_44089890/article/details/130460700
课时107 迁移学习实战_哔哩哔哩_bilibili