1#version 450
2
3#include "types.glsl"
4
5layout (push_constant) uniform parameter
6{
7 uint ne;
8 uint batches;
9 uint channels;
10 uint dst_w;
11 uint dst_h;
12 uint src_w;
13 uint src_h;
14 uint knl_w;
15 uint knl_h;
16 int stride_x;
17 int stride_y;
18 int pad_x;
19 int pad_y;
20 int dilation_x;
21 int dilation_y;
22} p;
23
24layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
25layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
26layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
27
28layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
29
30FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
31 uint i0 = idx / p.dst_w;
32 uint dst_x = idx - i0 * p.dst_w;
33 uint i1 = i0 / p.dst_h;
34 uint dst_y = i0 - i1 * p.dst_h;
35 uint n = i1 / p.channels;
36 uint c = i1 - n * p.channels;
37
38 uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
39 uint knl_i = c * p.knl_h * p.knl_w;
40
41 FLOAT_TYPE sum = 0.0;
42 for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
43 uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
44 if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
45 continue;
46 }
47 for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
48 uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
49 if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
50 continue;
51 }
52 FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
53 FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
54 sum = fma(v, k, sum);
55 }
56 }
57 return sum;
58}
59
60FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
61 uint i0 = idx / p.channels;
62 uint c = idx - i0 * p.channels;
63 uint i1 = i0 / p.dst_w;
64 uint dst_x = i0 - i1 * p.dst_w;
65 uint n = i1 / p.dst_h;
66 uint dst_y = i1 - n * p.dst_h;
67
68 uint src_i = n * p.channels * p.src_h * p.src_w;
69 uint src_row = p.src_w * p.channels;
70 uint knl_row = p.knl_w * p.channels;
71
72 FLOAT_TYPE sum = 0.0;
73 for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
74 uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
75 if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
76 continue;
77 }
78 for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
79 uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
80 if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
81 continue;
82 }
83 FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
84 FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
85 sum = fma(v, k, sum);
86 }
87 }
88 return sum;
89}
90
91void main() {
92 uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
93 if (idx >= p.ne) {
94 return;
95 }
96
97 FLOAT_TYPE result =
98#ifdef WHCN
99 conv_2d_dw_whcn(idx);
100#else
101 conv_2d_dw_cwhn(idx);
102#endif
103 dst_data[idx] = D_TYPE(result);
104}
105