一个简单的图像分类项目(六)编写脚本:初步训练

训练的脚本 ,用于训练和测试。lib.train.py: 

import time

from load_imags import train_loader, train_num
from nets import *


def main():
    # 定义网络
    print('Please choose a network:')
    print('1. ResNet18')
    print('2. VGG')

    # 选择网络
    while True:
        net_choose = input('')
        if net_choose == '1':
            net = resnet18_model().to(device)
            print('You choose ResNet18,now start training')
            break
        elif net_choose == '2':
            net = vgg_model().to(device)
            print('You choose VGG,now start training')
            break
        else:
            print('Please input a correct number!')

    # 定义损失函数和优化器
    loss_func = nn.CrossEntropyLoss()  # 交叉熵损失函数
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)  # 优化器使用Adam
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=5,
                                                gamma=0.9)  # 学习率衰减, 每5个epoch,学习率乘以0.9

    # 训练模型
    for epoch in range(num_epoches):
        trained_num = 0  # 记录训练过的图片数量
        total_correct = 0  # 记录正确数量
        print('-' * 100)
        print('Epoch {}/{}'.format(epoch + 1, num_epoches))
        begin_time = time.time()  # 记录开始时间
        net.train()  # 训练模式
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)  # 每batch_size个图像的数据
            labels = labels.to(device)  # 每batch_size个图像的标签
            trained_num += images.size(0)  # 记录训练过的图片数量
            outputs = net(images)  # 前向传播
            loss = loss_func(outputs, labels)  # 计算损失
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 优化器更新参数

            _, predicted = torch.max(outputs.data, 1)  # 预测结果
            correct = predicted.eq(labels).cpu().sum()  # 计算本batch_size的正确数量
            total_correct += correct  # 记录正确数量
            if (i + 1) % 50 == 0:  # 每50个batch_size打印一次
                print('trained: {}/{}'.format(trained_num, train_num))
                print('Loss: {:.4f}, Accuracy: {:.2f}%'.format(loss.item(), 100 * correct / images.size(0)))

        # 每5个epoch,学习率衰减
        scheduler.step()
        end_time = time.time()  # 记录结束时间
        print('Each train_epoch take time: {} s'.format(end_time - begin_time))
        print('This train_epoch accuracy: {:.2f}%'.format(100 * total_correct / train_num))


if __name__ == '__main__':
    main()
