Resnet与Pytorch花图像分类

1、介绍

1.1数据集介绍

flower_data
    ├── train
    │   └── 1-102(102个文件夹)
    │   	└── XXX.jpg(每个文件夹含若干张图像)
    ├── valid
    │   └── 1-102(102个文件夹)
    └── ───	└── XXX.jpg(每个文件夹含若干张图像)  
     
cat_to_name.json:每一类花朵的"名称-编号"对应关系

1.2 任务介绍

实现102种花朵的分类任务,即通过训练train数据集后,从valid数据集中选取某一花朵图像,能准确判别其属于哪一类花朵

1.3Resnet介绍

在ResNet网络中有如下两个亮点:

  1. 提出residual结构(残差结构),并搭建超深的网络结构(突破1000层)
  2. 使用Batch Normalization加速训练(丢弃dropout)

在ResNet网络提出之前,传统的卷积神经网络都是通过将一系列卷积层与下采样层进行堆叠得到的。但是当堆叠到一定网络深度时,就会出现两个问题:

  1. 梯度消失或梯度爆炸
  2. 退化问题(degradation problem)

2、数据预处理

2.1引入头文件

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2.2数据读取

#数据读取与预处理操作
data_dir = './flower_data/'
# 训练集
train_dir = data_dir + '/train'
#验证集
valid_ir = data_dir + '/valid'

2.3制作数据源

#制作数据源
data_transfroms = {
    'train':transforms.Compose([transforms.RandomRotation(45), #随机旋转(-45~45)
    transforms.CenterCrop(224), #从中心开始裁剪
    transforms.RandomHorizontalFlip(p = 0.5), #随机水平翻转
    transforms.RandomVerticalFlip(p = 0.5), #随机垂直翻转
    transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue = 0.1),
    transforms.RandomGrayscale(p = 0.025), #概率转换成灰度率,3通道就是R=G=B
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'valid':transforms.Compose([transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
}

2.4batch数据制作

#batch数据制作
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x),data_transfroms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size = batch_size,shuffle = True) for x in ['train','valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes

2.5读取数据标签

#读取标签对应的实际名字
with open('cat_to_name.json','r') as f:
    cat_to_name = json.load(f)

查看cat_to_name.json文件:

{'21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
 '32': 'garden phlox',
 '10': 'globe thistle',
 '6': 'tiger lily',
 '93': 'ball moss',
 '33': 'love in the mist',
 '9': 'monkshood',
 '102': 'blackberry lily',
 '14': 'spear thistle',
 '19': 'balloon flower',
 '100': 'blanket flower',
 '13': 'king protea',
 '49': 'oxeye daisy',
 '15': 'yellow iris',
 '61': 'cautleya spicata',
 '31': 'carnation',
 '64': 'silverbush',
 '68': 'bearded iris',
 '63': 'black-eyed susan',
 '69': 'windflower',
 '62': 'japanese anemone',
 '20': 'giant white arum lily',
 '38': 'great masterwort',
 '4': 'sweet pea',
 '86': 'tree mallow',
 '101': 'trumpet creeper',
 '42': 'daffodil',
 '22': 'pincushion flower',
 '2': 'hard-leaved pocket orchid',
 '54': 'sunflower',
 '66': 'osteospermum',
 '70': 'tree poppy',
 '85': 'desert-rose',
 '99': 'bromelia',
 '87': 'magnolia',
 '5': 'english marigold',
 '92': 'bee balm',
 '28': 'stemless gentian',
 '97': 'mallow',
 '57': 'gaura',
 '40': 'lenten rose',
 '47': 'marigold',
 '59': 'orange dahlia',
 '48': 'buttercup',
 '55': 'pelargonium',
 '36': 'ruby-lipped cattleya',
 '91': 'hippeastrum',
 '29': 'artichoke',
 '71': 'gazania',
 '90': 'canna lily',
 '18': 'peruvian lily',
 '98': 'mexican petunia',
 '8': 'bird of paradise',
 '30': 'sweet william',
 '17': 'purple coneflower',
 '52': 'wild pansy',
 '84': 'columbine',
 '12': "colt's foot",
 '11': 'snapdragon',
 '96': 'camellia',
 '23': 'fritillary',
 '50': 'common dandelion',
 '44': 'poinsettia',
 '53': 'primula',
 '72': 'azalea',
 '65': 'californian poppy',
 '80': 'anthurium',
 '76': 'morning glory',
 '37': 'cape flower',
 '56': 'bishop of llandaff',
 '60': 'pink-yellow dahlia',
 '82': 'clematis',
 '58': 'geranium',
 '75': 'thorn apple',
 '41': 'barbeton daisy',
 '95': 'bougainvillea',
 '43': 'sword lily',
 '83': 'hibiscus',
 '78': 'lotus lotus',
 '88': 'cyclamen',
 '94': 'foxglove',
 '81': 'frangipani',
 '74': 'rose',
 '89': 'watercress',
 '73': 'water lily',
 '46': 'wallflower',
 '77': 'passion flower',
 '51': 'petunia'}

3、数据展示

3.1图像处理函数

#展示数据
def im_convert(tensor):
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229,0.224,0.225)) + np.array((0.485,0.456,0.406))
    image = image.clip(0.1)

    return image

