基于安卓的虫害识别软件设计--(1)模型训练与可视化

引言

  • 简介:使用pytorch框架,从模型训练、模型部署完整地实现了一个基础的图像识别项目
  • 计算资源:使用的是Kaggle(每周免费30h的GPU)

1.创建名为“utils_1”的模块

模块中包含:训练和验证的加载器函数训练函数验证函数

import os
import sys

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def get_train_loader(image_path):
    train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform = train_transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32,
                                              shuffle=True, num_workers= 0)
    return train_loader

def get_val_loader(image_path):
    val_transform = transforms.Compose([transforms.Resize((224,224)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "validation"),
                                       transform = val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=32,
                                             shuffle = False, num_workers = 0)
    return val_loader

def train(train_loader,net):
    net.train()
    train_correct = 0.0
    train_loss = 0.0  # 初始化训练损失
    train_bar = tqdm(train_loader, file=sys.stdout)
    loss_function = nn.CrossEntropyLoss()
    loss_function = loss_function.to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    for step, data in enumerate(train_bar):
        images, labels = data
        images, labels = images.to(device),labels.to(device)
        # 梯度清零
        optimizer.zero_grad()
        # 训练
        outputs = net(images)
        # 计算损失
        loss = loss_function(outputs, labels)
        # 反向传播
        loss.backward()
        # 更新权重
        optimizer.step()
        # 统计
        _, preds = outputs.max(1)
        correct = preds.eq(labels).sum()
        train_correct += correct
        train_loss += loss.item()  # 累加损失值
        train_bar.desc = 'Training Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(
                loss.item(),
                (100. * correct) / len(outputs),
                trained_samples=step * train_loader.batch_size + len(images),
                total_samples=len(train_loader.dataset))
    train_correct = (100. * train_correct) / len(train_loader.dataset)
    train_loss /= len(train_loader)  # 计算平均损失值
    return train_correct, train_loss  # 返回训练正确率和平均损失值

def val(val_loader,net):
    net.eval()
    val_correct = 0.0
    val_loss = 0.0  # 初始化验证损失
    loss_function = nn.CrossEntropyLoss()
    loss_function = loss_function.to(device)

    val_bar = tqdm(val_loader, file=sys.stdout)
    for step, data in enumerate(val_bar):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            # 验证
            outputs = net(images)
            # 计算损失
            loss = loss_function(outputs, labels)
            # 统计
            _, preds = outputs.max(1)
            correct = preds.eq(labels).sum()
            val_correct += correct
            val_loss += loss.item()  # 累加损失值
            val_bar.desc = 'Valing Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(
                loss.item(),
                (100. * correct) / len(outputs),
                trained_samples=step * val_loader.batch_size + len(images),
                total_samples=len(val_loader.dataset))
    val_correct = (100. * val_correct) / len(val_loader.dataset)
    val_loss /= len(val_loader)  # 计算平均损失值
    return val_correct , val_loss  # 返回验证正确率和平均损失值

注意:若使用Kaggle,想要导入该模块,需要添加以下代码

import sys
sys.path.append(r'/kaggle/input/mycode2')

其中,模块路径如下图


2.主函数 

主函数包含:使用模型函数训练主函数画图代码

2.1使用模型函数 

【若使用其他模型,可chatgpt创建其函数】

(1)resnet101 

def get_resnet101(class_num):
    net_name = "resnet101"
    net = torchvision.models.resnet101(pretrained=True)
    net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input features
    net = net.to(device)
    return net_name, net

(2)resnet34 

def get_resnet34(class_num):
    net_name = "resnet34"
    net = torchvision.models.resnet34(pretrained=True)
    net.fc = Linear(in_features=512, out_features=class_num, bias=True)
    net = net.to(device)
    return net_name,net

(3)mobilenetv2

def get_mobilenet_v2(class_num):
    net_name = "mobilenet_v2"
    net = torchvision.models.mobilenet_v2(pretrained=True)
    net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)
    net = net.to(device)
    return net_name,net

 2.2画图代码 

    save_path="/kaggle/working/"  
  
    plt.figure(figsize=(12, 4))
    # loss
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')
    plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    # acc
    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')
    plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Acc')
    plt.legend()
    plt.savefig(os.path.join(save_path, 'result.png')) # 保存
    plt.show()

2.3完整代码 

import torch
import torchvision.models
from matplotlib import pyplot as plt
from torch.nn import Linear
import os

# 导入自己创建的模块
from utils_1 import get_train_loader, train, val, get_val_loader

# 模型选择
def get_resnet101(class_num):
    net_name = "resnet101"
    net = torchvision.models.resnet101(pretrained=True)
    net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input features
    net = net.to(device)
    return net_name, net

