034、test

之——全纪录

目录

之——全纪录

杂谈

正文

1.下载处理数据

2.数据集概览

3.构建自定义dataset

4.初始化网络

5.训练


杂谈

        综合方法试一下。


leaves

1.下载处理数据

        从官网下载数据集:Classify Leaves | Kaggle

        解压后有一个图片集,一个提交示例,一个测试集,一个训练集。

        images,27153个树叶图片:

        test.csv,8800个:

        train.csv,18353个:


2.数据集概览

        训练集、测试集、类别:

#导包
import random
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchvision
import pandas as pd
import matplotlib.pyplot as plt
from d2l import torch as d2l
from PIL import Image

train_data=pd.read_csv(r"D:\apycharmblackhorse\leaves\train.csv")
test_data=pd.read_csv(r"D:\apycharmblackhorse/leaves/test.csv")

train_images=train_data.iloc[:,0].values #把所有的训练集图片路径读进来成list
print("训练集数量:",len(train_images))
n_train=len(train_images)
test_images=test_data.iloc[:,0].values
print("测试集数量:",len(test_images))
n_test=len(test_images)

train_labels = pd.get_dummies(train_data.iloc[:, 1]).values.astype(int).argmax(1)
#独热编码后找到每行最大的索引记下来就是类别号,而顺序与独热编码colums,也就是与下方排序一致
# print(len(train_labels),train_labels)

#记录并排序所有的类别名
train_labels_header = pd.get_dummies(train_data.iloc[:, 1]).columns.values
print("总类别:",len(train_labels_header))
classes=len(train_labels_header)


3.构建自定义dataset

       继承 torch.utils.Dataset 类,自定义树叶分类数据集:

#继承 torch.utils.Dataset 类,自定义树叶分类数据集
class leaves_dataset(torch.utils.data.Dataset):
    #root数据目录, images图片路径, labels图片标签, transform数据增强
    def __init__(self, root, images, labels, transform):
        super(leaves_dataset, self).__init__()
        self.root = root
        self.images = images
        if labels is None:
            self.labels = None
        else:
            self.labels = labels
        self.transform = transform
    #获得指定样本
    def __getitem__(self, index):
        image_path = self.root + self.images[index]
        image = Image.open(image_path)
        #预处理
        image = self.transform(image)
        if self.labels is None:
            return image
        label = torch.tensor(self.labels[index])
        return image, label
    #获得数据集长度
    def __len__(self):
        return self.images.shape[0]

        构建读取数据与预处理:

def load_data(images, labels, batch_size, train):
    aug = []
    normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if (train):
        aug = [torchvision.transforms.CenterCrop(224),
               transforms.RandomHorizontalFlip(),
               transforms.RandomVerticalFlip(),
               transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
               transforms.ToTensor(),
               normalize]
    else:
        aug = [torchvision.transforms.Resize([256, 256]),
               torchvision.transforms.CenterCrop(224),
               transforms.ToTensor(),
               normalize]
    transform = transforms.Compose(aug)
    dataset = leaves_dataset(r"D:\apycharmblackhorse\leaves\\", images, labels, transform=transform)
    if train==True:type="训练"
    else:type="测试"
    print("载入:",dataset.__len__(),type)
    return torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0, shuffle=train)

train_iter = load_data(train_images, train_labels, 512, train=True)

4.初始化网络

        使用官方预训练模型初始化网络,并修改输出类别数:

#初始化网络
net = torchvision.models.resnet18(pretrained=True)

net.fc = nn.Linear(net.fc.in_features, classes)
nn.init.xavier_uniform_(net.fc.weight)
net.fc


