DNN二分类模型

import os
import datetime

#打印时间
def printbar():
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("\n"+"=========="*8 + "%s"%nowtime)

#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 
printbar()

准备数据:

import numpy as np 
import pandas as pd 
from matplotlib import pyplot as plt
import torch
from torch import nn
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

#正负样本数量
n_positive,n_negative = 2000,2000

#生成正样本, 小圆环分布              torch.normal 是 PyTorch 的一个函数,用于根据正态分布(高斯分布)生成随机数张量
r_p = 5.0 + torch.normal(0.0,1.0,size = [n_positive,1])   # 0.0表示均值,1.0表示标准差的正态分布, torch.normal(mean, std, size)
#print(r_p.shape)
theta_p = 2*np.pi*torch.rand([n_positive,1])  #torch.rand 是 PyTorch 中用于生成均匀分布(uniform distribution)的随机数张量的函数
# torch.rand()生成0到之间的均匀分布的随机数,torch.normal生成指定均值和标准差的正太分布的随机数
#print(theta_p.shape)   theta_p形状是[2000,1],每个数范围是0到2Π,cos(theta_p)范围是余弦函数一个周期的图像,*r_p表示(rcos角度,rsin角度)
Xp = torch.cat([r_p*torch.cos(theta_p),r_p*torch.sin(theta_p)],axis = 1)
# torch.cat 拼接两个生成一个二维张量,axis = 1表示在axis = 1方向上拼接,就是Xp(x,y)坐标,x,y都是-5到5之内的数,Xp的每一行表示一个点
print(Xp.shape)
Yp = torch.ones_like(r_p)
print(Yp.shape)

#生成负样本, 大圆环分布
r_n = 8.0 + torch.normal(0.0,1.0,size = [n_negative,1]) 
theta_n = 2*np.pi*torch.rand([n_negative,1])
Xn = torch.cat([r_n*torch.cos(theta_n),r_n*torch.sin(theta_n)],axis = 1)
Yn = torch.zeros_like(r_n)

#汇总样本
X = torch.cat([Xp,Xn],axis = 0)   # 正负样本所有点进行拼接。在0维度上,也就是把两个n行,每行一系列点的张量在行上往下拼接
Y = torch.cat([Yp,Yn],axis = 0)   # Yp,Yn表示正负样本的标签,也就是二分类的类别标签而已
 

#可视化
plt.figure(figsize = (6,6))
plt.scatter(Xp[:,0].numpy(),Xp[:,1].numpy(),c = "r")
plt.scatter(Xn[:,0].numpy(),Xn[:,1].numpy(),c = "g")
plt.legend(["positive","negative"]);

# 构建数据管道迭代器
def data_iter(features, labels, batch_size=8):
    num_examples = len(features)  #样本总个数,就是点的总个数  len()函数来获取列表的长度,也就是样本的数量
    indices = list(range(num_examples))  # 转换成列表类型  range(num_examples)生成一个有序数组,转成列表形式,
    # Python中,range()函数生成的整数序列范围为 [start, end) (包左不包右)
    # torch.range()函数生成的整数序列范围为 [start, end] (包左包右) 
    np.random.shuffle(indices)  #样本的读取顺序是随机的,打乱样本点的索引
    for i in range(0, num_examples, batch_size):
        # 列表切片,根据样本个数长度范围内的随机列表生成一个批次的列表索引
        indexs = torch.LongTensor(indices[i: min(i + batch_size, num_examples)])   
        # features输入的是X,[4000,2]的4000个点的张量格式,
        # torch.index_select是PyTorch中的一个函数,其主要用途是允许您在特定维度上选择索引,从而获取所需的张量子集2。
        yield  features.index_select(0, indexs), labels.index_select(0, indexs)  # 取出总样本中batch_size个数据
        