3.2展示图像

fig=plt.figure(figsize=(20, 12))
columns = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

4、进行迁移学习

迁移学习的关键点:

  • 研究可以用哪些知识在不同的领域或者任务中进行迁移学习,即不同领域之间有哪些共有知识可以迁移
  • 研究在找到了迁移对象之后,针对具体问题所采用哪种迁移学习的特定算法,即如何设计出合适的算法来提取和迁移共有知识
  • 研究什么情况下适合迁移,迁移技巧是否适合具体应用,其中涉及到负迁移的问题。

4.1训练全连接层

加载models中提供的模型,并且直接用训练好的权重当做初始化参数 

下载链接:https://download.pytorch.org/models/resnet152-394f9c45.pth

 选择resnet网络

model_name = 'resnet'  #可选的有: ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']

#是否用官方训练好的特征来做
feature_extract = True 

设置用GPU训练

#是否用GPU来训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('cuda is not available. Training on CPU')
else:
    print('cuda is available. Training on GPU')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

屏蔽预训练模型的权重,只训练全连接层的权重: 

def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameter():
            param.requires_grad = False

选择resnet152网络

model_ft = models.resnet152()

设置优化器:

#优化器设置
optimizer_ft = optim.Adam(params_to_update,lr = 1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) #学习率每7个epoch衰减成原来的1/10
criterion = nn.NLLLoss()

定义训练模块:

# 训练模块
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename = filename):
    since = time.time()
    best_acc = 0

    model.to(device)
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]['lr']]

    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {} / {}'.format(epoch,num_epochs - 1))
        print('-' * 10)

        #训练与验证
        for phase in ['train','valid']:
            if phase == 'train':
                model.train()  #训练
            else:
                model.eval()  #验证

            running_loss = 0.0
            running_corrects = 0

            #把数据取个遍
            for inputs,labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                #清零
                optimizer.zero_grad()

                #只有训练的时候计算与更新梯度
                with torch.set_grad_enabled(phase == 'train'):
                    if is_inception and phase == 'train':
                        outputs,aux_outputs = model(inputs)
                        loss1 = criterion(outputs,labels)
                        loss2 = criterion(aux_outputs,labels)
                        loss = loss1 + 0.4 * loss2
                    else: #resnet执行的是这里
                        outputs = model(inputs)
                        loss = criterion(outputs,labels)
                        _, preds = torch.max(outputs,1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                #计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}f'.format(time_elapsed // 60,time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase,epoch_loss,epoch_acc))

            #得到最好的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer':optimizer.state_dict(),
                }
                torch.save(state,filename)
                if phase == 'valid':
                    val_acc_history.append(epoch_acc)
                    valid_losses.append(epoch_loss)
                    scheduler.step(epoch_loss)
                if phase == 'train':
                    train_acc_history.append(epoch_acc)
                    train_losses.append(epoch_loss)

        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed //60,time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    #训练完后用最好的一次当做模型最终的结果
    model.load_state_dict(best_model_wts)
    return model,val_acc_history,train_acc_history.valid_losses,train_losses,LRs

开始训练:

# 开始训练
model_ft,val_acc_history,train_acc_history,valid_lossea,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=20,is_inception=(model_name == 'inception'))

4.2训练所有层

我们从上次训练好最优的那个全连接层的参数开始,以此为基础训练所有层,设置param.requires_grad = True表明接下来训练全部网络,之后把学习率调小一点,衰减函数为每7次衰减为原来的1/10,损失函数不变

再继续训练所有层
for param in model_ft.parameters():
    param.requires_grad = True

#再继续训练所有的参数,学习率调小一点(lr)
optimizer = optim.Adam(params_to_update,lr = 1e-4)
#衰减函数(每七次衰减为原来的七分之一)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)

#损失函数
criterion = nn.NLLLoss()

 导入之前的最优结果并开始训练:

#在之前训练得到最好的模型的基础上继续训练
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

model_ft,val_acc_history,train_acc_history,valid_lossea,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=10,is_inception=(model_name == 'inception'))

5、测试网络效果

5.1测试数据预处理

