pytorch 数据集处理以及模型训练

1.基础类说明       

        为了统一数据的加载和处理代码,pytorch提供了两个类,用来处理数据加载:

torch.utils.data.DataLoader
torch.utils.data.Dataset

        通过这两个类,可以使数据集加载和预处理代码,与模型训练代码脱钩,是代码模块化和可读性更高,DataLoader具有乱序和批次输出的功能非常实用,,datasettensor数据容器,DataLoader可以批量乱序输出dataset容器里面的数据,举例说明DataLoaderTensorDataset(TensorDataset继承自Dataset)。

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

#举例使用方法:

x = np.random.randn(100)
y = 100*x+10

x = torch.from_numpy(x)
y = torch.from_numpy(y)

ds = TensorDataset(x,y)
print(ds)
dl = DataLoader(ds,batch_size=10)
print(dl)
for a,s in dl:
	print(a,s)
#x,y = next(iter(s))#使用迭代方式访问
'''
打印输出
<torch.utils.data.dataset.TensorDataset object at 0x000001C25C8B0A50>
<torch.utils.data.dataloader.DataLoader object at 0x000001C255F36190>
tensor([[ 0.4691],
        [ 0.6049],
        [ 0.5738],
        [-1.0429],
        [ 1.1271],
        [-0.0227],
        [-0.7122],
        [-0.0492],
        [-0.3878],
        [ 0.2161]], dtype=torch.float64) tensor([[ 56.9061],
        [ 70.4902],
        [ 67.3768],
        [-94.2906],
        [122.7052],
        [  7.7305],
        [-61.2175],
        [  5.0850],
        [-28.7823],
        [ 31.6116]], dtype=torch.float64)
tensor([[-0.9782],
        [ 1.2216],
        [-1.1242],
        [-1.2297],
        [-0.1155],
        [-0.4263],
        [-0.3141],
        [-0.2565],
        [-1.0121],
        [-0.6660]], dtype=torch.float64) tensor([[ -87.8212],
        [ 132.1552],
        [-102.4236],
        [-112.9692],
        [  -1.5480],
        [ -32.6288],
        [ -21.4060],
        [ -15.6508],
        [ -91.2147],
        [ -56.6039]], dtype=torch.float64)
tensor([[ 1.3065],
        [ 0.2994],
        [-1.1172],
        [-0.0549],
        [ 0.7360],
        [-0.5772],
        [ 0.2071],
        [ 0.1534],
        [-1.2489],
        [-0.1326]], dtype=torch.float64) tensor([[ 140.6500],
        [  39.9404],
        [-101.7177],
        [   4.5059],
        [  83.6023],
        [ -47.7202],
        [  30.7051],
        [  25.3379],
        [-114.8920],
        [  -3.2643]], dtype=torch.float64)
tensor([[ 0.6008],
        [ 1.0718],
        [-1.2174],
        [ 2.5375],
        [-0.3207],
        [ 1.3478],
        [ 0.7117],
        [ 0.1565],
        [ 1.5195],
        [-0.8144]], dtype=torch.float64) tensor([[  70.0802],
        [ 117.1795],
        [-111.7351],
        [ 263.7489],
        [ -22.0686],
        [ 144.7844],
        [  81.1660],
        [  25.6479],
        [ 161.9522],
        [ -71.4358]], dtype=torch.float64)
tensor([[-1.2483],
        [-1.9078],
        [ 0.5961],
        [ 0.0194],
        [-0.1173],
        [ 0.3140],
        [-0.9329],
        [ 0.0038],
        [-0.4335],
        [-0.6057]], dtype=torch.float64) tensor([[-114.8282],
        [-180.7820],
        [  69.6107],
        [  11.9363],
        [  -1.7310],
        [  41.4012],
        [ -83.2908],
        [  10.3779],
        [ -33.3497],
        [ -50.5697]], dtype=torch.float64)
tensor([[-0.0495],
        [ 0.2895],
        [-0.6009],
        [ 1.0616],
        [ 0.3481],
        [-0.4579],
        [-1.8343],
        [ 1.6204],
        [-0.8834],
        [-1.0749]], dtype=torch.float64) tensor([[   5.0491],
        [  38.9459],
        [ -50.0886],
        [ 116.1596],
        [  44.8123],
        [ -35.7889],
        [-173.4340],
        [ 172.0404],
        [ -78.3447],
        [ -97.4876]], dtype=torch.float64)
tensor([[ 0.8313],
        [-0.7213],
        [ 0.3275],
        [ 1.3682],
        [-0.8968],
        [ 0.0987],
        [-0.1118],
        [-1.3022],
        [-0.9787],
        [ 0.9574]], dtype=torch.float64) tensor([[  93.1274],
        [ -62.1344],
        [  42.7548],
        [ 146.8156],
        [ -79.6817],
        [  19.8716],
        [  -1.1838],
        [-120.2235],
        [ -87.8683],
        [ 105.7363]], dtype=torch.float64)
tensor([[ 0.7923],
        [ 1.3725],
        [ 0.3167],
        [ 0.1243],
        [ 0.7679],
        [-0.1851],
        [-1.5475],
        [-0.0633],
        [ 1.0783],
        [-0.4816]], dtype=torch.float64) tensor([[  89.2304],
        [ 147.2453],
        [  41.6672],
        [  22.4315],
        [  86.7898],
        [  -8.5120],
        [-144.7464],
        [   3.6742],
        [ 117.8263],
        [ -38.1595]], dtype=torch.float64)
tensor([[ 2.0525],
        [ 0.7787],
        [-0.3905],
        [ 0.3564],
        [ 0.0701],
        [-0.9325],
        [-0.0311],
        [ 1.1144],
        [-0.7584],
        [-0.5550]], dtype=torch.float64) tensor([[215.2454],
        [ 87.8650],
        [-29.0486],
        [ 45.6430],
        [ 17.0132],
        [-83.2503],
        [  6.8893],
        [121.4400],
        [-65.8450],
        [-45.5046]], dtype=torch.float64)
tensor([[ 0.1598],
        [ 0.4774],
        [-0.3246],
        [ 0.4640],
        [-2.7714],
        [-0.5616],
        [ 1.8471],
        [ 1.1289],
        [ 1.5057],
        [-0.0776]], dtype=torch.float64) tensor([[  25.9807],
        [  57.7353],
        [ -22.4564],
        [  56.4036],
        [-267.1440],
        [ -46.1577],
        [ 194.7096],
        [ 122.8872],
        [ 160.5707],
        [   2.2443]], dtype=torch.float64)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()
<torch.utils.data.dataloader.DataLoader object at 0x0000026CC3992450>
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([0, 0, 6, 1, 4, 2, 9, 8, 3, 8])
torch.Size([10, 1, 28, 28]) torch.Size([10])
'''

