深度学习框架:Pytorch与Keras的区别与使用方法

  

☁️主页 Nowl

🔥专栏《机器学习实战》 《机器学习》

📑君子坐而论道,少年起而行之 

文章目录

Pytorch与Keras介绍

Pytorch

模型定义

模型编译

模型训练

输入格式

完整代码

Keras

模型定义

模型编译

模型训练

输入格式

完整代码

区别与使用场景

结语


Pytorch与Keras介绍

pytorch和keras都是一种深度学习框架,使我们能很便捷地搭建各种神经网络,但它们在使用上有一些区别,也各自有其特性,我们一起来看看吧

Pytorch

模型定义

我们以最简单的网络定义来学习pytorch的基本使用方法,我们接下来要定义一个神经网络,包括一个输入层,一个隐藏层,一个输出层,这些层都是线性的,给隐藏层添加一个激活函数Relu,给输出层添加一个Sigmoid函数

import torch
import torch.nn as nn


class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.Sigmoid(x)
        return x

模型编译

我们在之前的机器学习文章中反复提到过,模型的训练是怎么进行的呢,要有一个损失函数与优化方法,我们接下来看看在pytorch中怎么定义这些

import torch.optim as optim


# 实例化模型对象
model = SimpleNet()
# 定义损失函数
criterion = nn.MSELoss()

# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

我们上面创建的神经网络是一个类,所以我们实例化一个对象model,然后定义损失函数为mse,优化器为随机梯度下降并设置学习率

模型训练

# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值

# 训练模型
epochs = 100

for epoch in range(epochs):
    # 前向传播
    output = model(input_data)

    # 计算损失
    loss = criterion(output, target_data)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

以上步骤是先创建了一些随机样本,作为模型的训练集,然后定义训练轮次为100次,然后前向传播数据集,计算损失,再优化,如此反复

输入格式

关于输入格式是很多人在实战中容易出现问题的,对于pytorch创建的神经网络,我们的输入内容是一个torch张量,怎么创建呢

data = torch.Tensor([[1], [2], [3]])

很简单对吧,上面这个例子创建了一个torch张量,有三组数据,每组数据有1个特征

我们可以把这个数据输入到训练好的模型中,得到输出结果,如果输出不是torch张量,代码就会报错

完整代码

import torch
import torch.nn as nn
import torch.optim as optim


# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x



model = SimpleNet()
criterion = nn.MSELoss()

# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值

# 训练模型
epochs = 100

for epoch in range(epochs):
    # 前向传播
    output = model(input_data)

    # 计算损失
    loss = criterion(output, target_data)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


data = torch.Tensor([[1], [2], [3]])
prediction = model(data)

print(prediction)

可以看到模型输出了三个预测值

注意,这个任务本身没有意义,因为我们的训练集是随机生成的,这里主要学习框架的使用方法

Keras

我们在这里把和上面相同的神经网络结构使用keras框架实现一遍

模型定义

from keras.models import Sequential
from keras.layers import Dense


model = Sequential([
    Dense(32, input_dim=1, activation='relu'),
    Dense(1, activation='sigmoid')
])

注意这里也是一层输入层,一层隐藏层,一层输出层,和pytorch一样,输入层是隐式的,我们的输入数据就是输入层,上述代码定义了一个隐藏层,输入维度是1,输出维度是32,还定义了一个输出层,输入维度是32,输出维度是1,和pytorch环节的模型结构是一样的 

模型编译

那么在Keras中模型又是怎么编译的呢

model.compile(loss='mse', optimizer='sgd')

非常简单,只需要这一行代码 ,设置损失函数为mse,优化器为随机梯度下降

模型训练

模型的训练也非常简单

# 训练模型
model.fit(input_data, target_data, epochs=100)

 因为我们已经编译好了损失函数和优化器,在fit里只需要输入数据,输出数据和训练轮次这些参数就可以训练了

输入格式

对于Keras模型的输入,我们要把它转化为numpy数组,不然会报错

data = np.array([[1], [2], [3]])

完整代码

from keras.models import Sequential
from keras.layers import Dense
import numpy as np


# 定义模型
model = Sequential([
    Dense(32, input_dim=1, activation='relu'),
    Dense(1, activation='sigmoid')
])

