python模型训练

目录

1、新建模型   train_model.py

2、运行模型

(1)首先会下载data文件库

(2)完成之后会开始训练模型(10次)

3、 训练好之后,进入命令集

 4、输入命令:python -m tensorboard.main --logdir="C:\Users\15535\Desktop\day6\train"

(1)目录的绝对路径获得方法

 5、打开网页可视化图形

(1)运行完之后会自动有一个网址,点进去

 (2)显示


1、新建模型   train_model.py

import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.nn import CrossEntropyLoss


#step1.下载数据集

train_data=datasets.CIFAR10('./data',train=True,\
                            transform=torchvision.transforms.ToTensor(),
                            download=True)
test_data=datasets.CIFAR10('./data',train=False,\
                            transform=torchvision.transforms.ToTensor(),
                            download=True)

print(len(train_data))
print(len(test_data))


#step2.数据集打包
train_data_loader=DataLoader(train_data,batch_size=64,shuffle=False)
test_data_loader=DataLoader(test_data,batch_size=64,shuffle=False)

#step3.搭建网络模型

class My_Module(nn.Module):
    def __init__(self):
        super(My_Module,self).__init__()
        #64*32*32*32
        self.conv1=nn.Conv2d(in_channels=3,out_channels=32,\
                             kernel_size=5,padding=2)

        #64*32*16*16
        self.maxpool1=nn.MaxPool2d(2)

        #64*32*16*16
        self.conv2=nn.Conv2d(in_channels=32,out_channels=32,\
                             kernel_size=5,padding=2)

        #64*32*8*8
        self.maxpool2=nn.MaxPool2d(2)

        #64*64*8*8
        self.conv3=nn.Conv2d(in_channels=32,out_channels=64,\
                             kernel_size=5,padding=2)

        #64*64*4*4
        self.maxpool3=nn.MaxPool2d(2)

        #线性化
        self.flatten=nn.Flatten()
        self.linear1=nn.Linear(in_features=1024,out_features=64)
        self.linear2=nn.Linear(in_features=64,out_features=10)

    def forward(self,input):
        #input:64,3,32,32
        output1=self.conv1(input)
        output2=self.maxpool1(output1)
        output3=self.conv2(output2)
        output4=self.maxpool2(output3)
        output5=self.conv3(output4)
        output6=self.maxpool3(output5)
        output7=self.flatten(output6)
        output8=self.linear1(output7)
        output9=self.linear2(output8)

        return output9


my_model=My_Module()
# print(my_model)
loss_func=CrossEntropyLoss()#衡量模型训练的过程(输入输出之间的差值)
#优化器,lr越大模型就越“聪明”
optim = torch.optim.SGD(my_model.parameters(),lr=0.001)

writer=SummaryWriter('./train')
#################################训练###############################
for looptime in range(10):             #模型训练的次数:10
    print("------looptime:{}------".format(looptime+1))
    num=0
    loss_all=0
    for data in (train_data_loader):
        num+=1
        #前向
        imgs, targets = data
        output = my_model(imgs)
        loss_train = loss_func(output,targets)
        loss_all=loss_all+loss_train
        if num%100==0:
            print(loss_train)

        #后向backward 三步法  获取最小的损失函数
        optim.zero_grad()
        loss_train.backward()
        optim.step()

        # print(output.shape)
    loss_av=loss_all/len(test_data_loader)
    print(loss_av)
    writer.add_scalar('train_loss',loss_av,looptime)
    writer.close()
#################################验证#########################
    with torch.no_grad():
        accuracy=0
        test_loss_all=0
        for data in test_data_loader:
            imgs,targets = data
            output = my_model(imgs)
            loss_test = loss_func(output,targets)
            #output.argmax(1)---输出标签
            accuracy=(output.argmax(1)==targets).sum()

            test_loss_all = test_loss_all+loss_test
            test_loss_av = test_loss_all/len(test_data_loader)
            acc_av = accuracy/len(test_data_loader)

        print("测试集的平均损失{},测试集的准确率{}".format(test_loss_av,acc_av))
        writer.add_scalar('test_loss',test_loss_av,looptime)
        writer.add_scalar('acc',acc_av,looptime)

