pytorch实现水果2分类(蓝莓,苹果)

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.path

import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class DtataAndLabel(Dataset):
    def __init__(self,path='fruits',is_train=True):
        self.tran=transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(size=(88,88))
        ])
        is_train='train' if True else 'test'
        self.data=[]
        path=os.path.join(path,is_train)
        print('path=',path)
        print(os.path.join(path, '*', '*'))
        img_paths=glob.glob(os.path.join(path,'*','*'))
        for img_path in img_paths:
            label=0 if img_path.split('\\')[-2]=='blueberry' else 1
            self.data.append((img_path,label))
    def __getitem__(self, idx):
        #每一张图片返回一个img_tensor,one_hot
        img_path,label =self.data[idx]
        img=cv2.imread(img_path)
        # img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
        img_tensor=self.tran(img)
        img_tensor=img_tensor/255
        img_tensor=torch.flatten(img_tensor)
        one_hot=torch.zeros(2)
        one_hot[label]=1
        return img_tensor,one_hot
    def __len__(self):
        return len(self.data)

if __name__ == '__main__':
    # 测试
    data=DtataAndLabel()
    print(data[1][0].shape)
    print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(88*88*3, 800),
            nn.ReLU(),
            nn.Linear(800, 500),
            nn.ReLU(),
            nn.Linear(500, 800),
            nn.ReLU(),
            nn.Linear(800, 200),
            nn.ReLU(),
            nn.Linear(200, 2),

        )
        self.softmax=nn.Softmax(dim=1)
    def forward(self,x):
        x=self.model(x)
        x=self.softmax(x)
        return x
if __name__ == '__main__':
    net=Net()
    #测试一下
    x=torch.randn(1,100*100)
    out=net(x)
    print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():
    def __init__(self):
        self.writer = SummaryWriter("logs")
        self.train_data=DtataAndLabel(is_train=True)
        self.test_data=DtataAndLabel(is_train=False)
        #使用加载器分批加载
        self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)
        self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)
        #初始化网络
        #损失函数
        #优化器

        net=Net()
        self.net=net
        self.loss=nn.MSELoss()
        self.opt=torch.optim.Adam(net.parameters(),lr=0.001)
        self.min_loss=100.0
        self.weight_path='weight/best.pt'

    def train(self,epoch):
        sum_loss = 0
        sum_acc = 0
        for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):
            out = self.net(img_tensors)
            loss = self.loss(out, targets)

            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

            sum_loss += loss.item()
            pred_cls = torch.argmax(out, dim=1)
            target_cls = torch.argmax(targets, dim=1)
            accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))
            sum_acc += accuracy.item()
        avg_loss = sum_loss / len(self.train_loader)
        avg_acc = sum_acc / len(self.train_loader)
        print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
        self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)
        self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)

    def test(self,epoch):
        sum_loss = 0
        sum_acc = 0
        for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):
            out = self.net(img_tensors)
            loss = self.loss(out, targets)
            sum_loss += loss.item()
            pred_cls = torch.argmax(out, dim=1)
            target_cls = torch.argmax(targets, dim=1)
            accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))
            sum_acc += accuracy.item()
        avg_loss = sum_loss / len(self.test_loader)
        avg_acc = sum_acc / len(self.test_loader)
        print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
        self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)
        self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)
        if avg_loss<self.min_loss:
            self.min_loss=min(self.min_loss,avg_loss)
            torch.save(self.net.state_dict(), self.weight_path)
    def run(self):
        for epo in range(100):
            self.train(epo)
            self.test(epo)

if __name__ == '__main__':
    trainer=TrainAndTest()
    trainer.run()



精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/792666.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

错误处理在网络爬虫开发中的重要性:Perl示例_引言

错误处理的必要性 在网络爬虫的开发过程中&#xff0c;可能会遇到多种错误&#xff0c;包括但不限于&#xff1a; 网络连接问题服务器错误&#xff08;如404或500错误&#xff09;目标网站结构变化超时问题权限问题 错误处理机制可以确保在遇到这些问题时&#xff0c;爬虫能…

防火墙之安全策略

目录 前言&#xff1a; 一、实验需求 二、需求分析 三、具体操作 1&#xff0c;在安全区域中新建拓扑图中所显示的区域 2&#xff0c;在对象模块中的时间段中添加一个9&#xff1a;00-18&#xff1a;00 3&#xff0c;写两条安全策略&#xff0c;一个是生产区&#xff0c…

记录文字视差背景学习

效果图 文字背景会随鼠标上下移动变成红色或透明 html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><titl…

前端使用pinia中存入的值

导入pinia,创建pinia实例 使用pinia中的值

【Spring Boot】关系映射开发(二):一对多映射

《JPA 从入门到精通》系列包含以下文章&#xff1a; Java 持久层 API&#xff1a;JPA认识 JPA 的接口JPA 的查询方式基于 JPA 开发的文章管理系统&#xff08;CRUD&#xff09;关系映射开发&#xff08;一&#xff09;&#xff1a;一对一映射关系映射开发&#xff08;二&#…

设计已死?AIGC时代创意设计师的“智能设计”和“人智协作”

在当今这个AIGC&#xff08;AI Generated Content&#xff09;时代&#xff0c;有人宣称“设计已死”&#xff0c;认为人工智能将会完全取代设计师的工作。然而&#xff0c;事实真的如此吗&#xff1f;本文将深入探讨AIGC时代下创意设计师的“智能设计”与“人智协作”。 AIGC是…

Elasticsearch:Node.js ECS 日志记录 - Pino