2.pytorch提供的数据集

         pytorch 的 torchvision 模块提供了一些关于图像的数据集,均继承自torch.utils.data.Dataset 因此可以直接使用torch.utils.data.Dataloader,还提供了一些图像的转换方法,使用方法用最常见的 MNIST 举例:

import torchvision

from torchvision.transforms import ToTensor 
'''
1.将输入转为Tensor,
2.将图片格式转换为通道在前,常见通道为(高,宽,通道(像素点rgb))转换为(通道(像素点rgb),高,宽)
3.将像素取值归一化
'''
minidat = torchvision.datasets.MNIST('data',#文件夹名字
									 train=True,#训练数据和测试数据选择
									 transform=ToTensor(),#转换方法
									 download=True)#第一次需要下载库
print(minidat)
s = DataLoader(minidat,batch_size=10)
print(s)
# print(x,y)
for x,y in s:
	print(x.shape,y.shape)
	torch.tensor()
'''
Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()
<torch.utils.data.dataloader.DataLoader object at 0x00000254130BAC10>
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]]) torch.Size([10])
'''

3.模型定义示例 

        这里举例定义一个简单的线性模型:


class model(nn.Module):
    def __init__(self):
        super().__init__()
        self.L1 = nn.Linear(28*28,1024)
        self.L2 = nn.Linear(1024,256)
        self.L3 = nn.Linear(256,10)
        self.leakRelu = nn.LeakyReLU()

    def forward(self,input):
        x = input.view(-1,28*28)
        x = self.L1(x)
        x = self.leakRelu(x)
        x = self.L2(x)
        x = self.leakRelu(x)
        logist = self.L3(x)
        return logist #一般没有经过激活的返回取名logist

