1#version 450
2
3#include "generic_binary_head.glsl"
4#include "types.glsl"
5
6#if RMS_NORM_ROPE_FUSION
7
8layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
9layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
10
11// data is passed from rms_norm -> rope through shared memory.
12// rms_norm calls this data_d, rope calls this rope_data_a.
13// Binding 2 is not used
14shared FLOAT_TYPE rope_data_a[1024];
15#define data_d rope_data_a
16
17layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};
18layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};
19layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};
20layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows
21
22#include "rope_params.glsl"
23#include "rope_funcs.glsl"
24
25#define GGML_ROPE_TYPE_NORMAL 0
26#define GGML_ROPE_TYPE_NEOX 2
27#define GGML_ROPE_TYPE_MROPE 8
28#define GGML_ROPE_TYPE_VISION 24
29
30#endif
31
32#extension GL_EXT_control_flow_attributes : enable
33#define BLOCK_SIZE 512
34
35layout (constant_id = 1) const bool do_multiply = false;
36
37layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
38
39shared FLOAT_TYPE sumsh[BLOCK_SIZE];
40
41void rms_norm(uint num_iters) {
42 const uint ncols = p.ne00;
43 const uint nrows = gl_NumWorkGroups.x;
44 const uint nchannels = gl_NumWorkGroups.y;
45
46 const uint row = gl_WorkGroupID.x;
47 const uint channel = gl_WorkGroupID.y;
48 const uint samp = gl_WorkGroupID.z;
49 const uint tid = gl_LocalInvocationID.x;
50
51 const uint stride_row = p.nb01;
52 const uint stride_channel = p.nb02;
53 const uint stride_sample = p.nb03;
54
55 uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
56 uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
57#if RMS_NORM_ROPE_FUSION
58 // Per-row offset in shared memory
59 uint32_t d_offset = 0;
60#else
61 uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
62#endif
63 FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
64
65 [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
66 FLOAT_TYPE xi = FLOAT_TYPE(0);
67 if (col < ncols) {
68 xi = FLOAT_TYPE(data_a[a_offset + col]);
69 }
70 sum += xi * xi;
71 }
72
73 sumsh[tid] = sum;
74 // sum up partial sums and write back result
75 barrier();
76 [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
77 if (tid < s) {
78 sum += sumsh[tid + s];
79 sumsh[tid] = sum;
80 }
81 barrier();
82 }
83 sum = sumsh[0];
84
85 const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
86 const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
87
88 if (do_multiply) {
89 if (ncols > p.ne10) {
90 [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
91 if (col >= ncols) {
92 continue;
93 }
94 data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
95 }
96 } else {
97 [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
98 if (col >= ncols) {
99 continue;
100 }
101 data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
102 }
103 }
104 } else {
105 [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
106 if (col >= ncols) {
107 continue;
108 }
109 data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
110 }
111 }
112#if RMS_NORM_ROPE_FUSION
113 barrier();
114 rope_params rp = p.rope;
115 for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
116 if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
117 rope_neox(t, row, channel, samp, rp);
118 } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
119 rope_norm(t, row, channel, samp, rp);
120 }
121 }
122#endif
123}
124
125void main() {
126 // instantiate the rms_norm function for several different
127 // dimensions, to allow loop unrolling
128 uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
129 if (num_blocks > 32) {
130 rms_norm(num_blocks);
131 } else if (num_blocks > 16) {
132 rms_norm(32);
133 } else if (num_blocks > 12) {
134 rms_norm(16);
135 } else if (num_blocks > 10) {
136 rms_norm(12);
137 } else if (num_blocks > 8) {
138 rms_norm(10);
139 } else if (num_blocks > 4) {
140 rms_norm(8);
141 } else if (num_blocks == 4) {
142 rms_norm(4);
143 } else if (num_blocks == 3) {
144 rms_norm(3);
145 } else if (num_blocks == 2) {
146 rms_norm(2);
147 } else if (num_blocks == 1) {
148 rms_norm(1);
149 }
150}