图像识别网络与训练策略——基于经典网络架构训练图像分类模型

基于经典网络架构训练图像分类模型

总体框架

数据预处理部分:

- 数据增强:torchvision中transforms模块自带功能,比较实用
- 数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可
- DataLoader模块直接读取batch数据
网络模块设置:

- 加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习
- 需要注意的是别人训练好的任务跟咱们的可不是完全一样,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务
- 训练时可以全部重头训练,也可以只训练最后咱们任务的层,因为前几层都是做特征提取的,本质任务目标是一致的
网络模型保存与测试
- 模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
- 读取模型进行实际测试

0.导入包

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
import imageio
import time
import warnings
warnings.filterwarnings("ignore")
import random
import sys
import copy
import json
from PIL import Image

1. 数据读取与预处理操作

data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

2.制作好数据源

- data_transforms中指定了所有图像预处理操作
- ImageFolder假设所有的文件按文件夹保存好,每个文件夹下面存贮同一类别的图片,文件夹的名字为分类的名字
data_transforms = {
    'train': 
        transforms.Compose([
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ]),
    'valid': 
        transforms.Compose([
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

3.准备数据集和加载器

batch_size = 128

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

4.读取标签对应的实际名字

with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)

5. 加载models中提供的模型

并且直接用训练的好权重当做初始化参数

model_name = 'resnet'  #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
#是否用人家训练好的特征来做
feature_extract = True #都用人家特征,咱先不更新
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
模型参数要不要更新
有时候用人家模型,就一直用了,更不更新咱们可以自己定
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
model_ft = models.resnet18()#18层的能快点,条件好点的也可以选152

6.把模型输出层改成自己的

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    
    model_ft = models.resnet18(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, feature_extract)
    
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 102)#类别数自己根据自己任务来
                            
    input_size = 64#输入大小根据自己配置来

    return model_ft, input_size

7.设置哪些层需要训练

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)

#GPU还是CPU计算
model_ft = model_ft.to(device)

# 模型保存,名字自己起
filename='checkpoint.pth'

# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

 8.优化器设置

optimizer_ft = optim.Adam(params_to_update, lr=1e-2)#要训练啥参数,你来定
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
criterion = nn.CrossEntropyLoss()

9.训练模块

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,filename='best.pt'):
    #咱们要算时间的
    since = time.time()
    #也要记录最好的那一次
    best_acc = 0
    #模型也得放到你的CPU或者GPU
    model.to(device)
    #训练过程中打印一堆损失和指标
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    #学习率
    LRs = [optimizer.param_groups[0]['lr']]
    #最好的那次模型,后续会变的,先初始化
    best_model_wts = copy.deepcopy(model.state_dict())
    #一个个epoch来遍历
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 训练和验证
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # 训练
            else:
                model.eval()   # 验证

            running_loss = 0.0
            running_corrects = 0

            # 把数据都取个遍
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)#放到你的CPU或GPU
                labels = labels.to(device)

                # 清零
                optimizer.zero_grad()
                # 只有训练的时候计算和更新梯度
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                # 训练阶段更新权重
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # 计算损失
                running_loss += loss.item() * inputs.size(0)#0表示batch那个维度
                running_corrects += torch.sum(preds == labels.data)#预测结果最大的和真实值是否一致
                
            
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)#算平均
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            time_elapsed = time.time() - since#一个epoch我浪费了多少时间
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            

            # 得到最好那次的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                  'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重
                  'best_acc': best_acc,
                  'optimizer' : optimizer.state_dict(),
                }
                torch.save(state, filename)
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                #scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
        
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()
        scheduler.step()#学习率衰减

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

开始训练!

- 我们现在只训练了输出层

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20)
### 再继续训练所有层
for param in model_ft.parameters():
    param.requires_grad = True

# 再继续训练所有的参数,学习率调小一点
optimizer = optim.Adam(model_ft.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

# 损失函数
criterion = nn.CrossEntropyLoss()
# 加载之前训练好的权重参数

checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=10,)

 

### 加载训练好的模型
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)

