寒武纪显卡实现高维向量的softmax并行优化

关于寒武纪编程可以参考本人之前的文章添加链接描述,添加链接描述,添加链接描述

高维向量softmax的基础编程

高维向量的softmax实现更加复杂,回忆之前在英伟达平台上实现高维向量的softmax函数,比如说我们以形状为[1,2,3,4,5,6]的6维向量举例,变换维度假设axis=2,之前英伟达平台的实现,我们计算出变换维度的长度dimsize=3,其他维度的乘积othersize=1×2×4×5×6 = 240,步长stride= 1×6×5×4 = 120,使用othersize=240个线程块,其中每个线程块处理对应一份数据,计算出int tid =blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) × dimsize;全局索引为tid + threadIdx.x × stride,类似地,我们也按照这个思路来实现寒武纪显卡上的高维向量softmax:

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 4;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;//__bang_reduce_sum每次从src取128字节数据相加,对应32个float元素,并且0-31的结果保存在索引0,32-63的结果保存在索引1

__nram__ float src1[maxNum];//每次搬运maxNum数据到NRAM
__nram__ float destSum[maxNum];//后面数值求和
__nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]
__nram__ float srcMax[2];

__mlu_entry__ void softmaxKernel(float* dst, float* source1, int othersize, int dimsize, int stride) {
  __nram__ float destOldMax;
  __nram__ float destNewMax;
  int liu = false;
  if(liu){
    for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim){
      destOldMax = -INFINITY;
      destNewMax = -INFINITY;
      float sum_s = 1.0;
      int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
      for(int i = 0; i < dimsize; i++){
        __memcpy(src1, source1 + tid + i * stride, sizeof(float), GDRAM2NRAM);
        if(destNewMax < src1[0]){
          destNewMax = src1[0];
        }
        if(i > 0){
          sum_s = sum_s * exp(destOldMax - destNewMax) + exp(src1[0] - destNewMax);
        }
        destOldMax = destNewMax;
      }
      float globalSumInv = 1.0/sum_s;;
      for(int i = 0; i < dimsize; i++){
        __memcpy(src1, source1 + tid + i * stride, sizeof(float), GDRAM2NRAM);
        src1[0] = exp(src1[0] - destNewMax) * globalSumInv;
        __memcpy(dst + tid + i * stride, src1, sizeof(float), NRAM2GDRAM);
      }
    }
  }
  else{
    for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim){
      destOldMax = -INFINITY;
      destNewMax = -INFINITY;
      float sum_s = 1.0;
      int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
      for(int i = 0; i < dimsize + 1; i++){
        if(i < dimsize){
          __memcpy_async(src1 + i%2, source1 + tid + i * stride, sizeof(float), GDRAM2NRAM);
        }
        if(i > 0){
          if(destNewMax < src1[(i - 1)%2]){
            destNewMax = src1[(i - 1)%2];
          }
          if(i > 1){
            sum_s = sum_s * exp(destOldMax - destNewMax) + exp(src1[(i - 1)%2] - destNewMax);
          }
          destOldMax = destNewMax;
        }
        __sync_all_ipu();
      }
      float globalSumInv = 1.0/sum_s;;
      for(int i = 0; i < dimsize + 2; i++){
        if(i < dimsize){
          __memcpy(src1 + i%3, source1 + tid + i * stride, sizeof(float), GDRAM2NRAM);
        }
        if(i > 0 && i < dimsize + 1){
          src1[(i - 1)%3] = exp(src1[(i - 1)%3] - destNewMax) * globalSumInv;
        }
        if(i > 1){
          __memcpy(dst + tid + (i - 2) * stride, src1 + (i - 2)%3, sizeof(float), NRAM2GDRAM);
        }
        __sync_all_ipu();
      }
    }
  }
  
  
  
}


int main(void)
{
  int num = 32 * 16 * 64 * 128;//shape = {32, 16, 64, 128},axis = 2
  int stride = 128;
  int dimsize = 64;
  int othersize = 32 * 16 * 128;
  /***
  int num = 24;//shape = {2,3,2,2}, axis = 1
  int stride = 4;
  int dimsize = 3;
  int othersize = 8;
  ***/
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {4, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_dst = (float*)malloc(num * sizeof(float));
  float* host_src1 = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    host_src1[i] = i%4;
    //host_src1[i] = i;
  }

  float* mlu_dst;
  float* mlu_src1;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_dst, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src1, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src1, host_src1, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernel<<<dim, ktype, queue>>>(mlu_dst, mlu_src1, othersize, dimsize, stride);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_dst, mlu_dst, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_dst[i], host_src1[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_dst);
  cnrtFree(mlu_src1);
  
  
  free(host_dst);
  free(host_src1);
  

  return 0;
}
                           

我们利用taskId来处理othersize,但是考虑到taskDim往往是2或者4的倍数,而othersize不一定满足这个条件,因此我们使用for循环来解决,参考for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim)
进入上述for循环以后,我们尝试来处理dimsize,由于寒武纪的函数基本上支持向量操作,无法针对具体某个元素来处理,为此我们仍然把dimsize这份数据按照maxNum长度分成多个小单元,如果不能整除后面特殊处理,特殊处理的方式和上面一维向量一模一样。在代码24行——25行,这里使用两层for循环来加载数据,高维数组导致每次处理的数据不连续,间隔stride,为此必须要不断遍历数组把结果集中到src1数组上处理,后续的处理类似,这里不做赘述。

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 4;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;//__bang_reduce_sum每次从src取128字节数据相加,对应32个float元素,并且0-31的结果保存在索引0,32-63的结果保存在索引1

__nram__ float src1[maxNum];//每次搬运maxNum数据到NRAM
__nram__ float destSum[maxNum];//后面数值求和
__nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]
__nram__ float srcMax[2];

__mlu_entry__ void softmaxKernel(float* dst, float* source1, int othersize, int dimsize, int stride) {
  int remain = dimsize%maxNum;
  int repeat = (dimsize - remain)/maxNum;
  __nram__ float destOldMax;
  __nram__ float destNewMax;
  //下面利用taskId来处理其他维度
  for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim){
    destOldMax = -INFINITY;
    destNewMax = -INFINITY;
    __bang_write_zero(destSum, maxNum);
    int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
    for(int i = 0; i < repeat; i++){
      for(int j = 0; j < maxNum; j++){//从source1间隔stride读取数据
        __memcpy(src1 + j, source1 + tid + (i * maxNum + j) * stride, sizeof(float), GDRAM2NRAM);
      }
      __bang_argmax(srcMax, src1, maxNum);
      if(destNewMax < srcMax[0]){
        destNewMax = srcMax[0];//更新最大值
      }
      __bang_sub_scalar(src1, src1, destNewMax, maxNum);//src1 = src1 - 最大值
      __bang_active_exp_less_0(src1, src1, maxNum);//src1 = exp(src1 - 最大值)
      if(i > 0){
        __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
      }
      __bang_add(destSum, destSum, src1, maxNum);//destSum = destSum + exp(src1 - destNewMax)
      destOldMax = destNewMax;
    }
    //-------------------------------------
    if(remain){
      __bang_write_value(src1, maxNum, -INFINITY);//多余部分必须设置负无穷
      for(int j = 0; j < remain; j++){
        __memcpy(src1 + j, source1 + tid + (repeat * maxNum + j) * stride, sizeof(float), GDRAM2NRAM);
      }
      __bang_argmax(srcMax, src1, maxNum);
      if(destNewMax < srcMax[0]){
        destNewMax = srcMax[0];
      }
      __bang_write_value(src1, maxNum, destNewMax);//必须重新初始化为destNewMax
      for(int j = 0; j < remain; j++){
        __memcpy(src1 + j, source1 + tid + (repeat * maxNum + j) * stride, sizeof(float), GDRAM2NRAM);
      }
      __bang_sub_scalar(src1, src1, destNewMax, maxNum);//后面maxNum-remain部分为0
      __bang_active_exp_less_0(src1, src1, maxNum);//相当于多加了maxNum-remain
      if(repeat > 0){
        __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
      }
      __bang_add(destSum, destSum, src1, maxNum);
      destOldMax = destNewMax;
    }
    
    //--------------------------------
    __bang_write_zero(destSumFinal, warpSize);
    int segNum = maxNum / warpSize;
    for(int strip = segNum/2; strip > 0; strip = strip / 2){
      for(int i = 0; i < strip ; i++){
        __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
      } 
    }
    __bang_reduce_sum(destSumFinal, destSum, warpSize);
    
    if(remain){
      destSumFinal[0] = destSumFinal[0] - (maxNum - remain);
    }
    //__bang_printf("--max:%.3e,sum:%.6e,:%d\n",destNewMax,destSumFinal[0], maxNum - remain);
    //------------------------------------至此全局最大值为destNewMax,全局数值和为destSumFinal[0]
    float globalSumInv = 1.0/destSumFinal[0];
    if(remain){
      
      __bang_mul_scalar(src1, src1, globalSumInv, maxNum);//倒数和另一个向量逐元素相乘得到除法结果
      for(int j = 0; j < remain; j++){
        __memcpy(dst + tid + (repeat * maxNum + j) * stride, src1 + j, sizeof(float), NRAM2GDRAM);
      }
    }
    for(int i = 0; i < repeat; i++){
      for(int j = 0; j < maxNum; j++){
        __memcpy(src1 + j, source1 + tid + (i * maxNum + j) * stride, sizeof(float), GDRAM2NRAM);
      }
      __bang_sub_scalar(src1, src1, destNewMax, maxNum); 
      __bang_active_exp_less_0(src1, src1, maxNum);
      __bang_mul_scalar(src1, src1, globalSumInv, maxNum);//倒数和另一个向量逐元素相乘得到除法结果
      for(int j = 0; j < maxNum; j++){
        __memcpy(dst + tid + (i * maxNum + j) * stride, src1 + j, sizeof(float), NRAM2GDRAM);
      }
    }
    
    
  }
  
  
}


