1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#ifdef cl_intel_subgroups
  4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
  5#else
  6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  7#endif
  8
  9#ifdef cl_intel_required_subgroup_size
 10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
 11#define INTEL_GPU 1
 12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
 13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
 14#elif defined(cl_qcom_reqd_sub_group_size)
 15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 16#define ADRENO_GPU 1
 17#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
 18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 19#endif
 20
 21#ifdef ADRENO_GPU
 22REQD_SUBGROUP_SIZE_64
 23#endif
 24kernel void kernel_soft_max_4(
 25        global char * src0,
 26        ulong offset0,
 27        global char * src1,
 28        ulong offset1,
 29        global char * src2,
 30        ulong offset2,
 31        global char * dst,
 32        ulong offsetd,
 33        int ne00,
 34        ulong nb01,
 35        ulong nb02,
 36        ulong nb03,
 37        int ne12,
 38        int ne13,
 39        ulong nb11,
 40        ulong nb12,
 41        ulong nb13,
 42        ulong nb1,
 43        ulong nb2,
 44        ulong nb3,
 45        float scale,
 46        float max_bias,
 47        float m0,
 48        float m1,
 49        int n_head_log2
 50) {
 51    src0 = src0 + offset0;
 52    src1 = src1 + offset1;
 53    src2 = src2 + offset2;
 54    dst  = dst  + offsetd;
 55
 56    int i03 = get_group_id(2);
 57    int i02 = get_group_id(1);
 58    int i01 = get_group_id(0);
 59
 60    int i13 = i03%ne13;
 61    int i12 = i02%ne12;
 62    int i11 = i01;
 63
 64    global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
 65    global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
 66    global float  * psrc2 = src2 != src0 ? (global float  *)(src2) : 0;
 67    global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 68
 69    float slope = 1.0f;
 70
 71    // ALiBi
 72    if (max_bias > 0.0f) {
 73        int h = i02;
 74
 75        float base = h < n_head_log2 ? m0 : m1;
 76        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 77
 78        slope = pow(base, exp);
 79    }
 80
 81    // parallel max
 82    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
 83    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
 84        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
 85    }
 86    float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
 87
 88    const float max = sub_group_reduce_max(lmax);
 89
 90    // parallel sum
 91    float4 lsum4 = 0.0f;
 92    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
 93        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
 94        lsum4 += exp_psrc4;
 95        pdst4[i00] = exp_psrc4;
 96    }
 97    float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
 98
 99    float sum = sub_group_reduce_add(lsum);
100
101    if (psrc2) {
102        sum += exp(psrc2[i02] - max);
103    }
104
105    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
106        pdst4[i00] /= sum;
107    }
108}