# 测试数据管道效果   
batch_size = 8
# X,Y 是正负样本也就是两个类别拼接后的张量,X是点,Y是点的类别
(features,labels) = next(data_iter(X,Y,batch_size))
print(features)
print(labels)
class DNNModel(nn.Module):
    def __init__(self):
        super(DNNModel, self).__init__()
        self.w1 = nn.Parameter(torch.randn(2,4))  # 两行四列
        self.b1 = nn.Parameter(torch.zeros(1,4))  # 偏置全是0
        self.w2 = nn.Parameter(torch.randn(4,8))
        self.b2 = nn.Parameter(torch.zeros(1,8))
        self.w3 = nn.Parameter(torch.randn(8,1))
        self.b3 = nn.Parameter(torch.zeros(1,1))   # 最后预测结果形状[8,1]  是8个数的大小,就是这组权重偏置预测出来的结果

    # 正向传播
    def forward(self,x):
        x = torch.relu(x@self.w1 + self.b1)
        x = torch.relu(x@self.w2 + self.b2)
        y = torch.sigmoid(x@self.w3 + self.b3)
        return y   # y就是预测
    
    # 损失函数(二元交叉熵)
    def loss_fn(self,y_pred,y_true):  
        #将预测值限制在1e-7以上, 1- (1e-7)以下,避免log(0)错误
        eps = 1e-7
        y_pred = torch.clamp(y_pred,eps,1.0-eps)
        bce = - y_true*torch.log(y_pred) - (1-y_true)*torch.log(1-y_pred)
        return torch.mean(bce)
    
    # 评估指标(准确率)
    def metric_fn(self,y_pred,y_true):
        y_pred = torch.where(y_pred>0.5,torch.ones_like(y_pred,dtype = torch.float32),
                          torch.zeros_like(y_pred,dtype = torch.float32))
        acc = torch.mean(1-torch.abs(y_true-y_pred))
        return acc
    
model = DNNModel()
# 测试模型结构
batch_size = 10
(features,labels) = next(data_iter(X,Y,batch_size))

predictions = model(features)

loss = model.loss_fn(labels,predictions)
metric = model.metric_fn(labels,predictions)

print("init loss:", loss.item())
print("init metric:", metric.item())

训练模型

def train_step(model, features, labels):   
    
    # 正向传播求损失
    predictions = model.forward(features)
    loss = model.loss_fn(predictions,labels)
    metric = model.metric_fn(predictions,labels)
        
    # 反向传播求梯度
    loss.backward()
    
    # 梯度下降法更新参数
    for param in model.parameters():
        #注意是对param.data进行重新赋值,避免此处操作引起梯度记录
        param.data = (param.data - 0.01*param.grad.data) 
        
    # 梯度清零
    model.zero_grad()
        
    return loss.item(),metric.item()
 

def train_model(model,epochs):
    for epoch in range(1,epochs+1):
        loss_list,metric_list = [],[]
        for features, labels in data_iter(X,Y,20):
            # 分批训练
            lossi,metrici = train_step(model,features,labels)
            loss_list.append(lossi)
            metric_list.append(metrici)
        loss = np.mean(loss_list)
        metric = np.mean(metric_list)

        if epoch%10==0:
            printbar()  # 最开始定义的打印时间的函数
            print("epoch =",epoch,"loss = ",loss,"metric = ",metric)
        
train_model(model,epochs = 100)

# 结果可视化
fig, (ax1,ax2) = plt.subplots(nrows=1,ncols=2,figsize = (12,5))
ax1.scatter(Xp[:,0],Xp[:,1], c="r")
ax1.scatter(Xn[:,0],Xn[:,1],c = "g")
ax1.legend(["positive","negative"]);
ax1.set_title("y_true");

Xp_pred = X[torch.squeeze(model.forward(X)>=0.5)]
Xn_pred = X[torch.squeeze(model.forward(X)<0.5)]

ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c = "r")
ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c = "g")
ax2.legend(["positive","negative"]);
ax2.set_title("y_pred");

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/254416.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【网络安全】-Linux操作系统—操作系统发展历史与Linux

