1#version 450
  2
  3#include "generic_binary_head.glsl"
  4#include "types.glsl"
  5
  6#if RMS_NORM_ROPE_FUSION
  7
  8layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  9layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 10
 11// data is passed from rms_norm -> rope through shared memory.
 12// rms_norm calls this data_d, rope calls this rope_data_a.
 13// Binding 2 is not used
 14shared FLOAT_TYPE rope_data_a[1024];
 15#define data_d rope_data_a
 16
 17layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};
 18layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};
 19layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};
 20layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows
 21
 22#include "rope_params.glsl"
 23#include "rope_funcs.glsl"
 24
 25#define GGML_ROPE_TYPE_NORMAL 0
 26#define GGML_ROPE_TYPE_NEOX   2
 27#define GGML_ROPE_TYPE_MROPE  8
 28#define GGML_ROPE_TYPE_VISION 24
 29
 30#endif
 31
 32#extension GL_EXT_control_flow_attributes : enable
 33#define BLOCK_SIZE 512
 34
 35layout (constant_id = 1) const bool do_multiply = false;
 36
 37layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 38
 39shared FLOAT_TYPE sumsh[BLOCK_SIZE];
 40
 41void rms_norm(uint num_iters) {
 42    const uint ncols     = p.ne00;
 43    const uint nrows     = gl_NumWorkGroups.x;
 44    const uint nchannels = gl_NumWorkGroups.y;
 45
 46    const uint row       = gl_WorkGroupID.x;
 47    const uint channel   = gl_WorkGroupID.y;
 48    const uint samp      = gl_WorkGroupID.z;
 49    const uint tid       = gl_LocalInvocationID.x;
 50
 51    const uint stride_row       = p.nb01;
 52    const uint stride_channel   = p.nb02;
 53    const uint stride_sample    = p.nb03;
 54
 55    uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
 56    uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
 57#if RMS_NORM_ROPE_FUSION
 58    // Per-row offset in shared memory
 59    uint32_t d_offset = 0;
 60#else
 61    uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
 62#endif
 63    FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
 64
 65    [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
 66        FLOAT_TYPE xi = FLOAT_TYPE(0);
 67        if (col < ncols) {
 68            xi = FLOAT_TYPE(data_a[a_offset + col]);
 69        }
 70        sum += xi * xi;
 71    }
 72
 73    sumsh[tid] = sum;
 74    // sum up partial sums and write back result
 75    barrier();
 76    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
 77        if (tid < s) {
 78            sum += sumsh[tid + s];
 79            sumsh[tid] = sum;
 80        }
 81        barrier();
 82    }
 83    sum = sumsh[0];
 84
 85    const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
 86    const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
 87
 88    if (do_multiply) {
 89        if (ncols > p.ne10) {
 90            [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
 91                if (col >= ncols) {
 92                    continue;
 93                }
 94                data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
 95            }
 96        } else {
 97            [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
 98                if (col >= ncols) {
 99                    continue;
100                }
101                data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
102            }
103        }
104    } else {
105        [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
106            if (col >= ncols) {
107                continue;
108            }
109            data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
110        }
111    }
112#if RMS_NORM_ROPE_FUSION
113    barrier();
114    rope_params rp = p.rope;
115    for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
116        if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
117            rope_neox(t, row, channel, samp, rp);
118        } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
119            rope_norm(t, row, channel, samp, rp);
120        }
121    }
122#endif
123}
124
125void main() {
126    // instantiate the rms_norm function for several different
127    // dimensions, to allow loop unrolling
128    uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
129    if (num_blocks > 32) {
130        rms_norm(num_blocks);
131    } else if (num_blocks > 16) {
132        rms_norm(32);
133    } else if (num_blocks > 12) {
134        rms_norm(16);
135    } else if (num_blocks > 10) {
136        rms_norm(12);
137    } else if (num_blocks > 8) {
138        rms_norm(10);
139    } else if (num_blocks > 4) {
140        rms_norm(8);
141    } else if (num_blocks == 4) {
142        rms_norm(4);
143    } else if (num_blocks == 3) {
144        rms_norm(3);
145    } else if (num_blocks == 2) {
146        rms_norm(2);
147    } else if (num_blocks == 1) {
148        rms_norm(1);
149    }
150}