AlgoC++第七课:手写Matrix

目录

  • 手写Matrix
    • 前言
    • 1. 明确需求
    • 2. 基本实现
      • 2.1 创建矩阵
      • 2.2 外部访问
      • 2.3 <<操作符重载
    • 3. 矩阵运算
      • 3.1 矩阵标量运算
      • 3.2 通用矩阵乘法
      • 3.3 矩阵求逆
    • 4. 完整示例代码
    • 总结

手写Matrix

前言

手写AI推出的全新面向AI算法的C++课程 Algo C++,链接。记录下个人学习笔记,仅供自己参考。

本次课程主要是手写Matrix代码

课程大纲可看下面的思维导图

在这里插入图片描述

1. 明确需求

我们先来明确下 Matrix 类中应该实现那些功能

1.只能够表示 2 维的矩阵形式,即使是向量也会用 matrix 表示

​ -我们指表达 float 格式的矩阵,不表达其它形式

2.矩阵的乘法,通用矩阵的乘法形式

3.要能够求解逆矩阵

4.可以通过指定行和列进行矩阵的创建

5.可以允许使用 {1, 2, 3} 这种形式进行数据填充的方式创建

6.能够与标量进行常规的加减法

2. 基本实现

2.1 创建矩阵

我们先来实现 Matrix 类中矩阵创建的功能

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;

private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

#endif // MATRIX_HPP

在上面的示例代码中,我们定义了矩阵 Matrix 类的默认构造函数,以及定义了三个私有成员变量 rows_、cols、data_ 分别代表矩阵的行号、列号以及数据。#ifndef #define #endif 语句是为了防止头文件重复包含。值得注意的是,我们习惯在成员变量尾部加上_

我们再来实现使用 {1, 2, 3}进行数据填充的方式来创建矩阵

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        
        // 隐式转换,其实是执行了vector的赋值操作
        data_ = data;
        
        // 1. data的元素为空,说明是不指定数据情况下进行创建
        // 2. data的元素不空,说明是指定数据情况下创建
        	// 1. 元素数量等于rows * cols
        	// 2. 元素数量小于rows * cols
       	
        if(data_.empty()){
            
            // resize表示分配rows * cols个元素,在vector中
            // 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

#endif // MATRIX_HPP

在上面的示例代码中,我们实现了可以通过 Matrix m(3, 1, {1, 3, 2}) 这种方式来创建矩阵。值得注意的是

  • {} 这种类型的数据在 C++ 中叫做 initializer_list,是一种容器。
  • 在 C++ 的容器中,比如 STL 对象中的 vector、list 等,它们分配的空间,如果不进行初始化,则其内部的值是 0
  • 对于空,在 C 语言里面大部分是指 malloc 分配出来的内存没有初始化的情况
    • 此时 malloc 分配内存的值其实是随机的
    • 此时 new 分配的内存的值也是随机的
  • data 参数是以 & 引用的方式传递,防止拷贝;而 const 参数表示传入常引用,在函数内部对其不进行修改
  • 我们习惯在传值的时候,对于非基础类型,一般会传递常引用,使得效率更高,避免拷贝

2.2 外部访问

我们还希望能够外部访问到矩阵的行、列以及矩阵的元素变量

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        
        // 隐式转换,其实是执行了vector的赋值操作
        data_ = data;
        
        // 1. data的元素为空,说明是不指定数据情况下进行创建
        // 2. data的元素不空,说明是指定数据情况下创建
        	// 1. 元素数量等于rows * cols
        	// 2. 元素数量小于rows * cols
       	
        if(data_.empty()){
            
            // resize表示分配rows * cols个元素,在vector中
            // 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    int rows() {return rows_;}
    int cols() {return cols_;}
    std::vector<float>& data(){return data_;}
    
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

#endif // MATRIX_HPP

在上面的示例代码中,我们分别实现了 rows()、cols()、data() 函数用来访问矩阵,如下所示:

#include <iostream>
#include "matrix.hpp"
using namespace std;


int main(){

    Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});

    std::cout << "矩阵行 = " << m.rows() << std::endl;
    std::cout << "矩阵列 = " << m.cols() << std::endl;
    
    std::cout << "矩阵数据如下: " << std::endl;
    for(int i = 0; i < m.data().size(); ++i){
        std::cout << m.data()[i] << std::endl;
    }    

    return 0;
}

在这里插入图片描述