5.训练

         定义迭代器、优化器以及其他超参数,进行训练:

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=64, num_epochs=20,
                      param_group=True):
    train_slices = random.sample(list(range(n_train)), 15000)
    test_slices = list(set(range(n_train)) - set(train_slices))

    train_iter = load_data(train_images[train_slices], train_labels[train_slices], batch_size, train=True)
    test_iter = load_data(train_images[test_slices], train_labels[test_slices], batch_size, train=False)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        #别的层不变,最后一层10倍学习率
        trainer = torch.optim.Adam([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.Adam(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    print(111)
    try:
        d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
    except Exception as e:
        print(e)



#%%

#较小的学习率,通过微调预训练获得的模型参数
train_fine_tuning(net, 1e-3)

        小破脑跑得慢,之前不用预训练5个epoch后acc大概只能到0.3  ,使用预训练后到了0.6,但实际上感觉对于树叶的针对性分类还是需要从头开始才是最好的选择,资源不够这里就不做尝试了,大概尝试情况:


CIFAR-10

1.数据集


2.未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/158648.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++中静态成员变量和普通成员变量、私有成员变量和公有成员变量的区别

本文主要介绍和记录C中静态成员变量和普通成员变量、私有成员变量和公有成员变量的区别,并给出相关示例程序,最后结合相关工程应用中编译报错给出报错原因及介绍思路 一、静态成员变量和普通成员变量 C中,静态成员变量和普通成员变量有一些重…

C语言指针详解(1)(能看懂字就能明白系列)文章超长,慢慢品尝

目录 1、内存和地址 2、指针简介 与指针相关的运算符: 取地址操作符(&) 解引用操作符(间接操作符)(*) ​编辑 指针变量的声明 指针变量类型的意义 指针的基本操作 1、指针与整数相加…

网申线上测评,要不要找人代做在线测评?

这是知乎上看来的问题,感触颇多,于是决定针对这个问题写个稿子,希望能帮助到更多人朋友。 原文如下:现在各大公司在招聘时通常都会采取网申发OT筛选的形式,但是由于难度较大,不少人会选择付钱找别人代做的方…

常见面试题-MySQL软删除以及索引结构

为什么 mysql 删了行记录,反而磁盘空间没有减少? 答: 在 mysql 中,当使用 delete 删除数据时,mysql 会将删除的数据标记为已删除,但是并不去磁盘上真正进行删除,而是在需要使用这片存储空间时…

uart控制led与beep

仲裁模块代码: // 外设控制模块,根据uart接收到的数据,控制led与beep的标志信号。 module arbit(input wire sys_clk ,input wire sys_rst_n ,input wire pi_flag …

交易者最看重什么?anzo Capital这点最重要!

交易者最看重什么?有人会说技术,有人会说交易策略,有人会说盈利,但anzo Capital认为Vishal 最看重的应该是眼睛吧! 29岁的Vishal Agraval在9年前因某种原因失去了视力,然而,他的失明并未能阻…

【C#】类型转换-显式转换:括号强转、Parse法、Convert法、其他类型转string

目录 一、括号强转 1.有符号整型 2.无符号整型 3.浮点之间 4.无符号和有符号 5.浮点和整型 6.char和数值类型 7.bool和string是不能够通过 括号强转的 二、Parse法 1.有符号 2.无符号 3.浮点型 4.特殊类型 三、Convert法 1.转字符串 2.转浮点型 3.特殊类型转换…

R语言绘制精美图形 | 火山图 | 学习笔记

一边学习,一边总结,一边分享! 教程图形 前言 最近的事情较多,教程更新实在是跟不上,主要原因是自己没有太多时间来学习和整理相关的内容。一般在下半年基本都是非常忙,所有一个人的精力和时间有限&#x…

Rockdb简介

背景 最近在使用flink的过程中,由于要存储的状态很大,所以使用到了rockdb作为flink的后端存储,本文就来简单看下rockdb的架构设计 Rockdb设计 Rockdb采用了LSM的结构,它和hbase很像,不过严格的说,基于LS…

【Linux】安全审计-audit

文章目录 一、audit简介二、开启auditd服务三、相关文件四、审计规则五、审计日志查询及分析附录1:auditctl -h附录2:systemcall 类型 参考文章: 1、安全-linux audit审计使用入门 2、audit详细使用配置 3、Linux-有哪些常见的System Call&a…

redis实战篇(2)

优惠卷秒杀 通过本章节,我们可以学会Redis的计数器功能, 结合Lua完成高性能的redis操作,同时学会Redis分布式锁的原理,包括Redis的三种消息队列 3、优惠卷秒杀 3.1 -全局唯一ID 每个店铺都可以发布优惠券: 当用户抢…

卷积神经网络(CNN)鲜花的识别

文章目录 前期工作1. 设置GPU(如果使用的是CPU可以忽略这步)我的环境: 2. 导入数据3. 检查数据 二、数据预处理1. 加载数据2. 可视化数据3. 再次检查数据4. 配置数据集 三、构建CNN网络四、编译五、训练模型六、模型评估 前期工作 1. 设置GP…

IP池大小重要吗?

我们在寻找靠谱的IP代理时也经常遇到一个问题,IP代理池是什么?大小有何影响。今天就来跟大家普及一下,IP代理池大小的是否重要? 一、IP代理池是什么? I\P代理池是一个存储大量代理服务器IP地址的集合。它是一个由多个…

【MMC/SD/SDIO】读写操作

SD 总线是基于命令和数据流,它们由一个开始 Bit 发起,由一个停止 Bit 结束。 Command:命令开始一个操作。命令由 Host 驱动,或者给单卡(寻址命令),或者给所有连接的卡(广播命令&…

【操作系统】虚拟内存浅析

文章目录 虚拟内存的概念虚拟内存的实现请求分页存储管理缺页中断机构地址变换机构页面置换算法页面分配策略调入页面的时机 虚拟内存的概念 所谓的虚拟内存,是具有请求调入和置换功能,从逻辑上对内存容量加以扩充的一种存储器系统。他的组成如下&#…

微软 Gradle 强强联手,Gradle 构建服务器正式开源!

作者:Nick Zhu - Senior Program Manager, Developer Division At Microsoft 排版:Alan Wang Gradle 构建服务器 (Build Server for Gradle) 在九月份,我们宣布 Microsoft 和 Gradle 联手探索了一种基于 Build Server Protocol(B…

【SpringMvc】SpringMvc +MyBatis整理

🎄欢迎来到边境矢梦的csdn博文🎄 🎄本文主要梳理 Java 框架 中 SpringMVC的知识点和值得注意的地方 🎄 🌈我是边境矢梦,一个正在为秋招和算法竞赛做准备的学生🌈 🎆喜欢的朋友可以关…

为什么Go是后端开发的未来

近年来,Go 编程语言的流行度迅速增加。Go 最初由 Google 开发,迅速成为后端开发中最受欢迎的语言之一,特别是在分布式系统和微服务的开发中。本文将讨论为什么 Go 是后端开发的未来。 Go 简介 Go,又称为 Golang,是由…

【图数据库实战】HugeGraph架构

一、概述 作为一款通用的图数据库产品,HugeGraph需具备图数据的基本功能,如下图所示。HugeGraph包括三个层次的功能,分别是存储层、计算层和用户接口层。 HugeGraph支持OLTP和OLAP两种图计算类型,其中OLTP实现了Apache TinkerPop3…

如何从Android手机恢复已删除的联系人

联系人应该是最重要的信息之一。 如果您不小心从Android手机中删除了联系人,该怎么办? 如果不容易找回丢失的联系人,您可以使用奇客数据恢复安卓版。 从Android的手机中恢复已删除的联系人 只需删除Android联系人,然后您就可以通…