首先将新训练好的checkpoint.pth重命名为serious.pth,之后加载训练好的模型:

#加载训练好的模型
model_ft,input_size = initialize_model(model_name,102,feature_extract,use_pretrained=True)

#GPU模型
model_ft = model_ft.to(device)
#保存文件的名字
filename = 'serious.pth'
#加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['beat_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

定义图像处理函数:

def process_image(image_path):
    img = Image.open(image_path)

    #Resize,thumbnail方法只能进行缩小,所以进行判断
    if img.size[0] > img.size[1]:
        img.thumbnail((10000,256))
    else:
        img.thumbnail((256,10000))

    #Crop操作
    left_margin = (img.width-224)/2
    bottom_margin = (img.height-224)/2
    right_margin = (left_margin) + 224
    top_margin = bottom_margin + 224
    img  = img.crop(left_margin,bottom_margin,right_margin,top_margin)

    #相同的预处理方法
    img = np.array(img)/255
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    img = (img - mean)/std

    #注意颜色通道应该放在第一个位置
    img = img.transpose((2,0,1))

    return img

定义图像展示函数:

#展示数据
def imshow(image,ax = None,title = None):
    if ax is None:
        fig,ax = plt.subplots()

    #颜色通道还原
    image = np.array(image).transpose((1,2,0))

    #预处理还原
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    image = std * image + mean
    image = np.clip(image,0.1)

    ax.imshow(image)
    ax.set_title(title)

    return ax

展示一个数据:

image_path = 'image_06621.jpg'
img = process_image(image_path)
imshow(img)

 

得到一个batch测试数据:

#测试一个batch数据
dataiter = iter(dataloaders['valid'])
images,labels = dataiter.next()

model_ft.eval()

if train_on_gpu:
    output = model_ft(images.cuda())
else:
    output = model_ft(images)

利用torch.max()函数计算标签值:

#得到属于类别的八个编号
_,preds_tensor = torch.ax(output,1)
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())

5.2结果可视化

#展示预测结果
fig = plt.figure(figsize=(20,20))
columns = 4
rows = 2

