Pytorch深度学习实践笔记8(b站刘二大人)

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili​

目录

1 Pytorch 数据加载

2 Dataset和DataLoader

3 程序


1 Pytorch 数据加载

  • epoch、Batch-size 、iteration


例如下图:
8个样本、shuffle是打乱样本的顺序,Batch-szie为2,iteration 就是 8 / 2 为4,epoch是训练集进行几个轮次的迭代。

 




2 Dataset和DataLoader

 




Dataset 是一个抽象类,使用时必须进行重写,from 在torch.utils.data Dataset
(1)重写时,需要根据数据来进行构造__init__(self,filepath)
(2)__getitem__(self,index)用来让数据可以进行索引操作
(3)__len__(self)用来获取数据集的大小
DataLoader 用来加载数据为mini-Batch ,支持Batch-size 的设置,shuffle支持数据的打乱顺序。

  • 参数说明:
from torch.utils.data import DataLoader
 
test_load = DataLoader(dataset=test_data, batch_size=4 , shuffle= True, num_workers=0,drop_last=False)


batch_size=4表示每次取四个数据
shuffle= True表示开启数据集随机重排,即每次取完数据之后,打乱剩余数据的顺序,然后再进行下一次取
num_workers=0表示在主进程中加载数据而不使用任何额外的子进程,如果大于0,表示开启多个进程,进程越多,处理数据的速度越快,但是会使电脑性能下降,占用更多的内存
drop_last=False表示不丢弃最后一个批次,假设我数据集有10个数据,我的batch_size=3,即每次取三个数据,那么我最后一次只有一个数据能取,如果设置为true,则不丢弃这个包含1个数据的子集数据,反之则丢弃

 

  • 数据转换为dataset形式,进行DataLoader的使用
x_data = torch.tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0],[7.0],[8.0],[9.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0],[14.0],[16.0],[18.0]])

dataset = Data.TensorDataset(x_data,y_data)

loader = Data.DataLoader(  
    dataset=dataset,  
    batch_size=BATCH_SIZE,  
    shuffle=True,  
    num_workers=0  
)

pytorch中的DataLoader_pytorch dataloader-CSDN博客​


3 程序


数据分为训练集和测试集:Adam 训练

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
 
 
# 读取原始数据,并划分训练集和测试集
raw_data = np.loadtxt('./dataset/diabetes.csv.gz', delimiter=',', dtype=np.float32)
X = raw_data[:, :-1]
Y = raw_data[:, [-1]]
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,Y,test_size=0.1)
Xtest = torch.from_numpy(Xtest)
Ytest = torch.from_numpy(Ytest)
 
# 将训练数据集进行批量处理
# prepare dataset
 
class DiabetesDataset(Dataset):
    def __init__(self, data,label):
 
        self.len = data.shape[0] # shape(多少行,多少列)
        self.x_data = torch.from_numpy(data)
        self.y_data = torch.from_numpy(label)
 
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len
 
 
train_dataset = DiabetesDataset(Xtrain,Ytrain)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=0) #num_workers 多线程
 
# design model using class
 
 
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 2)
        self.linear4 = torch.nn.Linear(2, 1)
        self.sigmoid = torch.nn.Sigmoid()
 
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        x = self.sigmoid(self.linear4(x))
        return x
 
 
model = Model()
 
# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
epoch_list = []
loss_list = []

# training cycle forward, backward, update
def train(epoch):
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        y_pred = model(inputs)
 
        loss = criterion(y_pred, labels)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

 
 
def test():
    with torch.no_grad():
        y_pred = model(Xtest)
        y_pred_label = torch.where(y_pred>=0.5,torch.tensor([1.0]),torch.tensor([0.0]))
        acc = torch.eq(y_pred_label, Ytest).sum().item() / Ytest.size(0)
        print("test acc:", acc)
 
