1#version 450
 2
 3#include "types.glsl"
 4
 5layout (binding = 0) readonly buffer A {A_TYPE data_a[];};   // src0 - kernel:    [K, Cout, Cin]
 6layout (binding = 1) readonly buffer B {B_TYPE data_b[];};   // src1 - input:     [L, Cin]
 7layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};     // dst - result      [KL, Cout]
 8
 9layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
10
11layout (push_constant) uniform parameter {
12    uint32_t Cout;
13    uint32_t Cin;
14    uint32_t K;
15    uint32_t L;
16    uint32_t KL;
17
18    uint32_t nb01;
19    uint32_t nb02;
20    uint32_t nb11;
21    uint32_t nb1;
22
23    int32_t s0;
24} p;
25
26
27uint32_t Cout_idx = gl_WorkGroupID.x;
28const uint32_t bs = gl_WorkGroupSize.x;
29uint32_t tid = gl_LocalInvocationID.x;
30// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
31uint32_t tmp_len = bs*p.s0+p.K;
32shared D_TYPE tmp[4096];
33
34uint splitWork(uint workSize){
35    return (bs + workSize -1) / bs;
36}
37
38void main(){
39    for(uint32_t i = 0; i < splitWork(tmp_len); i++){
40        uint32_t idx = i*bs+tid;
41        if(idx < tmp_len){
42            tmp[idx] = 0.0;
43        }
44    }
45
46    uint32_t L_blocks = splitWork(p.L);
47    for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
48        if(L_block_id > 0){
49            barrier();
50            // Shift values in tmp to the current processing window
51            for(int i = 0; i < splitWork(tmp_len); i++){
52                uint32_t idx = i*bs+tid;
53                if(idx >= bs*p.s0 && idx < tmp_len){
54                    tmp[idx-bs*p.s0] = tmp[idx];
55                    tmp[idx] = 0.0;
56                }else if(idx >= p.K && idx < bs*p.s0){
57                    tmp[idx] = 0.0;
58                }
59            }
60        }
61        barrier();
62
63        // Save contributions of the block to tmp
64        uint32_t L_idx = L_block_id*bs + tid;
65        for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
66            D_TYPE dp = 0.0;
67            for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
68                A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
69                if(L_idx < p.L){
70                    B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
71                    dp = fma(elemKrn, elemInp, dp);
72                }
73            }
74            tmp[tid*p.s0 + K_idx] += dp;
75            barrier();
76        }
77
78        // Save the computed values except the last block that can have different size
79        uint32_t KLb_idx = L_block_id*bs*p.s0;
80        if(L_block_id < L_blocks-1){
81            for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
82                uint32_t sh_idx = p.s0*tid+s0_idx;
83                uint32_t KL_idx = KLb_idx+sh_idx;
84                if(KL_idx < p.KL){
85                    data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
86                }
87            }
88        }
89    }
90
91    for(uint32_t i = 0; i < splitWork(tmp_len); i++){
92        uint32_t idx = i*bs+tid;
93        uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
94        if(KL_idx < p.KL){
95            data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
96        }
97    }
98}