pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))

目录

1.迁移学习概念

2.数据预处理

 3.训练模型(基于迁移学习)

3.1选择网络,这里用resnet

3.2如果用GPU训练,需要加入以下代码

3.3卷积层冻结模块

3.4加载resnet152模

3.5解释initialize_model函数

3.6迁移学习网络搭建

3.7优化器

3.8训练模块(可以理解为主函数)

3.9开始训练

3.10微调

4.测试模型

4.1加载训练好的模型

4.2测试数据预处理

4.3数据展示

4.4提取测试数据集

4.5计算提取数据集的预测结果

4.6展示预测结果

参考文献


1.迁移学习概念

先说一下深度学习常见的问题

        1.数据集不够,通常用数据增强解决。

        2.参数难以确定,训练时间长,这就需要用迁移学习来解决

什么叫迁移学习呢:比方说有一个对100w的自行车数据集,并用VGG模型训练好的网络,而此时你想训练一个1w自行车数据集(虽然对象一样,但采集的数据会不同),也用VGG模型进行训练,你发现,你们数据集的对象一样,选用的网络模型一样,此时在初始化自己模型权重(就是卷积层,池化层和全连接层的参数)时,可以用人家训练好的模型参数(如果不这样就需要随机初始化模型权重),这样做可以节省大量寻找最优参数的时间,又可以保证参数的准确。

总结:迁移学习就是用别人的东西训练自己的东西,但要注意,为了使用别人的模型参数,要保证自己的数据对象、网络结构、输入和输出数据的结构和别人相同。比方说,别人识别狗,你不能识别 猫,别人用VGG你不能用resnet,别人输入和输入图像大小是224×224.你不能是256×256。

进一步理解迁移学习的使用1:看下图最大的红框,表示卷积层,当用别人的模型时,对卷积层的两种处理方式。

        A:作为自己模型权重的初始化参数。

        B:冻结卷积层网络,意思是直接用别人的参数,不再更新。冻结卷积层网络又分几种情况。

                B1:当数据量小时,冻结第二大红框表示的卷积层,剩下卷积层进行更新。因为数据量小时,容易过拟合,直接用别人呢参数最好。

                B2:当数据量中等时冻结最小红框表示的卷积层,剩下的卷积层进行更行。

                B3:当数据量足够大时,不冻结卷积层,用A的方法,只作为自己模型权重的初始化参数。数据量大时,虽然对象一样,但毕竟数据不同,会有一定差异,更新参数是最优选择。

 进一步理解迁移学习的使用2:说完卷积层,在说一下全连接层,必须要注意不管卷积层选A还是B,全连接层都是要更新的,原因在于,别人模型进行图像分类可能是进行1000个分类,而你只进行100或者999个分类,那么全连接层的参数肯定是不同的。

2.数据预处理

上接该文:pytorch实战-图像分类(一)(数据预处理)

 3.训练模型(基于迁移学习)

3.1选择网络,这里用resnet

model_name = 'resnet'  #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
#是否用人家训练好的特征来做
feature_extract = True 

3.2如果用GPU训练,需要加入以下代码

# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

3.3卷积层冻结模块

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

3.4加载resnet152模

注意:resnet152模型就是别人的模型。

model_ft = models.resnet152()
model_ft

3.5解释initialize_model函数

本小节只是截取pytorch官网的一个例子,用initialize_model说明在pytoch中迁移学习怎么使用,不属于本文代码

