1#include "upscale.cuh"
  2
  3static __global__ void upscale_f32(const float * x, float * dst,
  4        const int nb00, const int nb01, const int nb02, const int nb03,
  5        const int ne10, const int ne11, const int ne12, const int ne13,
  6        const float sf0, const float sf1, const float sf2, const float sf3) {
  7    int index = threadIdx.x + blockIdx.x * blockDim.x;
  8    if (index >= ne10 * ne11 * ne12 * ne13) {
  9        return;
 10    }
 11
 12    int i10 = index % ne10;
 13    int i11 = (index / ne10) % ne11;
 14    int i12 = (index / (ne10 * ne11)) % ne12;
 15    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
 16
 17    int i00 = i10 / sf0;
 18    int i01 = i11 / sf1;
 19    int i02 = i12 / sf2;
 20    int i03 = i13 / sf3;
 21
 22    dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
 23}
 24
 25static __global__ void upscale_f32_bilinear(const float * x, float * dst,
 26        const int nb00, const int nb01, const int nb02, const int nb03,
 27        const int ne00_src, const int ne01_src,
 28        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
 29        const float sf0, const float sf1, const float sf2, const float sf3,
 30        const float pixel_offset) {
 31    const int64_t index              = threadIdx.x + blockIdx.x * blockDim.x;
 32    const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
 33
 34    if (index >= dst_total_elements) {
 35        return;
 36    }
 37
 38    const int i10_dst = index % ne10_dst;
 39    const int i11_dst = (index / ne10_dst) % ne11_dst;
 40    const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
 41    const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
 42
 43    const int i02_src = (int)(i12_dst / sf2);
 44    const int i03_src = (int)(i13_dst / sf3);
 45
 46    const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
 47    int y0_src    = (int)floorf(y_src_f);
 48    int y1_src    = y0_src + 1;
 49
 50    y0_src = max(0, min(y0_src, ne01_src - 1));
 51    y1_src = max(0, min(y1_src, ne01_src - 1));
 52
 53    float dy = y_src_f - (float)y0_src;
 54    dy       = max(0.0f, min(dy, 1.0f));
 55
 56    float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
 57    int x0_src    = (int)floorf(x_src_f);
 58    int x1_src    = x0_src + 1;
 59
 60    x0_src = max(0, min(x0_src, ne00_src - 1));
 61    x1_src = max(0, min(x1_src, ne00_src - 1));
 62
 63    float dx = x_src_f - (float)x0_src;
 64    dx = max(0.0f, min(dx, 1.0f));
 65
 66    const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
 67    const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
 68    const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
 69    const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
 70
 71    const float val_a = *p_a;
 72    const float val_b = *p_b;
 73    const float val_c = *p_c;
 74    const float val_d = *p_d;
 75
 76    float result = val_a * (1.0f - dx) * (1.0f - dy) +
 77                   val_b * dx * (1.0f - dy) +
 78                   val_c * (1.0f - dx) * dy +
 79                   val_d * dx * dy;
 80
 81    dst[index] = result;
 82}
 83
 84// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
 85// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
 86static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,
 87        const int nb00, const int nb01, const int nb02, const int nb03,
 88        const int ne00_src, const int ne01_src,
 89        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
 90        const float sf0, const float sf1, const float sf2, const float sf3,
 91        const float pixel_offset) {
 92    const int64_t index              = threadIdx.x + blockIdx.x * blockDim.x;
 93    const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
 94
 95    if (index >= dst_total_elements) {
 96        return;
 97    }
 98
 99    const int i10_dst = index % ne10_dst;
100    const int i11_dst = (index / ne10_dst) % ne11_dst;
101    const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
102    const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
103
104    const int i02_src = (int)(i12_dst / sf2);
105    const int i03_src = (int)(i13_dst / sf3);
106
107    const float y = ((float)i11_dst + pixel_offset) / sf1;
108    const float x = ((float)i10_dst + pixel_offset) / sf0;
109
110    // support and invscale, minimum 1 pixel for bilinear
111    const float support1  = max(1.0f / sf1, 1.0f);
112    const float invscale1 = 1.0f / support1;
113    const float support0  = max(1.0f / sf0, 1.0f);
114    const float invscale0 = 1.0f / support0;
115
116    // the range of source pixels that contribute
117    const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));
118    const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
119    const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));
120    const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
121
122    // bilinear filter with antialiasing
123    float val = 0.0f;
124    float total_weight = 0.0f;
125
126    auto triangle_filter = [](float x) -> float {
127        return max(1.0f - fabsf(x), 0.0f);
128    };
129
130    for (int64_t sy = y_min; sy < y_max; sy++) {
131        const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
132
133        for (int64_t sx = x_min; sx < x_max; sx++) {
134            const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
135            const float weight = weight_x * weight_y;
136
137            if (weight <= 0.0f) {
138                continue;
139            }
140
141            const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
142            val += pixel * weight;
143            total_weight += weight;
144        }
145    }
146
147    if (total_weight > 0.0f) {
148        val /= total_weight;
149    }
150
151    dst[index] = val;
152}
153
154namespace bicubic_interpolation {
155// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
156__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
157
158static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
159static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
160
161static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {
162    const float w0 = weight2(x + 1);
163    const float w1 = weight1(x + 0);
164    const float w2 = weight1(1 - x);
165    const float w3 = weight2(2 - x);
166    return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
167};
168} // namespace bicubic_interpolation
169
170static __global__ void upscale_f32_bicubic(const float * x, float * dst,
171        const int nb00, const int nb01, const int nb02, const int nb03,
172        const int ne00_src, const int ne01_src,
173        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
174        const float sf0, const float sf1, const float sf2, const float sf3,
175        const float pixel_offset) {
176    using bicubic_interpolation::bicubic;
177
178    const int64_t index              = threadIdx.x + blockIdx.x * blockDim.x;
179    const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
180
181    if (index >= dst_total_elements) {
182        return;
183    }
184
185    const int i10_dst = index % ne10_dst;
186    const int i11_dst = (index / ne10_dst) % ne11_dst;
187    const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
188    const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
189
190    const int i02_src = (int)(i12_dst / sf2);
191    const int i03_src = (int)(i13_dst / sf3);
192
193    const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
194    const int y0_src    = (int)floorf(y_src_f);
195    const float dy      = y_src_f - (float)y0_src;
196
197    const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
198    const int x0_src    = (int)floorf(x_src_f);
199    const float dx      = x_src_f - (float)x0_src;
200
201    const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
202
203    auto load = [=](int x_off, int y_off) -> float {
204        int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));
205        int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));
206        return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
207    };
208
209    const float result = bicubic(
210        bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),
211        bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),
212        bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),
213        bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);
214
215    dst[index] = result;
216}
217
218static void upscale_f32_cuda(const float * x, float * dst,
219        const int nb00, const int nb01, const int nb02, const int nb03,
220        const int ne10, const int ne11, const int ne12, const int ne13,
221        const float sf0, const float sf1, const float sf2, const float sf3,
222        cudaStream_t stream) {
223    const int64_t dst_size   = ne10 * ne11 * ne12 * ne13;
224    const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
225
226    upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
227}
228
229static void upscale_f32_bilinear_cuda(const float * x, float * dst,
230        const int nb00, const int nb01, const int nb02, const int nb03,
231        const int ne00_src, const int ne01_src,
232        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
233        const float sf0, const float sf1, const float sf2, const float sf3,
234        const float pixel_offset, bool antialias, cudaStream_t stream) {
235    const int64_t dst_size   = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
236    const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
237
238    if (antialias) {
239        upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
240    } else {
241        upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
242    }
243}
244
245static void upscale_f32_bicubic_cuda(const float * x, float * dst,
246        const int nb00, const int nb01, const int nb02, const int nb03,
247        const int ne00_src, const int ne01_src,
248        const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
249        const float sf0, const float sf1, const float sf2, const float sf3,
250        const float pixel_offset, cudaStream_t stream) {
251    const int64_t dst_size   = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
252    const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
253
254    upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
255}
256
257void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
258    const ggml_tensor * src0 = dst->src[0];
259    const float * src0_d = (const float *)src0->data;
260    float * dst_d = (float *)dst->data;
261    cudaStream_t stream = ctx.stream();
262
263    GGML_ASSERT(src0->type == GGML_TYPE_F32);
264    GGML_ASSERT( dst->type == GGML_TYPE_F32);
265
266    const int mode_flags = dst->op_params[0];
267    const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
268
269    float sf0 = (float)dst->ne[0]/src0->ne[0];
270    float sf1 = (float)dst->ne[1]/src0->ne[1];
271    float sf2 = (float)dst->ne[2]/src0->ne[2];
272    const float sf3 = (float)dst->ne[3]/src0->ne[3];
273
274    float pixel_offset = 0.5f;
275    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
276        sf0          = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
277        sf1          = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
278        pixel_offset = 0.0f;
279    }
280
281    if (mode == GGML_SCALE_MODE_NEAREST) {
282        upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
283    } else if (mode == GGML_SCALE_MODE_BILINEAR) {
284        const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
285        upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
286                                 src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
287                                 sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
288    } else if (mode == GGML_SCALE_MODE_BICUBIC) {
289        upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
290                                 src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
291                                 sf0, sf1, sf2, sf3, pixel_offset, stream);
292    }
293}