int main(void)
{
  int num = 32 * 16 * 64 * 128;//shape = {32, 16, 64, 128},axis = 2
  int stride = 128;
  int dimsize = 64;
  int othersize = 32 * 16 * 128;
  /***
  int num = 24;//shape = {2,3,2,2}, axis = 1
  int stride = 4;
  int dimsize = 3;
  int othersize = 8;
  ***/
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {4, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_dst = (float*)malloc(num * sizeof(float));
  float* host_src1 = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    host_src1[i] = i%4;
    //host_src1[i] = i;
  }

  float* mlu_dst;
  float* mlu_src1;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_dst, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src1, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src1, host_src1, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernel<<<dim, ktype, queue>>>(mlu_dst, mlu_src1, othersize, dimsize, stride);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_dst, mlu_dst, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_dst[i], host_src1[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_dst);
  cnrtFree(mlu_src1);
  
  
  free(host_dst);
  free(host_src1);
  

  return 0;
}
                           

高维向量softmax的合并访存加速

上面提到的就是最简单最容易想到的编程手段了,上面的方案有一个问题,即数组元素的访问读取都是跳跃的,因此时间特别长,根本无法用于处理大规模数组,为了加速,下面我们尝试在原始方案上做优化。为了方便描述,我们以形状为[32,16,64,128]这样一个四维向量举例,其中softmax的操作维度假设axis=2,那么就可以计算出stride=128,dimsize=64,othersize=32×16×128。上面算法的特点是,利用不同taskId处理othersize得到对应的otherIdx,然后针对dimsize做循环,得到全局的index为otherIdx + i×stride,最终不断跳跃stride来获取数组对应元素,把结果集中到一个长度为maxNum的NRAM向量src里面,经过一系列变换以后通过for循环把src的元素写回目标向量dst中,这个过程最耗时的地方就在于数组的跳跃访问,为了解决这个问题,我们尝试一种合并访存的方式来读取数组,我们以4维向量举例子,其中假设向量的形状为[A,B,C,D],下面需要针对softmax的操作维度axis进行分类讨论,全局索引为i(BCD) + j(CD) + k(D) + s,具体想法如下:

axis=0

我们知道 j(CD) + k(D) + s对应的othersize刚好就是BCD,而stride正好也是BCD,为此我们可以这样读取数据,把向量分成A个单元,其中每个单元的长度为BCD,考虑for循环如下:for(i = 0; i < A; i++),循环体内每次读取source[i×(BCD):(i+1)×BCD]这部分数据,我们发现这样做可以得到A个长度为BCD的向量,而且每个向量对应元素的索引差别就是stride,因此我们完全可以把这A个向量存储起来,逐个元素比较最大值M,最终得到一个长度为BCD的向量tmpMax,其中tmpMax当中的每个元素正好就是不同(j,k,s)对应的最大值,类似的可以这样求出数值和以及把数据写回GDRAM。
下面这个bang_maxequal可以完成对应元素比较最大值,另外关于对应元素求和的函数直接使用bang_add即可。
在这里插入图片描述
在这种情况下,taskId用于处理othersize这部分,主要原因在于此时读取数据的时候,只有othersize这部分数据恰好是连续的。

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素

__mlu_entry__ void softmaxKernelAxis_s(float* destination, float* source, int othersize, int dimsize, int stride) {// axis = 0
  __nram__ float src[maxNum];//每次搬运maxNum数据到NRAM
  __nram__ float tmpSum[maxNum];
  __nram__ float tmpNewMax[maxNum];
  __nram__ float tmpOldMax[maxNum];

  int remain = othersize % taskDim;
  int stepEasy = (othersize - remain)/taskDim;
  int stepHard = stepEasy + 1;
  int step = (taskId < remain ? stepHard : stepEasy);//前部分taskId多处理一个元素
  int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);
  int remainNram = step%maxNum;
  int repeat = (step - remainNram)/maxNum;
  
  __bang_printf("taskId:%d, repeat:%d, step:%d, indStart:%d, remainNram:%d\n", taskId, repeat, step, indStart, remainNram);
  for(int j = 0; j < repeat; j++){
    __bang_write_value(tmpNewMax, maxNum, -INFINITY);
    __bang_write_zero(tmpSum, maxNum);
    for(int i = 0; i < dimsize; i++){
      __memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
      __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//不断更新最大值
      __bang_sub(src, src, tmpNewMax, maxNum);//x - M
      __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
      if(i > 0){
        __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
        __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
        __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
      }
      __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
      __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
    } 
    __bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
    //开始指数变换并且写回GDRAM
    __bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
    __memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
    for(int i = 0; i < dimsize - 1; i++){
      __memcpy(src, source + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
      __bang_sub(src, src, tmpNewMax, maxNum);//x - M
      __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
      __bang_mul(src, src, tmpSum, maxNum);
      __memcpy(destination + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
    } 
  }
  if(remainNram){
    __bang_write_value(tmpNewMax, maxNum, -INFINITY);
    __bang_write_zero(tmpSum, maxNum);
    __bang_write_zero(src, maxNum);
   
    
    for(int i = 0; i < dimsize; i++){
      __memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
      __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);
      __bang_sub(src, src, tmpNewMax, maxNum);//x - M
      __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
      if(i > 0){
        __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
        __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
        __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)
      }
      __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
      __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
    } 
    /***
    for(int k = 0; k < remainNram; k++){
      __bang_printf("%d,max:%.2f,sum:%.2f, src:%.2f\n",k, tmpNewMax[k], tmpSum[k], src[k]);
    }
    ***/
    __bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
    //开始指数变换并且写回GDRAM
    __bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
    __memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
    for(int i = 0; i < dimsize - 1; i++){
      __memcpy(src, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
      __bang_sub(src, src, tmpNewMax, maxNum);//x - M
      __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
      __bang_mul(src, src, tmpSum, maxNum);
      __memcpy(destination + i * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
    } 
    
  }
  
}


int main(void)
{
  //int shape[4] = {1024,128,32,32};
  //int shape[4] = {1024,64,32,32};
  int shape[4] = {1024,32,32,32};
  //int shape[4] = {2, 3, 2, 2};
  int axis = 0;
  int stride = 1;
  int dimsize = shape[axis];
  int num = 1;
  int othersize = 1;
  for(int s = 3; s >= 0; s--){
    num *= shape[s];
    if(s > axis){
      stride *= shape[s];
    }
    if(s != axis){
      othersize *= shape[s];
    }
  }
  
  printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, num:%d\n", axis, dimsize, stride, othersize, num);
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {4, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_destination = (float*)malloc(num * sizeof(float));
  float* host_src = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    host_src[i] = i%4;
    //host_src[i] = i;
  }

  float* mlu_destination;
  float* mlu_src;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernelAxis_s<<<dim, ktype, queue>>>(mlu_destination, mlu_src, othersize, dimsize, stride);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_destination);
  cnrtFree(mlu_src);
  
  
  free(host_destination);
  free(host_src);
  

  return 0;
}
                           