文章目录 操作系统发展历史初期的操作系统分时操作系统个人计算机操作系统 Linux的诞生UNIX与GNU项目Linux内核的创建 Linux的特点开放源代码多样性社区支持 Linux的应用服务器和超级计算机嵌入式系统桌面系统 总结 操作系统发展历史 操作系统&#xff08;Operating System&am…

详细教程 - 从零开发 Vue 鸿蒙harmonyOS应用 第五节 (基于uni-app封装鸿蒙接口请求库)

随着鸿蒙系统的兴起,越来越多的app会采用鸿蒙开发。而鸿蒙开发必不可少的就是调用各种接口服务。为了简化接口的调用流程,我们通常会做一层封装。今天就来讲解一下,如何用uni-app封装鸿蒙的接口请求库。 一、新建项目 首先我们要新建一个鸿蒙项目啦&#xff01;当然选择第一个…

neuq-acm预备队训练week 9 P1119 灾后重建

解题思路 本题可以用最短路算法——Floyd AC代码 #include<bits/stdc.h> #define inf 1e9 using namespace std; const int N 2e2 50; int n, m, q, now 0, a, b, c, t[N], G[N][N];int main() {scanf("%d%d", &n, &m);for(int i 0;i<n;i)sc…

044.Python异常处理_手动抛出异常自定义异常

我 的 个 人 主 页&#xff1a;&#x1f449;&#x1f449; 失心疯的个人主页 &#x1f448;&#x1f448; 入 门 教 程 推 荐 &#xff1a;&#x1f449;&#x1f449; Python零基础入门教程合集 &#x1f448;&#x1f448; 虚 拟 环 境 搭 建 &#xff1a;&#x1f449;&…

ubuntu 磁盘挂载

1.前提 给自己的计算机加了一个新硬盘&#xff0c;怎么在ubuntu中使用呢 特别提示&#xff01;对磁盘操作存在一定丢失数据的风险&#xff0c;本篇是在一个新购买的硬盘上进行操作&#xff01;如果你使用的是一个带数据的硬盘&#xff0c;请勿参考本篇文章&#xff01; 2.找…

解决:AttributeError: module ‘scipy.misc’ has no attribute ‘imread’

解决&#xff1a;AttributeError: module ‘scipy.misc’ has no attribute ‘imread’ 文章目录 解决&#xff1a;AttributeError: module scipy.misc has no attribute imread背景报错问题报错翻译报错位置代码报错原因解决方法方法一 scipy版本回退&#xff08;不推荐&#…

记录 | Visual Studio报错:const char*类型的值不能用于初始化char*类型

Visual Studio 报错&#xff1a; const char *”类型的值不能用于初始化“char *”类型的实体错误 解决办法&#xff1a; 1&#xff0c;强制类型转换&#xff0c;例如&#xff1a; char * Singer::pv[] {(char*)"other", (char*)"alto", (char*)"c…

会旋转的树,你见过吗?

&#x1f388;个人主页:&#x1f388; :✨✨✨初阶牛✨✨✨ &#x1f43b;强烈推荐优质专栏: &#x1f354;&#x1f35f;&#x1f32f;C的世界(持续更新中) &#x1f43b;推荐专栏1: &#x1f354;&#x1f35f;&#x1f32f;C语言初阶 &#x1f43b;推荐专栏2: &#x1f354;…

工业数据的特殊性和安全防护体系探索思考

随着工业互联网的发展&#xff0c;工业企业在生产运营管理过程中会产生各式各样数据&#xff0c;主要有研发设计数据、用户数据、生产运营数据、物流供应链数据等等&#xff0c;这样就形成了工业大数据&#xff0c;这些数据需要依赖企业的网络环境和应用系统进行内外部流通才能…

字体包引入以及使用

将UI给的字体包下载到assets文件夹下 app.vue全局定义 <style> * {margin: 0;padding: 0; }font-face {font-family: PangMenZhengDao-3;src: url(/assets/fonts/庞门正道标题体3.0.TTF) format(truetype);font-weight: bold;font-style: normal; } </style>页面使…