# GPU模式
model_ft = model_ft.to(device)

# 保存文件的名字
filename='best.pt'

# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

测试数据预处理

 ​​​​
- 测试数据处理方法需要跟训练时一直才可以
- crop操作的目的是保证输入的大小是一致的
- 标准化操作也是必须的,用跟训练数据相同的mean和std,但是需要注意一点训练数据是在0-1上进行标准化,所以测试数据也需要先归一化
- 最后一点,PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换

# 得到一个batch的测试数据
dataiter = iter(dataloaders['valid'])
images, labels = dataiter.next()

model_ft.eval()

if train_on_gpu:
    output = model_ft(images.cuda())
else:
    output = model_ft(images)
output表示对一个batch中每一个数据得到其属于各个类别的可能性
### 得到概率最大的那个
_, preds_tensor = torch.max(output, 1)

preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
preds

 

展示预测结果

def im_convert(tensor):
    """ 展示数据"""
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image
fig=plt.figure(figsize=(20, 20))
columns =4
rows = 2

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
                 color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/525387.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

@四年级家长,这条香港优才计划+华侨生联考捷径,一定要看!

四年级家长,这条香港优才计划华侨生联考捷径,一定要看! 香港身份的优势有多大?进可参加华侨生联考400分上内地985/211大学,退可参加香港DSE轻松上香港本地大学和海外高校。 但香港身份对子女的教育优势大小&#xff0c…

C++从入门到精通——this指针

this指针 前言一、this指针的引出问题 二、this指针的特性三、例题什么时候会出现编译报错什么时候会出现运行崩溃this指针存在哪里this指针可以为空吗 四、C语言和C实现Stack的对比C语言实现C实现 前言 this指针是一个特殊的指针,在C类的成员函数中使用。它指向调…

PPT 操作

版式 PPT中,巧妙使用母版,可以提高效率。 双击母版,选择其中一个版式,插入装饰符号。 然后选择关闭。 这个时候,在该版式下的所有页面,就会出现新加入的符号。不在该版式下的页面,不会出现新加…

【Redis】Redis的使用

登录redis [roottest2 ~]# redis-cli 127.0.0.1:6379> 或[roottest2 ~]# redis-cli -h 192.168.67.12 -p 6379 192.168.67.12:6379> redis-benchmark 测试工具 redis-benchmark 是官方自带的Redis性能测试工具,可以有效的测试Redis服务的性能 基本的测试语…

2.类与对象(上篇)

1.面向过程和面向对象初步认识 C是基于面向对象的,关注的是对象,将一件事情拆分成不同的对象,靠对象之间的交互完成。 2.类的引入 C**语言结构体中只能定义变量,在C中,结构体内不仅可以定义变量,也可以定…

(三)LTspice学习交流分析

文章目录 前言一、Edit simulation cmd二、添加激励总结 前言 上一节我们学习了LTspice的安装,很简单,无脑安装 (一)LTspice简介 (二)LTspice学习之简介2 今天我们来学习一下LTspice另一个非常重要的仿真功…

血泪教训!程序员副业接单做私活避坑指南!!!

缘起 经常有粉丝朋友向我哭诉留言,告知大白自己兼职被骗的经历: 在大家找兼职踩坑过程中,我总结下来就是以下几点血泪教训: 1. 有没有靠谱兼职推荐? 2. 零基础怎么兼职接单? 3. 怎么渗透?怎么…

C++设计模式:桥模式(五)

1、定义与动机 桥模式定义:将抽象部分(业务功能)与实现部分(平台实现)分离,使他们可以独立地变化引入动机: 由于某些类型的固有的实现逻辑,使得它们具有两个变化的维度,…

2.k8s架构

目录 k8s集群架构 控制平面 kube-apiserver kube-scheduler etcd kube-controller-manager node 组件 kubelet kube-proxy 容器运行时(Container Runtime) cloud-controller-manager 相关概念 k8s集群架构 一个Kubernetes集群至少包含一个控制…

飞书API(3):Python 自动读取多维表所有分页数据的三种方法