我们希望能够通过 m[i][j] 这种形式去访问到矩阵的元素,需要用到操作符重载

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        
        // 隐式转换,其实是执行了vector的赋值操作
        data_ = data;
        
        // 1. data的元素为空,说明是不指定数据情况下进行创建
        // 2. data的元素不空,说明是指定数据情况下创建
        	// 1. 元素数量等于rows * cols
        	// 2. 元素数量小于rows * cols
       	
        if(data_.empty()){
            
            // resize表示分配rows * cols个元素,在vector中
            // 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    int rows() {return rows_;}
    int cols() {return cols_;}
    std::vector<float>& data(){return data_;}
    
    float& operator()(int ir, int ic){
        
        // data_在内存中是连续的
        // 比如说我们有3x3的矩阵,那么
        // data_就等于 = {1, 2, 3, 4, 5, 6, 7, 8, 9}
        // 它代表的就是:
        /*
        	1 2 3
        	4 5 6
        	7 8 9
        */
        // 如果要访问2行,0列。此时应该是对应的7
        // 把2d的索引,映射到连续1d空间的索引上
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

#endif // MATRIX_HPP

在上述示例代码中,我们对操作符 () 进行了重载,使得矩阵可以通过 m(i,j) 这种方式访问矩阵的元素,值得注意的是:

  • 我们需要将 operator() 看成一个整体

  • 在 C++ 中不允许 [] 提供更多参数,这种操作只能提供一个参数,因此,可以换成 m(i,j),此时是可以允许的

  • 我们返回的是 float & 引用而不是常引用,意味着我们可以直接修改其内部元素

#include <iostream>
#include "matrix.hpp"
using namespace std;


int main(){

    Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});

    std::cout << "矩阵行 = " << m.rows() << std::endl;
    std::cout << "矩阵列 = " << m.cols() << std::endl;
    
    auto data = m.data();
    
    m(1, 1) = 123.5;

    std::cout << "矩阵数据如下: " << std::endl;
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            std::cout << m(ir, ic) << "\t";
        }
        std::cout << "\n";
    }

    return 0;
}

在这里插入图片描述

2.3 <<操作符重载

我们还需要重载操作符 <<,使其能直接打印矩阵即 std::cout << m

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        
        // 隐式转换,其实是执行了vector的赋值操作
        data_ = data;
        
        // 1. data的元素为空,说明是不指定数据情况下进行创建
        // 2. data的元素不空,说明是指定数据情况下创建
        	// 1. 元素数量等于rows * cols
        	// 2. 元素数量小于rows * cols
       	
        if(data_.empty()){
            
            // resize表示分配rows * cols个元素,在vector中
            // 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    // 在 rows() 后面加const,表示这个函数是常量函数
    // 潜台词是:它不会对内部成员做修改,仅仅只做访问查询
    int rows() const{return rows_;}
    int cols() const{return cols_;}
    std::vector<float>& data(){return data_;}
    
    // 这个表示,重载一个可以修改元素的函数
    float& operator(int ir, int ic){
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
    // 这个表示,重载一个只能读取的函数
    const float& operator()(int ir, int ic)const{
        
        // data_在内存中是连续的
        // 比如说我们有3x3的矩阵,那么
        // data_就等于 = {1, 2, 3, 4, 5, 6, 7, 8, 9}
        // 它代表的就是:
        /*
        	1 2 3
        	4 5 6
        	7 8 9
        */
        // 如果要访问2行,0列。此时应该是对应的7
        // 把2d的索引,映射到连续1d空间的索引上
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix( %d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
        printf("\n");
    }
    return out;
}

#endif // MATRIX_HPP

在上述示例代码中我们重载了操作符 <<,值得注意的是:

  • << 等价于 operator<<,因此 std::cout << 123 等价于 std::cout.operator<<(123)

  • << 操作符重载有两种方式

    • 第一种是存在于类内的,例如 std::cout.operator<<(m);由于这个是系统文件,最好不要修改
    • 第二种是存在于全局作用域的,例如 std::ostream& operator<<(std::ostream& out, const Matrix& m)
      • 首先,全局操作符重载,是特定操作符为函数名称
      • 其次,第一个参数,称之为左操作数;第二个参数,称之为右操作数
      • 左操作数 out 对象为引用而非常引用是因为 out 对象存在写操作,势必是修改,因此不能是常量了,必须是非常量引用
      • 右操作数 Matrix 对象为常引用是因为避免拷贝值得发生,我们在这里只需要读取就行了
      • 左操作数 << 右操作数 相当于 operator<<(左操作数,右操作数)
  • 在 C++ 中类的函数,分为常规函数(具有修改和访问权限)和常量函数(只有访问权限,没有修改权限)。右操作数是一个常量对象只能访问常量函数,因此在调用的函数应该是常量函数,所以在 rows()、cols()、operator() 函数后面都要加上 const 关键字,且 operator () 返回的应该是一个常量引用

  • 我们需要对 operator() 重载既能实现常量函数,又能完成修改操作,可以将 operator() 写两遍,一个表示重载一个可以修改元素的函数,一个表示重载一个只能读取的函数

#include <iostream>
#include "matrix.hpp"
using namespace std;


int main(){

    Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});

    cout << m ;

    return 0;
}

