Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

我们知道torch.meshgrid()函数的功能是生成网格,可以用于生成坐标;

在numpy中也有一样的函数np.meshgrid(),但是用法不太一样,我们直接上代码进行解释。

1、两者在用法上的区别

比如:我要生成下图的xy坐标点,看下两者的实现方式:

在这里插入图片描述

np.meshgrid()

>>> import numpy as np
>>> w, h = 4, 2
# 注意,此时输入的是由w和h生成的一维数组
#      此时输出的是网格x的坐标grid_x以及网格y的坐标grid_y
>>> grid_x, grid_y  = np.meshgrid(np.arange(w), np.arange(h)) 

>>> grid_x
array([[0, 1, 2, 3],  
       [0, 1, 2, 3]])
>>> grid_y
array([[0, 0, 0, 0],
       [1, 1, 1, 1]])

torch.meshgrid()

>>> import torch
# 注意,此时输入的是由h和w生成的一维数组(和numpy中的输入顺序相反)
#      此时输出的是网格y的坐标grid_y以及网格x的坐标grid_x(和numpy中的输出顺序相反)
>>> grid_y, grid_x =  torch.meshgrid(
...         torch.arange(h),
...         torch.arange(w)
...     )
>>> grid_x
tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])
>>> grid_y
tensor([[0, 0, 0, 0],
        [1, 1, 1, 1]])

2、应用案例

2.1 利用np.meshgrid()来画决策边界

我们可以利用np.meshgrid()来画等高线图

# 等高线图
import numpy as np
import matplotlib.pyplot as plt

# 模拟海拔高度
def fz(x, y):
  z = (1 -x / 2 + x**5 + y**3) * np.exp(-x**2-y**2)
  return z

w = np.linspace(-4, 4, 100)
h = np.linspace(-2, 2, 100)

grid_x, grid_y = np.meshgrid(w, h)
z = fz(grid_x, grid_y)

plt.figure('Contour Chart',facecolor='lightgray')
plt.title('contour',fontsize=16)
plt.grid(linestyle=':')

cntr = plt.contour(
    grid_x, # 网格坐标矩阵的x坐标(2维数组)
    grid_y, # 网格坐标矩阵的y坐标(2维数组)
    z,      # 网格坐标矩阵的z坐标(2维数组)
    8,      # 等高线绘制8部分
    colors = 'black', # 等高线图颜色
    linewidths = 0.5 # 等高线图线宽
)
# 设置标签
plt.clabel(cntr, inline_spacing = 1, fmt='%.2f', fontsize=10)
# 填充颜色  大的是红色  小的是蓝色
plt.contourf(grid_x, grid_y, z, 8, cmap='jet')

plt.legend()
plt.show()

在这里插入图片描述

我们可以利用np.meshgrid()来画决策边界。

from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import numpy as np

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC


# 使用sklearn自带的moon数据
X, y = make_moons(n_samples=100,noise=0.15,random_state=42)

# 绘制生成的数据
def plot_dataset(X,y,axis):
    plt.plot(X[:,0][y == 0],X[:,1][y == 0],'bs')
    plt.plot(X[:,0][y == 1],X[:,1][y == 1],'go')
    plt.axis(axis)
    plt.grid(True,which='both')


# 画出决策边界
def plot_pred(clf,axes):
    w = np.linspace(axes[0],axes[1], 100)
    h = np.linspace(axes[2],axes[3], 100)
    grid_x, grid_y = np.meshgrid(w, h)
    # grid_x 和 grid_y 被拉成一列,然后拼接成10000行2列的矩阵,表示所有点
    grid_xy = np.c_[grid_x.ravel(), grid_y.ravel()]
    # 二维点集才可以用来预测
    y_pred = clf.predict(grid_xy).reshape(grid_x.shape)
    # 等高线
    plt.contourf(grid_x, grid_y,y_pred,alpha=0.2)


ploy_kernel_svm_clf = Pipeline(
    steps=[
        ("scaler",StandardScaler()),
        ("svm_clf",SVC(kernel='poly', degree=3, coef0=1, C=5))
    ]
)