在我的上一篇文章 “Beats&#xff1a;使用 Filebeat 从 Python 应用程序中提取日志” 里&#xff0c;我详述了如何使用 Python 来生成日志&#xff0c;并使用 Filebeat 来收集日志到 Elasticsearch 中。在今天的文章中&#xff0c;我来详细描述如何使用 Node.js 来生成 ECS 相…

什么是数据标准?企业进行数据标准管理有这么多好处?

随着大数据、云计算和人工智能等技术的兴起&#xff0c;企业对数据的依赖程度日益增加。然而&#xff0c;数据的快速增长带来了一系列挑战&#xff0c;其中之一就是如何确保数据的一致性和准确性。这正是数据标准管理发挥作用的地方。 数据标准管理构成了数据管理的核心部分&a…

C语言学习网站推荐

1.C语言中文网&#xff1a; http://c.biancheng.net/view/510.html 此网站会提供很多关于C语言的入门算法&#xff0c;教程等等。分享的C语言知识结构是比较完整的。此网站还提供了很多关于C语言的资料&#xff0c;代码。有些资料是免费的&#xff0c;有些是需要付费的。 2.5…

常用知识碎片 Vue3 ref和reactive (内含其他常用知识)

目录 ref和reactive ref reactive 总结&#xff1a; setup语法糖 语法糖是啥&#xff1f; Vue3 setup语法糖 Vue3 不使用setup语法糖示例&#xff1a; Vue3 使用setup语法糖示例&#xff1a; ref和eative主要区别 ref和reactive 在 Vue 3 中&#xff0c;ref 和 reac…

git仓库使用

一、没有仓库 首先要有gitee账号 创建仓库 有了仓库就按已有仓库进行操作 二、已有仓库 先让仓库负责人把自己拉进仓库 成为开发者或者管理员 git clone 仓库地址 开始工作 git add . git commit -m “ 提交说明” git pull 更新一下也叫同步 将线上代码更新到本地 git pu…

BUUCTF[堆][unsortbin]

fastbin Attack 、unsorted bin 思路&#xff1a; 利用double free的方式泄漏出unsortbin中的main_arena地址。 释放一个不属于fast bin 的 chunk&#xff0c;并且该 chunk 不和 top chunk 紧邻时&#xff0c;该 chunk 会被首先放到 unsorted bin 中。 当有一个(或几个) sma…

BUCK电源芯片,电气参数,极限参数,工作特性,引脚功能

概述 在应用DC-DC开关电源芯片时&#xff0c;通常需要关注以下参数&#xff0c;同步与非同步&#xff0c;输入电压&#xff0c;输入电流&#xff0c;输出电压&#xff0c;输出电流&#xff0c;输入输出电容的选择&#xff1b;mosfet选型&#xff0c;电感选型&#xff0c;功耗&a…

Kodcloud可道云安装与一键发布上线实现远程访问详细教程

文章目录 1.前言2. Kodcloud网站搭建2.1. Kodcloud下载和安装2.2 Kodcloud网页测试 3. cpolar内网穿透的安装和注册4. 本地网页发布4.1 Cpolar云端设置4.2 Cpolar本地设置 5. 公网访问测试6.结语 1.前言 本文主要为大家介绍一款国人自研的在线Web文件管理器可道云&#xff0c;…

[Python学习篇] Python多线程

多线程 Python 多线程编程是一种在单个程序中同时执行多个线程的技术&#xff0c;主要用于提高程序的并发性和性能&#xff0c;尤其是在 I/O 操作频繁的场景下。Python 提供了 threading 模块来支持多线程编程。 基本概念 线程&#xff1a;线程是一个独立的执行流&#xff0c;可…

【SpringBoot Web框架实战教程】08 SpringBoot 自定义异常处理输出

不积跬步&#xff0c;无以至千里&#xff1b;不积小流&#xff0c;无以成江海。大家好&#xff0c;我是闲鹤&#xff0c;公众号&#xff1a;xxh_zone&#xff0c;十多年开发、架构经验&#xff0c;先后在华为、迅雷服役过&#xff0c;也在高校从事教学3年&#xff1b;目前已创业…

基于FPGA的数字信号处理(16)--定点数的舍入模式(7)6种舍入模式的总结

前言 在前面的6篇文章中&#xff0c;分别对6中舍入模式做了详细的介绍&#xff0c;本文在前文的基础上&#xff0c;再对这6种舍入模式做一个对比和总结。 6种舍入模式 在数据处理过程中&#xff0c;为了防止数据溢出而增加位宽是一种很常见的处理方式&#xff0c;但是随着算法链…

中国科学院地理所牛书丽团队《Global Change Biology 》最新成果!

本文首发于“生态学者”微信公众号&#xff01; 在全球气候变化的背景下&#xff0c;干旱地区的扩张对生态系统的氮循环产生了深远影响。氮同位素&#xff08;δ15N&#xff09;的天然丰度&#xff0c;尤其是土壤中的δ15N&#xff0c;是评估陆地生态系统氮循环动态和氮限制的关…

固定资产定位追踪管理系统

固定资产定位跟踪管理系统是一个利用技术手段对公司资产进行定位、跟踪和管理的系统。它帮助企业实时了解资产的位置、情况和使用情况&#xff0c;提高固定资产管理的效率和准确性。以下是该系统的主要功能和优势&#xff1a;  固定资产识别和识别&#xff1a;系统通过使用识…

springboot大学校园二手书交易APP-计算机毕业设计源码25753

摘 要 在数字化与移动互联网迅猛发展的今天&#xff0c;人们对于图书的需求与消费方式也在悄然改变。为了满足广大读者对图书的热爱与追求&#xff0c;我们倾力打造了一款基于Android平台的图书交易APP。这款APP不仅汇聚了海量的图书资源&#xff0c;提供了便捷的交易平台&…