【深度学习】快速制作图像标签数据集以及训练

快速制作图像标签数据集以及训练

制作DataSet

  • 先从网络收集十张图片 每种十张
    在这里插入图片描述

  • 定义dataSet和dataloader

import glob
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt


# 通过创建data.Dataset子类Mydataset来创建输入
class Mydataset(data.Dataset):
    # init() 初始化方法,传入数据文件夹路径
    def __init__(self, root):
        self.imgs_path = root

    # getitem() 切片方法,根据索引下标,获得相应的图片
    def __getitem__(self, index):
        img_path = self.imgs_path[index]

    # len() 计算长度方法,返回整个数据文件夹下所有文件的个数
    def __len__(self):
        return len(self.imgs_path)


# 使用glob方法来获取数据图片的所有路径
all_imgs_path = glob.glob(r"./Data/*/*.jpg")  # 数据文件夹路径

# 利用自定义类Mydataset创建对象brake_dataset
# 将所有的路径塞进dataset  使用每张图片的路径进行索引图片
brake_dataset = Mydataset(all_imgs_path)
# print("图片总数:{}".format(len(brake_dataset)))  # 返回文件夹中图片总个数



# 制作dataloader
brake_dataloader = torch.utils.data.DataLoader(brake_dataset, batch_size=2)  # 每次迭代时返回4个数据
# print(next(iter(break_dataloader)))

制作标签


# 为每张图片制作对应标签
species = ['sun', 'rain', 'cloud']
species_to_id = dict((c, i) for i, c in enumerate(species))
# print(species_to_id)

id_to_species = dict((v, k) for k, v in species_to_id.items())
# print(id_to_species)

# 对所有图片路径进行迭代
all_labels = []
for img in all_imgs_path:
	# 区分出每个img,应该属于什么类别
	for i, c in enumerate(species):
		if c in img:
			all_labels.append(i)
# print(all_labels)

制作数据和标签一起的dataset和dataloader

  • 上面的dataset不够完善
# 将数据转换为张量数据
# 对数据进行转换处理
transform = transforms.Compose([
	transforms.Resize((256, 256)),  # 做的第一步转换
	transforms.ToTensor()  # 第二步转换,作用:第一转换成Tensor,第二将图片取值范围转换成0-1之间,第三会将channel置前
])


class Mydatasetpro(data.Dataset):
	def __init__(self, img_paths, labels, transform):
		self.imgs = img_paths
		self.labels = labels
		self.transforms = transform

	# 进行切片
	def __getitem__(self, index):
		img = self.imgs[index]
		label = self.labels[index]
		pil_img = Image.open(img)  # pip install pillow
		pil_img = pil_img.convert('RGB')
		data = self.transforms(pil_img)
		return data, label

	# 返回长度
	def __len__(self):
		return len(self.imgs)


BATCH_SIZE = 4
brake_dataset = Mydatasetpro(all_imgs_path, all_labels, transform)
brake_dataloader = data.DataLoader(
	brake_dataset,
	batch_size=BATCH_SIZE,
	shuffle=True
)

imgs_batch, labels_batch = next(iter(brake_dataloader))

# 4 X 3 X 256 X 256
print(imgs_batch.shape)

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs_batch[:10], labels_batch[:10])):
	img = img.permute(1, 2, 0).numpy()
	plt.subplot(2, 3, i + 1)
	plt.title(id_to_species.get(label.item()))
	plt.imshow(img)
plt.show()  # 展示图片


制作训练集和测试集

# 划分数据集和测试集
index = np.random.permutation(len(all_imgs_path))

#  打乱所有图片的索引
print(index)

# 根据索引获取所有图片的路径
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]

print("打乱顺序之后的所有图片路径{}".format(all_imgs_path))
print("打乱顺序之后的所有图片索引{}".format(all_labels))

# 80%做训练集
s = int(len(all_imgs_path) * 0.8)
# print(s)

train_imgs = all_imgs_path[:s]
# print(train_imgs)
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]


# 将训练集和标签 制作dataset 需要转换为张量
train_ds = Mydatasetpro(train_imgs, train_labels, transform)  # TrainSet TensorData
test_ds = Mydatasetpro(test_imgs, test_labels, transform)  # TestSet TensorData
# print(train_ds)
# print(test_ds)
print("**********")
# 制作trainLoader
train_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)  # TrainSet Labels
test_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)  # TestSet Labels