上一小节介绍了怎么使用 Python 读取多维表的数据,看似可以成功获取到了所有的数据,但是在实际生产使用过程中,我们会发现,上一小节的代码并不能获取到所有的多维表数据,它只能获取一页,默认是第一页。因为…

MySQL的存储引擎、索引与事务

常见的端口号 MySQL–3306http–80https–443tcp–23fcp–21tomcat–8080ssh–22oracle–1521rockermq–9876 存储引擎 使用指令查看所有引擎: show engines;从图中可以看出MySQL默认的存储引擎是InnoDB;并且在5.7版本所有的存储引擎中只有 InnoDB 是…

亮数据----教你轻松获取数据

文章目录 1. 数据采集遇到的瓶颈1.1 不会造数据?1.2 不会写爬虫代码? 2.IP 代理基础知识2.1 基本概念2.2 作用和分类2.3 IP 代理的配置和使用2.4 安全和合规 3. 为何使用亮数据 IP 代理3.1 拥有丰富的代理网络服务3.2 简单易操作的采集工具3.3 拥有各平台…

路由器对数据包的处理过程分析笔记

虽然TCP-IP协议中传输数据会在各个路由器再次经过物理层、链路层、网络层的解封装、加工、封装、转发,但是对于两个主机间的运输层,在逻辑上,应用进程是直接通信的。 路由器主要工作在网络层,但它也涉及到物理层和链路层的一些功能…

PWM 脉宽跟随方案介绍

1. 前言 数字电源产品在使用桥式电路拓扑或是多路交错控制中,有时会需要滞后臂的 PWM 脉宽严格跟随超前臂的 PWM 脉宽,或从路的 PWM 脉宽严格跟随主路的 PWM 脉宽,本文将介绍如何利用高精度定时器实现 PWM 输出脉宽跟随,一种使用…

ai智能电销机器人的核心技术,工作原理和作用

科技快速发展的同时,带来了人工智能产品的普及。而ai智能电销机器人则成为推进电销行业的产物,那么ai智能电销机器人是如何帮助企业高效触客,有效地工作,效果又如何呢?我们一起来看看吧! 一、ai智能电销机器…

软件的生命周期_瀑布模型

瀑布模型 描述软件生成到消亡的过程模型图 该模型目前实际工作中已不常用,但是该模型是其他新型模型的“鼻祖 瀑布模型的优点 每个阶段比较清楚,并且有对应的文档产生 当前一个阶段完成后,才开始后面的阶段(一次性的&#xff09…

「媒体邀约」天津媒体邀约资源有哪些?媒体宣传现场报道

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 天津作为中国北方的重要城市,拥有丰富的媒体资源,可以为各类活动提供全面的媒体宣传和现场报道。以下是天津地区的媒体邀约资源: 1. 报纸媒体 - 《天…

如何搭建APP分发平台分发平台搭建教程

搭建一个APP分发平台可以帮助开发者更好地分发和管理他们的应用程序。下面是一个简要的教程,介绍如何搭建一个APP分发平台。 1.确定需求和功能:首先,确定你的APP分发平台的需求和功能。考虑以下几个方面: 用户注册和登录&#xff…

Redis从入门到精通(九)Redis实战(六)基于Redis队列实现异步秒杀下单

文章目录 前言4.5 分布式锁-Redisson4.5.4 Redission锁重试4.5.5 WatchDog机制4.5.5 MutiLock原理 4.6 秒杀优化4.6.1 优化方案4.6.2 完成秒杀优化 4.7 Redis消息队列4.7.1 基于List实现消息队列4.7.2 基于PubSub的消息队列4.7.3 基于Stream的消息队列4.7.4 基于Stream的消息队…

DIY可视化UniApp表格组件

表格组件在移动端的用处非常广泛,特别是在那些需要展示结构化数据、进行比较分析或提供详细信息的场景中。数据展示与整理:表格是展示结构化数据的理想方式,特别是在需要展示多列和多行数据时。通过表格,用户可以轻松浏览和理解数…