ploy_kernel_svm_clf.fit(X,y)

plot_pred(ploy_kernel_svm_clf,[-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.show()

在这里插入图片描述

2.2 利用torch.meshgrid()生成网格所有坐标的矩阵

在目标检测YOLO中将图像划分为单元网格的部分就用到了torch.meshgrid()函数。

import torch
import numpy as np


def create_grid(input_size, stride=32):
    # 1、获取原始图像的w和h
    w, h = input_size, input_size
    # 2、获取经过32倍下采样后的feature map
    ws, hs = w // stride, h // stride
    # 3、生成网格的y坐标和x坐标
    grid_y , grid_x = torch.meshgrid([
        torch.arange(hs),
        torch.arange(ws)
    ])
    # 4、将grid_x和grid_y进行拼接,拼接后的维度为【H, W, 2】
    grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
    # 【H, W, 2】 -> 【HW, 2】
    grid_xy = grid_xy.view(-1, 2)
    return grid_xy



if __name__ == '__main__':
    print(create_grid(input_size=32*4))
# 生成网格所有坐标的矩阵
tensor([[0., 0.],
        [1., 0.],
        [2., 0.],
        [3., 0.],
        
        [0., 1.],
        [1., 1.],
        [2., 1.],
        [3., 1.],
        
        [0., 2.],
        [1., 2.],
        [2., 2.],
        [3., 2.],
        
        [0., 3.],
        [1., 3.],
        [2., 3.],
        [3., 3.]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/264610.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何在Window系统下搭建Nginx服务器环境并部署前端项目

1.下载并安装Nginx 在nginx官网nginx: download 下载稳定版本至自己想要的目录。 解压后进入目录 2.启动Nginx服务器 启动方式有两种: (1)直接进入nginx安装目录下,双击nginx.exe运行,此时命令行窗口一闪而过&…

浏览器 cookie 的原理(详)

目录 1,cookie 的出现2,cookie 的组成浏览器自动发送 cookie 的条件 3,设置 cookie3.1,服务端设置3.1,客户端设置3.3,删除 cookie 4,使用流程总结 整理和测试花了很大时间,如果对你有…

python调用GPT API

每次让gpt给我生成一个调用api的程序时,他经常会调用以前的一些api的方法,导致我的程序运行错误,所以这期记录一下使用新的方法区调用api 参考网址 Migration Guide,这里简要地概括了一下新版本做了哪些更改 OpenAI Python API l…

引领汽车营销新趋势,3DCAT实时云渲染助力汽车三维可视化

当前,汽车产业发展正从电动化的上半场,向智能化的下半场迈进。除了车机技术体验的智能化之外,观车体验的智能化也不容忽视。 这是因为,随着数字化、智能化、个性化的趋势,消费者对汽车的需求和期待也越来越高&#xf…

2016年第五届数学建模国际赛小美赛B题直达地铁线路解题全过程文档及程序

2016年第五届数学建模国际赛小美赛 B题 直达地铁线路 原题再现: 在目前的大都市地铁网络中,在两个相距遥远的车站之间运送乘客通常需要很长时间。我们可以建议在两个长途车站之间设置直达班车,以节省长途乘客的时间。   第一部分&#xf…

Qt的简单游戏实现提供完整代码

文章目录 1 项目简介2 项目基本配置2.1 创建项目2.2 添加资源 3 主场景3.1 设置游戏主场景配置3.2 设置背景图片3.3 创建开始按钮3.4 开始按钮跳跃特效实现3.5 创建选择关卡场景3.6 点击开始按钮进入选择关卡场景 4 选择关卡场景4.1场景基本设置4.2 背景设置4.3 创建返回按钮4.…

Java面向对象(初级)

面向对象编程(基础) 面向对象编程(OOP)是一种编程范式,它强调程序设计是围绕对象、类和方法构建的。在面向对象编程中,程序被组织为一组对象,这些对象可以互相传递消息。面向对象编程的核心概念包括封装、继承和多态。…

2023.12.21 关于 Redis 常用数据结构 和 单线程模型

目录 各数据结构具体编码方式 查看 key 对应 value 的编码方式 Reids 单线程模型 经典面试题 IO 多路复用 Redis 常用数据结构 Redis 中所有的 key 均为 String 类型,而不同的是 value 的数据类型却有很多种以下介绍 5 种 value 常见的数据类型 注意&#xff1…

阿里云 ACK One 新特性:多集群网关,帮您快速构建同城容灾系统

云布道师 近日,阿里云分布式云容器平台 ACK One[1]发布“多集群网关”[2](ACK One Multi-cluster Gateways)新特性,这是 ACK One 面向多云、多集群场景提供的云原生网关,用于对多集群南北向流量进行统一管理。 基于 …

虚拟机的下载、安装(模拟出服务器)

下载 vmware workstation(收费的虚拟机) 下载vbox 网址:Oracle VM VirtualBox(免费的虚拟机) 以下选择一个下载即可,建议下载vbox,因为是免费的。安装的时候默认下一步即可(路径最好…

hiveserver负载均衡配置

一.安装nginx 参数我的另一篇文章:https://mp.csdn.net/mp_blog/creation/editor/135152478 二.配置nginx服务参数 worker_processes 1; events { worker_connections 1024; } stream { upstream hiveserver2 { # least_conn; # 使用最少连接路由…

八大排序算法@直接插入排序(C语言版本)

目录 直接插入排序概念算法思想代码实现核心算法:直接插入排序的算法实现: 特性总结 直接插入排序 概念 算法思想 把待排序的记录按其关键码值的大小逐个插入到一个已经排好序的有序序列中,直到所有的记录插入完为止,得到一个新…

【Spring实战】配置多数据源

文章目录 1. 配置数据源信息2. 创建第一个数据源3. 创建第二个数据源4. 创建启动类及查询方法5. 启动服务6. 创建表及做数据7. 查询验证8. 详细代码总结 通过上一节的介绍,我们已经知道了如何使用 Spring 进行数据源的配置以及应用。在一些复杂的应用中,…

文档 - - - Docsify文档创建

目录 1. Docsify 介绍2. 创建 Docsify 项目2.1 安装 Node.js2.1 安装 docsfiy-cli2.3 初始化项目2.4 运行项目2.5 使用 Python 运行项目(扩展,不推荐有bug) 3. 配置 Docsify 项目3.1 修改等待加载文字3.2 添加网站 ico 图标3.3 创建新页面写文…

python 用OpenCV 将图片转视频

import os import cv2 import numpy as npcv2.VideoWriter()参数 cv2.VideoWriter() 是 OpenCV 中用于创建视频文件的类。它的参数如下: filename:保存视频的文件名。 fourcc:指定视频编解码器的 FourCC 代码&#xf…

SVM —— 代码实现

SMO 算法的实现步骤: 代码如下: import numpy as np import matplotlib.pyplot as plt import seaborn as sns import random# 设置中文字体为宋体,英文字体为 times new roman sns.set(font"SimSun", style"ticks", fo…

webpack学习-7.创建库

webpack学习-7.创建库 1.暴露库1.1概念1.2验证1.2.1 不导出方法1.2.2 导出方法 2.外部化 lodash3.外部化的限制4.最终步骤5.使用自己的库5.1坑 6.总结 1.暴露库 这个模块学习有点坑。看名字就是把自己写的个包传到npm,而且还要在项目中使用到它,支持各种…

java类和对象的思想概述

0.面向对象Object OOP——名人名言:类是写出来的,对象是new出来的 **> 学习面向对象的三条路线 java类以及类成员:(重点)类成员——属性、方法、构造器、(熟悉)代码块、内部类面向对象特征&…

在Next.js和React中搭建Cesium项目

在Next.js和React中搭建Cesium项目,需要确保Cesium能够与服务端渲染(SSR)兼容,因为Next.js默认是SSR的。Cesium是一个基于WebGL的地理信息可视化库,通常用于在网页中展示三维地球或地图。下面是一个基本的步骤,用于在Next.js项目中…

VideoPoet: Google的一种用于零样本视频生成的大型语言模型

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…