1kernel void kernel_upscale(
  2    global const void * p_src0,
  3    ulong off_src0,
  4    global void * p_dst,
  5    ulong off_dst,
  6    ulong nb00,
  7    ulong nb01,
  8    ulong nb02,
  9    ulong nb03,
 10    int ne10,
 11    int ne11,
 12    int ne12,
 13    int ne13,
 14    float sf0,
 15    float sf1,
 16    float sf2,
 17    float sf3
 18) {
 19    global const char * src_base = (global const char *)p_src0 + off_src0;
 20    global float * dst_base = (global float *)((global char *)p_dst + off_dst);
 21
 22    int index = get_global_id(0);
 23    int dst_total_elements = ne10 * ne11 * ne12 * ne13;
 24
 25    if (index >= dst_total_elements) {
 26        return;
 27    }
 28
 29    int i10 = index % ne10;
 30    int i11 = (index / ne10) % ne11;
 31    int i12 = (index / (ne10 * ne11)) % ne12;
 32    int i13 = index / (ne10 * ne11 * ne12);
 33
 34    int i00 = (int)(i10 / sf0);
 35    int i01 = (int)(i11 / sf1);
 36    int i02 = (int)(i12 / sf2);
 37    int i03 = (int)(i13 / sf3);
 38
 39    ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00;
 40    global const float * src_element_ptr = (global const float *)(src_base + offset_src_element);
 41
 42    dst_base[index] = *src_element_ptr;
 43}
 44
 45kernel void kernel_upscale_bilinear(
 46    global const void * p_src0,
 47    ulong off_src0,
 48    global void * p_dst,
 49    ulong off_dst,
 50    ulong nb00,
 51    ulong nb01,
 52    ulong nb02,
 53    ulong nb03,
 54    int ne00_src,
 55    int ne01_src,
 56    int ne10_dst,
 57    int ne11_dst,
 58    int ne12_dst,
 59    int ne13_dst,
 60    float sf0,
 61    float sf1,
 62    float sf2,
 63    float sf3,
 64    float pixel_offset
 65) {
 66    global const char * src_base = (global const char *)p_src0 + off_src0;
 67    global float * dst_base = (global float *)((global char *)p_dst + off_dst);
 68
 69    int index = get_global_id(0);
 70    int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
 71
 72    if (index >= dst_total_elements) {
 73        return;
 74    }
 75
 76    int i10_dst = index % ne10_dst;
 77    int i11_dst = (index / ne10_dst) % ne11_dst;
 78    int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
 79    int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
 80
 81    int i02_src = (int)(i12_dst / sf2);
 82    int i03_src = (int)(i13_dst / sf3);
 83
 84    float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
 85    long y0_src = (long)floor(y_src_f);
 86    long y1_src = y0_src + 1;
 87
 88    y0_src = max(0L, min(y0_src, (long)ne01_src - 1));
 89    y1_src = max(0L, min(y1_src, (long)ne01_src - 1));
 90
 91    float dy = y_src_f - (float)y0_src;
 92    dy = max(0.0f, min(dy, 1.0f));
 93
 94    float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
 95    long x0_src = (long)floor(x_src_f);
 96    long x1_src = x0_src + 1;
 97
 98    x0_src = max(0L, min(x0_src, (long)ne00_src - 1));
 99    x1_src = max(0L, min(x1_src, (long)ne00_src - 1));
100
101    float dx = x_src_f - (float)x0_src;
102    dx = max(0.0f, min(dx, 1.0f));
103
104    global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
105    global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
106    global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
107    global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
108
109    const float val_a = *p_a;
110    const float val_b = *p_b;
111    const float val_c = *p_c;
112    const float val_d = *p_d;
113
114    float result = val_a * (1.0f - dx) * (1.0f - dy) +
115                   val_b * dx * (1.0f - dy) +
116                   val_c * (1.0f - dx) * dy +
117                   val_d * dx * dy;
118
119    dst_base[index] = result;
120}