pytorch实战-图像分类(一)(数据预处理)

目录

1.导入各种库

2.数据预处理

2.1数据读取

2.2图像增强

3.构建数据网络 

3.1网络构建

3.2读取标签对应的名字

4.展示数据

4.1数据转换

4.2画图

5.模型训练


1.导入各种库

上代码:

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2.数据预处理

2.1数据读取

先看以下训练集和验证集存放的位置

 上代码

data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

2.2图像增强

目的:我们所收集准备训练的数据都是很可贵的,数据越多成本也就越高,所以希望将有限的数据集最大化利用,这就时图像增强的目的。

定义:如下图小灰猫,进行翻转操作,小黄猫,进行不同角度的旋转操作,这样实现了一图多用的效果,在原数据的基础上,将数据集翻了几倍。比方说你现在有一个1w的数据集,经过数据增强,可以完成10w的数据集。

 上代码

data_transforms = {
    'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(224),#从中心开始裁剪(224×224),因为训练集收集的图大小可能不同,但神经网络需要同样大小的输入.
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率,p=0.5就是说,有50%概率执行该操作。
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(), #将数据转化成tensor格式输入
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#因为本例是要用别人的模型训练,所以要参考别人例子中提供的均值,标准差,对自己的的训练集进行标准化操作。
    ]),
    'valid': transforms.Compose([transforms.Resize(256), #验证集不需要做数据增强,其他处理方法和train一样。
        transforms.CenterCrop(224), #验证集数据裁剪成和训练集一样,才能对比
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

3.构建数据网络 

3.1网络构建

batch_size = 8

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']} # 构建分类任务数据集,注意不同任务数据集构建方式不同。
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']} # 按照batch_size = 8大小加载数据。
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} # 看一下数据的数量,该例'train': 6552, 'valid': 818
class_names = image_datasets['train'].classes

3.2读取标签对应的名字

网络最后的输出是一个代表类别的数值,比方说1,2,3,但我们希望看到这个数值对应的类别,所以json存这些信息,比方说{'1': 'pink primrose'}。

with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f) 

4.展示数据

4.1数据转换

注意:进行训练时需要tensor格式的数据,所以展示的时候tensor的数据需要转换成numpy的格式,而且还需要还原回标准化的结果。

def im_convert(tensor): #im_convert转化函数
    """ 展示数据"""
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

4.2画图

fig=plt.figure(figsize=(20, 12))
columns = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

5.模型训练

下接该文:pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/58286.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一台电脑给另外一台电脑共享网络

这里写自定义目录标题 有网的电脑上操作一根网线连接两台电脑没网的电脑上 有网的电脑上操作 右键->属性->共享 如同选择以太网,勾选。确认。 一根网线连接两台电脑 没网的电脑上 没网的电脑为mips&麒麟V10 新增个网络配置ww,设置如下。 …

2.05 购物车后台刷新并显示

一.用户登录添加商品使用cookie存入购物车,并把购物车商品传入到后台 步骤1:创建购物车BO对象 public class ShopcartBO {private String itemId;private String itemImgUrl;private String itemName;private String specId;private String specName;p…

7.物联网操作系统互斥信号量

1.使用互斥信号量解决信号量导致的优先级反转, 2.使用递归互斥信号量解决互斥信号量导致的死锁。 3.高优先级主函数中多次使用同一信号量的使用,使用递归互斥信号量,但要注意每个信号量的使用要对应一个释放 优先级翻转问题 优先级翻转功能需…

牛客网Verilog刷题——VL48

牛客网Verilog刷题——VL48 题目答案 题目 在data_en为高期间,data_in将保持不变,data_en为高至少保持3个B时钟周期。表明,当data_en为高时,可将数据进行同步。本题中data_in端数据变化频率很低,相邻两个数据间的变化&…

【计算机视觉|人脸建模】SOFA:基于风格、由单一示例的2D关键点驱动的3D面部动画

本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题:SOFA: Style-based One-shot 3D Facial Animation Driven by 2D landmarks 链接:SOFA: Style-based One-shot 3D Facial Animation Driven by 2D landmarks | Proceedings of …

磁盘均衡器:HDFS Disk Balancer

HDFS Disk Balancer 背景产生的问题以及解决方法 hdfs disk balancer简介HDFS Disk Balancer功能数据传播报告 HDFS Disk Balancer开启相关命令 背景 相比较于个人PC,服务器一般可以通过挂载多块磁盘来扩大单机的存储能力在Hadoop HDFS中,DataNode负责最…

【数据结构与算法】线索化二叉树

线索化二叉树 n 个节点的二叉链表中含有 n 1 【公式 2n - (n - 1) n 1】个空指针域。利用二叉链表中的空指针域,存放指向该节点在某种遍历次序下的前驱和后继节点的指针(这种附加的指针称为“线索”)。这种加上了线索的二叉链表称为线索链…

网站是如何进行访问的?在浏览器地址栏输入网址并回车的一瞬间到页面能够展示回来,经历了什么?

这个问题是检验web和计网学习程度的经典问题。 网站访问流程: 1.域名->ip地址 1) 在输入完一个域名之后,首先是检查浏览器自身的DNS缓存是否有相应IP地址映射,如果没有对应的解析记录,浏览器会查找本机的hosts配置文件&…