训练代码

import torch
import torchvision.models as models
from torch import nn
from torch import optim
from DataSetMake import brake_dataloader
from DataSetMake import train_dl, test_dl


# 判断是否使用GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#  使用resnet 训练
model_ft = models.resnet50(pretrained=True)  # 使用迁移学习,加载预训练权


in_features = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(in_features, 256),
							nn.ReLU(),
							# nn.Dropout(0, 4),
							nn.Linear(256, 4),
							nn.LogSoftmax(dim=1))

model_ft = model_ft.to(DEVICE)  # 将模型迁移到gpu

# 优化器
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(DEVICE)  # 将loss_fn迁移到GPU

# Adam损失函数
optimizer = optim.Adam(model_ft.fc.parameters(), lr=0.003)


epochs = 50  # 迭代次数
steps = 0
running_loss = 0
print_every = 10
train_losses, test_losses = [], []

for epoch in range(epochs):
	model_ft.train()
	# 遍历训练集数据
	for imgs, labels in brake_dataloader:
		steps += 1

		# 标签转换为 tensor
		labels = torch.tensor(labels, dtype=torch.long)

		# 将图片和标签 放到设备上
		imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

		optimizer.zero_grad()  # 梯度归零

		#  前向推理
		outputs = model_ft(imgs)

		# 计算loss
		loss = loss_fn(outputs, labels)
		loss.backward()  # 反向传播计算梯度
		optimizer.step()  # 梯度优化

		# 累加loss
		running_loss += loss.item()

		if steps % print_every == 0:
			test_loss = 0
			accuracy = 0

			# 验证模式
			model_ft.eval()

			# 测试集 不需要计算梯度
			with torch.no_grad():
				# 遍历测试集数据
				for imgs, labels in test_dl:
					#  转换为tensor
					labels = torch.tensor(labels, dtype=torch.long)

					#  数据标签 部署到gpu
					imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

					#  前向推理
					outputs = model_ft(imgs)

					#  计算损失
					loss = loss_fn(outputs, labels)

					# 累加测试机的损失
					test_loss += loss.item()

					ps = torch.exp(outputs)
					top_p, top_class = ps.topk(1, dim=1)

					equals = top_class == labels.view(*top_class.shape)
					accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

			train_losses.append(running_loss / len(train_dl))
			test_losses.append(test_loss / len(test_dl))

			print(f"Epoch {epoch + 1}/{epochs}.. "
				  f"Train loss: {running_loss / print_every:.3f}.. "
				  f"Test loss: {test_loss / len(test_dl):.3f}.. "
				  f"Test accuracy: {accuracy / len(test_dl):.3f}")

			#  回到训练模式 训练误差清0
			running_loss = 0
			model_ft.train()
torch.save(model_ft, "aerialmodel.pth")




在这里插入图片描述

预测代码

import os
import torch
from PIL import Image
from torch import nn
from torchvision import transforms, models

i = 0  # 识别图片计数
# 这里最好新建一个test_data文件随机放一些上面整理好的图片进去
root_path = r"D:\CODE\ImageClassify\Test"  # 待测试文件夹
names = os.listdir(root_path)

for name in names:
	print(name)
	i = i + 1
	data_class = ['sun', 'rain', 'cloud']  # 按文件索引顺序排列


	#  找出文件夹中的所有图片
	image_path = os.path.join(root_path, name)
	image = Image.open(image_path)
	print(image)

	#  张量定义格式
	transform = transforms.Compose([transforms.Resize((256, 256)),
									transforms.ToTensor()])
	# 图片转换为张量
	image = transform(image)
	print(image.shape)

	#  定义resnet模型
	model_ft = models.resnet50()

	# 模型结构
	in_features = model_ft.fc.in_features
	model_ft.fc = nn.Sequential(nn.Linear(in_features, 256),
								nn.ReLU(),
								# nn.Dropout(0, 4),
								nn.Linear(256, 4),
								nn.LogSoftmax(dim=1))


	# 加载已经训练好的模型参数
	model = torch.load("aerialmodel.pth", map_location=torch.device("cpu"))

	# 将每张图片 调整维度
	image = torch.reshape(image, (1, 3, 256, 256))  # 修改待预测图片尺寸,需要与训练时一致
	model.eval()

	#  速出预测结果
	with torch.no_grad():
		output = model(image)
	print(output)  # 输出预测结果
	# print(int(output.argmax(1)))
	# 对结果进行处理,使直接显示出预测的种类  根据索引判别是哪一类
	print("第{}张图片预测为:{}".format(i, data_class[int(output.argmax(1))]))