# def get_resnet34(class_num):
#     net_name = "resnet34"
#     net = torchvision.models.resnet34(pretrained=True)
#     net.fc = Linear(in_features=512, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net

# def get_mobilenet_v2(class_num):
#     net_name = "mobilenet_v2"
#     net = torchvision.models.mobilenet_v2(pretrained=True)
#     net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net

# 训练主函数
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #1 加载数据
    image_path = r"/kaggle/input/fruits3"
    train_loader = get_train_loader(image_path)
    val_loader = get_val_loader(image_path)
    #2 加载模型
    net_name,net = get_resnet34(class_num=5)
    #3 训练
    epochs = 5
    best_acc = 0
    
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    for epoch in range(epochs):
        train_acc,train_loss = train(train_loader, net)
        val_acc,val_loss = val(val_loader, net)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc.item())
        val_accs.append(val_acc.item())
        
        if best_acc<val_acc:
            best_acc = val_acc
            torch.save(net, os.path.join("/kaggle/working/", net_name + ".pt"))
    
    # 画图
    save_path="/kaggle/working/" # 图片保存路径
    
    plt.figure(figsize=(12, 4))
    # loss
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')
    plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    # acc
    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')
    plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Acc')
    plt.legend()
    plt.savefig(os.path.join(save_path, 'result.png')) # 保存
    plt.show()

2.4训练效果与模型文件

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/665842.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++ 特殊运算符

一 赋值运算符 二 等号作用 三 优先级和结合顺序 四 左值和右值 五 字节数运算符 条件运算符 使用条件运算符注意 逗号运算符 优先级和结合顺序 总结

【C++】问题及补充(2)

string s2“hello word”;是怎么进行隐式类型转换的 在这里&#xff0c;"hello world"是一个C字符串常量&#xff0c;而s2是一个std::string类型的变量。当你将C字符串常量赋值给一个std::string类型的变量时&#xff0c;会发生隐式类型转换。编译器会将C字符串常量转…

Vue常用自定义指令、纪录篇

文章目录 一、元素尺寸发生变化时二、点击元素外自定义指令三、元素拖拽自定义指令四、防抖自定义指令五、节流自定义指令六、权限判断自定义指令 一、元素尺寸发生变化时 使用场景&#xff1a; 当元素的尺寸发生变化时需要去适配一些元素时。 或者在元素尺寸发生变化时要去适配…

下载安装nvm,使用nvm管理node.js版本

目录 一、下载安装nvm&#xff08;windows&#xff09; 二、使用nvm管理node.js版本 &#xff08;1&#xff09;nvm命令行 &#xff08;2&#xff09; 使用nvm管理node.js版本 ①查看nvm版本 ②显示活动的node.js版本 ③列出可供下载的node.js版本 ④安装node.js指定版本 ⑤列出…

19.Redis之集群

1.集群的基本介绍 集群 这个词.广义的集群,只要你是多个机器,构成了分布式系统, 都可以称为是一个"集群"前面主从结构,哨兵模式,也可以称为是"广义的集群”狭义的集群,redis 提供的集群模式, 这个集群模式之下,主要是要解决,存储空间不足的问题(拓展存储空间) …

原生小程序一键获取手机号

1.效果图 2.代码index.wxml <!-- 获取手机号 利用手机号快速填写的功能&#xff0c;将button组件 open-type 的值设置为 getPhoneNumber--><button open-type"getPhoneNumber" bindgetphonenumber"getPhoneNumber">获取手机号</button> …

【再探】设计模式—访问者模式、策略模式及状态模式

访问者模式是用于访问复杂数据结构的元素&#xff0c;对不同的元素执行不同的操作。策略模式是对于具有多种实现的算法&#xff0c;在运行过程中可动态选择使用哪种具体的实现。状态模式是用于具有不同状态的对象&#xff0c;状态之间可以转换&#xff0c;且不同状态下对象的行…

Threejs(WebGL)绘制线段优化:Shader修改gl.LINES模式为gl.LINE_STRIP

目录 背景 思路 Threejs实现 记录每条线的点数 封装原始裁剪索引数据 封装合并几何体的缓冲数据&#xff1a;由裁剪索引组成的 IntArray 守住该有的线段&#xff01; 修改顶点着色器 修改片元着色器 完整代码 WebGL实现类似功能&#xff08;简易版&#xff0c;便于测…

极验4点选逆向 JS逆向分析 最新版验证码