C:\Users\DY\.conda\envs\torch\python.exe E:\AI_test\image_classification\lib\train.py 
Please choose a network:
1. ResNet18
2. VGG
2
You choose VGG,now start training
----------------------------------------------------------------------------------------------------
Epoch 1/100
trained: 6400/50000
Loss: 2.3902, Accuracy: 10.16%
trained: 12800/50000
Loss: 2.3063, Accuracy: 11.72%
trained: 19200/50000
Loss: 2.1875, Accuracy: 18.75%
trained: 25600/50000
Loss: 2.1349, Accuracy: 19.53%
trained: 32000/50000
Loss: 1.9848, Accuracy: 26.56%
trained: 38400/50000
Loss: 2.0000, Accuracy: 16.41%
trained: 44800/50000
Loss: 2.0151, Accuracy: 25.78%
Each train_epoch take time: 71.04850149154663 s
This train_epoch accuracy: 19.34%
----------------------------------------------------------------------------------------------------
Epoch 2/100
trained: 6400/50000
Loss: 1.8815, Accuracy: 28.12%
trained: 12800/50000
Loss: 1.8677, Accuracy: 34.38%
trained: 19200/50000
Loss: 1.7808, Accuracy: 39.06%
trained: 25600/50000
Loss: 1.9118, Accuracy: 29.69%
trained: 32000/50000
Loss: 1.6296, Accuracy: 39.84%
trained: 38400/50000
Loss: 1.6648, Accuracy: 35.94%
trained: 44800/50000
Loss: 1.7854, Accuracy: 33.59%
Each train_epoch take time: 66.71016025543213 s
This train_epoch accuracy: 33.65%
----------------------------------------------------------------------------------------------------
Epoch 3/100
trained: 6400/50000
Loss: 1.4987, Accuracy: 44.53%
trained: 12800/50000
Loss: 1.6677, Accuracy: 41.41%
trained: 19200/50000
Loss: 1.6952, Accuracy: 43.75%
trained: 25600/50000
Loss: 1.6941, Accuracy: 38.28%
trained: 32000/50000
Loss: 1.4057, Accuracy: 49.22%
trained: 38400/50000
Loss: 1.5183, Accuracy: 44.53%
trained: 44800/50000
Loss: 1.6591, Accuracy: 37.50%
Each train_epoch take time: 68.37232995033264 s
This train_epoch accuracy: 41.65%
----------------------------------------------------------------------------------------------------
Epoch 4/100
trained: 6400/50000
Loss: 1.6636, Accuracy: 43.75%
trained: 12800/50000
Loss: 1.5985, Accuracy: 42.19%
trained: 19200/50000
Loss: 1.4054, Accuracy: 52.34%
trained: 25600/50000
Loss: 1.4520, Accuracy: 40.62%
trained: 32000/50000
Loss: 1.4574, Accuracy: 46.09%
trained: 38400/50000
Loss: 1.4711, Accuracy: 42.19%
trained: 44800/50000
Loss: 1.4806, Accuracy: 43.75%
Each train_epoch take time: 68.32443571090698 s
This train_epoch accuracy: 46.48%
----------------------------------------------------------------------------------------------------
Epoch 5/100
trained: 6400/50000
Loss: 1.2265, Accuracy: 57.03%
trained: 12800/50000
Loss: 1.3454, Accuracy: 52.34%
trained: 19200/50000
Loss: 1.3527, Accuracy: 49.22%
trained: 25600/50000
Loss: 1.2874, Accuracy: 53.12%
trained: 32000/50000
Loss: 1.3666, Accuracy: 55.47%
trained: 38400/50000
Loss: 1.4465, Accuracy: 50.00%
trained: 44800/50000
Loss: 1.2802, Accuracy: 52.34%
Each train_epoch take time: 68.22098922729492 s
This train_epoch accuracy: 50.72%
----------------------------------------------------------------------------------------------------
Epoch 6/100
trained: 6400/50000
Loss: 1.3402, Accuracy: 51.56%
trained: 12800/50000
Loss: 1.2873, Accuracy: 53.91%
trained: 19200/50000
Loss: 1.3183, Accuracy: 52.34%
trained: 25600/50000
Loss: 1.3688, Accuracy: 48.44%
trained: 32000/50000
Loss: 1.2143, Accuracy: 55.47%
trained: 38400/50000
Loss: 1.2132, Accuracy: 56.25%
trained: 44800/50000
Loss: 1.3172, Accuracy: 53.12%
Each train_epoch take time: 68.76534986495972 s
This train_epoch accuracy: 54.53%
----------------------------------------------------------------------------------------------------
Epoch 7/100
trained: 6400/50000
Loss: 1.3156, Accuracy: 53.12%
trained: 12800/50000
Loss: 1.1412, Accuracy: 60.16%
trained: 19200/50000
Loss: 1.1978, Accuracy: 57.03%
trained: 25600/50000
Loss: 1.0312, Accuracy: 55.47%
trained: 32000/50000
Loss: 1.3486, Accuracy: 50.00%
trained: 38400/50000
Loss: 1.1591, Accuracy: 60.16%
trained: 44800/50000
Loss: 1.0707, Accuracy: 63.28%
Each train_epoch take time: 68.1180489063263 s
This train_epoch accuracy: 56.99%
----------------------------------------------------------------------------------------------------
Epoch 8/100

看得出,模型是在逐步收敛的。下一步,完善训练脚本,加入测试的代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/903117.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++】How the C++ Compiler Works

