CUBLAS库入门教程(从环境配置讲起)

文章目录

  • 前言
  • 一、搭建环境
  • 二、简单介绍
  • 三、 具体例子
  • 四、疑问


前言

CUBLAS库是NVIDIA CUDA用于线性代数计算的库。使用CUBLAS库的原因是我不想去直接写核函数。
(当然,你还是得学习核函数该怎么写。但是人家写好的肯定比我自己写的更准确!)


一、搭建环境

  1. 安装CUDA库,具体可以看我上一篇文章:在C++项目中集成CUDA程序加速(从环境配置讲起)
  2. 如果你是装在默认路径下,那么 CUBLAS库的头文件就在:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.0\include 路径下面,以cublas开头的.h文件。
  3. 所以,还是按照步骤1的文章进行环境配置,然后只需要多在添加依赖项中增加一个cublas.lib就可以了。

二、简单介绍

  1. CUBLAS Introdution 是官方文档。(全英文的,还有不少数学公式。大家有不理解的可以直接留言区问相关API,我们一起讨论学习。)
  2. CUBLAS Samples 是官方示例,所有API都有。
  3. 对于API名称,都是cublasl<t>...,其中有下述类型选择:
    在这里插入图片描述
  4. CUBLAS库的矩阵是列向量的,跟glm一致。
  5. CUBLAS对于矩阵或者向量的index是从1开始的。所以,如果有函数的返回结果是个index(比如查找矩阵中的最大值),记得要index - 1才是我们要的。

三、 具体例子

下面我以矩阵与向量相乘的函数进行举例,看看是怎么用的。

  1. 首先,通过查找官方文档,知道是如下的函数:
cublasStatus_t cublasDgemv(cublasHandle_t handle, cublasOperation_t trans,
                           int m, int n,
                           const double          *alpha,
                           const double          *A, int lda,
                           const double          *x, int incx,
                           const double          *beta,
                           double          *y, int incy)
/*
* handle		: CUBLAS的句柄,用以管理CUBLAS库的上下文和资源
* CUBLAS_OP_N	: 指定矩阵操作模式。CUBLAS_OP_N代表正常模式(列向量);CUBLAS_OP_T代表转置模式(行向量)
* m				: 矩阵A的行数
* n				: 矩阵A的列数
* alpha			: 与矩阵A相乘的标量
* A				: 指向存储在device上面的矩阵数据指针
* lda			: 矩阵的列数,代表矩阵在内存中的存储方式
* x				: 向量X
* incx			: 向量x中相邻两个元素的index间隔,一般为1
* beta			: 与向量y相乘的标量
* y				: 向量y
* incy			: 向量y中相邻两个元素的index间隔,一般为1
*/

具体计算公式如下:
这是具体的计算公式

  1. 如果我们只是想计算矩阵和向量相乘,那么我们只需要令 α = 1.0, β = 0.0,然后传入我们要的Ax就行了。
  2. 最后,具体代码如下:
/// MyCublas.cuh

#pragma once

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include "cublas_v2.h"

extern "C" void MatrixMulVectorCublas(
	const double* matrix, const int row, const int col, 
	const double* vector, double* result
);
/// MyCublas.cu

#include "MyCublas.cuh"
#include "CublasUtility.h"

void MatrixMulVectorCublas(
	const double* matrix, const int row, const int col,
	const double* vector, double* result)
{
	// 1. 初始化句柄
	cublasHandle_t handle;
	CUBLAS_CHECK(cublasCreate(&handle));

	// 2. 分配内存
	double* dev_matrix = NULL;
	double* dev_vector = NULL;
	CUDA_CHECK(cudaMalloc((void**)&dev_matrix, sizeof(double) * row * col));
	CUDA_CHECK(cudaMalloc((void**)&dev_vector, sizeof(double) * row));

	CUDA_CHECK(cudaMemcpy(dev_matrix, matrix, sizeof(double) * row * col, cudaMemcpyHostToDevice));
	CUDA_CHECK(cudaMemcpy(dev_vector, vector, sizeof(double) * row, cudaMemcpyHostToDevice));

	// 3. 执行矩阵乘法
	double* dev_result = NULL;
	CUDA_CHECK(cudaMalloc((void**)&dev_result, sizeof(double) * col));
	CUDA_CHECK(cudaMemset(dev_result, 0, sizeof(double) * col));

	const double alpha = 1.0;
	const double beta = 0.0;

	CUBLAS_CHECK(cublasDgemv(handle, CUBLAS_OP_N, row, col, &alpha, dev_matrix, col, dev_vector, 1, &beta, dev_result, 1));

	CUDA_CHECK(cudaMemcpy(result, dev_result, sizeof(double) * col, cudaMemcpyDeviceToHost));

	// 4. 释放内存
	CUDA_CHECK(cudaFree(dev_matrix));
	CUDA_CHECK(cudaFree(dev_vector));
	CUDA_CHECK(cudaFree(dev_result));
	CUBLAS_CHECK(cublasDestroy(handle));

}
/// CublasUtility.h

