1#version 450
  2
  3#include "types.glsl"
  4
  5layout (push_constant) uniform parameter
  6{
  7    uint ne;
  8    uint batches;
  9    uint channels;
 10    uint dst_w;
 11    uint dst_h;
 12    uint src_w;
 13    uint src_h;
 14    uint knl_w;
 15    uint knl_h;
 16    int stride_x;
 17    int stride_y;
 18    int pad_x;
 19    int pad_y;
 20    int dilation_x;
 21    int dilation_y;
 22} p;
 23
 24layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
 25layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
 26layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
 27
 28layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
 29
 30FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
 31    uint i0 = idx / p.dst_w;
 32    uint dst_x = idx - i0 * p.dst_w;
 33    uint i1 = i0 / p.dst_h;
 34    uint dst_y = i0 - i1 * p.dst_h;
 35    uint n = i1 / p.channels;
 36    uint c = i1 - n * p.channels;
 37
 38    uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
 39    uint knl_i = c * p.knl_h * p.knl_w;
 40
 41    FLOAT_TYPE sum = 0.0;
 42    for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
 43        uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
 44        if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
 45            continue;
 46        }
 47        for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
 48            uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
 49            if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
 50                continue;
 51            }
 52            FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
 53            FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
 54            sum = fma(v, k, sum);
 55        }
 56    }
 57    return sum;
 58}
 59
 60FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
 61    uint i0 = idx / p.channels;
 62    uint c = idx - i0 * p.channels;
 63    uint i1 = i0 / p.dst_w;
 64    uint dst_x = i0 - i1 * p.dst_w;
 65    uint n = i1 / p.dst_h;
 66    uint dst_y = i1 - n * p.dst_h;
 67
 68    uint src_i = n * p.channels * p.src_h * p.src_w;
 69    uint src_row = p.src_w * p.channels;
 70    uint knl_row = p.knl_w * p.channels;
 71
 72    FLOAT_TYPE sum = 0.0;
 73    for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
 74        uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
 75        if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
 76            continue;
 77        }
 78        for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
 79            uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
 80            if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
 81                continue;
 82            }
 83            FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
 84            FLOAT_TYPE k = FLOAT_TYPE(knl_data[        knl_y * knl_row + knl_x * p.channels + c]);
 85            sum = fma(v, k, sum);
 86        }
 87    }
 88    return sum;
 89}
 90
 91void main() {
 92    uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
 93    if (idx >= p.ne) {
 94        return;
 95    }
 96
 97    FLOAT_TYPE result =
 98#ifdef WHCN
 99        conv_2d_dw_whcn(idx);
100#else
101        conv_2d_dw_cwhn(idx);
102#endif
103    dst_data[idx] = D_TYPE(result);
104}
105