Firstly it needs to pre-process our code which means that any pre-processor statements get evaluated and once our code has been pre-processed we move on to more or less tokenizing(记号化) and parsing(解析) and basically sorting out(整理) this English C lan…

第2次CCF CSP认证真题解

1、相邻数对 题目链接&#xff1a;https://sim.csp.thusaac.com/contest/2/problem/0 本题和第1次认证的第1题“相反数”差不多&#xff0c;都是考察循环遍历比较和计数。 100分代码&#xff1a; #include <iostream> using namespace std; int main(int argc, char …

一款强大的开源OCR工具,支持90+语言识别

大家好&#xff0c;今天给大家分享一款功能强大的开源光学字符识别&#xff08;OCR&#xff09;工具Surya OCR&#xff0c;它基于先进的深度学习技术&#xff0c;提供了高效的字符识别能力&#xff0c;并支持多种语言的文本检测与识别。 项目介绍 核心功能 1.多语言支持 Sur…

破局:DLinear

1. Introduction (1) time series forecasting (TSF)&#xff1b; (2) 回顾 “ Transformer (Vaswani et al. 2017) ” 的各领域优秀表现&#xff1a; (3) IMS vs. DMS : → Consequently, IMS forecasting is preferable when there is a highly-accurate single-step fore…

量化交易打怪升级全攻略

上钟&#xff01; 继续分享量化干货~ 这次要唠的是Stat Arb的新作《Quant Roadmap》(中译名《量化交易路线图》)&#xff0c;为了方便&#xff0c;下文就称呼作者为“老S”&#xff0c;根据公开资料显示&#xff0c;他可是正儿八经的的量化研究员出身&#xff0c;在漂亮国头部对…

【electron8】electron实现“图片”的另存为

注&#xff1a;该列出的代码&#xff0c;都在文章内示例出 1. 另存为按钮事件&#xff1a; const saveAsHandler async () > {const { path, sessionId } recordInfoif(typeof message ! string) return;// 因为我的图片是加密的&#xff0c;所以我需要根据接口返回的路…

全国智能手机使用数据集-dta格式(包含2015、2017、2019三个版本)

数据简介&#xff1a;为推动经济社会高质量发展&#xff0c;缓解经济下行压力&#xff0c;中国加大推动5G基建、大数据中心等科技领域基础设施的建设和完善。数字技术深入各行各业&#xff0c;催生了新业态、新模式、新机会和新就业形式。智能手机作为劳动者使用数字技术的重要…

二叉树的存储方式和遍历方式

文章目录 二叉树的存储方式二叉树的遍历方式DFS--递归遍历DFS--迭代遍历BFS--层次遍历 LC102 二叉树的存储方式 链式存储&#xff08;指针&#xff09;或 顺序存储&#xff08;数组&#xff09; (1)链式存储&#xff1a;通过指针把分布在各个地址的节点串联一起。 (2)顺序存储…

docker上传离线镜像包到Artifactory

docker上传离线镜像包到Artifactory 原创 大阳 北京晓数神州科技有限公司 2024年10月25日 17:33 北京 随着docker官方源的封禁&#xff0c;最近国内资源也出现无法拉取的问题&#xff0c;Artifactory在生产环境中&#xff0c;很少挂外网代理去官方源拉取&#xff0c;小编提供…

大模型面试-Layer normalization篇

1. Layer Norm 的计算公式写一下&#xff1f; 2. RMS Norm 的计算公式写一下&#xff1f; 3. RMS Norm 相比于 Layer Norm 有什么特点&#xff1f; 4. Deep Norm 思路&#xff1f; 5. 写一下 Deep Norm 代码实现&#xff1f; 6.Deep Norm 有什么优点&#xff1f; 7.LN 在 LLMs …

每日一题之电话号码的字母组合

给定一个仅包含数字 2-9 的字符串&#xff0c;返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下&#xff08;与电话按键相同&#xff09;。注意 1 不对应任何字母。 示例 1&#xff1a; 输入&#xff1a;digits "23" 输出&#…