writer.close()

2、运行模型

(1)首先会下载data文件库

(2)完成之后会开始训练模型(10次)

3、 训练好之后,进入命令集

 4、输入命令:python -m tensorboard.main --logdir="C:\Users\15535\Desktop\day6\train"

(1)目录的绝对路径获得方法

执行下面的操作自动复制

 

 

 5、打开网页可视化图形

(1)运行完之后会自动有一个网址,点进去

 (2)显示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/421562.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

解决Unable to load class ‘org.gradle.api.attributes.VerificationType‘

在使用AdnroidStudio开发过程中难免会遇到Unable to load class org.gradle.api.attributes.VerificationType报错,可以尝试清理缓存重启解决 打开 File-》Invalidate Caches... 重启AndroidStudio后,重新加载即可,但也不是百分百解决。

java数据结构与算法刷题-----LeetCode437. 路径总和 III(前缀和必须掌握)

java数据结构与算法刷题目录(剑指Offer、LeetCode、ACM)-----主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/123063846 文章目录 1. 深度优先2. 前缀和 1. 深度优先 解题思路:时间复…

leetcode刷题(javaScript)——链表相关场景题总结

链表中的元素在内存中不是顺序存储的,而是通过next指针联系在一起的。常见的链表有单向链表、双向链表、环形链表等 在 JavaScript 刷题中涉及链表的算法有很多,常见的包括: 1. 遍历链表:从头到尾遍历链表,处理每个节点…

08、关于语法:resp?.data?.data 的含义与实际操作中可能遇到的问题

1、数据情况: 其一、从后端拿到的数据为: let resp.data {"data": [],"lag_mode": 3,"totol": 0 }或: let resp.data {"data": [],"totol": 0 }其二、目标数据为: // 想要…

1小时网络安全事件报告要求,持安零信任如何帮助用户应急响应?

12月8日,国家网信办起草发布了《网络安全事件报告管理办法(征求意见稿)》(以下简称“办法”)。拟规定运营者在发生网络安全事件时应当及时启动应急预案进行处置。 1小时报告 按照《网络安全事件分级指南》&#xff0c…

命令行启动mongodb服务器的问题及解决方案 -- Unrecognized option: storage.journal

目录 mongodb命令行启动问题 -- Unrecognized option: storage.journal问题日志:问题截图:问题来源:错误原因:解决方式: mongodb命令行启动问题 – Unrecognized option: storage.journal 同样是格式出问题的问题分析和…

ThreadLocal 为什么会内存泄漏吗?是怎么产生的?

ThreadLocal是什么 ThreadLocalMap 如何避免泄漏 ThreadLocal是什么 ThreadLocal是一个本地线程副本变量工具类。主要用于将私有线程和该线程存放的副本对象做一个映射,各个线程之间的变量互不干扰,在高并发场景下,可以实现无状态的调用&…

WPF 滑动条样式

效果图&#xff1a; 浅色&#xff1a; 深色&#xff1a; 滑动条部分代码&#xff1a; <Style x:Key"RepeatButtonTransparent" TargetType"{x:Type RepeatButton}"><Setter Property"OverridesDefaultStyle" Value"true"/&g…

[攻防世界]-Web:fileinclude解析(文件包含,添加后缀)

查看网页 查看源代码 意思就是&#xff0c;如果变量lan被设置就会触发文件包含。 但是要注意&#xff0c;这里的文件包含会自动加上后缀&#xff0c;所以payload要注意一点 payload&#xff1a; languagephp://filter/readconvert.base64-encode/resourceflag

基带信号处理设计原理图:2-基于6U VPX的双TMS320C6678+Xilinx FPGA K7 XC7K420T的图像信号处理板

基于6U VPX的双TMS320C6678Xilinx FPGA K7 XC7K420T的图像信号处理板 综合图像处理硬件平台包括图像信号处理板2块&#xff0c;视频处理板1块&#xff0c;主控板1块&#xff0c;电源板1块&#xff0c;VPX背板1块。 一、板卡概述 图像信号处理板包括2片TI 多核DSP处理…

