1#version 450
  2
  3#extension GL_EXT_shader_16bit_storage : require
  4#extension GL_EXT_control_flow_attributes : require
  5
  6#include "rte.glsl"
  7#include "types.glsl"
  8
  9layout (push_constant) uniform parameter
 10{
 11    BDA_STORAGE_T dst_addr;
 12    uint batch_offset; uint offset_delta;
 13    uint IC;
 14    uint IW; uint IH;
 15    uint OW; uint OH;
 16    uint KW; uint KH;
 17    uint pelements;
 18    uint CHW;
 19    int s0; int s1;
 20    int p0; int p1;
 21    int d0; int d1;
 22    uint batch_IC;
 23} p;
 24
 25layout(constant_id = 0) const uint BLOCK_SIZE = 32;
 26
 27const uint NUM_ITER = 512 / BLOCK_SIZE;
 28
 29layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 30
 31layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 32layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 33
 34#if BDA
 35layout (buffer_reference) buffer D_ptr {D_TYPE d;};
 36#endif
 37
 38void im2col(const uint y, const uint z) {
 39    const uint gidx = gl_GlobalInvocationID.x;
 40
 41    const uint oh = y;
 42    const uint batch = z / p.IC;
 43    const uint ic = z % p.IC;
 44
 45    const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
 46    const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
 47    const int oh_s1 = int(oh) * p.s1;
 48    const uint ksize = p.OW * p.KH;
 49
 50    const uint base_linear_idx = gidx * NUM_ITER;
 51
 52    uint current_kx = base_linear_idx / ksize;
 53    const uint rem = base_linear_idx - (current_kx * ksize);
 54    uint current_ky = rem / p.OW;
 55    uint current_ix = rem % p.OW;
 56
 57    A_TYPE values[NUM_ITER];
 58    BDA_OFFSET_T offset_dst[NUM_ITER];
 59    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
 60        values[idx] = A_TYPE(0);
 61    }
 62
 63    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
 64
 65        const uint linear_idx = base_linear_idx + idx;
 66
 67        if (linear_idx >= p.pelements) {
 68            continue;
 69        }
 70
 71        const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
 72        const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
 73
 74        offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
 75
 76        if ((iih < p.IH) && (iiw < p.IW)) {
 77            values[idx] = data_a[src_base + iih * p.IW + iiw];
 78        }
 79
 80        if (++current_ix == p.OW) {
 81            current_ix = 0;
 82            if (++current_ky == p.KH) {
 83                current_ky = 0;
 84                current_kx++;
 85            }
 86        }
 87    }
 88
 89    [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
 90
 91        const uint linear_idx = base_linear_idx + idx;
 92
 93        if (linear_idx >= p.pelements) {
 94            continue;
 95        }
 96
 97#if BDA
 98        D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
 99        dst_addr.d = D_TYPE(values[idx]);
100#else
101        data_d[offset_dst[idx]] = D_TYPE(values[idx]);
102#endif
103    }
104}
105
106void main() {
107    uint y = gl_GlobalInvocationID.y;
108    while (y < p.OH) {
109        uint z = gl_GlobalInvocationID.z;
110        while (z < p.batch_IC) {
111            im2col(y, z);
112            z += gl_NumWorkGroups.z;
113        }
114        y += gl_NumWorkGroups.y;
115    }
116}