亚马逊云科技re:Invent推出生成式AI技术堆栈及关键服务和工具

亚马逊云科技于29日推出“生成式AI技术堆栈”后&#xff0c;又在30日的re:Invent 2023大会上宣布了一系列支持这一全新堆栈的关键服务和工具。 亚马逊云科技数据和人工智能副总裁Swami Sivasubramanian在主题演讲中&#xff0c;将生成式人工智能与“超新星爆炸”进行了比较&am…

ffmpeg可以做什么

用途 FFmpeg是一个功能强大的多媒体处理工具&#xff0c;可以处理音频和视频文件。它是一个开源项目&#xff0c;可在各种操作系统上运行&#xff0c;包括Linux、Windows和Mac OS X等。以下是FFmpeg可以做的一些主要任务&#xff1a; 转换媒体格式&#xff1a;可将一个媒体格式…

云原生基础入门概念

文章目录 发现宝藏云原生的概念云原生的关键技术为何选择云原生&#xff1f;云原生的实际应用好书推荐 发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【宝藏入口】。 云原生的概念 当谈及现…

http跟https的区别

只要上过网的朋友一定接触过“HTTP”&#xff0c;每次开网页的时候&#xff0c;不管是什么网址&#xff0c;其前面都会出现HTTP字样&#xff0c;比如 “http&#xff1a;55049sjad.com”、“http&#xff1a;544.65.5.6.com”等等&#xff0c;而有些时候打开如银行等对安全性要…

【linux】SSH终端Putty配置:文件上传/下载、显示中文字体、自动登录

文章目录 写在前面putty上传/下载文件1. 下载2. 解压和配置3. 使用sz/rz3.1 下载文件:sz3.2 上传文件:rz 显示中文字体1. 下载合适的字体2. 解压和安装3. putty配置 putty自动登录1. putty配置2. putty快捷方式配置3. 使用putty 写在后面 写在前面 一篇博客介绍了12种SSH终端工…

D3132|贪心算法

435.无重叠区间 初始思路&#xff1a; 我的思路就是如果有两个区间重叠&#xff0c;保留end比较小的那个区间&#xff0c;删除end比较大的区间。 class Solution {public int eraseOverlapIntervals(int[][] intervals) {Arrays.sort(intervals, new Comparator<int[]>…

恐怖题材黑马大作,艾尔莎B760M-E D5和你玩转《心灵杀手2》

说起恐怖题材的游戏&#xff0c;相信不少朋友都会第一时间想到《生化危机》、《寂静岭》、《死亡空间》等经典系列与作品。而在最近这几年&#xff0c;恐怖题材游戏也有不少黑马出现&#xff0c;比如最近推出的《心灵杀手2》就是2010年《心灵杀手》的续作&#xff0c;它是由开发…

iPhone16:首款AI iPhone?

随着科技水平的不断发展&#xff0c;智能手机逐渐成为人们最依赖的电子产品之一。为能够满足用户需求&#xff0c;手机的硬件、外观设计与性能飞速提升&#xff0c;这也导致智能手机市场快速进入到瓶颈期。 为了能够带来更优秀的表现&#xff0c;苹果可能会为iPhone 16系列带来…

中国首家!腾讯云入选Gartner®视频平台服务市场指南代表厂商

近日&#xff0c; Gartner正式发布《Market Guide for Video Platform Services》&#xff08;《视频平台服务市场指南》&#xff0c;下称“《指南》”&#xff09;&#xff0c;凭借领先的音视频技术和产品组合优势&#xff0c;腾讯云成为中国首家且唯一入选的代表厂商。 腾讯…

GZ015 机器人系统集成应用技术样题8-学生赛

2023年全国职业院校技能大赛 高职组“机器人系统集成应用技术”赛项 竞赛任务书&#xff08;学生赛&#xff09; 样题8 选手须知&#xff1a; 本任务书共 25页&#xff0c;如出现任务书缺页、字迹不清等问题&#xff0c;请及时向裁判示意&#xff0c;并进行任务书的更换。参赛队…