axis = -1

此时softmax操作维度正好是最后一个,这个时候就更加简单了,把向量分成ABC个单元,每个单元长度为D,考虑这样一个for循环:for(i = 0; i < ABC; i++),每轮循环读取source[i×(D):(i+1)×D]这份数据,针对这部分数据做规约获得最大值M,经过这个循环以后就可以得到不同(i,j,k)对应的最大值,对应的也就是othersize这部分数据对应的最大值,类似的可以得到数值和以及把数据写回GDRAM。在这种情况下,数据在axis=-1这个轴连续,此时并行策略有两种:
第一种策略:用taskId处理othersize,具体做法可以是for(i=taskId; i < ABC; i += taskDim),然后每轮循环内部读取对应的长度为D的数据,但是此时D不一定是2的幂次方,而且NRAM上也不一定能一次放下长度为D的向量,所以这个时候在循环内部,还需要额外针对source[i×(D):(i+1)×D]多做一个循环,每次循环读取maxNum个元素,直到数据读取结束。
第二种并行策略:串行处理othersize,for(i = 0; i < ABC; i++),在循环内部针对source[i×(D):(i+1)×D]这份数据分配给不同的taskId,这种做法导致每个taskId分到的数据是source[i×(D):(i+1)×D]一部分,在我们之前代码里面就是step,并且step也不一定是2的幂次方,也不一定能够在NRAM上放下,而且我们需要的最大值是source[i×(D):(i+1)×D]这部分数据的最大值,如果把这部分数据切分到不同taskId,最后算完以后还得额外针对不同taskId做一个规约(和上面的一维向量一模一样)。
经过上面两种分析,我们倾向于采取第一种策略。另外如果使用for(i=taskId; i < ABC; i += taskDim),站在taskId的角度来看,每次循环读取数据都是跳跃的。如果我们提前设定好step,让不同taskId处理的索引在[taskId×step:(taskId+1)×step]这个区间,此时站在taskId的角度来说,每次循环读取的数据会相对连续(但是需要实验结果来验证)。不过为了方便起见,我们还是使用for(i=taskId; i < ABC; i += taskDim)这种循环模式来计算结果。

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;

