pybind11:对比C++和Python解线性方程组的速度

前言

上篇博客介绍了如何在用pybind11实现ndarray和C++数组的转换自由,pybind11:实现ndarray转C++原生数组(没看过的朋友可以去看一看)下面我们以一个实际的算法例子演示一下如何使用这个技术,方便的实现 Python 调用 C++ 写的算法,并观察两个语言分别运行同一算法时,算法速率的显著性差异

算法选择

解线性方程组在工程乃至很多领域都是一项非常重要的技术,该算法的时间复杂度也是人们一直在追求的。下面就以手写最经典的 Gauss列主元素消去法 来作为实例算法演示C++和Python的运行速率的差距。

它的算法先后对比较简单,主要分为消去回代两个过程,消去即通过矩阵行变换将一个普通的矩阵(增广矩阵)转化为一个上三角矩阵,回代即从主对角线上最后一个元素( x n x_n xn)开始往回代,依次计算出( x n x_n xn, x n − 1 x_{n-1} xn1, … , x 1 x_1 x1)的值,具体过程如下

lll
按照这个思路,直接给出Python和C++的算法代码

C++

#include <iostream>

using namespace std;

// A: 系数矩阵 b: 右侧常数向量 n: 维数  
// 即计算 Ax = b (n个未知数)
double* Solve(double** A, double* d, int n) {

    double** a = new double*[n];
    for (int i = 0; i < n; ++i) {
        a[i] = new double[n+1];
        for (int j = 0; j < n; ++j) {
            a[i][j] = A[i][j];
        }
        a[i][n] = d[i];
    }

    for (int k = 0; k < n - 1; ++k) {
        // 选主元
        int p = k;
        for (int i = k + 1; i < n; ++i) {
            if (std::abs(a[i][k]) > std::abs(a[p][k])) {
                p = i;
            }
        }      
        // 交换行
        double* temp = a[k];
        a[k] = a[p];
        a[p] = temp;

        // 消元
        for (int i = k + 1; i < n; ++i) {
            double factor = a[i][k] / a[k][k];
            for (int j = k; j < n + 1; ++j) {
                a[i][j] -= factor * a[k][j];
            }
        }
    }

    // 回代
    double* x = new double[n];
    x[n - 1] = a[n - 1][n] / a[n - 1][n - 1];
    for (int i = n - 2; i >= 0; --i) {
        double sum = 0;
        for (int j = i + 1; j < n; ++j) {
            sum += a[i][j] * x[j];
        }
        x[i] = (a[i][n] - sum) / a[i][i];
    }

    // 释放动态分配的内存
    for (int i = 0; i < n; ++i) {
        delete[] a[i];
    }
    delete[] a;

    return x;
}

Python

import numpy as np

# A: 系数矩阵 b: 右侧常数向量 
# 即计算 Ax = b
def solve(A, d):
	# 确保以浮点型数据计算
	A.astype(np.float64)
    d.astype(np.float64)
    
    a = np.hstack((A, d.reshape(len(d), 1)))  # 水平拼接(先将d转化为列向量)

    n = len(A[0])
    for k in range(n - 1):
        # 选主元
        for p in range(k+1, n):
            if np.abs(a[p, k]) == np.max(np.abs(a[k:, k])):
                a[k, :], a[p, :] = a[p, :], a[k, :].copy()
        # 消元
        for i in range(k+1, n):
            a[i, k:] = a[i, k:] - a[i, k] / a[k, k] * a[k, k:]
	# 回代
    x = np.zeros(n)
    x[n-1] = a[n-1, n] / a[n-1, n-1]
    for i in range(n-2, -1, -1):
        x[i] = (a[i, n] - np.sum(a[i, i+1:n] * x[i+1:])) / a[i, i]

    return x

不难看出,C++代码远比Python代码复杂,单从使用的循环数量粗糙的分析,C++代码中使用了两个二重循环(水平拼接和选主元)和两个三重循环(消元和回代)而Python代码只用到了两个二重循环(选主元和消元)和一个一重循环(消元和回代),所以无论是算法的时间复杂度还是空间复杂度,Python代码都远小于C++的,那么它们实际的运行速度的快慢呢?

构建项目

构建项目,使用Pybind11将C++代码编译为pyd文件,项目结构即配置过程的详情参考:pybind11:实现Python调用C++代码(入门)

核心cpp代码如下:

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <iostream>
#include "pybind11_tools.cpp"
#include "cpp_code.cpp"

namespace py = pybind11;


py::array_t<double> SolvePybind11(py::array_t<double>& inputMatrix, py::array_t<double>& inputVector) {

    // ndarray转C数组
    NdarrayToCppArray<double>  InputMatrix(inputMatrix); 
    NdarrayToCppArray<double> InputVector(inputVector);

    // 记录长度 (待解方程的维度)
    int len = InputVector.lens[0];

    // 调用C函数
    double* result = Solve(InputMatrix.Matrix, InputVector.Vector, len);

    // 将C数组转化为ndarray
    py::array_t<double> outputArray = CToNdarray(result, len);

    return outputArray;
}

