1#version 450
  2
  3#extension GL_EXT_shader_16bit_storage : require
  4#extension GL_EXT_nonuniform_qualifier : enable
  5#extension GL_EXT_control_flow_attributes : require
  6#if ADD_RMS
  7#extension GL_KHR_shader_subgroup_arithmetic : enable
  8#extension GL_KHR_shader_subgroup_basic : enable
  9#endif
 10
 11#include "rte.glsl"
 12#include "types.glsl"
 13#include "utils.glsl"
 14
 15layout (push_constant) uniform parameter2
 16{
 17    // shape for dst
 18    uint ne20; uint ne21; uint ne22; uint ne23;
 19
 20    // strides for srcs+dst
 21    uint nb[12][4];
 22
 23    uint rms_partials;
 24} p;
 25
 26// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
 27layout (binding = 0)  buffer A0 {A_TYPE data_a[];} a0;
 28layout (binding = 1)  buffer A1 {A_TYPE data_a[];} a1;
 29layout (binding = 2)  buffer A2 {A_TYPE data_a[];} a2;
 30layout (binding = 3)  buffer A3 {A_TYPE data_a[];} a3;
 31layout (binding = 4)  buffer A4 {A_TYPE data_a[];} a4;
 32layout (binding = 5)  buffer A5 {A_TYPE data_a[];} a5;
 33layout (binding = 6)  buffer A6 {A_TYPE data_a[];} a6;
 34layout (binding = 7)  buffer A7 {A_TYPE data_a[];} a7;
 35layout (binding = 8)  buffer A8 {A_TYPE data_a[];} a8;
 36layout (binding = 9)  buffer A9 {A_TYPE data_a[];} a9;
 37layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
 38layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
 39layout (binding = 0)  buffer D0 {D_TYPE data_d[];} d0;
 40layout (binding = 1)  buffer D1 {D_TYPE data_d[];} d1;
 41layout (binding = 2)  buffer D2 {D_TYPE data_d[];} d2;
 42layout (binding = 3)  buffer D3 {D_TYPE data_d[];} d3;
 43layout (binding = 4)  buffer D4 {D_TYPE data_d[];} d4;
 44layout (binding = 5)  buffer D5 {D_TYPE data_d[];} d5;
 45layout (binding = 6)  buffer D6 {D_TYPE data_d[];} d6;
 46layout (binding = 7)  buffer D7 {D_TYPE data_d[];} d7;
 47layout (binding = 8)  buffer D8 {D_TYPE data_d[];} d8;
 48layout (binding = 9)  buffer D9 {D_TYPE data_d[];} d9;
 49layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
 50layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
 51layout (binding = 0, std430)  buffer PartialBuf0 {float partial_sums[];} partials0;
 52layout (binding = 1, std430)  buffer PartialBuf1 {float partial_sums[];} partials1;
 53layout (binding = 2, std430)  buffer PartialBuf2 {float partial_sums[];} partials2;
 54layout (binding = 3, std430)  buffer PartialBuf3 {float partial_sums[];} partials3;
 55layout (binding = 4, std430)  buffer PartialBuf4 {float partial_sums[];} partials4;
 56layout (binding = 5, std430)  buffer PartialBuf5 {float partial_sums[];} partials5;
 57layout (binding = 6, std430)  buffer PartialBuf6 {float partial_sums[];} partials6;
 58layout (binding = 7, std430)  buffer PartialBuf7 {float partial_sums[];} partials7;
 59layout (binding = 8, std430)  buffer PartialBuf8 {float partial_sums[];} partials8;
 60layout (binding = 9, std430)  buffer PartialBuf9 {float partial_sums[];} partials9;
 61layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
 62layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;
 63
 64layout(constant_id = 0) const uint num_srcs = 2;
 65
 66FLOAT_TYPE load_a(uint b, uint i) {
 67    switch (b) {
 68    case 0:  return FLOAT_TYPE(a0.data_a[i]);
 69    case 1:  return FLOAT_TYPE(a1.data_a[i]);
 70    case 2:  return FLOAT_TYPE(a2.data_a[i]);
 71    case 3:  return FLOAT_TYPE(a3.data_a[i]);
 72    case 4:  return FLOAT_TYPE(a4.data_a[i]);
 73    case 5:  return FLOAT_TYPE(a5.data_a[i]);
 74    case 6:  return FLOAT_TYPE(a6.data_a[i]);
 75    case 7:  return FLOAT_TYPE(a7.data_a[i]);
 76    case 8:  return FLOAT_TYPE(a8.data_a[i]);
 77    case 9:  return FLOAT_TYPE(a9.data_a[i]);
 78    case 10: return FLOAT_TYPE(a10.data_a[i]);
 79    case 11: return FLOAT_TYPE(a11.data_a[i]);
 80    default: return FLOAT_TYPE(0);
 81    }
 82}
 83
 84void store_d(uint b, uint i, FLOAT_TYPE v) {
 85    switch (b) {
 86    case 0:  d0.data_d[i] = D_TYPE(v); break;
 87    case 1:  d1.data_d[i] = D_TYPE(v); break;
 88    case 2:  d2.data_d[i] = D_TYPE(v); break;
 89    case 3:  d3.data_d[i] = D_TYPE(v); break;
 90    case 4:  d4.data_d[i] = D_TYPE(v); break;
 91    case 5:  d5.data_d[i] = D_TYPE(v); break;
 92    case 6:  d6.data_d[i] = D_TYPE(v); break;
 93    case 7:  d7.data_d[i] = D_TYPE(v); break;
 94    case 8:  d8.data_d[i] = D_TYPE(v); break;
 95    case 9:  d9.data_d[i] = D_TYPE(v); break;
 96    case 10: d10.data_d[i] = D_TYPE(v); break;
 97    case 11: d11.data_d[i] = D_TYPE(v); break;
 98    default: break;
 99    }
100}
101
102void store_partial(uint b, uint i, float v) {
103    switch (b) {
104    case 0:  partials0.partial_sums[i] = v; break;
105    case 1:  partials1.partial_sums[i] = v; break;
106    case 2:  partials2.partial_sums[i] = v; break;
107    case 3:  partials3.partial_sums[i] = v; break;
108    case 4:  partials4.partial_sums[i] = v; break;
109    case 5:  partials5.partial_sums[i] = v; break;
110    case 6:  partials6.partial_sums[i] = v; break;
111    case 7:  partials7.partial_sums[i] = v; break;
112    case 8:  partials8.partial_sums[i] = v; break;
113    case 9:  partials9.partial_sums[i] = v; break;
114    case 10: partials10.partial_sums[i] = v; break;
115    case 11: partials11.partial_sums[i] = v; break;
116    default: break;
117    }
118}
119
120uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
121    return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
122}
123
124uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
125    uint nb20 = p.nb[num_srcs][0];
126    uint nb21 = p.nb[num_srcs][1];
127    uint nb22 = p.nb[num_srcs][2];
128    uint nb23 = p.nb[num_srcs][3];
129    return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
130}
131
132uint get_idx() {
133    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
134}
135
136const uint num_threads = 256;
137
138layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
139
140#if ADD_RMS
141// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
142shared FLOAT_TYPE sumsh[num_threads];
143#endif
144
145void main() {
146    uint idx = get_idx();
147    uint orig_idx = idx;
148
149    uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
150
151    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
152    const uint num_iter = 2;
153
154    FLOAT_TYPE sum_sq = 0;
155
156    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
157        if (idx >= ne) {
158            continue;
159        }
160        uint i00, i01, i02, i03;
161        get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);
162
163        FLOAT_TYPE sum = FLOAT_TYPE(0);
164        [[unroll]] for (uint s = 0; s < num_srcs; ++s) {
165            sum += load_a(s, src_idx(s, i00, i01, i02, i03));
166        }
167        sum_sq += sum*sum;
168        store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);
169
170        idx += num_threads;
171    }
172
173#if ADD_RMS
174    if (p.rms_partials != 0) {
175        // reduce the sum within each subgroup, then across subgroups
176        const uint NumSubgroups = num_threads / gl_SubgroupSize;
177        sum_sq = subgroupAdd(sum_sq);
178        if (gl_SubgroupInvocationID == 0) {
179            sumsh[gl_SubgroupID] = sum_sq;
180        }
181        barrier();
182        [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
183            if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
184                sum_sq += sumsh[gl_SubgroupID + s];
185                sumsh[gl_SubgroupID] = sum_sq;
186            }
187            barrier();
188        }
189
190        if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
191            store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);
192        }
193    }
194#endif
195}