4.模型训练

       数据集固定方式之后,模型训练可以写一个固定的模型训练函数:

'''
m :模型
dl:训练数据集
optfun:优化函数
bach:全部数据训练批次
'''
def train(m,dl,lossfun,optfun,bach):
    # m = model()
    # dl = torch.utils.data.DataLoader()
    # lossfun = nn.CrossEntropyLoss
    # optfun = torch.optim.SGD()

    m.train()
    for count in np.arange(bach):
        for x,y in dl:
            y_pred = m(x)
            loss = lossfun(y_pred,y)
            optfun.zero_grad()
            loss.backward()
            optfun.step()
            with torch.no_grad():
                a = y_pred.argmax(1).data.numpy()
                b = y.data.numpy()
                c=((a==b).astype(np.int32).sum()/len(b))
                print('Prediction accuracy: ',c)
        print("Training times:",count)

5.模型存储与装载 

        在模型训练好以后可以存储起来代码:

'''
保存模型
 m 是模型
 p 是存储的文件名和路径
'''
def SaveModel(m,p):
    torch.save(m.state_dict(),p)

'''
装载模型
 p 是存储的文件名和路径
'''
def LoadModel(p):
    m = model()
    m.load_state_dict(torch.load(p))
    m.eval()
    return m

6.总结

对于 torchvision.transforms提供的转换工具函数使用示例:

#该方法把图像数据转化为tensor数据
trans_img = torchvision.transforms.ToTensor()
#转化方法如下
img = trans_img(img)

'''
如果一次要进行好几个转换可以合并转换功能
'''
trans_img = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                torchvision.transforms.ToPILImage()])
img = trans_img(img)

模型训练,保存,运用示例:

trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    train_ds = torchvision.datasets.FashionMNIST('data',
                                          train=True,
                                          transform=trans,
                                          download=True)
    train_dl = DataLoader(train_ds,batch_size=128)
    test_ds = torchvision.datasets.FashionMNIST('data',
                                          train=True,
                                          transform=trans,
                                          download=False)
    test_dl = DataLoader(test_ds, batch_size=1)

    trans_img = torchvision.transforms.ToTensor()
    m_path = 'Fashion.pth'
    mod = model()
    lossfun = nn.CrossEntropyLoss()
    optfun = torch.optim.SGD(mod.parameters(),lr=0.001)

    # train(mod,train_dl,lossfun,optfun,100)
    # SaveModel(mod,m_path)

    mod = LoadModel(m_path)

    errCount = 0
    correctCount = 0
    for x,y in test_dl:
        y_pred=mod(x)

        print(y_pred.argmax(1),y)
        # cv.imshow(np.squeeze(x.data.numpy()))
        dy = y_pred.argmax(1)
        if dy.item()==y.item():
            correctCount+=1
        else:
            errCount+=1
    print('err:',errCount,'correct:',correctCount,correctCount/(correctCount+errCount))

后续关于继承Dataset,进行数据加载,会继续添加相关示例。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/415229.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

10.网络游戏逆向分析与漏洞攻防-游戏网络架构逆向分析-接管游戏发送数据的操作