// 绑定C++函数
PYBIND11_MODULE(tryPybind, m) {
    m.def("solve", &SolvePybind11);
}

“pybind11_tools.cpp” 中装有自主编写的实现ndarray和C++数组相互转化的API(InputMatrix, InputVector, CToNdarray),代码详情参考:pybind11:实现ndarray转C++原生数组

调用写好的接口可以很方便的将写好的C++代码打包为二进制文件(pyd文件)提供给Python使用。将生成的pyd文件(tryPybind)放置于测试代码一个目录下,即可开始测试

开始测试

首先测试是否正常(以numpy的(linalg.solve函数))计算的结果为标准答案)

import numpy as np
import python_tool as pytool
import tryPybind  # 编译好的C++二进制文件

# ax = b (随便定义一个方程组)
a = np.array([[1., 5., 3.], 
              [4., 2., 6.],
              [9., 8., 7.]])
b = np.array([10., 20. ,30.])

print('python运算结果:')
x1 = pytool.solve(a, b)
print(x1)

print('C++运算结果:')
x2 = tryPybind.solve(a, b)
print(x2)

print('numpy运算结果:')
x3 = np.linalg.solve(a, b)
print(x3)

结果如下:

yunsaunjieg

算法没大问题,接下来测试对比运行时间,用rand生成一组随机的1000维线性方程组,分别计算运行时间,代码如下:

import numpy as np
import python_tool as pytool
import time
import tryPybind

# ax = b (1000维)
a = np.random.rand(1000, 1000)
b = np.random.rand(1000)

print('计算一个1000维的线性方程组,分别耗时如下:')

t1 = time.time()
x1 = pytool.solve(a, b)
t2 = time.time()
print("Python耗时: " + str(t2 - t1) + '秒')

t1 = time.time()
x2 = tryPybind.solve(a, b)
t2 = time.time()
print("C++耗时:    " + str(t2 - t1) + '秒')

t1 = time.time()
x3 = np.linalg.solve(a, b)
t2 = time.time()
print('numpy耗时:  ' + str(t2 - t1) + '秒')

多次执行该代码,结果如下:

yunsuansj
不难发现,Python代码虽然算法简介,复杂度低,但实际的运算速率远慢于低于C++(接近6倍),以此可直观的看见解释性语言的效率之慢的通病,更能体现出在Python代码中调用写好的C++代码所带来的强大优势。同时也可发现numpy每次都能以不到 0.2 秒的速度算完了一个1000维的线性方程组,其算法之强大。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/251195.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于linux系统的Tomcat+Mysql+Jdk环境搭建(三)centos7 安装Tomcat

Tomcat下载官网&#xff1a; Apache Tomcat - Which Version Do I Want? JDK下载官网&#xff1a; Java Downloads | Oracle 中国 如果不知道Tomcat的哪个版本应该对应哪个版本的JDK可以打开官网&#xff0c;点击Whitch Version 下滑&#xff0c;有低版本的&#xff0c;如…

Caused by: java.net.ConnectException: 拒绝连接: hadoop104/192.168.124.130:4142

项目场景&#xff1a;hadoop102接收消息&#xff0c;自定义拦截器&#xff0c;包含hello的发往hadoop103,不包含的发往hadoop104 报错原因&#xff1a; 原因1&#xff1a; 应该先开启接收方&#xff08;服务端&#xff09;&#xff0c;hadoop103,hadoop104,最后开启hadoop10…

编译android的C版本Lua库

本文讲述如何使用android studio 编译最新版本的Lua开源库),请自行下载。 我们提供的Demo,可以自行下载,工程结构如下: 本文编译的是Lua 5.4.6的版本,编译采用cmake的方式,我们支持编译静态库和动态库(我们在这一讲里:“Lua与***C在Android上的互调”是使用静态库)…

智能优化算法应用:基于JAYA算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于JAYA算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于JAYA算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.JAYA算法4.实验参数设定5.算法结果6.参考文献7.MA…

大数据与深度挖掘:如何在数字营销中与研究互动

数字营销最吸引人的部分之一是对数据的内在关注。 如果一种策略往往有积极的数据&#xff0c;那么它就更容易采用。同样&#xff0c;如果一种策略尚未得到证实&#xff0c;则很难获得支持进行测试。 数字营销人员建立数据信心的主要方式是通过研究。这些研究通常分为两类&…

Vue3快速上手笔记

Vue3快速上手 1.Vue3简介 2020年9月18日&#xff0c;Vue.js发布3.0版本&#xff0c;代号&#xff1a;One Piece&#xff08;海贼王&#xff09;耗时2年多、2600次提交、30个RFC、600次PR、99位贡献者github上的tags地址&#xff1a;https://github.com/vuejs/vue-next/release…