在这里插入图片描述

最后我们来区分下

int rows() const{return rows_};
const int rows(){return rows_};

第一个函数 int rows() const{return rows_} 是一个 const 成员函数,其中的 const 关键字表示该成员函数不会修改对象的成员变量,即保证了函数内部不会修改 rows_ 的值,同时该成员函数可以被 const 对象和非 const 对象调用

而第二个函数 const int rows(){return rows_} 是一个非 const 成员函数,其中的 const 关键字表示函数返回的是 const 类型的值,但并没有对函数本身做出限制。该函数可以被非 const 对象调用,但不能被 const 对象调用

因此,第一个函数可以用于保证对象的成员变量不被修改,并且可以被 const 和非 const 对象调用,而第二个函数只能被非 const 对象调用,如果被 const 对象调用则会产生编译错误

3. 矩阵运算

3.1 矩阵标量运算

实现矩阵与标量的 + - * / 四则运算

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        data_ = data;

        if(data_.empty()){
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    // ========== + - * / ==========
    Matrix element_wise(const std::function<float(float)>& func) const{
        Matrix output = *this;	// 复制一份
        for(int i = 0; i < output.data_.size(); ++i){
            output.data_[i] = func(output.data_[i]);
        }
        return output;
    }
    
    Matrix operator*(float value) const{
        // lambda函数 C++11特性
        return element_wise([&](float x){return x * value;});
    }
    
    Matrix operator+(float value) const{
        return element_wise([&](float x){return x + value});
    }
    
    Matrix operator-(float value) const{
        return element_wise([&](float x){return x - value;});
    }
    
    Matrix operator/(float value) const{
        return element_wise([&](float x){return x / value;});
    }
    
    int rows() const{return rows_;}
    int cols() const{return cols_;}
    std::vector<float>& data(){return data_;}
    
    float& operator(int ir, int ic){
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
    const float& operator()(int ir, int ic)const{
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix( %d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
    }
    return out;
}

Matrix operator*(float value, const Matrix& m){
    return m * value;
}

#endif // MATRIX_HPP

上述示例代码中实现了矩阵 Matrix 类中加、减、乘、除四个运算符的重载实现,值得注意的是:

  • element_wise() 函数是对矩阵中的每个元素都应用一个函数,得到一个新的矩阵并返回。这样可以将四则运算统一起来,四则运算的函数采用匿名函数实现可以简化代码

  • 四个运算符重载函数 operator*()、operator+()、operator-()、operator/(),都是调用 element_wise() 函数,传入一个对应的 lambda 表达式,对矩阵中的每个元素都进行相应的四则运算,得到一个新的矩阵并返回。

  • lambda 表达式定义了一个匿名函数,并且可以捕获一定范围内的变量。其语法形式如下:

    [capture](params) opt -> ret {body;}
    
    • 其中 capture 是捕获列表,& 表示捕获全局引用,= 表示捕获全局值,&value 表示捕获特定的值/引用
    • params 是参数列表,和普通函数的参数一样
    • opt 是函数选项,不需要可以省略
    • ret 是返回值类型,可以省略,编译器会自动推导
    • body 是函数体
    • 关于 lambda 表达式的更多细节可参考 https://subingwen.cn/cpp/lambda/
  • 在矩阵与标量乘操作中可能存在 m*2 和 2*m 两种情形,我们分别采用了类内重载和全局重载两种方式实现

#include <iostream>
#include "matrix.hpp"
using namespace std;


int main(){

    Matrix m1(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});

    std::cout << m1 * 2 + 1 - 10;
    std::cout << 2 * m2;

    return 0;
}

在这里插入图片描述

3.2 通用矩阵乘法

通用矩阵乘法 gemm() 的实现

我们需要先编译 OpenBLAS 库,具体可参考 Ubuntu20.04软件安装大全

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        data_ = data;

        if(data_.empty()){
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    // ========== + - * / ==========
    Matrix element_wise(const std::function<float(float)>& func) const{
        Matrix output = *this;	// 复制一份
        for(int i = 0; i < output.data_.size(); ++i){
            output.data_[i] = func(output.data_[i]);
        }
        return output;
    }
    
    Matrix operator*(float value) const{
        // lambda函数 C++11特性
        return element_wise([&](float x){return x * value;});
    }
    
    Matrix operator+(float value) const{
        return element_wise([&](float x){return x + value});
    }
    
    Matrix operator-(float value) const{
        return element_wise([&](float x){return x - value;});
    }
    
    Matrix operator/(float value) const{
        return element_wise([&](float x){return x / value;});
    }
    
    int rows() const{return rows_;}
    int cols() const{return cols_;}
    const std::vector<float>& data(){return data_;}
    const float* ptr()const{return data_.data();}
    float* ptr(return data_.data();)
    
    float& operator(int ir, int ic){
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
    const float& operator()(int ir, int ic)const{
        int index = ir * cols_ + ic;
        return data_[index];
    }
     
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix( %d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
    }
    return out;
}

Matrix operator*(float value, const Matrix& m){
    return m * value;
}

// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta){
    // AB = C
    // A^T B = C
    // A B^T = C
    // AB * scale + bias
    // C = ta(A) * tb(B) * alpha + beta
    // Cmxn = ta(A)mxk ta(B)kxn
    
    int ta_rows = ta ? a.cols() : a.rows();
    int ta_cols = ta ? a.rows() : a.cols();
    int tb_rows = tb ? b.cols() : b.rows();
    int tb_cols = tb ? b.rows() : b.cols();
    
    Matrix c(ta_rows, tb_cols);
    int m = ta_rows;
    int n = tb_cols;
    int k = ta_cols;
    
    // 为了解决比如步长不等于列数的情况
    int lda = a.cols();		// A矩阵的每一行所需要的步长
    int ldb = b.cols();
    int ldc = c.cols();
    
    cblas_sgemm(
    	CblasRowMajor,
        ta ? CblasTrans : CblasNoTrans,
        tb ? CblasTrans : CblasNoTrans,
        m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc
    );
    return c;
}

#endif // MATRIX_HPP

cblas_sgemm 函数是 BLAS 库中的矩阵乘法函数,其参数如下:(from chatGPT)

void cblas_sgemm(
    const enum CBLAS_ORDER Order,
    const enum CBLAS_TRANSPOSE TransA,
    const enum CBLAS_TRANSPOSE TransB,
    const int M,
    const int N,
    const int K,
    const float alpha,
    const float *A,
    const int lda,
    const float *B,
    const int ldb,
    const float beta,
    float *C,
    const int ldc
);

其中各参数的含义如下:

  • Order:矩阵的存储顺序。CBLAS_ORDER 枚举类型,取值可以是 CblasRowMajor 或 CblasColMajor,分别表示按行存储和按列存储。
  • TransA:A 矩阵的转置情况。CBLAS_TRANSPOSE 枚举类型,取值可以是 CblasNoTrans(不转置)、CblasTrans(转置)或 CblasConjTrans(共轭转置)。
  • TransB:B 矩阵的转置情况。CBLAS_TRANSPOSE 枚举类型,取值可以是 CblasNoTrans(不转置)、CblasTrans(转置)或 CblasConjTrans(共轭转置)。
  • M:C 矩阵的行数。
  • N:C 矩阵的列数。
  • K:A 和 B 矩阵中共享的维度,即 A 矩阵的列数或 B 矩阵的行数。
  • alpha:乘法操作的系数,通常取值为1。
  • A:存储 A 矩阵的数组。
  • lda:A 矩阵每行的元素个数,通常为 A 矩阵的列数。
  • B:存储 B 矩阵的数组。
  • ldb:B 矩阵每行的元素个数,通常为 B 矩阵的列数。
  • beta:加法操作的系数,通常取值为 0。
  • C:存储结果 C 矩阵的数组。
  • ldc:C 矩阵每行的元素个数,通常为 C 矩阵的列数。

cblas_sgemm 函数会对 A、B、C 矩阵进行矩阵乘法运算,并将结果存储在 C 矩阵中。其中 A 矩阵的大小为 MxK,B 矩阵的大小为 KxN,C 矩阵的大小为 MxN

注意实现和声明需要分离 class 不存在这种情况,只有函数存在。当函数的声明和实现都放在头文件时,可能会出现重复定义的问题

#include <iostream>
#include "matrix.hpp"
using namespace std;

Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta);