内容参考于&#xff1a;易道云信息技术研究院VIP课 上一个内容&#xff1a;接管游戏连接服务器的操作 码云地址&#xff08;master 分支&#xff09;&#xff1a;染指/titan 码云版本号&#xff1a;00820853d5492fa7b6e32407d46b5f9c01930ec6 代码下载地址&#xff0c;在 ti…

RTF文件格式解析(二)图像问题

图片 一个RTF文件可以包含由其他应用创建的图象。这些图象可以是16进制(默认的)或2进制格式。图象属于目标引用&#xff0c;由\pict 控制字开始。如后面的例子中将描述的&#xff0c;\pict关键字应在\*\shppict引用控制关键字之后。一个图象引用具有如下语法&#xff1a; <p…

frp 内网穿透 linux部署版

frp 内网穿透 linux部署版 前提安装 frp阿里云服务器配置测试服务器配置访问公网 前提 使用 frp&#xff0c;您可以安全、便捷地将内网服务暴露到公网&#xff0c;通过访问公网 IP 直接可以访问到内网的测试环境。准备如下&#xff1a; 公网 IP已部署好的测试服务 IP:端口号阿…

v68.指针

1.取地址运算 1.1 1.2 打印出变量的地址&#xff0c;需要使用 %p&#xff0c;注意后面加运算符 & 。注意输出地址的代码格式。%p会把这个值作地址来输出&#xff0c;输出的结果前面会加0x&#xff0c;并且以16进制的方式来输出地址 注意int 的大小是否和地址大小相同取决…

嵌入式 Linux 下的 LVGL 移植

目录 准备创建工程修改配置修改 lv_drv_conf.h修改 lv_conf.h修改 main.c修改 Makefile 编译运行更多内容 LVGL&#xff08;Light and Versatile Graphics Library&#xff0c;轻量级通用图形库&#xff09;是一个轻量化的、开源的、在嵌入式系统中广泛使用的图形库&#xff0c…

算法:动态规划全解(上)

一、动态规划初识 1.介绍 动态规划&#xff0c;英文&#xff1a;Dynamic Programming&#xff0c;简称DP&#xff0c;如果某一问题有很多重叠子问题&#xff0c;使用动态规划是最有效的。所以动态规划中每一个状态一定是由上一个状态推导出来的。 例如&#xff1a;有N件物品…

逆向案例三:动态xhr包中AES解密的一般步骤,以精灵数据为例

补充知识&#xff1a;进行AES解密需要知道四个关键字&#xff0c;即密钥key,向量iv,模式mode,填充方式pad 一般网页AES都是16位的&#xff0c;m3u8视频加密一般是AES-128格式 网页链接:https://www.jinglingshuju.com/articles 进行抓包结果返回的是密文&#xff1a; 一般思…

软考48-上午题-【数据库】-数据查询语言DQL3-表的连接查询

一、表的连接查询 数据查询中&#xff0c;经常需要提取两个或者多个表的数据&#xff0c;需要用表的连接来实现若干个表数据的联合查询。格式如下&#xff1a; select 列名1, 列名2, 列名3, ...... from 表1, 表2, ...... where 连接条件 在SQL SERVER中&#xff0c;连接分为…

Apipost自动化测试持续集成配置方法

安装 Apipost-cli npm install -g apipost-cli 运行脚本 安装好Apipost-cli后&#xff0c;在命令行输入生成的命令&#xff0c;即可执行测试用例&#xff0c;运行完成后会展示测试进度并生成测试报告。 Jenkins配置 Apipost cli基于Node js运行 需要在jenkins上配置NodeJs依…

Python中re(正则)模块的使用

re 是 Python 标准库中的一个模块&#xff0c;用于支持正则表达式操作。通过 re 模块&#xff0c;可以使用各种正则表达式来搜索、匹配和操作字符串数据。 使用 re 模块可以帮助在处理字符串时进行高效的搜索和替换操作&#xff0c;特别适用于需要处理文本数据的情况。 # 导入…

【MySQL】MySQL复合查询--多表查询自连接子查询 - 副本