Docker单机部署OceanBase

文章目录 说明机器软硬件要求指导文档本次部署环境说明 OceanBase单机部署&#xff08;Docker&#xff09;一&#xff1a;拉取 OceanBase 数据库相关镜像二&#xff1a;启动 OceanBase 数据库实例完整启动日志展示 三&#xff1a;连接实例遇到报错&#xff1a;没有mysql客户端 …

selenium-grid4.3.0两种模式记录

selenium-grid4.3.0两种模式记录 本文运行&#xff0c;需要提前配置好Java11以及安装好Chrom、Firefox、Safari其中一个浏览器&#xff0c;如果是Chrom、Firefox需要下载对应版本的驱动&#xff0c;并给 webdriver 配置环境变量&#xff0c;Safari浏览器Mac系统会自带&#xf…

HiveSql语法优化二 :join算法

Hive拥有多种join算法&#xff0c;包括Common Join&#xff0c;Map Join&#xff0c;Bucket Map Join&#xff0c;Sort Merge Buckt Map Join等&#xff0c;下面对每种join算法做简要说明&#xff1a; Common Join Common Join是Hive中最稳定的join算法&#xff0c;其通过一个M…

案例067:基于微信小程序的小区租拼车管理信息系统

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;SSM JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder X 小程序…

【MySQL】(DDL) 数据库操作

创建&#xff1a; create database 数据库名称; //创建数据库 create database if not exists 数据库名 ; //创建数据库并添加判断 &#xff08;如果存在就不创建不存在就创建 &#xff09; create database 数据库名 default charset 字符集 ; //创建数据库并设置字符集 查…

计算机毕业设计 基于SpringBoot的二手物品交易管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

漏刻有时数据可视化Echarts组件开发(42)动态创建DIV容器

效果展示 引入外部文件 <script src"js/jquery.min.js"></script><script type"text/javascript" src"js/echarts.5.4.3.min.js"></script>CSS层叠样式表 实现一行3列效果&#xff0c;自动换行&#xff1b; .ecbox {he…

HNU-计算机网络-实验3-应用层和传输层协议分析(PacketTracer)

计算机网络 课程基础实验三应用层和传输层协议分析&#xff08;PacketTracer&#xff09; 计科210X 甘晴void 202108010XXX 【给助教的验收建议】 如果是助教&#xff0c;比起听同学读报告&#xff0c;更好的验收方式是随机抽取一个场景&#xff08;URL/HTTPS/FTP&#xff09…

Kafka-日志索引

Kafka的Log日志梳理 Topic下的消息是如何存储的&#xff1f; 在搭建Kafka服务时&#xff0c;在server.properties配置文件中通过log.dir属性指定了Kafka的日志存储目录。 实际上&#xff0c;Kafka的所有消息就全都存储在这个目录下。 这些核心数据文件中&#xff0c;.log结尾…

Flask学习三:模型操作

ORM flask 通过Model操作数据库&#xff0c;不管你的数据库是MySQL还是Sqlite&#xff0c;flask自动帮你生成相应数据库类型的sql语句&#xff0c;所以不需要关注sql语句和类型&#xff0c;对数据的操作flask帮我们自动完成&#xff0c;只需要会写Model就可以了 flask使用对象关…

[渗透测试学习] Keeper - HackTheBox

信息搜集 nmap扫描一下端口 nmap -sV -sC -p- -v --min-rate 1000 10.10.11.227发现有两个端口&#xff0c;22端口为ssh服务&#xff0c;80端口是http服务 尝试访问一下80端口 提示我们要访问tickets.keeper.htb/rt/&#xff0c;访问发现不行 那么把该域名添加到hosts里面 …

数据库——关系数据的规范化:范式判断【知识点罗列+例题讲解】

知识点罗列&#xff1a; 各种范式之间的关系 1.第一范式1NF&#xff1a; 如果关系模式R中所有的属性都具有原子性&#xff0c;均是不可再分的&#xff08;一个属性不能再被分解成更小的数据单元&#xff09;&#xff0c;则称R属于第一范式&#xff0c;简称1NF&#xff0c;记作R…

springMVC-模型数据的处理

一、数据放入到request域当中 1、把获取的数据放入request域中&#xff0c; 方便在跳转页面去显示 <a>添加主人信息</a> <form action"vote/vote04" method"post" >主人id&#xff1a;<input type"text" name"id&q…

Redis List类型

列表类型是用来存储多个有序的字符串&#xff0c;如图所示&#xff0c;a、b、c、d、e 五个元素从左到右组成了一个有序的列表&#xff0c;列表中的每个字符串称为元素 (element)&#xff0c;一个列表最多可以存储2的32次方 -1个元素。在 Redis 中&#xff0c;可以对列表两端插入…