if __name__ == '__main__':
    for epoch in range(10000):
        loss_val = train(epoch)
        print("epoch: ",epoch," loss: ",loss_val)
        epoch_list.append(epoch)
        loss_list.append(loss_val)
    test()

    plt.plot(epoch_list,loss_list)
    plt.title("Adam")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("./data/pytorch7_1.png")



简单的程序
 

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
 
# prepare dataset
 
 
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0] # shape(多少行,多少列)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
 
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len
 
 
dataset = DiabetesDataset('./dataset/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程
 
 
# design model using class
 
 
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()
 
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x
 
 
model = Model()
 
# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
 
# training cycle forward, backward, update
if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batch
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
 
            optimizer.zero_grad()
            loss.backward()
 
            optimizer.step()

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/649506.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

上海市港股通交易佣金手续费最低是多少?万0.8?恒生港股通ETF今起发行!港股通的价值如何?

港股通交易佣金概述 港股通的交易佣金可能会因证券公司和投资者的不同而有所差异。 上海市港股通交易佣金最低可以万分之零点八(0.008%),但这需要投资者与证券公司客户经理了解,进行沟通和申请。 一般来说,证券公司…

[CISCN2024]-PWN:orange_cat_diary(glibc2.23.,仅可修改最新堆块,house of orange)

查看保护 查看ida 这里我们仅可以修改最新申请出来的堆块,但是有uaf漏洞。 完整exp: from pwn import* #context(log_leveldebug) pprocess(./orange) free_got0x201F78def alloc(size,content):p.sendlineafter(bPlease input your choice:,b1)p.send…

【Spring Cloud】服务熔断

目录 服务雪崩效应服务雪崩效应形成的原因及应对策略小结 Hystrix介绍Hystrix可以做什么1.资源隔离2.请求熔断3.服务降级 小结 Hystrix实现服务降级方式一:HystrixCommand注解方式1.服务提供者1.1业务接口和业务实现中添加方法hystrixTimeout1.2控制器中处理/provid…

信息安全基础(补充)

)的内容主要有数据备份、数据修复、系统恢复等。响应(Respons)的内容主要有应急策略、应急机制、应急手段、入侵过程分析及安全状态评估等。 面向数据挖掘的隐私保护技术主要解决高层应用中的隐私保护问题,致力于研究如何根据不同…

html5各行各业官网模板源码下载(2)

文章目录 1.来源2.源码模板2.1 HTML5好看的旅行网站模板源码2.2 HTML5自适应医院叫号大屏模板源码2.3 HTML5好看的高科技登录页面模板源码2.4 HTML5宠物美容服务公司网站模板源码2.5 HTML5创意品牌广告设计公司网站模板源码2.6 HTML5实现室内设计模板源码2.7 HTML5黄金首饰网站…

PL5358A 单芯锂离子/聚合物电池保护IC芯片

一般说明 PL5358A系列产品是锂离子/聚合物电池保护的高集成解决方案。PL5358A包含先进 的功率MOSFET,高精度电压检测电路和延迟电路。5358A被放入一个超小的SOT23-5封装,只有一个外部元件,使其成为理想的解决方案,在有限的…

美业SaaS收银系统源码-已过期卡项需要延期怎么操作?美业系统实操

美业SaaS系统 连锁多门店美业收银系统源码 多门店管理 / 会员管理 / 预约管理 / 排班管理 / 商品管理 / 促销活动 PC管理后台、手机APP、iPad APP、微信小程序 1.询问会员手机号和需要延期的卡项 2.PC运营后端-数据导入-修改已售卡项,搜索手机号 3.把需要卡项选…

D - Permutation Subsequence(AtCoder Beginner Contest 352)

题目链接: D - Permutation Subsequence (atcoder.jp) 题目大意: 分析: 相对于是记录一下每个数的位置 然后再长度为k的区间进行移动 然后看最大的pos和最小的pos的最小值是多少 有点类似于滑动窗口 用到了java里面的 TreeSet和Map TreeSet存的是数…

Vue3中如何自定义消息总线