int main(){

    Matrix m1(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});

    std::cout << gemm(m1, false, m2, false, 1.0f, 0.0f);
    std::cout << gemm(m1, true, m2, false, 1.0f, 0.0f);

    return 0;
}

在这里插入图片描述

3.3 矩阵求逆

求逆矩阵

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        data_ = data;

        if(data_.empty()){
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    // ========== + - * / ==========
    Matrix element_wise(const std::function<float(float)>& func) const{
        Matrix output = *this;	// 复制一份
        for(int i = 0; i < output.data_.size(); ++i){
            output.data_[i] = func(output.data_[i]);
        }
        return output;
    }
    
    Matrix operator*(float value) const{
        // lambda函数 C++11特性
        return element_wise([&](float x){return x * value;});
    }
    
    Matrix operator+(float value) const{
        return element_wise([&](float x){return x + value});
    }
    
    Matrix operator-(float value) const{
        return element_wise([&](float x){return x - value;});
    }
    
    Matrix operator/(float value) const{
        return element_wise([&](float x){return x / value;});
    }
    
    int rows() const{return rows_;}
    int cols() const{return cols_;}
    const std::vector<float>& data(){return data_;}
    const float* ptr()const{return data_.data();}
    float* ptr(return data_.data();)
    
    float& operator(int ir, int ic){
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
    const float& operator()(int ir, int ic)const{
        int index = ir * cols_ + ic;
        return data_[index];
    }
     
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix( %d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
    }
    return out;
}

Matrix operator*(float value, const Matrix& m){
    return m * value;
}

// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta) const{
    // AB = C
    // A^T B = C
    // A B^T = C
    // AB * scale + bias
    // C = ta(A) + tb(B) * alpha + beta
    // Cmxn = ta(A)mxk ta(B)kxn
    
    int ta_rows = ta ? a.cols() : a.rows();
    int ta_cols = ta ? a.rows() : a.cols();
    int tb_rows = tb ? b.cols() : b.rows();
    int tb_cols = tb ? b.rows() : b.cols();
    
    Matrix c(ta_rows, tb_cols);
    int m = ta_rows;
    int n = tb_cols;
    int k = ta_cols;
    
    // 为了解决比如步长不等于列数的情况
    int lda = a.cols();		// A矩阵的每一行所需要的步长
    int ldb = b.cols();
    int ldc = c.cols();
    
    cblas_sgemm(
    	CblasRowMajor,
        ta ? CblasTrans : CblasNoTrans,
        tb ? CblasTrans : CblasNoTrans,
        m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc
    );
    return c;
}

