1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21#ifdef ADRENO_GPU
 22REQD_SUBGROUP_SIZE_64
 23#endif
 24kernel void kernel_soft_max_f16(
 25        global char * src0,
 26        ulong offset0,
 27        global char * src1,
 28        ulong offset1,
 29        global char * src2,
 30        ulong offset2,
 31        global char * dst,
 32        ulong offsetd,
 33        int ne00,
 34        ulong nb01,
 35        ulong nb02,
 36        ulong nb03,
 37        int ne12,
 38        int ne13,
 39        ulong nb11,
 40        ulong nb12,
 41        ulong nb13,
 42        ulong nb1,
 43        ulong nb2,
 44        ulong nb3,
 45        float scale,
 46        float max_bias,
 47        float m0,
 48        float m1,
 49        int n_head_log2
 50) {
 51    src0 = src0 + offset0;
 52    src1 = src1 + offset1;
 53    src2 = src2 + offset2;
 54    dst  = dst  + offsetd;
 55
 56    int i03 = get_group_id(2);
 57    int i02 = get_group_id(1);
 58    int i01 = get_group_id(0);
 59
 60    int i13 = i03%ne13;
 61    int i12 = i02%ne12;
 62    int i11 = i01;
 63
 64    global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
 65    global half  * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
 66    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
 67    global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 68
 69    float slope = 1.0f;
 70
 71    // ALiBi
 72    if (max_bias > 0.0f) {
 73        int h = i02;
 74
 75        float base = h < n_head_log2 ? m0 : m1;
 76        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 77
 78        slope = pow(base, exp);
 79    }
 80
 81    // parallel max
 82    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
 83    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
 84        lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
 85    }
 86    float max = sub_group_reduce_max(lmax);
 87
 88    // parallel sum
 89    float lsum = 0.0f;
 90    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
 91        float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
 92        lsum += exp_psrc0;
 93        // Remember the result of exp here. exp is expensive, so we really do not
 94        // wish to compute it twice.
 95        pdst[i00] = exp_psrc0;
 96    }
 97
 98    float sum = sub_group_reduce_add(lsum);
 99
100    if (psrc2) {
101        sum += exp(psrc2[i02] - max);
102    }
103
104    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
105        pdst[i00] /= sum;
106    }
107}