前馈神经网络正则化例子

直接看代码:

import torch  
import numpy as np  
import random  
from IPython import display  
from matplotlib import pyplot as plt  
import torchvision  
import torchvision.transforms as transforms   

mnist_train = torchvision.datasets.MNIST(root='/MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  

batch_size = 256 

train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  

num_inputs,num_hiddens,num_outputs =784, 256,10

def init_param():
    W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  
    b1 = torch.zeros(1, dtype=torch.float32)  
    W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  
    b2 = torch.zeros(1, dtype=torch.float32)  
    params =[W1,b1,W2,b2]
    for param in params:
        param.requires_grad_(requires_grad=True)  
    return W1,b1,W2,b2

def relu(x):
    x = torch.max(input=x,other=torch.tensor(0.0))  
    return x

def net(X):  
    X = X.view((-1,num_inputs))  
    H = relu(torch.matmul(X,W1.t())+b1)  
    #myrelu =((matmal x,w1)+b1),return  matmal(myrelu,w2 )+ b2
    return relu(torch.matmul(H,W2.t())+b2 )
    return torch.matmul(H,W2.t())+b2


def SGD(paras,lr):  
    for param in params:  
        param.data -= lr * param.grad  
        
def l2_penalty(w):
    return (w**2).sum()/2


def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):  
    
    train_ls, test_ls = [], []
    
    for epoch in range(num_epochs):
        
        ls, count = 0, 0
        
        for X,y in train_iter :
            X = X.reshape(-1,num_inputs)
            l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)
            
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            
            ls += l.item()
            count += y.shape[0]
            
        train_ls.append(ls)
        
        ls, count = 0, 0
        
        for X,y in test_iter:
            X = X.reshape(-1,num_inputs)
            l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)
            ls += l.item()
            count += y.shape[0]
            
        test_ls.append(ls)
        
        if(epoch)%2==0:
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
            
    return train_ls,test_ls



lr = 0.01

num_epochs = 20

Lamda = [0,0.1,0.2,0.3,0.4,0.5]

Train_ls, Test_ls = [], []

for lamda in Lamda:
    print("current lambda is %f"%lamda)
    W1,b1,W2,b2 = init_param()
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)
    train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)   
    Train_ls.append(train_ls)
    Test_ls.append(test_ls)
    
x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))

plt.figure(figsize=(10,8))

for i in range(0,len(Lamda)):
    plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)
    
    plt.xlabel('different epoch')
    
    plt.ylabel('loss')
    
plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderAxesPad = 0.)

plt.title('train loss with L2_penalty')

plt.show()

运行结果:

在这里插入图片描述

疑问和心得:

  1. 画图的实现和细节还是有些模糊。
  2. 正则化系数一般是一个可以根据算法有一定变动的常数。
  3. 前馈神经网络中,二分类最后使用logistic函数返回,多分类一般返回softmax值,若是一般的回归任务,一般是直接relu返回。
  4. 前馈神经网络的实现,从物理层上应该是全连接的,但是网上的代码一般都是两层单个神经元,这个容易产生误解。个人感觉,还是要使用nn封装的函数比较正宗。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/80514.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

链游再进化 Web3版CSGO来袭

过去几年,游戏开发者们一直希望借Web3这个价值流通网络,改造传统游戏的经济系统,将虚拟资产的掌管权交给用户,让资产自由地在市场流通。 Web3游戏发展史上,涌现过CryptoKitties、Axie Infinity两大爆款,但…

爬虫框架- feapder + 爬虫管理系统 - feaplat 的学习简记

文章目录 feapder 的使用feaplat 爬虫管理系统部署 feapder 的使用 feapder是一款上手简单,功能强大的Python爬虫框架 feapder 官方文档 文档写的很详细,可以直接上手。 基本命令: 创建爬虫项目 feapder create -p first-project创建爬虫 …

LRU算法源码实现

算法介绍: 最近最久未使用(Least Recently Used LRU)算法是⼀种缓存淘汰策略。该算法的思路是,将最近一段时间内最久未使用的页面置换出去。 升级版LRUK算法见 基于LRU-K算法设计本地缓存实现流量削峰https://blog.csdn.net/l…

Deep Learning With Pytorch - 最基本的感知机、贯序模型/分类、拟合

文章目录 如何利用pytorch创建一个简单的网络模型?Step1. 感知机,多层感知机(MLP)的基本结构Step2. 超平面 ω T ⋅ x b 0 \omega^{T}xb0 ωT⋅xb0 or ω T ⋅ x b \omega^{T}xb ωT⋅xb感知机函数 Step3. 利用感知机进行决策…

虚拟机问题

虚拟机无法识别USB设备 经排查为VMware USB Arbitration Service 没有启动,但是VMware USB Arbitration Service依赖于VMware Workstation Server启动 VMware USB Arbitration Service(VMUSBArbService)是由 VMware 虚拟化软件提供的一个服务,用于协调和管理主机系统上的…

Flink CDC系列之:基于 Flink CDC 构建 MySQL 和 Postgres 的 Streaming ETL