#pragma once
#include <string>
#include <stdexcept>
// CUDA API error checking
#define CUDA_CHECK(err)                                                                            \
    do {                                                                                           \
        cudaError_t err_ = (err);                                                                  \
        if (err_ != cudaSuccess) {                                                                 \
            std::printf("CUDA error %d at %s:%d\n", err_, __FILE__, __LINE__);                     \
            throw std::runtime_error("CUDA error");                                                \
        }                                                                                          \
    } while (0)

// cublas API error checking
#define CUBLAS_CHECK(err)                                                                          \
    do {                                                                                           \
        cublasStatus_t err_ = (err);                                                               \
        if (err_ != CUBLAS_STATUS_SUCCESS) {                                                       \
            std::printf("cublas error %d at %s:%d\n", err_, __FILE__, __LINE__);                   \
            throw std::runtime_error("cublas error");                                              \
        }                                                                                          \
    } while (0)

/// main.cpp

#include "MyCublas.cuh"

#include <iostream>

int main()
{
    double matrix[12] = { 1.0, 2.0, 3.0, 4.0,
                    5.0, 6.0, 7.0, 8.0,
                    9.0, 10.0, 11.0, 12.0};

    double vector[4] = { 1.0, 2.0, 3.0};

    double result[4] = { 0.0 };

    MatrixMulVectorCublas(matrix, 3, 4, vector, result);

    for (int i = 0; i < 4; ++i)
    {
        std::cout << result[i] << ", ";
    }
	return 0;
}

四、疑问

对于上述代码,我还有以下的疑问:

  1. 我在运行下面这句的时候,VS显示我的进程内存会到2.2GB左右,难道真的需要这么大吗?
CUBLAS_CHECK(cublasCreate(&handle));
  1. 上述代码运行的结果是:38, 44, 50, 0。但是实际结果应该是:38, 44, 50, 56。查了很久还是没差出来为什么。希望有细心的小伙伴帮我检查一下!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/98667.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序发布一个npm包

参考:https://developers.weixin.qq.com/miniprogram/dev/devtools/npm.html 同npm一样流程 npm install weixin_heath_apis

计算机视觉-LeNet

目录 LeNet LeNet在手写数字识别上的应用 LeNet在眼疾识别数据集iChallenge-PM上的应用 LeNet LeNet是最早的卷积神经网络之一。1998年&#xff0c;Yann LeCun第一次将LeNet卷积神经网络应用到图像分类上&#xff0c;在手写数字识别任务中取得了巨大成功。LeNet通过连续使用…

电脑组装教程分享!

案例&#xff1a;如何自己组装电脑&#xff1f; 【看到身边的小伙伴组装一台自己的电脑&#xff0c;我也想试试。但是我对电脑并不是很熟悉&#xff0c;不太了解具体的电脑组装步骤&#xff0c;求一份详细的教程&#xff01;】 电脑已经成为我们日常生活中不可或缺的一部分&a…

vue3组合式api bus总线式通信

vue2中可以创建一个 vue 实例&#xff0c; 做为 总结来完成组件间的通信 但是在vue3中&#xff0c; 这种方法是不能使用的。 因为vue3中main.js中&#xff0c; 使用的createApp() 没有机会再写 new Vue了 但是我们可以使用 mitt 的插件来解决这个问题 vue3 bus组件的用法 安装…

Python序列类型

序列&#xff08;Sequence&#xff09;是有顺序的数据列&#xff0c;Python 有三种基本序列类型&#xff1a;list, tuple 和 range 对象&#xff0c;序列&#xff08;Sequence&#xff09;是有顺序的数据列&#xff0c;二进制数据&#xff08;bytes&#xff09; 和 文本字符串&…

MATLAB中isequal函数转化为C语言

背景 有项目算法使用matlab中isequal函数进行运算&#xff0c;这里需要将转化为C语言&#xff0c;从而模拟算法运行&#xff0c;将算法移植到qt。 MATLAB中isequal简单介绍 语法 tf isequal(A,B) tf isequal(A1,A2,...,An) 说明 如果 A 和 B 等效&#xff0c;则 tf is…

飞书接入ChatGPT,实现智能化问答助手功能,提供高效的解答服务

文章目录 前言环境列表1.飞书设置2.克隆feishu-chatgpt项目3.配置config.yaml文件4.运行feishu-chatgpt项目5.安装cpolar内网穿透6.固定公网地址7.机器人权限配置8.创建版本9.创建测试企业10. 机器人测试 前言 在飞书中创建chatGPT机器人并且对话&#xff0c;在下面操作步骤中…

DataLoader的使用

示例代码&#xff1a; import torchvision from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter# 准备的测试数据集 test_data torchvision.datasets.CIFAR10("./dataset", trainFalse, transformtorchvision.transforms.…