for idx in range(columns * rows):
    ax = fig.add_subplot(rows,columns,idx+1,xticks=[],yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} {}".format(cat_to_name[str(preds[idx])],cat_to_name[str(labels[idx].item())]),
                 color = ("green" if cat_to_name[str(preds[idx])] == cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

结果如下(绿色标题代表识别成功,红色标题代表识别失败,括号里面为真实值,括号外为预测值)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/53913.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

html5播放器视频切换和连续播放的实例

当前播放器实例可以使用changeVid接口切换正在播放的视频。当有多个视频,在上一个视频播放完毕时,自动播放下一个视频时也可采用该处理方式。 const option {vid: 88083abbf5bcf1356e05d39666be527a_8,//autoplay: true,//playsafe: , //PC端播放加密视…

Maven右侧依赖Dependencies消失

项目右侧的Maven依赖Dependencies突然消失,项目中的注解都出现报错,出现这种情况应该是因为IDEA版本早于maven版本,重新检查项目中的Maven路径,选择File->Settings->搜索Maven,检查Maven home directory&#xf…

SHELL——备份脚本

编写脚本,使用mysqldump实现分库分表备份。 1、获取分库备份的库名列表 [rootweb01 scripts]# mysql -uroot -p123456 -e "show databases;" | egrep -v "Database|information_schema|mysql|performance_schema|sys" mysql: [Warning] Using …

StoneDB亮相2023数据技术嘉年华:增强AP、升级TP、信创替换,让万千DBA用得更省心,企业用得更省钱

2023 年 4 月 8 日,第十二届『数据技术嘉年华』(DTC 2023) 在北京圆满举办。本届大会以“开源 融合 数智化 —— 引领数据技术发展,释放数据要素价值”为主题。大会汇聚众多优秀厂商、先进技术、卓越产品和优秀案例,来自数据领域的领军人物…

离线情况下解决pyinstaller生成的可执行文件过大问题

由于工作原因,我的电脑没法上传和下载文件,所以一开始选择了anaconda完成python的工作。使用了pyinstaller将脚本生成可执行文件。但是生成出来的exe巨大无比(一个简单的脚本300多M,要花两分钟时间打开),于…

[每日习题]进制转换 参数解析——牛客习题

hello,大家好,这里是bang___bang_,本篇记录2道牛客习题,进制转换(简单),参数解析(中等),如有需要,希望能有所帮助! 目录 1️⃣进制转换 2️⃣参…

【Spring事务学习】事务分类 隔离级别 事务传播机制

目录 需要知道: 🍑1、什么是事务? 🍑2、事务的主要操作3个 一、Spring中事务的实现方式 🍑1、编程式事务(手动写代码操作事务)(了解) 🍑2、声明式事务&…

前端学习——Vue (Day8)

Vue3 create-vue搭建Vue3项目 注意要使用nodejs16.0版本以上,windows升级node可以西安使用where node查看本地node位置,然后到官网下载msi文件,在本地路径下安装即可 安装完可以使用node -v检查版本信息 项目目录和关键文件 组合式API - s…

ALLEGRO之Tools

本文主要介绍了ALLEGRO的Tools菜单。 (1)Create Module:暂不清楚; (2)Padstack:主要用于查看焊盘尺寸; (3)Pad:暂不清楚; &#xff…

Segment anything(图片分割大模型)

目录 1.Segment anything 2.补充图像分割和目标检测的区别 1.Segment anything 定义:图像分割通用大模型 延深:可以预计视觉检测大模型,也快了。 进一步理解:传统图像分割对于下图处理时,识别房子的是识别房子的模型…

app自动化测试之Appium问题分析及定位

使用 Appium 进行测试时,会产生大量日志,一旦运行过程中遇到报错,可以通过 Appium 服务端的日志以及客户端的日志分析排查问题。 Appium Server日志-开启服务 通过命令行的方式启动 Appium Server,下面来分析一下启动日志&#…

飞桨AI Studio可以玩多模态了?MiniGPT4实战演练!

MiniGPT4是基于GPT3的改进版本,它的参数量比GPT3少了一个数量级,但是在多项自然语言处理任务上的表现却不逊于GPT3。项目作者以MiniGPT4-7B作为实战演练项目。 创作者:衍哲 体验链接: https://aistudio.baidu.com/aistudio/proj…

Vue 基础语法(二)

一、背景: 我们对于基础语法,说白了就是实现元素赋值,循环,判断,以及事件响应即可! 二、v-bind 我们已经成功创建了第一个 Vue 应用!看起来这跟渲染一个字符串模板非常类似,但是 V…

IO流简述

IO流IO流使用场景 什么是IO流常用的IO流字节流字符流缓冲流 BIO、NIO、AIO的区别 IO流 IO流使用场景 如果操作的是纯文本文件,优先使用字符流如果操作的是图片、视频、音频等二进制文件。优先使用字节流如果不确定文件类型,优先使用字节流。字节流是万能…

Java版知识付费 Spring Cloud+Spring Boot+Mybatis+uniapp+前后端分离实现知识付费平台免费搭建

提供职业教育、企业培训、知识付费系统搭建服务。系统功能包含:录播课、直播课、题库、营销、公司组织架构、员工入职培训等。 提供私有化部署,免费售后,专业技术指导,支持PC、APP、H5、小程序多终端同步,支持二次开发…

springboot整合tio-websocket方案实现简易聊天

写在最前: 常用的http协议是无状态的,且不能主动响应到客户端。最初想实现状态动态跟踪只能用轮询或者其他效率低下的方式,所以引入了websocket协议,允许服务端主动向客户端推送数据。在WebSocket API中,浏览器和服务…

StopWatch与ThreadLocal

目录 1、StopWatch 1、1作用: 1、2方法: 1、3使用方法 2、ThreadLocal 2、1什么是ThreadLocal 2、2简单例子 2、3使用ThreadLocal带来的四个好处 2、4主要方法 2、5ThreadLocal内存泄漏问题 1、StopWatch 1、1作用: 统计代码块耗时时…

vue中使用axios发送请求时,后端同一个session获取不到值

问题描述: 在登录页面加载完成后通过axios请求后端验证码接口(这时后端会生成一个session用于保存验证码数值),当输入完用户名、密码、验证码后请求登录接口,报错验证码输入错误,打印后端保存验证码的sessi…

【ArcGIS Pro二次开发】(54):三调名称转用地用海名称

三调地类和用地用海地类之间有点相似但并不一致。 在做规划时,拿到的三调,都需要将三调地类转换为用地用海地类,然后才能做后续的工作。 一般情况下,三调转用地用海存在【一对一,多对一和一对多】3种情况。 前2种情况…

11-3_Qt 5.9 C++开发指南_QSqlQuery的使用(QSqlQuery 是能执行任意 SQL 语句的类)

文章目录 1. QSqlQuery基本用法2. QSqlQueryModel和QSqlQuery联合使用2.1 可视化UI设计框架2.1.1主窗口的可视化UI设计框架2.1.2 对话框的可视化UI设计框架 2.2 数据表显示2.3 编辑记录对话框2.4 编辑记录2.5 插入记录2.6 删除记录2.7 记录遍历2.8 程序框架及源码2.8.1 程序整体…