微信小程序学习实录11:精通表单数据绑定,构建高效用户界面

微信小程序中的表单数据绑定是一种非常实用的功能&#xff0c;它允许开发者将页面上的表单元素与数据进行关联&#xff0c;从而实现数据的双向绑定。这样做的好处是能够简化代码&#xff0c;提高开发效率&#xff0c;并且让数据管理变得更加直观。 一、基本概念 数据绑定&am…

Spring Cloud +UniApp智慧工地源码,智慧工地综合解决方案,建筑工程云平台源码

Spring Cloud UniApp智慧工地源码&#xff0c;智慧工地全套源代码包含&#xff1a;PC端大屏端移动端 智慧工地解决方案以工程建设现场管理需求为主线&#xff0c;以AI、物联网、BIM技术为手段&#xff0c;对施工现场进行立体化、全方位、全时段管理&#xff0c;实现规范施工管…

解决VMware虚拟机的字体过小问题

前言&#xff1a; &#xff08;1&#xff09;先装VMware VMware17Pro虚拟机安装教程(超详细)-CSDN博客 &#xff08;2&#xff09;通过清华等镜像网站安装好Ubuntu镜像&#xff0c;下面贴上链接 教程虚拟机配置我没有做&#xff0c;因为学校给了现成的虚拟机~~大家需要的自己…

数据结构之单链表——考研笔记

文章目录 一.单链表定义1.什么是单链表2.代码实现3.不带头结点的单链表4.带头结点的单链表 二.单链表插入删除1.按位序插入&#xff08;带头结点&#xff09;2.插入时不带头节点3.指定节点的后插操作4.指定节点的前插操作5.按位序删除&#xff08;带头结点&#xff09;6.删除指…

2024年【北京市安全员-A证】找解析及北京市安全员-A证考试试卷

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 北京市安全员-A证找解析考前必练&#xff01;安全生产模拟考试一点通每个月更新北京市安全员-A证考试试卷题目及答案&#xff01;多做几遍&#xff0c;其实通过北京市安全员-A证证考试很简单。 1、【多选题】《中华人…

保姆级教程 | 全流程免费:合并多份长宽不同的PDF成相同大小并进行瘦身

背景 由于老板需要&#xff0c;完成不同PDF文件&#xff08;a&#xff0c;b&#xff0c;c....&#xff09;合并&#xff0c;同时要求主文件&#xff08;A&#xff09;小于6M。合并过程中发现各个PDF大小&#xff08;长宽&#xff09;并不相同&#xff0c;造成合并后效果不好也…

网站安全问题都有哪些,分别详细说明

网站安全问题涉及多个方面&#xff0c;以下是一些常见的网站安全问题及其详细说明&#xff1a; 数据泄露 问题描述&#xff1a;数据泄露是指网站存储的用户敏感信息&#xff08;如用户名、密码、信用卡信息等&#xff09;被非法获取。黑客可能通过SQL注入、XSS攻击等手段窃取这…

Unity编辑器界面及其基础功能介绍

文章目录 Unity编辑器界面编辑器默认界面布局打开和关闭编辑界面自定义界面布局Unity资源商店Unity Assets Store什么是资源商店&#xff1f;资源商店中包含哪些东西&#xff1f;如何进行素材导入&#xff1f;Unity官网购买素材或插件导入方法非官网素材导入非官网插件导入 Sce…

【WRF数据准备】基于GEE下载静态地理数据-叶面积指数LAI及绿色植被率Fpar

【WRF数据准备】基于GEE下载静态地理数据 准备:WRF所需静态地理数据(Static geographical data)数据范围说明基于GEE下载叶面积指数及绿色植被率GEE数据集介绍数据下载:LAI(叶面积指数)和Fpar(绿色植被率)数据处理:基于Python处理为单波段LAI数据参考GEE的介绍可参见另…