# 创建随机输入数据和目标数据
input_data = np.random.randn(100, 1)  # 100个样本,每个样本有10个特征
target_data = np.random.randn(100, 1)  # 100个样本,每个样本有5个目标值

# 编译模型
model.compile(loss='mse', optimizer='sgd')
# 训练模型
model.fit(input_data, target_data, epochs=10)

data = np.array([[1], [2], [3]])

prediction = model(data)
print(prediction)

可以看到,同样的任务,Keras的代码量小很多

区别与使用场景

Keras代码量少,使用便捷,适用于快速实验和快速神经网络设计

而pytorch由于结构是由类定义的,可以更加灵活地组建神经网络层,这对于要求细节的任务更有利,同时,pytorch还采用动态计算图,使得模型的结构可以在运行时根据输入数据动态调整,但这个特点我还没有接触到,之后可能会详细讲解

结语

Keras和Pytorch都各有各的优点,请读者根据需求选择,同时有些深度学习教程偏向于使用某一种框架,最好都学习一点,以适应不同的场景

 

感谢阅读,觉得有用的话就订阅下本专栏吧 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/203113.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【网络安全】-安全常见术语介绍

文章目录 介绍1. 防火墙(Firewall)定义通俗解释 2. 恶意软件(Malware)定义通俗解释 3. 加密(Encryption)定义通俗解释 4. 多因素认证(Multi-Factor Authentication,MFA)定…

springMVC实验(二)—调式工具APIFOX的使用

【知识要点】 后端开发调试工具 前后端分离已经成为互联网类软件开发主流模式,没有前端操作的支持,如何调试后端程序的就是开发人员必须解决的问题。如:get类请求可以直接使用浏览器就能模拟测试,但是post、put等类型的请求&…

22款奔驰GLS450升级香氛负离子 淡淡的幽香

香氛负离子系统是由香氛系统和负离子发生器组成的一套配置,也可以单独加装香氛系统或者是负离子发生器,香氛的主要作用就是通过香氛外壳吸收原厂的香水再通过空调管输送到内饰中,而负离子的作用就是安装在空气管中通过释放电离子来打击空气中…

支持向量机,硬间隔,软间隔,核技巧,超参数设置,分类与回归

SVM(Support Vector Machine,支持向量机)是一种非常常用并且有效的监督学习算法,在许多领域都有广泛应用。它可以用于二分类问题和多分类问题,并且在处理高维数据和特征选择方面非常强大。SVM算法的核心思想是通过找到…

nodejs 沙盒逃逸