__mlu_entry__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize) {// axis = -1
  __nram__ float src[maxNum];
  __nram__ float destSum[maxNum];//后面数值求和
  __nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]
  __nram__ float srcMax[2];
  __nram__ float destOldMax;
  __nram__ float destNewMax;

  int remain = dimsize % maxNum;
  int repeat = (dimsize - remain)/maxNum;
  for(int otherIdx = taskId; otherIdx < othersize; otherIdx += taskDim){
    int tid = otherIdx * dimsize;
    destOldMax = -INFINITY;
    destNewMax = -INFINITY;
    __bang_write_zero(destSum, maxNum);
    for(int i = 0; i < repeat; i++){
      __memcpy(src, source + tid + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
      __bang_argmax(srcMax, src, maxNum);
      if(destNewMax < srcMax[0]){
        destNewMax = srcMax[0];//更新最大值
      }
      __bang_sub_scalar(src, src, destNewMax, maxNum);//src = src - 最大值
      __bang_active_exp_less_0(src, src, maxNum);//src = exp(src - 最大值)
      if(i > 0){
        __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
      }
      __bang_add(destSum, destSum, src, maxNum);
      destOldMax = destNewMax;
    }
    //------------
    if(remain){
      __bang_write_value(src, maxNum, -INFINITY);//多余部分必须设置负无穷
      __memcpy(src, source + tid + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
      
      __bang_argmax(srcMax, src, maxNum);
      if(destNewMax < srcMax[0]){
        destNewMax = srcMax[0];
      }
      __bang_write_value(src, maxNum, destNewMax);//必须重新初始化为destNewMax
      __memcpy(src, source + tid + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
      __bang_sub_scalar(src, src, destNewMax, maxNum);//后面maxNum-remain部分为0
      __bang_active_exp_less_0(src, src, maxNum);//相当于多加了maxNum-remain
      if(repeat > 0){
        __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
      }
      __bang_add(destSum, destSum, src, maxNum);
      destOldMax = destNewMax;
    }
    //--------------
    //--------------------------------
    __bang_write_zero(destSumFinal, warpSize);
    int segNum = maxNum / warpSize;
    for(int strip = segNum/2; strip > 0; strip = strip / 2){
      for(int i = 0; i < strip ; i++){
        __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
      } 
    }
    __bang_reduce_sum(destSumFinal, destSum, warpSize);
    
    if(remain){
      destSumFinal[0] = destSumFinal[0] - (maxNum - remain);
    }
    //-----------
    float globalSumInv = 1.0/destSumFinal[0];
    for(int i = 0; i < repeat; i++){
      __memcpy(src, source + tid + i * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
      __bang_sub_scalar(src, src, destNewMax, maxNum); 
      __bang_active_exp_less_0(src, src, maxNum);
      __bang_mul_scalar(src, src, globalSumInv, maxNum);//倒数和另一个向量逐元素相乘得到除法结果
      __memcpy(destination + tid + i * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
      
    }
    if(remain){
      __bang_write_value(src, maxNum, destNewMax);
      __memcpy(src, source + tid + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
      __bang_sub_scalar(src, src, destNewMax, maxNum);
      __bang_active_exp_less_0(src, src, maxNum);
      __bang_mul_scalar(src, src, globalSumInv, maxNum);//倒数和另一个向量逐元素相乘得到除法结果
      __memcpy(destination + tid + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
    }
  }
  
}


int main(void)
{
  //int shape[4] = {1024,128,32,32};
  //int shape[4] = {1024,64,32,32};
  int shape[4] = {1024,32,32,32};
  //int shape[4] = {2, 3, 2, 2};
  int axis = 3;
  int stride = 1;
  int dimsize = shape[axis];
  int num = 1;
  int othersize = 1;
  for(int s = 3; s >= 0; s--){
    num *= shape[s];
    if(s > axis){
      stride *= shape[s];
    }
    if(s != axis){
      othersize *= shape[s];
    }
  }
  
  printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, num:%d\n", axis, dimsize, stride, othersize, num);
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {4, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_destination = (float*)malloc(num * sizeof(float));
  float* host_src = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    //host_src[i] = i%4;
    host_src[i] = i;
  }

  float* mlu_destination;
  float* mlu_src;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernelAxis_e<<<dim, ktype, queue>>>(mlu_destination, mlu_src, othersize, dimsize);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_destination);
  cnrtFree(mlu_src);
  
  
  free(host_destination);
  free(host_src);
  

  return 0;
}
                           

0 < axis < dimsize - 1

假设dim表示向所属空间的维度,此时最为复杂,结合上面axis=0和axis=-1的分析,这里我们这样考虑0 < axis < dim - 1,为了方便叙述,我们分别以axis=1和axis=2来解释数据读取的做法:
axis=1,对于[A,B,C,D]这样的向量来说,我们设置otherIdx=i(BCD)和循环for(j = 0; j < B; j++),其中每轮循环读取长度为CD的数据source[otherIdx + j×stride:otherIdx + j×stride + CD],此时我们发现对于固定的otherIdx来说,经过for循环以后会得到dimsize=B个长度为CD的向量,并且我们逐个元素比较最大值最终可以得到一个长度为CD的向量tmpMax,其中tmpMax保存的是对于固定otherIdx下对应于(k,s)的最大值,类似的可以得到数值和以及写回数据。
axis=2,我们设置otherIdx=i(BCD) + j(CD)和循环for(k = 0; k < C; k++),其中每轮循环读取长度为D的数据source[otherIdx + k×stride:otherIdx + k×stride + D],此时我们发现对于固定的otherIdx来说,经过for循环以后会得到dimsize=C个长度为D的向量,并且我们逐个元素比较最大值最终可以得到一个长度为D的向量tmpMax,其中tmpMax保存的是对于固定otherIdx下对应于(s)的最大值,类似的可以得到数值和以及写回数据。
我们可以得到规律,如果axis是中间维度,那么我们需要固定axis之前的otherIdx,然后设置对应的for循环,每轮循环读取axis之后的数据即可。我们设置两个参数frontsize和behindsize分别表示axis前面和后面的数据,比如说axis=1,frontsize=A,behindsize=CD,如果axis=2,那么frontsize=AB,behindsize=D。
这种时候我们需要考虑taskId到底用来处理frontsize还是behindsize,两种想法都可以,下面我们来分析一下两种不同的策略,我们以axis=2来举例说明:
第一种:taskId处理frontsize,即for(ind = taskId; ind < frontsize; ind += taskDim),由于axis=2,此时我们知道frontsize=AB,ind对应的二维索引(i,j)有对应关系ind=iB + j,但是我们需要对ind进一步做一个转换得到frontIdx = ind×CD,更加一般的情况是frontIdx = ind×dimsize×behindsize。进入这个循环以后继续for(k = 0; k < C; k++),此时开始一次读取behindsize个数据。
第二种:taskId处理behindsize,此时对于frontsize只能串行处理了,即for(ind = 0; ind < frontsize; ind += 1),由于axis=2,frontIdx = ind×CD,更加一般的情况是frontIdx = ind×dimsize×behindsize。进入这个循环以后继续for(k = 0; k < C; k++),此时由于taskId处理的是behindsize,那么不同taskId分配的数据量是step,开始一次读取step个数据。
粗糙的观察,我们倾向于选择第一种策略,另外我们注意到,其实behindsize就是stride,为此后面我们不区分两者。
策略1:

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素

__mlu_entry__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {
  // 0<axis<dim -1 
  __nram__ float src[maxNum];
  __nram__ float tmpSum[maxNum];
  __nram__ float tmpNewMax[maxNum];
  __nram__ float tmpOldMax[maxNum];

  int remain = stride % maxNum;
  int repeat = (stride - remain) / maxNum;
  for(int ind = taskId; ind < frontsize; ind += taskDim){
    int frontIdx = ind * dimsize * stride;
    for(int j = 0; j < repeat; j++){
      __bang_write_value(tmpNewMax, maxNum, -INFINITY);
      __bang_write_zero(tmpSum, maxNum);
      __bang_write_zero(src, maxNum);
      for(int i = 0; i < dimsize; i++){
        __memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
        __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//不断更新最大值
        __bang_sub(src, src, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
        if(i > 0){
          __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
          __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
          __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
        }
        __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
        __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
      }
      __bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
      //开始指数变换并且写回GDRAM
      __bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
      __memcpy(destination + (dimsize - 1) * stride + frontIdx + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
      for(int i = 0; i < dimsize - 1; i++){
        __memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
        __bang_sub(src, src, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
        __bang_mul(src, src, tmpSum, maxNum);
        __memcpy(destination + frontIdx + i * stride + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
      } 
    }
    if(remain){
      __bang_write_value(tmpNewMax, maxNum, -INFINITY);
      __bang_write_zero(tmpSum, maxNum);
      __bang_write_value(src, maxNum, -INFINITY);
      for(int i = 0; i < dimsize; i++){
        __memcpy(src, source + frontIdx + i * stride + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
        __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);
        __bang_sub(src, src, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
        if(i > 0){
          __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
          __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
          __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)
        }
        __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
        __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
      }
      //-------------------
      __bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
      //开始指数变换并且写回GDRAM
      __bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
      __memcpy(destination + (dimsize - 1) * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
      for(int i = 0; i < dimsize - 1; i++){
        __memcpy(src, source + i * stride + frontIdx + repeat * maxNum, remain * sizeof(float), GDRAM2NRAM);
        __bang_sub(src, src, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
        __bang_mul(src, src, tmpSum, maxNum);
        __memcpy(destination + i * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
      } 
      //---------------------
    }
  }
  
}


int main(void)
{
  //int shape[4] = {1024,128,32,32};
  //int shape[4] = {1024,64,32,32};
  int shape[4] = {1024,32,32,32};
  //int shape[4] = {2, 3, 2, 2};
  int axis = 2;
  int stride = 1;
  int dimsize = shape[axis];
  int num = 1;
  int othersize = 1;
  int frontsize = 1;
  
  for(int s = 3; s >= 0; s--){
    num *= shape[s];
    if(s > axis){
      stride *= shape[s];
    }
    if(s < axis){
      frontsize *= shape[s];
    }
    if(s != axis){
      othersize *= shape[s];
    }
  }
  
  printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, frontsize:%d, num:%d\n", axis, dimsize, stride, othersize, frontsize, num);
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {4, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_destination = (float*)malloc(num * sizeof(float));
  float* host_src = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    //host_src[i] = i%4;
    host_src[i] = i;
  }

  float* mlu_destination;
  float* mlu_src;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernelAxis_m<<<dim, ktype, queue>>>(mlu_destination, mlu_src, frontsize, dimsize, stride);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_destination);
  cnrtFree(mlu_src);
  
  
  free(host_destination);
  free(host_src);
  

  return 0;
}
                           

策略2:

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素

__mlu_entry__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {
  // 0<axis<dim -1 
  __nram__ float src[maxNum];
  __nram__ float tmpSum[maxNum];
  __nram__ float tmpNewMax[maxNum];
  __nram__ float tmpOldMax[maxNum];

  int remain = stride % taskDim;
  int stepEasy = (stride - remain)/taskDim;
  int stepHard = stepEasy + 1;
  int step = (taskId < remain ? stepHard : stepEasy);//前部分taskId多处理一个元素
  int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);

  int remainNram = step % maxNum;
  int repeat = (step - remainNram) / maxNum;
  for(int ind = 0; ind < frontsize; ind ++){
    int frontIdx = ind * dimsize * stride;
    for(int j = 0; j < repeat; j++){
      __bang_write_value(tmpNewMax, maxNum, -INFINITY);
      __bang_write_zero(tmpSum, maxNum);
      __bang_write_zero(src, maxNum);
      for(int i = 0; i < dimsize; i++){
        __memcpy(src, source + frontIdx + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
        __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//不断更新最大值
        __bang_sub(src, src, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
        if(i > 0){
          __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
          __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
          __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
        }
        __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
        __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
      }
      __bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
      //开始指数变换并且写回GDRAM
      __bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
      __memcpy(destination + (dimsize - 1) * stride + frontIdx + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
      for(int i = 0; i < dimsize - 1; i++){
        __memcpy(src, source + frontIdx + i * stride + indStart + j * maxNum, maxNum * sizeof(float), GDRAM2NRAM);
        __bang_sub(src, src, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
        __bang_mul(src, src, tmpSum, maxNum);
        __memcpy(destination + frontIdx + i * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
      } 
    }
    if(remainNram){
      __bang_write_value(tmpNewMax, maxNum, -INFINITY);
      __bang_write_zero(tmpSum, maxNum);
      __bang_write_zero(src, maxNum);
      for(int i = 0; i < dimsize; i++){
        __memcpy(src, source + frontIdx + i * stride + indStart + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
        __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);
        __bang_sub(src, src, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
        if(i > 0){
          __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
          __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
          __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)
        }
        __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
        __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
      }
      //-------------------
      __bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
      //开始指数变换并且写回GDRAM
      __bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
      __memcpy(destination + (dimsize - 1) * stride + indStart + frontIdx + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
      for(int i = 0; i < dimsize - 1; i++){
        __memcpy(src, source + i * stride + frontIdx + repeat * maxNum, remainNram * sizeof(float), GDRAM2NRAM);
        __bang_sub(src, src, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
        __bang_mul(src, src, tmpSum, maxNum);
        __memcpy(destination + i * stride + indStart + frontIdx + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);
      } 
      //---------------------
    }
  }
  
}


int main(void)
{
  int shape[4] = {1024,128,32,32};
  //int shape[4] = {1024,64,32,32};
  //int shape[4] = {1024,32,32,32};
  //int shape[4] = {2, 3, 2, 2};
  int axis = 1;
  int stride = 1;
  int dimsize = shape[axis];
  int num = 1;
  int othersize = 1;
  int frontsize = 1;
  ;
  for(int s = 3; s >= 0; s--){
    num *= shape[s];
    if(s > axis){
      stride *= shape[s];
    }
    if(s < axis){
      frontsize *= shape[s];
    }
    if(s != axis){
      othersize *= shape[s];
    }
  }
  
  printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, frontsize:%d, num:%d\n", axis, dimsize, stride, othersize, frontsize, num);
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {16, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION4;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_destination = (float*)malloc(num * sizeof(float));
  float* host_src = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    //host_src[i] = i%4;
    host_src[i] = i;
  }

  float* mlu_destination;
  float* mlu_src;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernelAxis_m<<<dim, ktype, queue>>>(mlu_destination, mlu_src, frontsize, dimsize, stride);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_destination);
  cnrtFree(mlu_src);
  
  
  free(host_destination);
  free(host_src);
  

  return 0;
}
                           

这里我们不妨看一下不同规模情况下上面并行策略带来的优化效果,下面针对axis=1,2都是指策略1,因为策略2的效果太差不展示:
在这里插入图片描述

高维softmax的进一步优化

axis = -1

从上面的表格我们发现对于axis=-1,此时虽然数据读取连续,但是速度仍然非常慢,我们发现最主要原因在于src数组大量内存浪费。比如说我们上面表格的例子,最后一个维度长度是32,但是我们为src开辟的内存是maxNum×sizeof(float),在上面的做法中,我们一次只从GDRAM读取32个浮点数到NRAM,剩下的空间全部浪费了,所以速度特别慢,为了充分利用这部分内存,下面我们将给出另一种思路。
上面做法的本质其实是taskId处理othersize,然后一个src处理一个otherIdx,相当于说src只存放固定一个otherIdx,axis=-1对应的这部分数据。为了充分利用内存,这里我们希望一个src可以存储多个otherIdx对应的axis=-1的这份数据,我们不妨先假设maxNum正好整除shape[-1],并且shape[-1]也是2的幂次方,假设multiple=maxNum/shape[-1]=maxNum/dimsize,此时一个src存储了muitiple个otherIdx对应的数据,一共有othersize个长度为dimsize的向量,一个src就存储了multiple个这样的向量,而且我们一共使用taskDim个任务,因此一次就可以存储size=multiple×taskDim个长度为dimsize的向量,下面为了方便叙述,我们引入一些变量:
multiple=maxNum/shape[-1]=maxNum/dimsize:一个src可以存储多少个长度为dimsize的向量
size=multiple×taskDim:开辟taskDim个任务可以存储长度为dimsize的向量数目
remainS = othersize % size:如果不能整除,多余的余数需要特殊处理,分配给不同taskId,每个taskId额外获得step个
taskRepeat = (othersize - remainS) / size:经过taskReapt次循环可以加载的othersize对应的数据量
整体来看,每个taskId处理的数据量就是(taskRepeat * multiple + step) * dimsize,此时我们可以计算出不同taskId的偏移量,计算以后,下面我们站在taskId的角度来看计算过程:
首先进入一个循环(int s = 0; s < taskRepeat; s++),循环体内部在原有偏移量的情况下计算出不同s对应的偏移量为tid = s × multiple × dimsize,循环体内部每次从GDRAM中读取长度为multiple×dimsize的数据加载到src上,然后再开一个循环(int j = 0; j < multiple; j++),单独针对src处理,每次从src读取长度为dimsize的数据进行求和,指数变换,最终把结果写回GDRAM。
跳出上面的二重循环以后,下面针对额外获得的step这份数据进行处理,此时只需要一重循环(int s = 0; s < step; s++),循环体内每次直接从source读取长度为dimsize的数据,经过一系列计算以后写回GDRAM即可。

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素
const int warpSize = 32;
//dimS至少要等于dimsize,且是最近的2的幂次方,同时由于后面需要规约,为此dimS至少是32
//下面这个kernel只适合dimsize < maxNum的情况
template<int dimS>
__mlu_entry__ void softmaxKernelAxis_e(float* destination, float* source, int othersize, int dimsize) {// axis = -1
  int multiple = maxNum / dimsize;
  int size = taskDim * multiple;
  int remainS = othersize % size;
  int taskRepeat = (othersize - remainS) / size;
  int remainT = remainS % taskDim;
  int stepEasy = (remainS - remainT) / taskDim;
  int stepHard = stepEasy + 1;
  int step = (taskId < remainT ? stepHard : stepEasy);
  //每个taskId处理othersize分配的量就是taskRepeat * multiple + step
  //整体来看,每个taskId处理的数据量就是(taskRepeat * multiple + step) * dimsize
  int startHard = taskId * (taskRepeat * multiple + stepHard);
  int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
  int indStart = (taskId < remainT ? startHard: startEasy);
  source = source + indStart * dimsize;
  destination = destination + indStart * dimsize;
  //printf("taskRepeat:%d, indstart:%d, %d\n", taskRepeat, indStart, indStart * dimsize);
  __nram__ float src[maxNum];

  __nram__ float tmp[dimS];
  __nram__ float destSum[dimS];//后面数值求和
  __nram__ float destSumFinal[warpSize];//将destSum规约到destFinal[0]
  __nram__ float srcMax[2];
  

  int tid;
  for(int s = 0; s < taskRepeat; s++){
    
    tid = s * multiple * dimsize;
    __memcpy(src, source + tid, multiple * dimsize * sizeof(float), GDRAM2NRAM);
    for(int j = 0; j < multiple; j++){
      __bang_write_zero(destSum, dimS);
      __bang_write_zero(destSumFinal, warpSize);
      __bang_write_value(tmp, dimS, -INFINITY);

      __memcpy(tmp, src + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);
      __bang_argmax(srcMax, tmp, dimS);
      __bang_write_value(tmp, dimS, srcMax[0]);//必须重新初始化为srcMax[0]
      __memcpy(tmp, src + j * dimsize, dimsize * sizeof(float), NRAM2NRAM);//必须要重新读取
      __bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
      __bang_active_exp_less_0(tmp, tmp, dimS);//这里我们认为负无穷-srcMax[0]非常小,所以后面dimS - dimsize部分认为是0
      __bang_add(destSum, destSum, tmp, dimS);
      
      int segNum = dimS / warpSize;//开始数值求和
      for(int strip = segNum/2; strip > 0; strip = strip / 2){
        for(int i = 0; i < strip ; i++){
          __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
        } 
      }
      __bang_reduce_sum(destSumFinal, destSum, warpSize);//此时destSumFinal[0]保存的就是当前dimsize长度数据的数值和
      destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);
      //__bang_printf("max:%.2f, sum:%.2f\n", srcMax[0], destSumFinal[0]);
      float globalSumInv = 1.0/destSumFinal[0];
      __bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);
      //__memcpy(src + j * dimsize, tmp, dimsize * sizeof(float), NRAM2NRAM);
      __memcpy(destination + tid + j * dimsize, tmp, dimsize * sizeof(float), NRAM2GDRAM);
    }//必须马上写回GDRAM,如果先写回src,然后src写回GDRAM,可能出现src写回GDRAM没有结束就修改src数据的情况
    //__memcpy(destination + tid, src, multiple * dimsize * sizeof(float), NRAM2GDRAM);
  }
  
  for(int s = 0; s < step; s++){
    tid = taskRepeat * multiple * dimsize + s * dimsize;
    __bang_write_zero(destSum, dimS);
    __bang_write_zero(destSumFinal, warpSize);
    __bang_write_value(tmp, dimS, -INFINITY);
    __memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);
    
    __bang_argmax(srcMax, tmp, dimS);
    __bang_write_value(tmp, dimS, srcMax[0]);
    __memcpy(tmp, source + tid, dimsize * sizeof(float), GDRAM2NRAM);
    __bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
    
    __bang_active_exp_less_0(tmp, tmp, dimS);//后面dimS - dimsize部分是1
    __bang_add(destSum, destSum, tmp, dimS);
    
    int segNum = dimS / warpSize;//开始数值求和
    for(int strip = segNum/2; strip > 0; strip = strip / 2){
      for(int i = 0; i < strip ; i++){
        __bang_add(destSum + i * warpSize, destSum + i * warpSize, destSum + (i + strip) * warpSize, warpSize);
      }
    }
    __bang_reduce_sum(destSumFinal, destSum, warpSize);//此时destSumFinal[0]保存的就是当前dimsize长度数据的数值和
    destSumFinal[0] = destSumFinal[0] - (dimS - dimsize);
    //__bang_printf(":%.2f,max:%.2f, sum:%.2f, final:%.2f\n",tmp[1], srcMax[0], destSum[1], destSumFinal[0]);
    float globalSumInv = 1.0/destSumFinal[0];
    __bang_mul_scalar(tmp, tmp, globalSumInv, maxNum);
    __memcpy(destination + tid, tmp, dimsize * sizeof(float), NRAM2GDRAM);
  }
  //__bang_printf("max:%.2f, sum:%.2f\n", srcMax[0], destSumFinal[0]);
  
}


int main(void)
{
  //int shape[4] = {1024,128,32,32};
  //int shape[4] = {1024,64,32,32};
  int shape[4] = {1024,32,32,32};
  //int shape[4] = {2, 3, 2, 2};
  int axis = 3;
  int stride = 1;
  int dimsize = shape[axis];
  int num = 1;
  int othersize = 1;
  for(int s = 3; s >= 0; s--){
    num *= shape[s];
    if(s > axis){
      stride *= shape[s];
    }
    if(s != axis){
      othersize *= shape[s];
    }
  }
  
  printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, num:%d\n", axis, dimsize, stride, othersize, num);
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {4, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_destination = (float*)malloc(num * sizeof(float));
  float* host_src = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    host_src[i] = i%4;
    //host_src[i] = i;
  }

  float* mlu_destination;
  float* mlu_src;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernelAxis_e<32><<<dim, ktype, queue>>>(mlu_destination, mlu_src, othersize, dimsize);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_destination);
  cnrtFree(mlu_src);
  
  
  free(host_destination);
  free(host_src);
  

  return 0;
}
                           

在这里插入图片描述

0 < axis < dimsize - 1

这种情况更加特殊,根据上面的分析,我们知道如果axis是中间维度,比如说[A,B,C,D]向量,axis=1,索引为i(BCD)+j(CD)+k(D)+s,此时我们把索引分成三个部分,i(BCD)称之为frontIdx,k(D)+s对应的部分是长度为CD的behindsize,而且我们知道behindsize=stride,以及中间对应的j(CD)。上面我们分析,对于固定的frontIdx来说,behindsize在内存中是连续的,我们可以使用for(j = 0: j < B: j++),循环体内每次读取[frontIdx + j×(CD):frontIdx + (j+1)×(CD)]数据,因此得到B个长度为CD的向量,然后这B个向量逐元素对比最大值得到一个长度为CD的向量tmpNewMax,此时tmpNewMax对应元素保存的就是固定frontIdx下不同(k,s)对应的最大值。
和上面axis=-1类似,这种情况如果behindsize远远小于maxNum,那么src也会有大量的内存浪费,因此我们也希望能让src尽可能多加载数据。
这里我们需要考虑一下maxNum和BCD的相对大小,在axis=1的情况下,如果BCD的大小和maxNum差不多,那么我们尽量希望src一次加载长度为BCD的向量,此时src保存的数据相当于是固定frontIdx情况下,对于所有(k,s)的数据,接下来我们针对src的数据做一个循环for(j=0;j<B;j++),循环体每次读取长度为CD的数据,不断更新最大值,最后写回GDRAM。这种做法更加适合axis相对靠前,CD小于maxNum,BCD小于maxNum但是BCD接近maxNum的情况,因为当axis相对靠前的时候,此时dimsize×stride会更有机会超过maxNum。如果说stride比maxNum小,但是dimsize×stride比maxNum大,此时我们就需要针对dimsize进行拆分,详细细节参考代码。
如果axis相对靠后,此时就算是dimsize×stride也远小于maxNum,那么就算一次读取长度为dimsize×stride的数据,src也会有大量内存浪费,此时我们就希望src能够读取多个以长度为dimsize×stride的数据,保证src内存尽可能填充满(最极端的例子,比如说上面的4维向量[A,B,C,D],axis=2,如果D远小于maxNum,CD远小于maxNum,BCD远小于maxNum,就连ABCD也远小于maxNum,此时就干脆让src一次把所有数据都加载进来)。这个时候就需要额外开辟一个长度为dimsize×stride的NRAM向量,每次从src中读取数据,不断计算循环(和原始做法一样,只不过原来从GDRAM读取长度为dimsize×stride的数据,现在变成了从NRAM的src中读取长度为dimsize×stride的数据)。

axis相对靠前

此时虽然stride<maxNum,但是dimsize×stride>=maxNum,那么我们一次让src加载multiple×stride个数据,其中multiple=maxNum/stride,代码如下:
下面这个代码需要特别注意的是,计算出frontIdx以后,千万不能写source = source + frontIdx,而是应该在数据读取的时候进行偏移,否则会导致内存踩踏(内存踩踏原因还在查找)

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素

//strideS是大于等于stride的最小的二的幂次方
template<int strideS>
__mlu_entry__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {
  // 0<axis<dim -1 
  __nram__ float src[maxNum];
  __nram__ float tmp[strideS];
  __nram__ float tmpOldMax[strideS];
  __nram__ float tmpNewMax[strideS];
  __nram__ float tmpSum[strideS];
  if(dimsize * stride >= maxNum){
    int multiple = maxNum / stride;
    int size = multiple * stride;//一个src最多可以放的数据量
    int remain = dimsize % multiple;//如果不能整除,这部分数据需要特殊处理
    int repeat = (dimsize - remain) / multiple;//为了加载整个dimsize需要的循环总数
    printf("maxNum:%d, dimsize * stride:%d, multiple:%d, size:%d, repeat:%d,remain:%d\n",maxNum, dimsize * stride, multiple, size, repeat,remain);
    for(int ind = taskId; ind < frontsize; ind += taskDim){
      int frontIdx = ind * dimsize * stride;
      
      __bang_write_value(tmpNewMax, strideS, -INFINITY);//必须初始化为负无穷
      __bang_write_value(tmp, strideS, -INFINITY);//必须初始化为负无穷
      __bang_write_zero(tmpSum, strideS);//必须初始化为0
      
      for(int j = 0; j < repeat; j++){
        __memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
        for(int m = 0; m < multiple; m++){
          __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
          
          __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//虽然tmpNewMax后面strideS-stride部分是0,但是不用写回GDRAM,不影响结果
          
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//tmp后面strideS-stride部分是0
          __bang_active_exp_less_0(tmp, tmp, strideS);
          if(j != 0 || m != 0){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
          //if(m == 0) __bang_printf("tmp:%.2f, tmpMax[0]:%.2f,tmpSum[0]:%.2f\n", tmp[1], tmpNewMax[1],tmpSum[0]);
          __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
        }
      }
      //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[0],tmpSum[0]);
      if(remain){
        __memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(float), GDRAM2NRAM);
        for(int m = 0; m < remain; m++){
          __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
          __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//tmp后面strideS-stride部分是0
          __bang_active_exp_less_0(tmp, tmp, strideS);
          if(repeat != 0 || m != 0){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
        }
      }
      
      //此时tmpNewMax存储的是对应于固定frontIdx,behindsize对应数据的最大值,而tmpSum存储的就是对应数值和
      //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
      __bang_active_recip(tmpSum, tmpSum, strideS);
      //__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
      if(remain){
        for(int m = 0; m < remain; m++){
          __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);
          __bang_active_exp_less_0(tmp, tmp, strideS);
          __bang_mul(tmp, tmp, tmpSum, strideS);
          __memcpy(destination + frontIdx + repeat * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
        }
        
      }
      for(int j = 0 ; j < repeat; j++){
        __memcpy(src, source + frontIdx + j * multiple * stride, size * sizeof(float), GDRAM2NRAM);
        for(int m = 0; m < multiple; m++){
          __memcpy(tmp, src + m * stride, stride * sizeof(float), NRAM2NRAM);
          
          __bang_sub(tmp, tmp, tmpNewMax, strideS);
          __bang_active_exp_less_0(tmp, tmp, strideS);
          __bang_mul(tmp, tmp, tmpSum, strideS);
          __memcpy(destination + frontIdx + j * multiple * stride + m * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
        }
      }
    }

  }
  
  
}


int main(void)
{
  //int shape[4] = {1024,128,32,32};
  //int shape[4] = {1024,64,32,32};
  int shape[4] = {1024,32,32,32};
  //int shape[4] = {2, 3, 2, 2};
  int axis = 1;
  int stride = 1;
  int dimsize = shape[axis];
  int num = 1;
  int othersize = 1;
  int frontsize = 1;
  ;
  for(int s = 3; s >= 0; s--){
    num *= shape[s];
    if(s > axis){
      stride *= shape[s];
    }
    if(s < axis){
      frontsize *= shape[s];
    }
    if(s != axis){
      othersize *= shape[s];
    }
  }
  
  printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, frontsize:%d, num:%d\n", axis, dimsize, stride, othersize, frontsize, num);
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {4, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_destination = (float*)malloc(num * sizeof(float));
  float* host_src = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    //host_src[i] = i%4;
    host_src[i] = i;
  }

  float* mlu_destination;
  float* mlu_src;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernelAxis_m<1024><<<dim, ktype, queue>>>(mlu_destination, mlu_src, frontsize, dimsize, stride);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_destination);
  cnrtFree(mlu_src);
  
  
  free(host_destination);
  free(host_src);
  

  return 0;
}
                           

axis相对靠后

此时不仅stride<maxNum,dimsize×stride<maxNum,那么干脆定义behindsize = dimsize×stride,我们一次让src加载multiple×behindsize个数据,其中multiple=maxNum/behindsize,代码如下:

#include <bang.h>
#include <bang_device_functions.h>
#define EPS 1e-7
const int NRAM_MAX_SIZE = 1024 * 128;//后续树状求和必须保证NRAM_MAX_SIZE为2的幂次
const int maxNum = NRAM_MAX_SIZE/sizeof(float); //NRAM上最多存储maxNum个float元素

//strideS是大于等于stride的最小的二的幂次方
template<int strideS>
__mlu_entry__ void softmaxKernelAxis_m(float* destination, float* source, int frontsize, int dimsize, int stride) {
  // 0<axis<dim -1 
  __nram__ float src[maxNum];
  __nram__ float tmp[strideS];
  __nram__ float tmpOldMax[strideS];
  __nram__ float tmpNewMax[strideS];
  __nram__ float tmpSum[strideS];
  if(dimsize * stride < maxNum){
    int behindsize = dimsize * stride;
    int multiple = maxNum / behindsize;//表示一个maxNum能够在frontsize中分担的量
    int size = multiple * behindsize;//一个taskId中一个src能够加载的数据量
    int remainF = frontsize % (taskDim * multiple);
    int remainT = remainF % taskDim;
    int stepEasy = (remainF - remainT) / taskDim;
    int stepHard = stepEasy + 1;
    int step = (taskId < remainT ? stepHard : stepEasy);
    int taskRepeat = (frontsize - remainF) / (taskDim * multiple);
    //此时对应于frontsize,每个taskId处理的数据量是taskRepeat * multiple + step
    int startHard = taskId * (taskRepeat * multiple + stepHard);
    int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
    int indStart = (taskId < remainT ? startHard: startEasy);
    source = source + indStart * behindsize;//indStart * behindsize表示不同taskId对应的偏移量
    destination = destination + indStart * behindsize;
    int tid;
    for(int s = 0; s < taskRepeat; s++){
      tid = s * multiple * behindsize;
      __memcpy(src, source + tid, multiple * behindsize * sizeof(float), GDRAM2NRAM);
      for(int m = 0; m < multiple; m++){
        __bang_write_zero(tmpSum, strideS);
        __bang_write_value(tmp, strideS, -INFINITY);
        __bang_write_value(tmpNewMax, strideS, -INFINITY);
        for(int i = 0; i < dimsize; i++){
          __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
          __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
          __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
          if(i > 0){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);      //sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
        }
        __bang_active_recip(tmpSum, tmpSum, strideS);
        __bang_mul(tmp, tmp, tmpSum, strideS);//上面循环结束tmp存储的数据可以利用
        //__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
        __memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
        for(int i = 0; i < dimsize - 1; i++){
          __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
          __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
          __bang_mul(tmp, tmp, tmpSum, strideS);
          //__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
          __memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
        }
      }
      __memcpy(destination + tid, src, multiple * behindsize * sizeof(float), NRAM2GDRAM);
    }
    __bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize);
    if(step){
      tid = taskRepeat * multiple * behindsize; 
      __memcpy(src, source + tid, step * behindsize * sizeof(float), GDRAM2NRAM);
      for(int m = 0; m < step; m++){
        __bang_write_zero(tmpSum, strideS);
        __bang_write_value(tmp, strideS, -INFINITY);
        __bang_write_value(tmpNewMax, strideS, -INFINITY);
        for(int i = 0; i < dimsize; i++){
          __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
          __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
          __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
          if(i > 0){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);      //sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
        }
        //__bang_printf("max:%.2f,%.2f, sum:%.2f,sum:%.2f\n", tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]);
        __bang_active_recip(tmpSum, tmpSum, strideS);
        __bang_mul(tmp, tmp, tmpSum, strideS);//上面循环结束tmp存储的数据可以利用
        //__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
        __memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
        for(int i = 0; i < dimsize - 1; i++){
          __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(float), NRAM2NRAM);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
          __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
          __bang_mul(tmp, tmp, tmpSum, strideS);
          //__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
          __memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(float), NRAM2NRAM);
        }
      }
      __memcpy(destination + tid, src, step * behindsize * sizeof(float), NRAM2GDRAM);
    }
  }
  
  
}


int main(void)
{
  //int shape[4] = {1024,128,32,32};
  //int shape[4] = {1024,64,32,32};
  int shape[4] = {1024,32,32,32};
  //int shape[4] = {2, 3, 2, 2};
  int axis = 2;
  int stride = 1;
  int dimsize = shape[axis];
  int num = 1;
  int othersize = 1;
  int frontsize = 1;
  ;
  for(int s = 3; s >= 0; s--){
    num *= shape[s];
    if(s > axis){
      stride *= shape[s];
    }
    if(s < axis){
      frontsize *= shape[s];
    }
    if(s != axis){
      othersize *= shape[s];
    }
  }
  
  printf("axis:%d, dimsize:%d, stride:%d, othersize:%d, frontsize:%d, num:%d\n", axis, dimsize, stride, othersize, frontsize, num);
  cnrtQueue_t queue;
  CNRT_CHECK(cnrtSetDevice(0));
  CNRT_CHECK(cnrtQueueCreate(&queue));

  cnrtDim3_t dim = {4, 1, 1};
  int taskNum = dim.x * dim.y * dim.z;
  cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION1;

  cnrtNotifier_t start, end;
  CNRT_CHECK(cnrtNotifierCreate(&start));
  CNRT_CHECK(cnrtNotifierCreate(&end));

  float* host_destination = (float*)malloc(num * sizeof(float));
  float* host_src = (float*)malloc(num * sizeof(float));
  

  for (int i = 0; i < num; i++) {
    //host_src[i] = i%4;
    host_src[i] = i;
  }

  float* mlu_destination;
  float* mlu_src;
  
  CNRT_CHECK(cnrtMalloc((void**)&mlu_destination, num * sizeof(float)));
  CNRT_CHECK(cnrtMalloc((void**)&mlu_src, num * sizeof(float)));
  

  CNRT_CHECK(cnrtMemcpy(mlu_src, host_src, num * sizeof(float), cnrtMemcpyHostToDev));
  
  //----------------------------
  CNRT_CHECK(cnrtPlaceNotifier(start, queue));
  softmaxKernelAxis_m<1024><<<dim, ktype, queue>>>(mlu_destination, mlu_src, frontsize, dimsize, stride);
  CNRT_CHECK(cnrtPlaceNotifier(end, queue));
  cnrtQueueSync(queue);

  //---------------------------
  CNRT_CHECK(cnrtMemcpy(host_destination, mlu_destination, num * sizeof(float), cnrtMemcpyDevToHost));
  for(int i = 0; i < 24; i++){
    printf("softmax[%d]:%.6e,origin:%.6f\n", i, host_destination[i], host_src[i]);
  }
  float timeTotal;
  CNRT_CHECK(cnrtNotifierDuration(start, end, &timeTotal));
  printf("Total Time: %.3f ms\n", timeTotal / 1000.0);

  CNRT_CHECK(cnrtQueueDestroy(queue));

  cnrtFree(mlu_destination);
  cnrtFree(mlu_src);
  
  
  free(host_destination);
  free(host_src);
  

  return 0;
}
                           

下面使用的taskDim都是4,任务类型都是Union1:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/374658.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux自有服务—防火墙和计划任务

Linux常用自有服务有NTP时间同步服务、firewalld防火墙服务和crond计划任务服务&#xff0c;NTP在上一篇中讲过&#xff0c;这次主要来说一下防火墙firewalld与计划任务的相关内容。如下。 一、Linux中防火墙firewalld 1、什么是防火墙 防火墙&#xff1a;防范一些网络攻击…

thinkphp获取用户最新的阅读记录,按书籍id去重,返回最新的阅读记录

通过uid查询data_user_zhangjie的记录 去重shuji_id 获取createtime最新的一条数据 //获取用户章节记录public function getUserZhangjieList(){$uid = input(uid);if(empty

基于SpringBoot+Vue的外卖点餐管理系统

末尾获取源码作者介绍&#xff1a;大家好&#xff0c;我是墨韵&#xff0c;本人4年开发经验&#xff0c;专注定制项目开发 更多项目&#xff1a;CSDN主页YAML墨韵 学如逆水行舟&#xff0c;不进则退。学习如赶路&#xff0c;不能慢一步。 目录 一、项目简介 二、开发技术与环…

vue3项目中使用mapv

vue3项目中使用mapv mapv是百度地图官方提供的地图数据可视化开源项目&#xff0c;提供了很多效果酷炫的绘图api mapv地址在这里&#xff0c;示例图在这里 先解释为什么要用mapv echarts画的地图&#xff0c;都是行政区划&#xff0c;就算是geo地图&#xff0c;也只能在行政…

c++父类转换为子类,子类转换为父类,子类父类指针相互强制转换

1.子类转换为父类 子类转换为父类之后&#xff0c;不能调用子类独有的函数和成员变量&#xff0c;只能调用子类继承的虚函数&#xff0c;利用 多态的特性。 #include <iostream>class base { public:virtual void Show(){std::cout << "base class" &…

HDL Designer 2021.1 如何将默认编辑器修改为VsCode

第1步 安装Vscode 第2步 添加Vscode至HDL Designer 第3步 更改HDL Designer编译器 第4步 修改结束&#xff0c;在HDL Designer中双击block可使用Vscode编辑verilog

十分钟掌握前端获取实时数据的三种主流方式

前端获取实时数据的三种主流方式 本文聊聊前端获取实时数据的三种主要方式。想象一下&#xff0c;我们在网上购物时&#xff0c;经常能看到最新的优惠信息弹出&#xff0c;或者在社交媒体上看到朋友的最新动态更新。这些都是因为后端在默默地向我们的页面推送了最新的消息。那…

Oracle systemstate、gdb、dbx介绍

当数据库出现严重的性能问题或者hang了的时候&#xff0c; 可能最常用的办法就是重启数据库&#xff0c;简单有效解决问题&#xff1b;但是重启后如何追踪问题的根本原因成了难题&#xff0c;很多信息随着重启也消失不见了&#xff0c;让追查问题变的十分棘手&#xff0c;这时就…

LeetCode:9.回文数,对整数的反转操作

博主本想找个简单的题水一下&#xff0c;结果太久没写这块的代码&#xff0c;直接写着宕机着&#xff0c;十分难受&#xff0c;最后还调试了几下&#xff0c;悲&#xff0c; 目录 题目&#xff1a; 思路&#xff1a; 官方代码&#xff08;反转一半的&#xff09;&#xff1a…

如何在 Java 中通过 Map.Entry 访问 Map 的元素

我们使用 Map.Entry 来遍历 ConcurrentHashMap 的代码片段如下&#xff1a; for (Map.Entry<String, String> entry : map.entrySet()) { System.out.println("Key: " entry.getKey() ", Value: " entry.getValue()); } 在 Map.java 中&…

SpringBoot:配置相关知识点

SpringBoot&#xff1a;多环境配置 配置知识点demo&#xff1a;点击查看LearnSpringBoot02 点击查看更多的SpringBoot教程 一、SpringBootApplication SpringBootApplication 来标注一个主程序类&#xff0c;说明这是一个Spring Boot应用&#xff0c;运行这个类的main方法来…

挑战杯 python+opencv+机器学习车牌识别

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于机器学习的车牌识别系统 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;4分工作量&#xff1a;4分创新点&#xff1a;3分 该项目较为新颖&#xff0c;适…

ios设备解锁 --Apeaksoft iOS Unlocker

Apeaksoft iOS Unlocker是一款针对iOS系统的密码解锁工具。其主要功能包括解锁多种锁屏类型&#xff0c;包括数字密码、Touch ID、Face ID和自定义密码。此外&#xff0c;它还可以帮助用户删除iPhone密码以进入锁屏设备&#xff0c;忘记的Apple ID并将iPhone激活为新的&#xf…

【ARM Coresight 系列文章 8.1 - ARM Coresight 通过 APBIC arbiter】

请阅读【ARM Coresight SoC-400/SoC-600 专栏导读】 文章目录 APBIC arbiter仲裁使用举例APBIC arbiter 在 SoC-600中,APBIC 是用来为 APB4 master 和 APB4 slave 提供 连接关系的组件。APB 是一种简单的总线协议,通常用于低带宽或低性能外设,如定时器、接口控制等。APBIC …

Unity3d Shader篇(三)— 片元半兰伯特着色器解析

文章目录 前言一、片元半兰伯特着色器是什么&#xff1f;1. 片元漫反射着色器的工作原理2. 片元半兰伯特着色器的优缺点优点&#xff1a;缺点&#xff1a; 3. 公式 二、使用步骤1. Shader 属性定义2. SubShader 设置3. 渲染 Pass4. 定义结构体和顶点着色器函数5. 片元着色器函数…

java中String类常用API

前言&#xff1a;在学习java的String类的时候&#xff0c;有很多的API需要了解&#xff0c;下面我将举出其中在新手学习时使用频率较大的几个API。 先大体看一下有哪几个&#xff1a;&#xff08;如图&#xff09; 目录 1.equals&#xff08;&#xff09;和 equalsIgnoreCase&…

Linux环境下配置mysql主从复制

主从配置需要注意的地方 1、主DB server和从DB server数据库的版本一致 2、主DB server和从DB server数据库数据一致[这里就会可以把主的备份在从上还原&#xff0c;也可以直接将主的数据目录拷贝到从的相应数据目录] 3、主DB server开启二进制日志,主DB server和从DB serve…

「递归算法」:全排列

一、题目 给定一个不含重复数字的数组 nums &#xff0c;返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1&#xff1a; 输入&#xff1a;nums [1,2,3] 输出&#xff1a;[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2&#xff1a; 输入&#xf…

【前端web入门第四天】02 CSS三大特性+背景图

文章目录: 1. CSS三大特性 1.1继承性 1.2 层叠性 1.3 优先级 1.3.1 优先级1.3.2 优先级-叠加计算规则 2. 背景图 2.1 背景属性2.2 背景图2.3 背景图的平铺方式2.4 背景图位置2.5 背景图缩放2.6 背景图固定2.7 背景复合属性 1. CSS三大特性 1.1继承性 什么是继承性? 子级默…

进阶C语言-通讯录的实现

通讯录 🎈1.设计要求🎈2.程序实现🔭2.1打印菜单及初始化通讯录🔭2.2显示所有联系人🔭2.3查找指定的联系人🔭2.4删除指定的联系人🔭2.5查找指定的联系人🔭2.6修改指定联系人🔭2.7按照年龄排序(以此为例)🎈3.全部源码以及实现🎈1.设计要求 🌞通过前面…