工程目录结构

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/114189.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

项目管理-组织战略类型和层次讲解

组织战略类型和层次 对于不同的组织战略可能会采用不同的项目管理形式,组织作为项目管理的载体,其战略决策对项目管理体系的架构,对组织与项目之间责权利的划分具有深远的影响,组织的战略文化也会影响到项目的组织文化氛围。因此…

c++实现观察者模式

前言 我觉得这是最有意思的模式&#xff0c;其中一个动&#xff0c;另外的自动跟着动。发布-订阅&#xff0c;我觉得很巧妙。 代码 头文件 #pragma once #include<vector> #include<string> #include<iostream>// 抽象观察者 class Aobserver { public:v…

自制目录扫描工具并由py文件转为exe可执行程序

心血来潮让ChatGPT写了一个目录扫描工具&#xff0c;然后进行了一定的修改和完善&#xff0c;可以实现对网站目录的一个简单扫描并输出扫描结果&#xff0c;主要包括存在页面、重定向页面和禁止访问页面。 虽然代码很简单&#xff0c;但是做这个东西的过程还是挺有意思的&…

NoSQL数据库使用场景以及架构介绍

文章目录 一. 什么是NoSQL&#xff1f;二. NoSQL分类三. NoSQL与关系数据库有什么区别四. NoSQL主要优势和缺点五. NoSQL体系框架 其它相关推荐&#xff1a; 系统架构之微服务架构 系统架构设计之微内核架构 鸿蒙操作系统架构 架构设计之大数据架构&#xff08;Lambda架构、Kap…

伊朗网络间谍组织针对中东金融和政府部门

导语 近日&#xff0c;以色列网络安全公司Check Point与Sygnia发现了一起针对中东金融、政府、军事和电信部门的网络间谍活动。这一活动由伊朗国家情报和安全部门&#xff08;MOIS&#xff09;支持的威胁行为者发起&#xff0c;被称为"Scarred Manticore"。该组织被认…

前端BOM、DOM

文章目录 BOM操作window对象navigator对象&#xff08;了解即可&#xff09;history对象location对象弹出框警告框确认框提示框 计时相关1.过一段时间之后触发&#xff08;一次&#xff09;2.每隔三秒时间触发一次 DOM操作HTML DOM树 查找标签直接查找间接查找 节点操作操作 获…

数据库连接池大小的调整原则

配置连接池是开发人员经常犯的错误。配置池时需要理解几个原则&#xff08;对于某些人来说可能违反直觉&#xff09;。 想象一下&#xff0c;您有一个网站&#xff0c;虽然可能不是 Facebook 规模的&#xff0c;但仍然经常有 10,000 个用户同时发出数据库请求&#xff0c;每秒…

GD32 单片机 硬件I2C死锁解决方法

死锁的复现方式 在I2C恢复函数下个断点&#xff08;检测到I2C多次超时之后&#xff0c;应该能跳转到I2C恢复函数&#xff09;使用镊子&#xff0c;将SCL与SDA短接&#xff0c;很快就能看到程序停到恢复函数的断点上&#xff0c;此时再执行恢复函数&#xff0c;看能否正常走出&…

CSS3网页布局基础

CSS布局始于第2个版本&#xff0c;CSS 2.1把布局分为3种模型&#xff1a;常规流、浮动、绝对定位。CSS 3推出更多布局方案&#xff1a;多列布局、弹性盒、模板层、网格定位、网格层、浮动盒等。本章重点介绍CSS 2.1标准的3种布局模型&#xff0c;它们获得所有浏览器的全面、一致…

「直播回放」使用 PLC + OPC + TDengine,快速搭建烟草生产监测系统