1.[GFCTF 2021]ez_calc 一道很有意思的一道nodejs的题 沙箱逃逸和绕过: F12 看源码 if(req.body.username.toLowerCase() ! admin && req.body.username.toUpperCase() ADMIN && req.body.passwd admin123){ // 登录成功&am…

Windows下命令行启动与关闭WebLogic的相关服务

WebLogic 的服务器类型 WebLogic提供了三种类型的服务器: 管理服务器节点服务器托管服务器 示例和关系如下图: 对应三类服务器, 就有三种启动和关闭的方式。本篇介绍使用命令行脚本的方式启动和关闭这三种类型的服务器。 关于WebLogic 的…

分布式架构demo

1、外层创建pom 版本管理器 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.15</version><relativePath/> <!-- lookup parent from repository…

使用Java将yaml转为properties,保证顺序、实测无BUG版本

使用Java将yaml转为properties 一、 前言1.1 顺序错乱的原因1.2 遗漏子节点的原因 二、优化措施三、源码 一、 前言 浏览了一圈网上的版本&#xff0c;大多存在以下问题&#xff1a; 转换后顺序错乱遗漏子节点 基于此进行了优化&#xff0c;如果只是想直接转换&#xff0c;可…

用Java制作简易版的王者荣耀

第一步是创建项目 项目名自拟 第二部创建个包名 来规范class 创建类 GameFrame 运行类 package com.sxt;import java.awt.Graphics; import java.awt.Image; import java.awt.Toolkit; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import j…

Facebook广告投放效果不佳?这些投放技巧我不允许你不知道!

众所周知&#xff0c;Facebook广告对于跨境卖家来说是非常有效的站外引流渠道&#xff0c;通过Facebook广告投放可以提高跨境卖家的产品销量和排名&#xff0c;但是有时明明广告已经投放出去了&#xff0c;却无法被受众看到&#xff0c;完全没有获得成果&#xff0c;或许你会怪…

焊接设备行业分析:预计2029年将达到834亿元

近年来我国焊割设备行业的主要出口产品多为零部件以及部分中、低端设备&#xff0c;其出口单价和利润额均相对较低。 随着国内原材料价格上涨和人民币不断升值&#xff0c;出口产品的竞争力日趋下降&#xff0c;利润空间也随着出口价格的下降被进一步压缩。同时近年来国际经济形…

应用密码学期末复习(3)

目录 第三章 现代密码学应用案例 3.1安全电子邮件方案 3.1.1 PGP产生的背景 3.2 PGP提供了一个安全电子邮件解决方案 3.2.1 PGP加密流程 3.2.2 PGP解密流程 3.2.3 PGP整合了对称加密和公钥加密的方案 3.3 PGP数字签名和Hash函数 3.4 公钥分发与认证——去中心化模型 …

开源免费跨平台数据同步工具-Syncthing

Syncthing是一款开源免费跨平台的文件同步工具&#xff0c;是基于P2P技术实现设备间的文件同步&#xff0c;所以它的同步是去中心化的&#xff0c;即你并不需要一个服务器&#xff0c;故不需要担心这个中心的服务器给你带来的种种限制&#xff0c;而且类似于torrent协议&#x…

leetcode每日一题34

89.格雷编码 观察一下n不同时的格雷编码有什么特点 n1 [0,1] n2 [0,1,3,2] n3 [0,1,3,2,6,7,5,4] …… 可以看到nk时&#xff0c;编码数量是nk-1的数量的一倍 同时nk编码的前半部分和nk-1一模一样 nk编码的最后一位是2k-1 后半部分的编码是其对应的前半部分的对称的位置的数字…

ISCTF2023新生赛Misc部分WP

ISCTF2023新生赛部分WP MISC&#xff1a;签到&#xff1a;你说爱我&#xff1f;尊嘟假嘟&#xff1a;小蓝鲨的秘密&#xff1a;easy_zip:杰伦可是流量明星&#xff1a;蓝鲨的福利&#xff1a;Ez_misc:PNG的基本食用:小猫&#xff1a;MCSOG-猫猫&#xff1a;镜流:stream&#xf…

UG\NX二次开发 设置对象上属性的锁定状态UF_ATTR_set_locked

文章作者:里海 来源网站:里海NX二次开发3000例专栏 感谢粉丝订阅 感谢 Miracle UG开发 订阅本专栏,非常感谢。 简介 设置对象上属性的锁定状态UF_ATTR_set_locked,需要先在“用户默认设置”中勾选“通过NX Open锁定属性”(文件->实用工具->用户默认设置->基本环境…

Pycharm中使用matplotlib绘制动态图形

Pycharm中使用matplotlib绘制动态图形 最终效果 最近用pycharm学习D2L时发现官方在jupyter notebook交互式环境中能动态绘制图形&#xff0c;但是在pycharm脚本环境中只会在最终 plt.show() 后输出一张静态图像。于是有了下面这段自己折腾了一下午的代码&#xff0c;用来在pych…

orvibo旗下的VS30ZW网关分析之一

概述 从官网的APP支持的智能中枢来看,一共就两种大类: MixPad系列和网关系列 排除MixPad带屏网关外,剩余的设备如下图: 目前在市场上这四种网关已经下市,官方已经宣布停产。所以市场上流通的也几乎绝迹。 从闲鱼市场上可以淘到几个,拿来分析一下,这里我手头有如下的两…

M1安装RabbitMQ

1.查看centos内核版本 uname -a uname -r2.安装之前的准备工作 安装RabbitMQ必装Erlang(RabbitMQ官网添加链接描述) 2.1.Erlang简介 Erlang是一种通用的面向并发的编程语言&#xff0c;它由瑞典电信设备制造商爱立信所辖的CS-Lab开发&#xff0c;目的是创造一种可以应对…

TA-Lib学习研究笔记(一)

TA-Lib学习研究笔记&#xff08;一&#xff09; 1.介绍 TA-Lib&#xff0c;英文全称“Technical Analysis Library”,是一个用于金融量化的第三方库&#xff0c;涵盖了150多种交易软件中常用的技术分析指标&#xff0c;如RSI,KDJ,MACD, MACDEXT, MACDFIX, SAR, SAREXT, MA,SM…