1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : require
  5
  6#define BLOCK_SIZE 32
  7#define FLOAT_TYPE float
  8
  9layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 10
 11#include "mul_mat_vec_iface.glsl"
 12
 13layout (push_constant) uniform parameter
 14{
 15    uint ncols_x;
 16    uint nrows_x;
 17    uint row_stride_x;
 18    uint channel_stride_x;
 19    uint channel_stride_y;
 20    uint channel_x_divisor;
 21    uint ne12;
 22    uint b_offset;
 23    uint d_offset;
 24    uint nb03;
 25    uint nb13;
 26    uint nb23;
 27    uint fusion_flags;
 28} p;
 29
 30shared FLOAT_TYPE tmp[BLOCK_SIZE];
 31
 32void main() {
 33    const uint tid       = gl_LocalInvocationID.x;
 34    const uint row_x     = gl_GlobalInvocationID.y;
 35    const uint channel   = gl_GlobalInvocationID.z;
 36    const uint i3        = gl_WorkGroupID.x;
 37    const uint channel_x = channel / p.channel_x_divisor;
 38    const uint channel_y = channel % p.ne12;
 39
 40    const uint nrows_y   = p.ncols_x;
 41    const uint nrows_dst = p.nrows_x;
 42    const uint row_dst   = row_x;
 43
 44    const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;
 45
 46    FLOAT_TYPE temp = 0.0f;
 47
 48    // Detect alignment for vector loads
 49    bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
 50
 51    for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
 52
 53        // Unroll 2x and do vec4 loads if aligned
 54        const uint unroll_count = 2;
 55        if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
 56            [[unroll]] for (uint i = 0; i < unroll_count; ++i) {
 57                const uint col_x = col_x0 + 4*tid;
 58
 59                const uint row_y = col_x;
 60
 61                const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
 62                const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
 63
 64                const vec4 av4 = vec4(data_a_v4[ix / 4]);
 65                const vec4 bv4 = vec4(data_b_v4[iy / 4]);
 66
 67                temp += dot(av4, bv4);
 68
 69                col_x0 += 4*BLOCK_SIZE;
 70            }
 71        // do vec4 loads if aligned
 72        } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
 73            const uint col_x = col_x0 + 4*tid;
 74
 75            const uint row_y = col_x;
 76
 77            const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
 78            const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
 79
 80            const vec4 av4 = vec4(data_a_v4[ix / 4]);
 81            const vec4 bv4 = vec4(data_b_v4[iy / 4]);
 82
 83            temp += dot(av4, bv4);
 84
 85            col_x0 += 4*BLOCK_SIZE;
 86        } else {
 87            const uint col_x = col_x0 + tid;
 88            if (col_x >= p.ncols_x) {
 89                break;
 90            }
 91
 92            const uint row_y = col_x;
 93
 94            const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
 95            const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
 96
 97            const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
 98
 99            temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
100            col_x0 += BLOCK_SIZE;
101        }
102    }
103
104    tmp[tid] = temp;
105
106    // sum up partial sums and write back result
107    barrier();
108    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
109        if (tid < s) {
110            tmp[tid] += tmp[tid + s];
111        }
112        barrier();
113    }
114
115    if (tid == 0) {
116        if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
117            tmp[0] += FLOAT_TYPE(data_fuse0[idst]);
118        }
119        if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
120            tmp[0] += FLOAT_TYPE(data_fuse1[idst]);
121        }
122        data_d[idst] = tmp[0];
123    }
124}