目录 声明&#xff01; 一、请求流程分析 二、加密参数w与payload 三、参数w生成位置 四、结果展示&#xff1a; 原创文章&#xff0c;请勿转载&#xff01; 本文内容仅限于安全研究&#xff0c;不公开具体源码。维护网络安全&#xff0c;人人有责。 声明&#xff01; 本文章…

mirth Connect 自定义JAVA_HOME

mirth Connect 自定义JAVA_HOME 1、背景 服务器上安装了两个不同版本的Java&#xff0c;我希望Mirth服务使用与默认系统不同的版本。自定义指定java版本 2、解决方法 2.1 优先级说明 系统变量JAVA_HOME (设置后&#xff0c;mirth会根据这个进行启动运行服务&#xff0c;优先级…

家政预约小程序10公众号集成

目录 1 使用测试号3 工作流配置4 配置关注事件脚本5 注册开放平台6 获取公众号access_token6 实现关注业务逻辑总结 我们本次实战项目构建的相当于一个预约平台&#xff0c;既有家政企业&#xff0c;也有家政服务人员还有用户。不同的人员需要收到不同的消息&#xff0c;比如用…

根据状态转移图实现时序电路 (三段式状态机)

看图编程 * ** 代码 module seq_circuit(input C ,input clk ,input rst_n,output wire Y ); reg [1:0] current_stage ; reg [1:0] next_stage ; reg Y_reg; //输出//第一段 &#xff1a; 初始化当前状态和…

vmware esxi虚拟化数据迁移

1、启用esxi的ssh 登录esxi的web界面&#xff0c;选择主机-》操作——》服务——》启动ssh 2.xshell登录esxi 3、找到虚拟机所在目录 blog.csdnimg.cn/direct/d57372536a4145f2bcc1189d02cc7da8.png)#### 3在传输数据前需关闭防火墙服务 查看防火墙状态&#xff1a;esxcli …

vue3学习(六)

前言 接上一篇学习笔记&#xff0c;今天主要是抽空学习了vue的状态管理&#xff0c;这里学习的是vuex&#xff0c;版本4.1。学习还没有学习完&#xff0c;里面有大坑&#xff0c;难怪现在官网出的状态管理用Pinia。 一、vuex状态管理知识点 上面的方式没有写全&#xff0c;还有…

如何修改开源项目中发现的bug?

如何修改开源项目中发现的bug&#xff1f; 目录 如何修改开源项目中发现的bug&#xff1f;第一步&#xff1a;找到开源项目并建立分支第二步&#xff1a;克隆分支到本地仓库第三步&#xff1a;在本地对项目进行修改第四步&#xff1a;依次使用命令行进行操作注意&#xff1a;Gi…

OAK相机如何将 YOLOv9 模型转换成 blob 格式?

编辑&#xff1a;OAK中国 首发&#xff1a;oakchina.cn 喜欢的话&#xff0c;请多多&#x1f44d;⭐️✍ 内容可能会不定期更新&#xff0c;官网内容都是最新的&#xff0c;请查看首发地址链接。 Hello&#xff0c;大家好&#xff0c;这里是OAK中国&#xff0c;我是Ashely。 专…

逆天工具一键修复图片,视频去码。简直不要太好用!

今天&#xff0c;我要向您推荐一款功能强大的本地部署软件&#xff0c;它能够在您的计算机上一键修复图片和视频&#xff0c;去除令人不悦的码赛克&#xff08;轻度马赛克&#xff09;。这款软件是开源的&#xff0c;并在GitHub上公开可用&#xff0c;您可以免费下载并使用。 …

智能制造案例专题|与MongoDB一起解锁工业4.0转型与增长的无限潜力!

MongoDB 智能制造 数字化技术的洪流在各个产业链的主干和枝节涌现。在工业制造领域&#xff0c;能否通过数字化技术实现各生产要素、生产环节之间的紧密配合&#xff0c;高效规划、管理整个生产流程&#xff0c;是企业提升韧性、赢得竞争的关键。随着工业4.0的深入发展和智能…

Kafka自定义分区器编写教程

1.创建java类MyPartitioner并实现Partitioner接口 点击灯泡选择实现方法&#xff0c;导入需要实现的抽象方法 2.实现方法 3.自定义分区器的使用 在自定义生产者消息发送时&#xff0c;属性配置上加入自定义分区器 properties.put(ProducerConfig.PARTITIONER_CLASS_CONFIG,&q…

stack和queue(1)

一、stack的简单介绍和使用 1.1 stack的介绍 1.stack是一种容器适配器&#xff0c;专门用在具有先进后出&#xff0c;后进先出操作的上下文环境中&#xff0c;其删除只能从容器的一端进行元素的插入和弹出操作。 2.stack是作为容器适配器被实现的&#xff0c;容器适配器即是…