具体操作如下

        1.下载别人的模型参数,这里下载restnet152模型

        2.选择需要冻结的卷积层

        3.改变全连接层的输出个数,这里将1000改为102

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # 选择合适的模型,不同模型的初始化方法稍微有点区别
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet152
        """
        model_ft = models.resnet152(pretrained=use_pretrained) #下载resnet152模型
        set_parameter_requires_grad(model_ft, feature_extract) #选择冻结哪部分卷积层
        num_ftrs = model_ft.fc.in_features #全连接层的输入比方说全连接层是2048×1000,这就是2048.
        model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102),
                                   nn.LogSoftmax(dim=1)) #原resnet152的全连接层输出是1000,自己模型需要的输出是102,进行改动。
        input_size = 224
    return model_ft, input_size

3.6迁移学习网络搭建

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)

#GPU计算
model_ft = model_ft.to(device)

# 模型保存
filename='checkpoint.pth'

# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

3.7优化器

就是用该方法更新模型参数

# 优化器设置
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
#最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLLoss()整合
criterion = nn.NLLLoss()

3.8训练模块(可以理解为主函数)

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False,filename=filename):
    since = time.time() #
    best_acc = 0
    """
    checkpoint = torch.load(filename)
    best_acc = checkpoint['best_acc']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    model.class_to_idx = checkpoint['mapping']
    """
    model.to(device)

    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]['lr']]

    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 训练和验证
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # 训练
            else:
                model.eval()   # 验证

            running_loss = 0.0
            running_corrects = 0

            # 把数据都取个遍
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清零
                optimizer.zero_grad()
                # 只有训练的时候计算和更新梯度
                with torch.set_grad_enabled(phase == 'train'):
                    if is_inception and phase == 'train':
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:#resnet执行的是这里
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # 训练阶段更新权重
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            

            # 得到最好那次的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                  'state_dict': model.state_dict(),
                  'best_acc': best_acc,
                  'optimizer' : optimizer.state_dict(),
                }
                torch.save(state, filename)
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                scheduler.step(epoch_loss)
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
        
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 训练完后用最好的一次当做模型最终的结果
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

3.9开始训练

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, is_inception=(model_name=="inception"))

3.10微调

在2.9中得到的模型,是冻结了卷积层,只训练了全连接层,所以此时希望在此基础上再对卷积层进行训练。

for param in model_ft.parameters():
    param.requires_grad = True

# 再继续训练所有的参数,学习率调小一点
optimizer = optim.Adam(params_to_update, lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

# 损失函数
criterion = nn.NLLLoss()

# Load the checkpoint,加载自己的模型

checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
#model_ft.class_to_idx = checkpoint['mapping']

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=10, is_inception=(model_name=="inception"))

4.测试模型

4.1加载训练好的模型

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)

# GPU模式
model_ft = model_ft.to(device)

# 保存文件的名字
filename='seriouscheckpoint.pth'

# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

4.2测试数据预处理

        1.测试数据处理方法需要跟训练时一直才可以

        2.crop操作的目的是保证输入的大小是一致的

        3.标准化操作也是必须的,用跟训练数据相同的mean和std,但是需要注意一点训练数据是在0-1上进行标准化,所以测试数据也需要先归一化

        4.PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换

def process_image(image_path):
    # 读取测试数据
    img = Image.open(image_path)
    # Resize,thumbnail方法只能进行缩小,所以进行了判断
    if img.size[0] > img.size[1]:
        img.thumbnail((10000, 256))
    else:
        img.thumbnail((256, 10000))
    # Crop操作
    left_margin = (img.width-224)/2
    bottom_margin = (img.height-224)/2
    right_margin = left_margin + 224
    top_margin = bottom_margin + 224
    img = img.crop((left_margin, bottom_margin, right_margin,   
                      top_margin))
    # 相同的预处理方法
    img = np.array(img)/255
    mean = np.array([0.485, 0.456, 0.406]) #provided mean
    std = np.array([0.229, 0.224, 0.225]) #provided std
    img = (img - mean)/std
    
    # 注意颜色通道应该放在第一个位置
    img = img.transpose((2, 0, 1))
    
    return img

4.3数据展示

def imshow(image, ax=None, title=None):
    """展示数据"""
    if ax is None:
        fig, ax = plt.subplots()
    
    # 颜色通道还原
    image = np.array(image).transpose((1, 2, 0))
    
    # 预处理还原
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    ax.imshow(image)
    ax.set_title(title)
    
    return ax

4.4提取测试数据集

# 得到一个batch的测试数据
dataiter = iter(dataloaders['valid'])
images, labels = dataiter.next()

model_ft.eval()

if train_on_gpu:
    output = model_ft(images.cuda())
else:
    output = model_ft(images)

4.5计算提取数据集的预测结果

_, preds_tensor = torch.max(output, 1)

preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
preds

4.6展示预测结果

fig=plt.figure(figsize=(20, 20))
columns =4
rows = 2

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
                 color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

参考文献

1.6-训练结果与模型保存_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/61826.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

迭代器模式(Iterator)

迭代器模式是一种行为设计模式,可以在不暴露底层实现(列表、栈或树等)的情况下,遍历一个聚合对象中所有的元素。 Iterator is a behavior design pattern that can traverse all elements of an aggregate object without exposing the internal imple…

SIFT算法原理:SIFT算法详细介绍

前面我们介绍了Harris和Shi-Tomasi角点检测算法,这两种算法具有旋转不变性,但不具有尺度不变性,以下图为例,在左侧小图中可以检测到角点,但是图像被放大后,在使用同样的窗口,就检测不到角点了。…

uniapp微信小程序中打开腾讯地图获取用户位置信息

实现的效果 第一步:首先登录微信公众平台 , 需要用到AppID 第二步: 注册登录腾讯位置服务 注册需要手机号和邮箱确认,然后创建应用 创建后点击添加key 添加后会生成key,后面会用到这个key 第三步: 登录微信公众平台&a…

Git Submodule 更新子库失败 fatal: Unable to fetch in submodule path

编辑本地目录 .git/config 文件 在 [submodule “Assets/CommonModule”] 项下 加入 fetch refs/heads/:refs/remotes/origin/

记录下:win10 AMD CPU 下载 Chromium 源码并编译(版本 103.0.5060.66)

文章目录 一、一些主要地址连接二、环境配置1、如何找官方环境文档:1.1 如何找到这个不同版本的文档: 2、电脑配置:3、visual studio 2019安装:3.1 社区版下载:3.2 安装配置: 4、debugtools配置4.1 如何判断…

LeetCode每日一题——1331.数组序号转换

题目传送门 题目描述 给你一个整数数组 arr ,请你将数组中的每个元素替换为它们排序后的序号。 序号代表了一个元素有多大。序号编号的规则如下: 序号从 1 开始编号。一个元素越大,那么序号越大。如果两个元素相等,那么它们的…

上市公司-上下游和客户数据匹配(2001-2022年)

参考《中国工业经济》中陶锋(2023)的做法,对上市公司的上下游供应商和客户数据进行匹配。形成“上游供应商—目标企业—下游客户一年度数据集”。该是关于上市公司的上下游以及客户数据匹配的详细信息。 它呈现出由各种上游供应商和下游客户…

mysql高级三:sql性能优化+索引优化+慢查询日志

内容介绍 单表索引失效案例 0、思考题:如果把100万数据插入MYSQL ,如何提高插入效率 (1)关闭自动提交,只手动提交一次 (2)删除除主键索引外其他索引 (3)拼写mysql可以执…

ELK日志分析系统简介

ELK日志分析系统简介 ElasticsearchLogstashKibana主要功能Kibana日志处理步骤ELK的工作原理 日志服务器 提高安全性 集中存放日志 缺陷 ​ 对日志的分析困难 ELK日志分析系统 Elasticsearch 概述:提供了一个分布式多用户能力的全文搜索引擎 核心概念 接近实时 集群 节…

特性Attribute

本文只提及常用的特性,更多特性请查看官方文档。 AddComponentMenu - Unity 脚本 API 常用特性 AddComponentMenu 添加组件菜单 使用 AddComponentMenu 属性可在“Component”菜单中的任意位置放置脚本,而不仅是“Component > Scripts”菜单。 使用…

GODOT游戏引擎简介,包含与unity性能对比测试,以及选型建议

GODOT,是一个免费开源的3D引擎。本文以unity作对比,简述两者区别和选型建议。由于是很久以前写的ppt,技术原因视频和部分章节丢失了。建议当做业务参考。 GODOT目前为止遇到3个比较重大的基于,第一个是oprea的合作奖,…

嵌入式开发学习(STC51-13-温度传感器)

内容 通过DS18B20温度传感器,在数码管显示检测到的温度值; DS18B20介绍 简介 DS18B20是由DALLAS半导体公司推出的一种的“一线总线(单总线)”接口的温度传感器; 与传统的热敏电阻等测温元件相比,它是一…

【C++】map和set

目录 一、容器补充1.序列式容器与关联式容器2.键值对3.树形结构的关联式容器 二、set1.set的介绍2.set的使用3.multset的介绍4.multset的使用 三、map1.map的介绍2.map的使用3.multimap的介绍4.multimap的使用 一、容器补充 1.序列式容器与关联式容器 我们已经接触过STL中的部…

Mysql自动同步的详细设置步骤

以下步骤是真实的测试过程,将其记录下来,与大家共同学习。 一、环境说明: 1、主数据库: (1)操作系统:安装在虚拟机中的CentOS Linux release 7.4.1708 (Core) [rootlocalhost ~]# cat /etc/redh…

Docker学习(二十四)报错速查手册

目录 一、This error may indicate that the docker daemon is not running 报错docker login 报错截图:原因分析:解决方案: 二、Get "https://harbor.xxx.cn/v2/": EOF 报错docker login 报错截图:原因分析&#xff1a…

使用ubuntu-base制作根文件系统

1:ubuntu官网下载最小根文件系统: 放置到电脑的ubuntu中, Mkdir Ubuntu_rootfs Cd Ubuntu_rootfs Sudo tar –zxvf Ubuntu-bash-xxxxxx.tar.gz 2:电脑的ubuntu安装qemu搭建arm模拟系统 将/usr/bin/qemu-arm-static/(64位拷贝…

(黑客)自学笔记

一、自学网络安全学习的误区和陷阱 1.不要试图先成为一名程序员(以编程为基础的学习)再开始学习 行为:从编程开始掌握,前端后端、通信协议、什么都学。 缺点:花费时间太长、实际向安全过渡后可用到的关键知识并不多。…

Java基础面试题2

Java基础面试题 一、IO和多线程专题 1.介绍下进程和线程的关系 进程:一个独立的正在执行的程序 线程:一个进程的最基本的执行单位,执行路径 多进程:在操作系统中,同时运行多个程序 多进程的好处:可以充…

[webpack] 基本配置 (一)

文章目录 1.基本介绍2.功能介绍3.简单使用3.1 文件目录和内容3.2 下载依赖3.3 启动webpack 4.基本配置4.1 五大核心概念4.2 基本使用 1.基本介绍 Webpack 是一个静态资源打包工具。它会以一个或多个文件作为打包的入口, 将我们整个项目所有文件编译组合成一个或多个文件输出出去…

macbook怎么卸载软件?2023年最新全新解析macbook电脑怎样删除软件

macbook怎么卸载软件?2023年最新全新解析macbook电脑怎样删除软件。关于Mac笔记本如何卸载软件_Mac笔记本卸载软件的四种方法的知识大家了解吗?以下就是小编整理的关于Mac笔记本如何卸载软件_Mac笔记本卸载软件的四种方法的介绍,希望可以给到…