前言 在 Vue 开发中,组件之间的通信是一个常见的需求,无论是父组件向子组件传递数据,还是子组件向父组件传递数据,甚至是兄弟组件之间的数据交换。这些通信需求在构建复杂的 Vue 应用时尤为关键。 Vue 提供了多种组件通信的方式…

Linux系统安装AMH服务器管理面板并实现远程访问管理维护

目录 前言 1. Linux 安装AMH 面板 2. 本地访问AMH 面板 3. Linux安装Cpolar 4. 配置AMH面板公网地址 5. 远程访问AMH面板 6. 固定AMH面板公网地址 1. 部署Docker Registry 2. 本地测试推送镜像 3. Linux 安装cpolar 4. 配置Docker Registry公网访问地址 5. 公网远程…

如何被谷歌收录?

最简单的方法就是提交网站给谷歌,但这种方法可操作空间不大,一天一般也就只有十条左右的链接可以提交,对于一些大网站来说,这种方法显然不适用,这时候GPC爬虫池的好处就体现了,GPC爬虫池对希望提升Google搜…

re:记录下正则的使用方法

1、match pattern r(\d{4})[-\/](\d{1,2})[-\/](\d{1,2}) match re.search(pattern, text) if match:year, month, day match.groups()

SolidWorks教育版 学生使用的优势

在工程技术领域的学习中,计算机辅助设计软件(CAD)如SolidWorks已经成为学生掌握专业知识和技能的必要工具。SolidWorks教育版作为专为教育机构和学生设计的版本,不仅提供了与商业版相同的强大功能,还为学生带来了诸多独…

TypeScript 语言在不改变算法复杂度前提下,细节上性能优化,运行时性能提升效果明显吗?

有经验的专家写的代码,和无经验的新手写的代码,在运行时性能上大概会有多少差异? 个人感觉,常规业务逻辑代码通常可以差 1 倍;如果算上框架的影响,可以差 2~4 倍。 仅考虑业务代码的话,新手容易…

科普:水冷负载的工作原理

水冷负载是一种利用水作为冷却介质,将电子设备产生的热量传递到外部环境的散热方式。它广泛应用于各种电子设备,如服务器、数据中心、电力设备等,以提高设备的运行效率和稳定性。本文将对水冷负载的工作原理进行简要科普。 水冷负载的工作原理…

我被恐吓了,对方扬言要压测我的网站

大家好我是聪,昨天真是水逆,在技术群里交流问题,竟然被人身攻击了!骂的话太难听具体就不加讨论了,人身攻击我可以接受,我接受不了他竟然说要刷我接口!!!!这下…

C#多维数组不同读取方式的性能差异

背景 近来在优化一个图像显示程序,图像数据存储于一个3维数组data[x,y,z]中,三维数组为一张张图片数据的叠加而来,其中x为图片的张数,y为图片行,Z为图片的列,也就是说这个三维数组存储的为一系列图片的数据…

记录下所遇到远程桌面连接方法winSCP跟mstsc

之前公司使用过连接远程桌面,今天又遇到要使用远程桌面问题,来记录下。 之前公司使用的是winR 然后回车弹出 后面按照用户名密码就能登陆了 今天后台给了我一张图片准备接着用这个方法,后台就说这个东西要下载winSCP 后台发给我图片 然后去…

mfc140u.dll丢失的解决方法有哪些?怎么全面修复mfc140u.dll文件

mfc140u.dll丢失其实相对来说不太常见到,因为这个文件一般是不丢失的,不过既然有人遇到这种问题,那么小编一定满足各位,给大家详细的唠叨一下mfc140u.dll丢失的各种解决方法,教大家以最快最有效率的方法去解决mfc140u.…

python文件IO基础知识

目录 1.open函数打开文件 2.文件对象读写数据和关闭 3.文本文件和二进制文件的区别 4.编码和解码 读写文本文件时 读写二进制文件时 5.文件指针位置 6.文件缓存区与flush()方法 1.open函数打开文件 使用 open 函数创建一个文件对象,read 方法来读取数据&…