PyTorch – 逻辑回归

data

首先导入torch里面专门做图形处理的一个库,torchvision,根据官方安装指南,你在安装pytorch的时候torchvision也会安装。

我们需要使用的是torchvision.transforms和torchvision.datasets以及torch.utils.data.DataLoader

首先DataLoader是导入图片的操作,里面有一些参数,比如batch_size和shuffle等,默认load进去的图片类型是PIL.Image.open的类型,如果你不知道PIL,简单来说就是一种读取图片的库

torchvision.transforms里面的操作是对导入的图片做处理,比如可以随机取(50, 50)这样的窗框大小,或者随机翻转,或者去中间的(50, 50)的窗框大小部分等等,但是里面必须要用的是transforms.ToTensor(),这可以将PIL的图片类型转换成tensor,这样pytorch才可以对其做处理

torchvision.datasets里面有很多数据类型,里面有官网处理好的数据,比如我们要使用的MNIST数据集,可以通过torchvision.datasets.MNIST()来得到,还有一个常使用的是torchvision.datasets.ImageFolder(),这个可以让我们按文件夹来取图片,和keras里面的flow_from_directory()类似,具体的可以去看看官方文档的介绍。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

# 定义超参数

batch_size = 32

learning_rate = 1e-3

num_epoches = 100

# 下载训练集 MNIST 手写数字训练集

train_dataset = datasets.MNIST(root=\'./data\', train=True,

                               transform=transforms.ToTensor(),

                               download=True)