在烟草工业场景里&#xff0c;多数设备的自动控制都是通过 PLC 可编程逻辑控制器来实现的&#xff0c;PLC 再将采集的数据汇聚至 OPC 服务器。传统的 PI System、实时数据库、组态软件等与 OPC 相连&#xff0c;提供分析、可视化、报警等功能&#xff0c;这类系统存在一些问题&…

历年网规上午真题(2017年)

解析:D/C 计算机主要性能指标:时钟频率(主频)、运算速度、运算精度、内存大小、数据处理速率(PDR)等 数据库主要指标:最大并发、负载均衡能力、最大连接数等 解析:A 敏捷开发是一种应对快速变化的需求的一种软件开发方法,是一种以人为核心、迭代、循序渐进的开发方…

项目实战:编辑页面加载库存信息

1、前端编辑页面加载水果库存信息逻辑edit.js let queryString window.location.search.substring(1) if(queryString){var fid queryString.split("")[1]window.onloadfunction(){loadFruit(fid)}loadFruit function(fid){axios({method:get,url:edit,params:{fi…

【使用Python编写游戏辅助工具】第四篇:Windows窗口操作

前言 这里是【使用Python编写游戏辅助工具】的第四篇&#xff1a;Windows窗口操作。本文主要介绍使用Python来实现Windows窗口的各种操作。 Windows窗口操作是游戏辅助功能中不可或缺的一部分。 Windows窗口操作指的是与Windows操作系统中的窗口进行交互和控制的操作&#xff…

【Redis】安装(Linuxwindow)及Redis的常用命令

Redis简介 Redis是一个开源&#xff08;BSD许可&#xff09;&#xff0c;内存存储的数据结构服务器&#xff0c;可用作数据库&#xff0c;高速缓存和消息队列代理。 它支持字符串、哈希表、列表、集合、有序集合&#xff0c;位图&#xff0c;hyperloglogs等数据类型。内置复…

【Java初阶练习题】-- 循环+递归练习题

循环练习题02 打印X图形计算1/1-1/21/3-1/41/5 …… 1/99 - 1/100 的值输出一个整数的每一位如&#xff1a;123的每一位是3&#xff0c;2&#xff0c;1模拟登录使用方法求最大值求斐波那契数列的第n项。(迭代实现)求和的重载求最大值方法的重载递归求N阶乘递归求 1 2 3 ...…

C++之初始化列表详细剖析

一、初始化列表定义 初始化列表&#xff1a;以一个冒号开始&#xff0c;接着是一个以逗号分隔的数据成员列表&#xff0c;每个"成员变量"后面跟一个放在括号中的初始值或表达式。 class Date { public:Date(int year, int month, int day): _year(year), _month(mont…

华纳云:centos系统中怎么查看cpu信息?

在CentOS系统中&#xff0c;我们可以使用一些命令来查看CPU的详细信息。下面介绍几个常用的命令&#xff1a; 1. lscpu lscpu命令可以显示CPU的架构、型号、核心数、线程数、频率等信息。 # lscpu 执行以上命令后&#xff0c;会输出类似以下内容&#xff1a; 2. cat /proc/…

3D医学三维技术影像PACS系统源码

一、系统概述 3D医学影像PACS系统&#xff0c;它集影像存储服务器、影像诊断工作站及RIS报告系统于一身,主要有图像处理模块、影像数据管理模块、RIS报告模块、光盘存档模块、DICOM通讯模块、胶片打印输出等模块组成&#xff0c; 具有完善的影像数据库管理功能&#xff0c;强大…

Oil Crop Science:DAP-seq技术揭示花生中AhTWRKY24和AhTWRKY106转录因子下游调控基因

2023年6月4日&#xff0c;青岛农业大学草业学院宋辉教授课题组的研究成果&#xff0c;发表在Oil Crop Science期刊上&#xff0c;文章题目为Identification of the target genes of AhTWRKY24 and AhTWRKY106 transcription factors reveals their regulatory network in Arach…

【好书推荐】AI时代架构师修炼之道:ChatGPT让架构师插上翅膀

目录 前言 ChatGPT对架构师工作的帮助 快速理解和分析需求 提供代码建议和解决方案 辅助系统设计和优化 提高团队协作效率 如何使用ChatGPT提高架构师工作效率 了解用户需求和分析问题 编码实践和问题解决 系统设计和优化建议 团队协作和沟通效率提升 知识管理和文…