【Spring Boot】Thymeleaf模板引擎 — Thymeleaf表达式

Thymeleaf表达式 本节介绍Thymeleaf的各种表达式&#xff0c;通过一些简单的例子来演示Thymeleaf的表达式及用法。 1.变量表达式 变量表达式即获取后台变量的表达式。使用${}获取变量的值&#xff0c;例如&#xff1a; <p th:text"${name}">hello</p>…

leetcode 763. 划分字母区间

2023.8.3 本题的关键是要确保同一字母需要在同一片段中&#xff0c;而这就需要关注到每个字母最后一次出现的位置。 思路&#xff1a;用一个哈希表保存每个字母&#xff08;26个&#xff09;最后一次出现的位置。然后从头遍历&#xff0c;不断更新最右边界&#xff0c;直到当前…

一个严肃的话题,ADR会取代WAF和RASP吗?

做安全的人应该都对WAF耳熟能详&#xff0c;也就是我们常说的Web应用防火墙&#xff0c;成为了应用安全防护的明星产品之一。从传统的防火墙、IDS、IPS&#xff0c;再到WAF横空出世&#xff0c;引领技术趋势若干年&#xff0c;这一阶段可以称为应用安全防护1.0时代。作为一款成…

计算机毕设 深度学习疫情社交安全距离检测算法 - python opencv cnn

文章目录 0 前言1 课题背景2 实现效果3 相关技术3.1 YOLOV43.2 基于 DeepSort 算法的行人跟踪 4 最后 0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff0c;这两…

jar命令的安装与使用

场景&#xff1a; 项目中经常遇到使用WinR软件替换jar包中的文件&#xff0c;有时候存在WinRAR解压替换时提示没有权限&#xff0c;此时winRAR不能用还有有什么方法替换jar包中的文件。 方法&#xff1a; 使用jar命令进行修改替换 问题&#xff1a; 执行jar命令报错jar 不…

【从零开始学习JAVA | 第三十七篇】初识多线程

目录 前言&#xff1a; ​编辑 引入&#xff1a; 多线程&#xff1a; 什么是多线程&#xff1a; 多线程的意义&#xff1a; 多线程的应用场景&#xff1a; 总结&#xff1a; 前言&#xff1a; 本章节我们将开始学习多线程&#xff0c;多线程是一个很重要的知识点&#xff…

MYSQL进阶-事务

1.什么是数据库事务&#xff1f; 事务是一个不可分割的数据库操作序列&#xff0c;也是数据库并发控制的基本单位&#xff0c;其执 行的结果必须使数据库从一种一致性状态变到另一种一致性状态。事务是逻辑上 的一组操作&#xff0c;要么都执行&#xff0c;要么都不执行。 事务…

使用 LangChain 搭建基于 Amazon DynamoDB 的大语言模型应用

LangChain 是一个旨在简化使用大型语言模型创建应用程序的框架。作为语言模型集成框架&#xff0c;在这个应用场景中&#xff0c;LangChain 将与 Amazon DynamoDB 紧密结合&#xff0c;构建一个完整的基于大语言模型的聊天应用。 本次活动&#xff0c;我们特意邀请了亚马逊云科…

华为云CTS 使用场景

云审计服务 CTS 云审计服务&#xff08;Cloud Trace Service&#xff09;&#xff0c;帮助您监控并记录华为云账号的活动&#xff0c;包括通过控制台、API、开发者工具对云上产品和服务的访问和使用行为&#xff0c;提供对各种云资源操作记录的收集、存储和查询功能&#xff0…

应用在多媒体手机中的低功率立体声编解码器

多媒体手机一般是指可以录制或播放视频的手机。多媒体的定义是多种媒体的综合&#xff0c;一般是图像、文字、声音等多种结合&#xff0c;所以多媒体手机是可以处理和使用图像文字声音相结合的移动设备。目前流行的多媒体概念&#xff0c;主要是指文字、图形、图像、声音等多种…

【0803作业】创建两个线程:其中一个线程拷贝图片的前半部分,另一个线程拷贝后半部分(4种方法)

方法一&#xff1a;使用pthread_create、pthread_exit、pthread_join函数【两个线程不共用同一份资源】 先在主函数创建并清空拷贝的目标文件&#xff0c;再创建两个线程&#xff0c;在两个线程内部同时打开要读取的文件以及要拷贝的目标文件&#xff08;两个线程不共用同一份资…

Vulnhub: BlueMoon: 2021靶机

kali&#xff1a;192.168.111.111 靶机&#xff1a;192.168.111.174 信息收集 端口扫描 nmap -A -sC -v -sV -T5 -p- --scripthttp-enum 192.168.111.174 80端口目录爆破&#xff0c;发现文件&#xff1a;hidden_text gobuster dir -u http://192.168.111.174 -w /usr/sha…