test_dataset = datasets.MNIST(root=\'./data\', train=False,

                              transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

以上就是我们对图片数据的读取操作

model

之前讲过模型定义的框架,废话不多说,直接上代码

1

2

3

4

5

6

7

8

9

10

class Logstic_Regression(nn.Module):

    def __init__(self, in_dim, n_class):

        super(Logstic_Regression, self).__init__()

        self.logstic = nn.Linear(in_dim, n_class)

    def forward(self, x):

        out = self.logstic(x)

        return out

model = Logstic_Regression(28*28, 10)  # 图片大小是28x28

我们需要向这个模型传入参数,第一个参数定义为数据的维度,第二维数是我们分类的数目。

接着我们可以在gpu上跑模型,怎么做呢?

首先可以判断一下你是否能在gpu上跑

1

torh.cuda.is_available()

如果返回True就说明有gpu支持

接着你只需要一个简单的命令就可以了

1

2

3

4

5

model = model.cuda()

或者

model.cuda()

然后需要定义loss和optimizer

1

2

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

这里我们使用的loss是交叉熵,是一种处理分类问题的loss,optimizer我们还是使用随机梯度下降

train

接着就可以开始训练了

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

for epoch in range(num_epoches):

    print(\'epoch {}\'.format(epoch 1))

    print(\'*\'*10)

    running_loss = 0.0

    running_acc = 0.0

    for i, data in enumerate(train_loader, 1):

        img, label = data

        img = img.view(img.size(0), -1)  # 将图片展开成 28x28

        if use_gpu:

            img = Variable(img).cuda()

            label = Variable(label).cuda()

        else:

            img = Variable(img)

            label = Variable(label)

        # 向前传播

        out = model(img)

        loss = criterion(out, label)

        running_loss  = loss.data[0] * label.size(0)

        _, pred = torch.max(out, 1)

        num_correct = (pred == label).sum()

        running_acc  = num_correct.data[0]

        # 向后传播

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

注意我们如果将模型放到了gpu上,相应的我们的Variable也要放到gpu上,也很简单

1

2

img = Variable(img).cuda()

label = Variable(label).cuda()

然后可以测试模型,过程与训练类似,只是注意要将模型改成测试模式

1

model.eval()

这是跑完100 epoch的结果

具体的结果多久打印一次,如何打印可以自己在for循环里面去设计。

相关代码:pytorch-beginner: pytorch-beginner

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/391401.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【plt.imshow显示图像】:从入门到精通,只需一篇文章!【Matplotlib】

【plt.imshow显示图像】:从入门到精通,只需一篇文章!【Matplotlib】 🚀 利用Matplotlib进行数据可视化示例 🌵文章目录🌵 📘 1. plt.imshow入门:认识并安装Matplotlib库&#x1f308…

Java编程在工资信息管理中的最佳实践

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…

Crypto-RSA1

题目: 已知p,q,dp,dq,c求明文: 首先有如下公式: dp≡d mod (p-1),dq≡d mod (q-1) , m≡c^d(mod n) , npq python代码解题如下: import libnump 863763376725700856709965348654109117132049…

浅谈语义分割、图像分类与目标检测中的TP、TN、FP、FN

语义分割 TP:正确地预测出了正类,即原本是正类,识别的也是正类 TN:正确地预测出了负类,即原本是负类,识别的也是负类 FP:错误地预测为了正类,即原本是负类,识别的是正类…

建造者模式-Builder Pattern

原文地址:https://jaune162.blog/design-pattern/builder-pattern/ 引言 现在一般大型的业务系统中的消息通知的形式都会有多种,比如短信、站内信、钉钉通知、邮箱等形式。虽然信息内容相同,但是展现形式缺不同。如短信使用的是纯文本的形式,钉钉使用的一般是Markdown的形…

挖掘在线零售数据:基于RFM理论的用户细分分析与营销策略

挖掘在线零售数据:基于RFM理论的用户细分分析与营销策略 基于RFM理论的用户细分分析项目背景和意义数据准备和预处理RFM分析1. 计算RFM指标2. 数据转换和处理 K-Means聚类分析结果和建议总结 基于RFM理论的用户细分分析 在商业运营中,了解客户并将其分组…

使用WGCLOUD监测摄像头的运行状态

WGCLOUD WGCLOUD是一款开源运维工具,免费高效,可以用来监测摄像头的工作状态,如果发现故障,那么WGCLOUD会发送告警通知消息,提醒我们的工程师进行处理 我们可以用WGCLOUD的PING监测模块,或者端口监测模块…

【开源】JAVA+Vue.js实现大学计算机课程管理平台

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 实验课程档案模块2.2 实验资源模块2.3 学生实验模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 实验课程档案表3.2.2 实验资源表3.2.3 学生实验表 四、系统展示五、核心代码5.1 一键生成实验5.2 提交实验5.3 批阅实…

C++入门学习(三十)一维数组的三种定义方式

数组是什么? 数组(Array)是有序的元素序列。 若将有限个类型相同的变量的集合命名,那么这个名称为数组名。组成数组的各个变量称为数组的分量,也称为数组的元素,有时也称为下标变量。用于区分数组的各个元素…

re-captioning技术是什么

参考https://zhuanlan.zhihu.com/p/664192860 模型对图片进行caption操作时,输出的标题一般描述图片中的主体,而忽视了背景、常识关系等更为细节的描述。 图片比较重要的细节的描述应当包括: 物体存在的场景。如:在厨房的水槽&am…

视频号小店怎么做?新手必须掌握的三点核心步骤,建议收藏

大家好,我是电商花花。 现在短视频的快速发展,电商和直播、短视频不断结合发展,在去年视频号小店也迎来了大爆发,有不少朋友都靠着做视频号小店赚到了自己做电商的第一捅金,直接让很多朋友接触视频号小店,…

TOP100 图论

1.200. 岛屿数量 给你一个由 1(陆地)和 0(水)组成的的二维网格,请你计算网格中岛屿的数量。 岛屿总是被水包围,并且每座岛屿只能由水平方向和/或竖直方向上相邻的陆地连接形成。 此外,你可以…

ADSelfService Plus发布离线MFA功能,强化远程工作安全性

ManageEngine ADSelfService Plus推出离线多因素身份验证,提升远程工作安全性确保通过先进的验证方法对企业数据进行授权访问,无论时间、地点或连接问题如何允许远程用户安全进行身份验证,即使未连接到认证服务器或互联网使用高度安全的基于T…

安装cockpit

1、下载cockpit yum -y install cockpit 下载相关环境 yum install qemu-kvm libvirt libvirt-daemon virt-install virt-manager libvirt-dbus 2、启动libvirtd systemctl start libvirtd.service systemctl enable libvirtd.service 3、设置开机自启动 systemctl enabl…

后端学习:Maven模型与Springboot框架

Maven 初识Maven Maven:是Apache旗下的一个开源项目,是一款用于管理和构建java项目的工具。 Maven的作用1.依赖管理2.统一项目结构3.项目构建依赖管理:方便快捷的管理项目依赖的资源(jar包),避免版本冲突问题   当使用maven进行项目依赖…

面试经典150题 -- 链表 (总结)

总的地址 : 面试经典 150 题 - 学习计划 - 力扣(LeetCode)全球极客挚爱的技术成长平台 c链表总结 : 链表总结 -- 《数据结构》-- c/c-CSDN博客 141 . 环形链表 详细题解参考 : 141 . 环形链表-CSDN博客 这里给出慢双指针的代码 : /*** Defini…

《白话C++》第10章 STL和boost,Page70~72 boost::scoped_ptr(未完待续)

《泛型》篇中提到的某个IT项目的辩论会, 一派坚持智能指针和裸指针可以“离婚”,它们是std::auto_ptr的支持者, 一派认为智能指针和裸指针不可以“离婚”,boost::scoped_ptr体现了他们的观点: boost::scoped_ptr基本…

OpenCV 4基础篇| OpenCV简介

目录 1. 什么是OpenCV2. OpenCV的发展历程3. 为什么用OpenCV4. OpenCV应用领域5. OpenCV的功能模块5.1 基本模块5.2 扩展模块5.3 常用函数目录 1. 什么是OpenCV OpenCV(Open Source Computer Vision Library)是一个开源的计算机视觉和机器学习软件库。它…

PyTorch-线性回归

已经进入大模微调的时代&#xff0c;但是学习pytorch&#xff0c;对后续学习rasa框架有一定帮助吧。 <!-- 给出一系列的点作为线性回归的数据&#xff0c;使用numpy来存储这些点。 --> x_train np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],[9.779], [6.1…

多维时序 | Matlab实现TCN-RVM时间卷积神经网络结合相关向量机多变量时间序列预测

多维时序 | Matlab实现TCN-RVM时间卷积神经网络结合相关向量机多变量时间序列预测 目录 多维时序 | Matlab实现TCN-RVM时间卷积神经网络结合相关向量机多变量时间序列预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 Matlab实现TCN-RVM时间卷积神经网络结合相关向量机…