1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : require
  5#if USE_SUBGROUP_ADD
  6#extension GL_KHR_shader_subgroup_arithmetic : enable
  7#endif
  8
  9#define FLOAT_TYPE float
 10
 11layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 12
 13#include "mul_mat_vec_iface.glsl"
 14
 15layout(constant_id = 0) const int BLOCK_SIZE = 32;
 16// gqa_ratio is in the range [1,8]
 17layout(constant_id = 1) const uint gqa_ratio = 1;
 18
 19layout (push_constant) uniform parameter
 20{
 21    uint ncols_x;
 22    uint nrows_x;
 23    uint nchannels_x;
 24    uint nchannels_y;
 25    uint b_offset;
 26    uint d_offset;
 27    uint fusion_flags;
 28} p;
 29
 30#if !USE_SUBGROUP_ADD
 31shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
 32#endif
 33
 34void main() {
 35    const uint tid = gl_LocalInvocationID.x;
 36    const uint row_x = gl_GlobalInvocationID.y;
 37
 38    uint channel, channel_x;
 39
 40    // When gqa_ratio > 1, each invocation does multiple rows.
 41    // The row in the A matrix is starting from channel / gqa_ratio and the
 42    // rows in the B matrix are [channel, channel+gqa_ratio).
 43    // When gpa_ratio is 1, each invocation does one row.
 44    if (gqa_ratio > 1) {
 45        channel_x = gl_GlobalInvocationID.z;
 46        channel = channel_x * gqa_ratio;
 47    } else {
 48        channel = gl_GlobalInvocationID.z;
 49        channel_x = channel / (p.nchannels_y / p.nchannels_x);;
 50    }
 51
 52    const uint nrows_y = p.ncols_x;
 53    const uint nrows_dst = p.nrows_x;
 54    const uint row_dst = row_x;
 55
 56    FLOAT_TYPE temp[8];
 57    [[unroll]] for (uint i = 0; i < 8; ++i) {
 58        temp[i] = FLOAT_TYPE(0.0f);
 59    }
 60
 61    // Detect alignment for vector loads
 62    bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
 63
 64    for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
 65
 66        // Use vec4 loads if aligned
 67        if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
 68
 69            uint col_x = col_x0 + 4*tid;
 70            const uint row_y = col_x;
 71
 72            // x is transposed and permuted
 73            const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
 74            const vec4 av4 = vec4(data_a_v4[ix / 4]);
 75
 76            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
 77                // y is not transposed but permuted
 78                const uint iy = (channel + c)*nrows_y + row_y;
 79
 80                vec4 bv4 = data_b_v4[iy / 4];
 81                temp[c] += dot(av4, bv4);
 82            }
 83
 84            col_x0 += 3*BLOCK_SIZE;
 85        } else {
 86            const uint col_x = col_x0 + tid;
 87
 88            if (col_x >= p.ncols_x) {
 89                break;
 90            }
 91
 92            // x is transposed and permuted
 93            const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
 94            const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
 95
 96            const uint row_y = col_x;
 97
 98            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
 99                // y is not transposed but permuted
100                const uint iy = (channel + c)*nrows_y + row_y;
101
102                temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
103            }
104        }
105    }
106
107#if USE_SUBGROUP_ADD
108    // reduce vec4 at a time
109    vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
110    t = subgroupAdd(t);
111    temp[0] = t[0];
112    temp[1] = t[1];
113    temp[2] = t[2];
114    temp[3] = t[3];
115    if (gqa_ratio > 4) {
116        t = vec4(temp[4], temp[5], temp[6], temp[7]);
117        t = subgroupAdd(t);
118        temp[4] = t[0];
119        temp[5] = t[1];
120        temp[6] = t[2];
121        temp[7] = t[3];
122    }
123#else
124    [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
125        tmp[c][tid] = temp[c];
126    }
127    // sum up partial sums and write back result
128    barrier();
129    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
130        if (tid < s) {
131            [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
132                temp[c] += tmp[c][tid + s];
133                tmp[c][tid] = temp[c];
134            }
135        }
136        barrier();
137    }
138    [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
139        temp[c] = tmp[c][tid];
140    }
141#endif
142
143    if (tid == 0) {
144        [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
145            // dst is not transposed and not permuted
146            const uint idst = (channel + c)*nrows_dst + row_dst;
147            if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
148                temp[c] += FLOAT_TYPE(data_fuse0[idst]);
149            }
150            if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
151                temp[c] += FLOAT_TYPE(data_fuse1[idst]);
152            }
153            data_d[idst] = temp[c];
154        }
155    }
156}