考取ORACLE数据库OCP的必要性 Oracle数据库

OCP证书是什么&#xff1f; OCP&#xff0c;全称Oracle Certified Professional&#xff0c;是Oracle公司的Oracle数据库DBA&#xff08;Database Administrator&#xff0c;数据库管理员)认证课程。这是Oracle公司针对数据库管理领域设立的一项认证课程&#xff0c;旨在评估和…

(Sora模型风口)2024最新GPT4.0使用教程,AI绘画,一站式解决

一、前言 ChatGPT3.5、GPT4.0、GPT语音对话、Midjourney绘画&#xff0c;文档对话总结DALL-E3文生图&#xff0c;相信对大家应该不感到陌生吧&#xff1f;简单来说&#xff0c;GPT-4技术比之前的GPT-3.5相对来说更加智能&#xff0c;会根据用户的要求生成多种内容甚至也可以和…

云上攻防-云原生篇Docker安全权限环境检测容器逃逸特权模式危险挂载

知识点: 1、云原生-Docker安全-容器逃逸&特权模式 2、云原生-Docker安全-容器逃逸&挂载Procfs 3、云原生-Docker安全-容器逃逸&挂载Socket 4、云原生-Docker安全-容器逃逸条件&权限高低 章节点&#xff1a; 云场景攻防&#xff1a;公有云&#xff0c;私有云&…

GIT分支管理与远程操作

文章目录 10.分支操作-分支介绍(掌握)目标内容小结 11.分支操作-分支创建与切换目标内容小结 12.分支操作-分支合并与删除目标内容小结 13.GIT远程仓库介绍与码云仓库注册创建目标内容小结 14.GIT远程仓库操作-关联、拉取、推送、克隆(不用刻意记住命令)目标内容小结 10.分支操…

Zynq—AD9238数据采集DDR3缓存千兆以太网发送实验(一)

Zynq—AD9238数据采集DDR3缓存千兆以太网发送实验&#xff08;前导&#xff09; 四、AXI转FIFO接口模块设计 1.AXI接口知识 AXI协议是基于 burst的传输&#xff0c;并且定义了以下 5 个独立的传输通道&#xff1a; 读地址通道&#xff08;Read Address Channel&#xff0c; …

python统计分析——广义线性模型的评估

参考资料&#xff1a;用python动手学统计学 残差是表现数据与模型不契合的程度的重要指标。 1、导入库 # 导入库 # 用于数值计算的库 import numpy as np import pandas as pd import scipy as sp from scipy import stats # 导入绘图的库 import matplotlib.pyplot as plt i…

加密与安全_探索非对称加密算法_RSA算法

文章目录 Pre主流的非对称加密算法典型算法&#xff1a;RSACodeRSA的公钥和私钥的恢复小结 Pre 加密与安全_探索密钥交换算法&#xff08;Diffie-Hellman算法&#xff09; 中我们可以看到&#xff0c;公钥-私钥组成的密钥对是非常有用的加密方式&#xff0c;因为公钥是可以公开…

【Vue】npm run build 打包报错:请在[.env.local]中填入key后方可使用...

报错如下 根目录添加 .env.local 文件 .env.local &#xff1a;本地运行下的配置文件 配置&#xff1a;VUE_GITHUB_USER_NAME 及 VUE_APP_SECRET_KEY 原因

通过GitHub探索Python爬虫技术

1.检索爬取内容案例。 2.找到最近更新的。(最新一般都可以直接运行) 3.选择适合自己的项目&#xff0c;目前测试下面画红圈的是可行的。 4.方便大家查看就把代码粘贴出来了。 #图中画圈一代码 import requests import os import rewhile True:music_id input("请输入歌曲…

文本多分类

还在用BERT做文本分类&#xff1f;分享一套基于预训练模型ERNIR3.0的文本多分类全流程实例【文本分类】_ernir 文本分类-CSDN博客 /usr/bin/python3 -m pip install --upgrade pip python3-c"import platform;print(platform.architecture()[0]);print(platform.machine…