1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#ifdef ADRENO_GPU
22REQD_SUBGROUP_SIZE_64
23#endif
24kernel void kernel_soft_max_f16(
25 global char * src0,
26 ulong offset0,
27 global char * src1,
28 ulong offset1,
29 global char * src2,
30 ulong offset2,
31 global char * dst,
32 ulong offsetd,
33 int ne00,
34 ulong nb01,
35 ulong nb02,
36 ulong nb03,
37 int ne12,
38 int ne13,
39 ulong nb11,
40 ulong nb12,
41 ulong nb13,
42 ulong nb1,
43 ulong nb2,
44 ulong nb3,
45 float scale,
46 float max_bias,
47 float m0,
48 float m1,
49 int n_head_log2
50) {
51 src0 = src0 + offset0;
52 src1 = src1 + offset1;
53 src2 = src2 + offset2;
54 dst = dst + offsetd;
55
56 int i03 = get_group_id(2);
57 int i02 = get_group_id(1);
58 int i01 = get_group_id(0);
59
60 int i13 = i03%ne13;
61 int i12 = i02%ne12;
62 int i11 = i01;
63
64 global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65 global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66 global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
67 global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
68
69 float slope = 1.0f;
70
71 // ALiBi
72 if (max_bias > 0.0f) {
73 int h = i02;
74
75 float base = h < n_head_log2 ? m0 : m1;
76 int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
77
78 slope = pow(base, exp);
79 }
80
81 // parallel max
82 float lmax = psrc2 ? psrc2[i02] : -INFINITY;
83 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
84 lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
85 }
86 float max = sub_group_reduce_max(lmax);
87
88 // parallel sum
89 float lsum = 0.0f;
90 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
91 float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
92 lsum += exp_psrc0;
93 // Remember the result of exp here. exp is expensive, so we really do not
94 // wish to compute it twice.
95 pdst[i00] = exp_psrc0;
96 }
97
98 float sum = sub_group_reduce_add(lsum);
99
100 if (psrc2) {
101 sum += exp(psrc2[i02] - max);
102 }
103
104 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
105 pdst[i00] /= sum;
106 }
107}