Flink CDC系列之:基于 Flink CDC 构建 MySQL 和 Postgres 的 Streaming ETL 一、技术路线二、MySQL数据库建表三、PostgreSQL数据库建表四、在 Flink SQL CLI 中使用 Flink DDL 创建表五、关联订单数据并且将其写入 Elasticsearch 中六、Kibana查看商品和物流信息的…

leetcode611. 有效三角形的个数(java)

有效三角形的个数 有效三角形的个数排序加二分排序 双指针 上期算法 有效三角形的个数 给定一个包含非负整数的数组 nums ,返回其中可以组成三角形三条边的三元组个数。 示例 1: 输入: nums [2,2,3,4] 输出: 3 解释:有效的组合是: 2,3,4 (使用第一个 2) 2,3,4 (使…

如何修复损坏的DOC和DOCX格式Word文件?

我们日常办公中,经常用到Word文档。但是有时会遇到word文件损坏、无法打开的情况。这时该怎么办?接着往下看,小编在这里就给大家带来最简单的Word文件修复方法! 很多时候DOC和DOCX Word文件会无缘无故的损坏无法打开,一…

【C++ 记忆站】引用

文章目录 一、引用概念二、引用特性1、引用在定义时必须初始化2、一个变量可以有多个引用3、引用一旦引用一个实体,再不能引用其他实体 三、常引用四、使用场景1、做参数1、输出型参数2、大对象传参 2、做返回值1、传值返回2、传引用返回 五、传值、传引用效率比较六…

【C语言】每日一题(找到所有数组中消失的数字)

找到所有数组中消失的数字,链接奉上。 这里简单说一下,因为还没有接触到动态内存,数据结构,所以知识有限,也是尽力而为,结合题库的评论区找到了适合我的解法,以后有机会,会补上各种…

图数据库_Neo4j中文版_Centos7.9安装Neo4j社区版3.5.9_基于jdk1.8---Neo4j图数据库工作笔记0012

由于我们在国内使用啊,具体还是要用中文版滴,找了好久这个neo4j,原来还是有中文版的, https://we-yun.com/doc/neo4j-chs/ 中文版下载地址在这里: 所有版本都在这里了,需要哪个自己去下载就可以了,要注意下载以后,参考: https://we-yun.com/blog/prod-56.html 在这个位置下载…

画质提升+带宽优化,小红书音视频团队端云结合超分落地实践

随着视频业务和短视频播放规模不断增长,小红书一直致力于研究:如何在保证提升用户体验质量的同时降低视频带宽成本? 在近日结束的音视频技术大会「LiveVideoStackCon 2023」上海站中,小红书音视频架构视频图像处理算法负责人剑寒向…

2023.8 - java - 对象和类

public class Dog {String breed;int size;String colour;int age;void eat() {}void run() {}void sleep(){}void name(){} } 一个类可以包含以下类型变量: 局部变量:在方法、构造方法或者语句块中定义的变量被称为局部变量。变量声明和初始化都是在方…

实现Java异步调用的高效方法

文章目录 为什么需要异步调用?Java中的异步编程方式1. 使用多线程2. 使用Java异步框架 异步调用的关键细节结论 🎉欢迎来到Java学习路线专栏~实现Java异步调用的高效方法 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒🍹✨博客主页:IT陈寒的博…

LabVIEW开发最小化5G系统测试平台

LabVIEW开发最小化5G系统测试平台 由于具有大量存储能力和数据的应用程序的智能手机的激增,当前一代产品被迫提高其吞吐效率。正交频分复用由于其卓越的品质,如单抽头均衡和具有成本效益的实施,现在被广泛用作物理层技术。这些好处是以严格的…

Azure存储访问层

blob数据的热访问层,冷访问层和存档访问层 Azure Blob 存储是一种托管对象存储服务,可用于存储和访问大量非结构化数据,如文本和二进制数据。Azure Blob 存储提供了三个不同层级的访问方式,以适应不同数据的使用模式和成本效益需…

手把手教学——终端工具xshell与文件传输工具xftp使用步骤及详解

前言 xshell是一款常用于连接本地linux服务以及云服务器的终端远程连接工具,该款终端工具常搭配远程文件传输工具xftp一起使用,由于还有很多小伙伴还不知道这两款终端工具的使用流程及步骤,Darren洋在这里给小伙伴们进行详细讲解。 一、下载工…

proteus结合keil-arm编译器构建STM32单片机项目进行仿真

proteus是可以直接创建设计图和源码的,但是源码编译它需要借助keil-arm编译器,也就是我们安装keil-mdk之后自带的编译器。 下面给出一个完整的示例,主要是做一个LED灯闪烁的效果。 新建工程指定路径,Schematic,PCB layout都选择默…

【马蹄集】第二十三周——进位制专题

进位制专题 目录 MT2186 二进制?不同!MT2187 excel的烦恼MT2188 单条件和MT2189 三进制计算机1MT2190 三进制计算机2 MT2186 二进制?不同! 难度:黄金    时间限制:1秒    占用内存:128M 题目…

推荐一个绘图平台(可替代Visio)

不废话,简易记网址: draw.io 网站会重定向到:https://app.diagrams.net/