postgresql-窗口函数

postgresql-窗口函数 简介窗口函数的定义分区选项&#xff08;PARTITION BY&#xff09;排序选项&#xff08;ORDER BY&#xff09;窗口选项&#xff08;frame_clause&#xff09; 聚合窗口函数排名窗口函数演示了 CUME_DIST 和 NTILE 函数 取值窗口函数 简介 常见的聚合函数&…

java八股文面试[数据库]——MySql聚簇索引和非聚簇索引区别

聚集索引和非聚集索引 聚集索引和非聚集索引的根本区别是表记录的排列顺序和与索引的排列顺序是否一致。 1、聚集索引 聚集索引表记录的排列顺序和索引的排列顺序一致&#xff08;以InnoDB聚集索引的主键索引来说&#xff0c;叶子节点中存储的就是行数据&#xff0c;行数据在…

音频母带制作::AAMS V4.0 Crack

自动音频母带制作简介。 使用 AAMS V4 让您的音乐听起来很美妙&#xff01; 作为从事音乐工作的音乐家&#xff0c;您在向公众发布材料时需要尽可能最好的声音&#xff0c;而为所有音频扬声器系统提供良好的商业声音是一项困难且耗时的任务。AI掌握的力量&#xff01; 掌控您…

Web安全——信息收集上篇

Web安全 一、信息收集简介二、信息收集的分类三、常见的方法四、在线whois查询在线网站备案查询 五、查询绿盟的whois信息六、收集子域名1、子域名作用2、常用方式3、域名的类型3.1 A (Address) 记录&#xff1a;3.2 别名(CNAME)记录&#xff1a;3.3 如何检测CNAME记录&#xf…

flutter自定义按钮-文本按钮

目录 前言 需求 实现 前言 最近闲着无聊学习了flutter的一下知识&#xff0c;发现flutter和安卓之间&#xff0c;页面开发的方式还是有较大的差异的&#xff0c;众所周知&#xff0c;android的页面开发都是写在xml文件中的&#xff0c;而flutter直接写在代码里&#xff08;da…

Java 16进制字符串转换成GBK字符串

问题&#xff1a; 现在已知有一个16进制字符串 435550D3C3D3DAD4DABDBBD2D7CFECD3A6CFFBCFA2D6D0B4E6B7C5D5DBBFDBD0C5CFA2A3ACD5DBBFDBBDF0B6EE3130302E3036 而且知道这串的字符串对应的内容是&#xff1a; CUP用于在交易响应消息中存放折扣信息&#xff0c;折扣金额100.06 但…

如何为你的公司选择正确的AIGC解决方案?

如何为你的公司选择正确的AIGC解决方案&#xff1f; 摘要引言词汇解释&#xff08;详细版本&#xff09;详细介绍1. 确定需求2. 考虑技术能力3. 评估可行性4. 比较不同供应商 代码快及其注释注意事项知识总结 博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的博客&…

java实现多文件压缩zip

1&#xff0c;需求 需求要求实现多个文件压缩为zip文件 2&#xff0c;代码 package com.example.demo;import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; import…

三极管,MOS管开关应用总结

目录 一、符号二、共同点1.三级管&#xff0c;mos管符号中都有一个PN结&#xff0c;可以根据PN结方向区分型号 二、区别1. 三级管导通条件&#xff1a;PN结正偏2. mos管导通条件&#xff1a; PN结反偏这样就不需要记哪个极的电压&#xff0c; 一、符号 1.三极管 2.mos管 MOS…

微信小程序scroll-view隐藏滚动条参数不生效问题

如题&#xff0c;先来看看问题是怎么出现的。 先看文档如何隐藏滚动条&#xff1a; 再根据文档实现wxml文件&#xff1a; <scroll-view show-scrollbar"{{false}}" enhanced><view wx:for"{{1000}}">11111</view> </scroll-view>…

(笔记一)利用open_cv在图像上进行点标记,文字注记,画圆、多边形、椭圆

&#xff08;1&#xff09;CV2中的绘图函数&#xff1a; cv2.line() 绘制线条cv2.circle() 绘制圆cv2.rectangle() 绘制矩形cv2.ellipse() 绘制椭圆cv2.putText() 添加注记 &#xff08;2&#xff09;注释 img表示需要绘制的图像color表示线条的颜色&#xff0c;采用颜色矩阵…

【跟小嘉学 Rust 编程】二十、进阶扩展

系列文章目录 【跟小嘉学 Rust 编程】一、Rust 编程基础 【跟小嘉学 Rust 编程】二、Rust 包管理工具使用 【跟小嘉学 Rust 编程】三、Rust 的基本程序概念 【跟小嘉学 Rust 编程】四、理解 Rust 的所有权概念 【跟小嘉学 Rust 编程】五、使用结构体关联结构化数据 【跟小嘉学…