1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#ifdef ADRENO_GPU
22REQD_SUBGROUP_SIZE_64
23#endif
24kernel void kernel_soft_max_4(
25 global char * src0,
26 ulong offset0,
27 global char * src1,
28 ulong offset1,
29 global char * src2,
30 ulong offset2,
31 global char * dst,
32 ulong offsetd,
33 int ne00,
34 ulong nb01,
35 ulong nb02,
36 ulong nb03,
37 int ne12,
38 int ne13,
39 ulong nb11,
40 ulong nb12,
41 ulong nb13,
42 ulong nb1,
43 ulong nb2,
44 ulong nb3,
45 float scale,
46 float max_bias,
47 float m0,
48 float m1,
49 int n_head_log2
50) {
51 src0 = src0 + offset0;
52 src1 = src1 + offset1;
53 src2 = src2 + offset2;
54 dst = dst + offsetd;
55
56 int i03 = get_group_id(2);
57 int i02 = get_group_id(1);
58 int i01 = get_group_id(0);
59
60 int i13 = i03%ne13;
61 int i12 = i02%ne12;
62 int i11 = i01;
63
64 global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
65 global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66 global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
67 global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
68
69 float slope = 1.0f;
70
71 // ALiBi
72 if (max_bias > 0.0f) {
73 int h = i02;
74
75 float base = h < n_head_log2 ? m0 : m1;
76 int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
77
78 slope = pow(base, exp);
79 }
80
81 // parallel max
82 float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
83 for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
84 lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
85 }
86 float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
87
88 const float max = sub_group_reduce_max(lmax);
89
90 // parallel sum
91 float4 lsum4 = 0.0f;
92 for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
93 const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
94 lsum4 += exp_psrc4;
95 pdst4[i00] = exp_psrc4;
96 }
97 float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
98
99 float sum = sub_group_reduce_add(lsum);
100
101 if (psrc2) {
102 sum += exp(psrc2[i02] - max);
103 }
104
105 for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
106 pdst4[i00] /= sum;
107 }
108}