文章目录 1.基本查询回顾2.多表查询3.自连接4.子查询 4.1单行子查询4.2多行子查询4.3多列子查询4.4在from子句中使用子查询4.5合并查询 4.5.1 union4.5.2 union all 1.基本查询回顾 表的内容如下&#xff1a; mysql> select * from emp; ----------------------------…

ubuntu安装新版本的CMake

来到cmake官网选择版本 我需要在嵌入式板子上的Ubuntu18安装使用 故我选择aarch64版本。 按F12进入检查模式得到下载链接。 在板子上运行以下命令&#xff0c;获取安装脚本 wget https://github.com/Kitware/CMake/releases/download/v3.28.3/cmake-3.28.3-linux-aarch64.s…

Django模板(四)

一、include标签 加载一个模板,并在当前上下文中进行渲染。这是一种在模板中 “包含” 其他模板的方式 简单的理解:在当前模板中引入另外一个模板内容 1.1、使用方法 模板名称可以是变量,也可以是单引号或双引号的硬编码(带引号)的字符串 {% include "foo/bar.ht…

接口自动化测试之HTTP协议详解

协议 简单理解&#xff0c;计算机与计算机之间的通讯语言就叫做协议&#xff0c;不同的计算机之间只有使用相同的协议才能通信。所以网络协议就是为计算机网络中进行数据交换而建立的规则&#xff0c;标准或约定的集合。 OSI模型 1978年国际化标准组织提出了“开放系统互联网…

BerDiff: Conditional Bernoulli Diffusion Modelfor Medical Image Segmentation

BerDiff:用于医学图像分割的条件伯努利扩散模型 摘要&#xff1a; 医学图像分割是一项具有挑战性的任务&#xff0c;具有固有的模糊性和高度的不确定性&#xff0c;这主要是由于肿瘤边界不明确和多个似是而非的注释等因素。分割口罩的准确性和多样性对于在临床实践中为放射科…

【UVM_Introduction Factory_2024.02.28】

Introduction 通用验证方法学UVM&#xff08;2014年1.2版本延续至今&#xff09; 作用&#xff1a; 降低验证工程复杂度&#xff0c;保证验证可靠性&#xff0c;提升验证效率 提供一套标准的类库&#xff0c;减轻环境构建的负担&#xff0c;更多的投入制定验证计划和创建测试场…

springboot228高校教师电子名片系统

高校教师电子名片系统的设计与实现 摘 要 传统信息的管理大部分依赖于管理人员的手工登记与管理&#xff0c;然而&#xff0c;随着近些年信息技术的迅猛发展&#xff0c;让许多比较老套的信息管理模式进行了更新迭代&#xff0c;名片信息因为其管理内容繁杂&#xff0c;管理数…

【数据结构】数组

第一章、为什么数组的下标一般从0开始编号 提到数组&#xff0c;读者肯定不陌生&#xff0c;甚至还会很自信地说&#xff0c;数组很简单。编程语言中一般会有数组这种数据类型。不过&#xff0c;它不仅是编程语言中的一种数据类型&#xff0c;还是基础的数据结构。尽管数组看起…

代码随想录算法训练营29期|day64 任务以及具体安排

第十章 单调栈part03 有了之前单调栈的铺垫&#xff0c;这道题目就不难了。 84.柱状图中最大的矩形class Solution {int largestRectangleArea(int[] heights) {Stack<Integer> st new Stack<Integer>();// 数组扩容&#xff0c;在头和尾各加入一个元素int [] ne…

半小时到秒级,京东零售定时任务优化怎么做的?

导言&#xff1a; 京东零售技术团队通过真实线上案例总结了针对海量数据批处理任务的一些通用优化方法&#xff0c;除了供大家借鉴参考之外&#xff0c;也更希望通过这篇文章呼吁大家在平时开发程序时能够更加注意程序的性能和所消耗的资源&#xff0c;避免在流量突增时给系统…