1kernel void kernel_pad(
 2        global void * src0,
 3        ulong offset0,
 4        global void * dst,
 5        ulong offsetd,
 6        int ne00, int ne01, int ne02, int ne03,
 7        ulong nb00, ulong nb01, ulong nb02, ulong nb03,
 8        int ne0, int ne1, int ne2, int ne3,
 9        ulong nb0, ulong nb1, ulong nb2, ulong nb3,
10        int lp0, int rp0,
11        int lp1, int rp1,
12        int lp2, int rp2,
13        int lp3, int rp3
14) {
15    src0 = (global float*)((global char*)src0 + offset0);
16    dst  = (global float*)((global char*)dst  + offsetd);
17
18    int i0 = get_global_id(0);
19    int i1 = get_group_id(1);
20    int i2 = get_group_id(2) % ne2;
21    int i3 = get_group_id(2) / ne2;
22
23    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
24        return;
25    }
26
27    uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
28    uint dst_idx  =         i3*nb3  +         i2*nb2  +         i1*nb1  +         i0*nb0;
29
30    global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);
31    global float * dst_ptr  = (global float *)((global char *)dst  + dst_idx);
32
33    bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&
34                         (i1 >= lp1 && i1 < ne1 - rp1) &&
35                         (i2 >= lp2 && i2 < ne2 - rp2) &&
36                         (i3 >= lp3 && i3 < ne3 - rp3);
37
38    *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;
39}