// ========== inv ========== 
Matrix inverse(const Matrix& a){
    
    if(a.rows() != a.cols()){
        printf("Invalid to compute inverse matrix by %d x %d\n", a.rows(), a.cols());
        return Matrix();
    }
    
    Matrix output = a;
    int n = a.rows();
    int *ipiv = new int[n];
    
    /* LU分解 */
    int code = LAPACKE_sgetrf(LAPACK_COL_MAJOR, n, n, output.ptr(), n, ipiv);
    if(code == 0){
        /* 使用LU分解求解通用逆矩阵 */
        code = LAPACKE_sgetri(LAPACK_COL_MAJOR, n, output.ptr(), n, ipiv);
    }

    if(code != 0){
        printf("LAPACKE inverse matrix failed, code = %d\n", code);
        return Matrix();
    }

    delete[] ipiv;
    return output;
}
#endif // MATRIX_HPP

LAPACK_sgetrf函数和LAPACK_sgetri 函数的使用可以实现对逆矩阵的求解。具体来说,可以先使用LAPACK_sgetrf函数进行 LU 分解,然后再使用LAPACK_sgetri` 函数对 LU 分解后的矩阵进行求逆。(from chatGPT)

LAPACK_sgetrf 函数的参数:

  • order:表示矩阵数据的存储顺序,可以是LAPACK_ROW_MAJOR或者LAPACK_COL_MAJOR。
  • m:表示矩阵 A 的行数。
  • n:表示矩阵 A 的列数。
  • A:指向矩阵 A 的指针。
  • lda:表示矩阵 A 的行宽。
  • ipiv:指向一个长度为 min(m,n) 的整数数组,存储 LU 分解的行置换信息。

LAPACK_sgetrf 函数的返回值:

  • 如果返回值等于零,则表示操作成功完成。
  • 如果返回值小于零,则表示参数错误或某个 U(i,i) 为零,无法进行 LU 分解。
  • 如果返回值大于零,则表示 A 的前返回值列的 LU 分解出现奇异矩阵,无法求解。

LAPACK_sgetri 函数的参数:

  • order:表示矩阵数据的存储顺序,可以是LAPACK_ROW_MAJOR或者LAPACK_COL_MAJOR。
  • n:表示矩阵A的行数和列数。
  • A:指向矩阵A的指针。
  • lda:表示矩阵A的行宽。
  • ipiv:指向一个长度为n的整数数组,存储LU分解的行置换信息。

LAPACK_sgetri 函数的返回值:

  • 如果返回值等于零,则表示操作成功完成。
  • 如果返回值小于零,则表示参数错误。
  • 如果返回值大于零,则表示某个 A(i,i) 为零,无法进行求解。
#include <iostream>
#include "matrix.hpp"
using namespace std;

Matrix inverse(const Matrix& a);


int main(){

    Matrix m1(3, 3, {3, 2, 3, 4, 5, 6, 7, 8, 9});
    Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});

    std::cout << inverse(m1);
    std::cout << inverse(m2);

    return 0;
}

在这里插入图片描述

4. 完整示例代码

我们将声明和实现分离,共三个文件为 matrix.hpp、matrix.cpp、main.cpp

matrix.hpp

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <iostream>
#include <ostream>
#include <initializer_list>
#include <functional>

class Matrix{
public:
    Matrix();
    Matrix(int rows, int cols, const std::initializer_list<float>& data={});

    Matrix element_wise(const std::function<float(float)>& func) const;
    Matrix operator*(float value) const;
    Matrix operator-(float value) const;
    Matrix operator+(float value) const;
    Matrix operator/(float value) const;

    const int rows() const{return rows_;};
    const int cols() const{return cols_;};
    const std::vector<float>& data() const{return data_;};
    const float* ptr()const{return data_.data();};
    float* ptr(){return data_.data();};

    float& operator()(int ir, int ic);

    const float& operator()(int ir, int ic) const;

    // Matrix gemm(const Matrix& other, bool ta, bool tb, float alpha=1.0f, float beta=0.0f){
    //     return ::gemm(*this, ta, other, tb, alpha, beta);
    // }

private:
    int rows_;
    int cols_;
    std::vector<float> data_;
};

std::ostream& operator<<(std::ostream& out, const Matrix& m);

Matrix operator*(float value, const Matrix& m);

#endif // MATRIX_HPP

matrix.cpp

#include "cblas.h"
#include "lapacke.h"
#include "matrix.hpp"

Matrix::Matrix(){}
Matrix::Matrix(int rows, int cols, const std::initializer_list<float>& data){
    this->rows_ = rows;
    this->cols_ = cols;
    this->data_ = data;

    if(this->data_.size() < rows * cols)
        this->data_.resize(rows * cols);
};

Matrix Matrix::element_wise(const std::function<float(float)>& func) const{
    Matrix output = *this;
    for(int i = 0; i < output.data_.size(); ++i){
        output.data_[i] = func(output.data_[i]);
    }
    return output;
}

Matrix Matrix::operator*(float value) const{
    return element_wise([&](float x){return x * value;});
}
Matrix Matrix::operator-(float value) const{
    return element_wise([&](float x){return x - value;});
}
Matrix Matrix::operator+(float value) const{
    return element_wise([&](float x){return x + value;});
}
Matrix Matrix::operator/(float value) const{
    return element_wise([&](float x){return x / value;});
}

float& Matrix::operator()(int ir, int ic){
    int index = ir * cols_ + ic;
    return data_[index];
}

const float& Matrix::operator()(int ir, int ic) const{
    int index = ir * cols_ + ic;
    return data_[index];
}


std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix (%d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
        printf("\n");
    }
    return out;
}

Matrix operator*(float value, const Matrix& m){
    return m * value;
}


// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta){
    // C = ta(A) * ta(B) * alpha + beta

    int ta_rows = ta ? a.cols() : a.rows();
    int ta_cols = ta ? a.rows() : a.cols();
    int tb_rows = tb ? b.cols() : b.rows();
    int tb_cols = tb ? b.rows() : b.cols();

    Matrix c(ta_rows, tb_cols);
    int m = ta_rows;
    int n = tb_cols;
    int k = ta_cols;
    
    // 为了解决比如步长不等于列数的情况
    int lda = a.cols();     // A矩阵的每一行所需要的步长
    int ldb = b.cols();
    int ldc = c.cols();

    cblas_sgemm(
        CblasRowMajor,
        ta ? CblasTrans : CblasNoTrans,
        tb ? CblasTrans : CblasNoTrans,
        m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc 
    );
    return c;
}

// ========== inv ========== 
Matrix inverse(const Matrix& a){
    
    if(a.rows() != a.cols()){
        printf("Invalid to compute inverse matrix by %d x %d\n", a.rows(), a.cols());
        return Matrix();
    }
    
    Matrix output = a;
    int n = a.rows();
    int *ipiv = new int[n];
    
    /* LU分解 */
    int code = LAPACKE_sgetrf(LAPACK_COL_MAJOR, n, n, output.ptr(), n, ipiv);
    if(code == 0){
        /* 使用LU分解求解通用逆矩阵 */
        code = LAPACKE_sgetri(LAPACK_COL_MAJOR, n, output.ptr(), n, ipiv);
    }

    if(code != 0){
        printf("LAPACKE inverse matrix failed, code = %d\n", code);
        return Matrix();
    }

    delete[] ipiv;
    return output;
}

main.cpp

#include <iostream>
#include "matrix.hpp"
using namespace std;

Matrix inverse(const Matrix& a);
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta);


int main(){

    Matrix m1(3, 3, {3, 2, 3, 4, 5, 6, 7, 8, 9});
    Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});

    std::cout << gemm(m1, false, m2, false, 1.0f, 0.0f);
    std::cout << inverse(m2);

    std::cout << m1 * 2 + 5 - 1;

    return 0;
}

总结

本次课程跟随杜老师手写了 Matrix 类的具体实现,学习到了很多关于 C++ 语法、习惯的知识,同时也碰到了一些问题,学习的过程就是不断解决问题的过程😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/15672.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Mysql安装

目录&#xff1a; 1.Mysql安装 2. 安装MySQL出现1045错误 3.更改数据库编码格式 1.mysql图文安装教程(详细说明) &#xff11;、打开下载的mysql安装文件mysql-5.5.27-win32.zip&#xff0c;双击解压缩&#xff0c;运行“setup.exe” &#xff12;、选择安装类型&#xff0c;有…

【计算机网络】1.1——因特网概述

因特网概述&#xff08;了解&#xff09; 网络、互联网和因特网 网络由若干结点和连接这些结点的链路组成 多个网络还可以通过路由器互连起来&#xff0c;互联网是"网络的网络“ internet 和 Internet internet&#xff08;互联网或互连网&#xff09;是通用名词 泛指…

最近部门新的00后真是卷王,工作没1年,入职18K

都说00后躺平了&#xff0c;但是有一说一&#xff0c;该卷的还是卷。 这不&#xff0c;前段时间我们公司来了个00后&#xff0c;工作都没1年&#xff0c;到我们公司起薪18K&#xff0c;都快接近我了。后来才知道人家是个卷王&#xff0c;从早干到晚就差搬张床到工位睡觉了。 …

生成式模型与辨别式模型

分类模型可以分为两大类:生成式模型与辨别式模型。本文解释了这两种模型类型之间的区别&#xff0c;并讨论了每种方法的优缺点。 辨别式模型 辨别式模型是一种能够学习输入数据和输出标签之间关系的模型&#xff0c;它通过学习输入数据的特征来预测输出标签。在分类问题中&…

Vue+Echart实现利用率表盘效果【组件已封装,可直接使用】

效果演示 当利用超过70%&#xff08;可以自行设置&#xff09;&#xff0c;表盘变红 组件 里面对应两个图片资源&#xff0c;panelBackground_red.png 和 panelBackground_green.png&#xff0c;请前往百度网盘进行下载。如果喜欢其他颜色&#xff0c;可以使用.psd来修改导出…

基于html+css的图展示43

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

vCenter(PSC)正常更改或重置administrator@vsphere.local用户的密码方法

1. 正常更改administratorvsphere.local用户密码 在vCenter界面中选择“菜单”下的“系统管理”&#xff0c;如下图所示&#xff1a; 然后在Single Sign On下的用户和组中&#xff0c;选择“vsphere.local”域&#xff0c;再对Administrator用户进行编辑&#xff0c;即可进行…

Unsupervised Learning of Depth and Ego-Motion from Video 论文精读

视频中深度和自我运动的无监督学习 摘要 我们提出了一个无监督学习框架&#xff0c;用于从非结构化视频序列中进行单眼深度和相机运动估计。与其他工作[10&#xff0c;14&#xff0c;16]一样&#xff0c;我们使用端到端的学习方法&#xff0c;将视图合成作为监督信号。与之前…

[openwrt] valgrind定位内存泄漏

目录 要求 valgrind 简介 工具介绍 linux程序的内存布局 内存检查的原理 valgrind的使用 使用举例 内存泄漏 内存越界 内存覆盖 Linux分配虚拟内存&#xff08;申请内存&#xff09;的两种方式 brk和mmap 要求 被调试程序带有-g参数编译&#xff0c;携带debug参数…

Ubuntu22.04部署Pytorch2.0深度学习环境

文章目录 安装Anaconda创建新环境安装Pytorch2.0安装VS CodeUbuntu下实时查看GPU状态的方法小实验&#xff1a;Ubuntu、Windows10下GPU训练速度对比 Ubuntu安装完显卡驱动、CUDA和cudnn后&#xff0c;下面部署深度学习环境。 &#xff08;安装Ubuntu系统、显卡驱动、CUDA和cudn…

10 dubbo源码学习_线程池

1. 线程模型&线程池介绍1.1 线程池1.2 线程模型 2. 线程池源码分析2.1 FixedThreadPool2.2 CachedThreadPool2.3 LimitedThreadPool 3. 线程模型源码3.1 AllDispatcher3.2 DirectDispatcher3.3 MessageOnlyDispatcher3.4 ExecutionDispatcher3.5 ConnectionOrderedDispatch…

这些使用工具大推荐,现在知道不晚

1.Snip Snip是一款截图软件&#xff0c;它突出的优点就是可以制作滚动截图。 例如&#xff1a;对整个网页进行截图&#xff0c;使用Snip即可轻松获取&#xff0c;无需处理水印。 2.Sleep Cycle 快节奏、高压力的生活导致我们越来越晚睡觉&#xff0c;睡眠质量越来越差。 想提…

jsp家庭农场投入品信息管理系统Myeclipse开发mysql数据库web结构jsp编程计算机网页项目

一、源码特点 jsp家庭农场投入品信息管理系统是一套完善的java web信息管理系统 serlvet dao bean 开发&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发…

SQL学习日记

目录 一、数据定义&#xff08;create&#xff0c;alter&#xff0c;drop&#xff09; 1.1数据类型 补充注释 1.2定义基本表&#xff08;create&#xff0c;alter&#xff0c;drop&#xff09; 1.3约束 1.3.1主键约束 1.3.2外码约束 ​编辑 补充CASCADE 关键字 1.3.3…

深度学习 - 45.MMOE Gate 简单实现 By Keras

目录 一.引言 二.MMoE 模型分析 三.MMoE 逻辑实现 • Input • Expert Output • Gate Output • Weighted Sum • Sigmoid Output • 完整代码 四.总结 一.引言 上一篇文章介绍了 MMoE 借鉴 MoE 的思路&#xff0c;为每一类输出构建一个 Gate 并最终加权多个 Exper…

05 KVM虚拟化Linux Bridge环境部署

文章目录 05 KVM虚拟化Linux Bridge环境部署5.1 安装Linux Bridge5.1.1 安装bridge-utils软件包5.1.2 确认安装是否成功 5.2 配置Linux Bridge5.2.1 创建网桥br05.2.2 将物理网卡ens33绑定到Linux Bridge5.2.3 配置ens33的ip5.2.4 为Linux Bridge网桥br0分配ip5.2.4.1 DHCP设置…

sin(x) + cos(x) 的极大值和极小值

sinx cosx 的极大值和极小值 理论推导图像 今天遇到了一个问题&#xff0c;就是如何求解 sin ⁡ x cos ⁡ x \sin{x} \cos{x} sinxcosx 的极大值和极小值。这里特来记录一下。 理论推导 首先&#xff0c;我们假设&#xff1a; sin ⁡ x cos ⁡ x R sin ⁡ ( x α ) (…

Vue(Vue脚手架)

一、使用Vue脚手架&#xff08;Vue Cli&#xff09; Vue官方提供脚手架平台选择最新版本&#xff1a; 可以相加兼容的标准化开发工具&#xff08;开发平台&#xff09; 禁止&#xff1a;最新的开发技术版本和比较旧版本的开发平台 Vue CLI&#x1f6e0;️ Vue.js 开发的标准工…

所有知识付费都可以用 ChatGPT 再割一次?

伴随春天一起到来的&#xff0c;还有如雨后春笋般冒出的 ChatGPT / AI 相关的付费社群、课程训练营、知识星球等。 ChatGPT 吹来的这股 AI 热潮&#xff0c;这几个月想必大家多多少少都能感受到。 ▲ 图片来源&#xff1a;网络 这两张图是最近在圈子里看到的。 一张是国内各…

第五章——动态规划3

蒙德里安的梦想 我们在黑框内横着放红框&#xff0c;我们发现当横向小方格摆好之后&#xff0c;纵向小方格只能一次纵向摆好&#xff0c;即纵向小方格只有一种方案&#xff0c;即整个摆放小方格的方案数等于横着摆放小方格的方案数 f[i,j]表示的是现在要在第i列摆&#xff0c;j…