1#include "upscale.cuh"
2
3static __global__ void upscale_f32(const float * x, float * dst,
4 const int nb00, const int nb01, const int nb02, const int nb03,
5 const int ne10, const int ne11, const int ne12, const int ne13,
6 const float sf0, const float sf1, const float sf2, const float sf3) {
7 int index = threadIdx.x + blockIdx.x * blockDim.x;
8 if (index >= ne10 * ne11 * ne12 * ne13) {
9 return;
10 }
11
12 int i10 = index % ne10;
13 int i11 = (index / ne10) % ne11;
14 int i12 = (index / (ne10 * ne11)) % ne12;
15 int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
16
17 int i00 = i10 / sf0;
18 int i01 = i11 / sf1;
19 int i02 = i12 / sf2;
20 int i03 = i13 / sf3;
21
22 dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
23}
24
25static __global__ void upscale_f32_bilinear(const float * x, float * dst,
26 const int nb00, const int nb01, const int nb02, const int nb03,
27 const int ne00_src, const int ne01_src,
28 const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
29 const float sf0, const float sf1, const float sf2, const float sf3,
30 const float pixel_offset) {
31 const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
32 const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
33
34 if (index >= dst_total_elements) {
35 return;
36 }
37
38 const int i10_dst = index % ne10_dst;
39 const int i11_dst = (index / ne10_dst) % ne11_dst;
40 const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
41 const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
42
43 const int i02_src = (int)(i12_dst / sf2);
44 const int i03_src = (int)(i13_dst / sf3);
45
46 const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
47 int y0_src = (int)floorf(y_src_f);
48 int y1_src = y0_src + 1;
49
50 y0_src = max(0, min(y0_src, ne01_src - 1));
51 y1_src = max(0, min(y1_src, ne01_src - 1));
52
53 float dy = y_src_f - (float)y0_src;
54 dy = max(0.0f, min(dy, 1.0f));
55
56 float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
57 int x0_src = (int)floorf(x_src_f);
58 int x1_src = x0_src + 1;
59
60 x0_src = max(0, min(x0_src, ne00_src - 1));
61 x1_src = max(0, min(x1_src, ne00_src - 1));
62
63 float dx = x_src_f - (float)x0_src;
64 dx = max(0.0f, min(dx, 1.0f));
65
66 const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
67 const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
68 const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
69 const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
70
71 const float val_a = *p_a;
72 const float val_b = *p_b;
73 const float val_c = *p_c;
74 const float val_d = *p_d;
75
76 float result = val_a * (1.0f - dx) * (1.0f - dy) +
77 val_b * dx * (1.0f - dy) +
78 val_c * (1.0f - dx) * dy +
79 val_d * dx * dy;
80
81 dst[index] = result;
82}
83
84// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
85// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
86static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,
87 const int nb00, const int nb01, const int nb02, const int nb03,
88 const int ne00_src, const int ne01_src,
89 const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
90 const float sf0, const float sf1, const float sf2, const float sf3,
91 const float pixel_offset) {
92 const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
93 const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
94
95 if (index >= dst_total_elements) {
96 return;
97 }
98
99 const int i10_dst = index % ne10_dst;
100 const int i11_dst = (index / ne10_dst) % ne11_dst;
101 const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
102 const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
103
104 const int i02_src = (int)(i12_dst / sf2);
105 const int i03_src = (int)(i13_dst / sf3);
106
107 const float y = ((float)i11_dst + pixel_offset) / sf1;
108 const float x = ((float)i10_dst + pixel_offset) / sf0;
109
110 // support and invscale, minimum 1 pixel for bilinear
111 const float support1 = max(1.0f / sf1, 1.0f);
112 const float invscale1 = 1.0f / support1;
113 const float support0 = max(1.0f / sf0, 1.0f);
114 const float invscale0 = 1.0f / support0;
115
116 // the range of source pixels that contribute
117 const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));
118 const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
119 const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));
120 const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
121
122 // bilinear filter with antialiasing
123 float val = 0.0f;
124 float total_weight = 0.0f;
125
126 auto triangle_filter = [](float x) -> float {
127 return max(1.0f - fabsf(x), 0.0f);
128 };
129
130 for (int64_t sy = y_min; sy < y_max; sy++) {
131 const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
132
133 for (int64_t sx = x_min; sx < x_max; sx++) {
134 const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
135 const float weight = weight_x * weight_y;
136
137 if (weight <= 0.0f) {
138 continue;
139 }
140
141 const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
142 val += pixel * weight;
143 total_weight += weight;
144 }
145 }
146
147 if (total_weight > 0.0f) {
148 val /= total_weight;
149 }
150
151 dst[index] = val;
152}
153
154namespace bicubic_interpolation {
155// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
156__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
157
158static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
159static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
160
161static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {
162 const float w0 = weight2(x + 1);
163 const float w1 = weight1(x + 0);
164 const float w2 = weight1(1 - x);
165 const float w3 = weight2(2 - x);
166 return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
167};
168} // namespace bicubic_interpolation
169
170static __global__ void upscale_f32_bicubic(const float * x, float * dst,
171 const int nb00, const int nb01, const int nb02, const int nb03,
172 const int ne00_src, const int ne01_src,
173 const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
174 const float sf0, const float sf1, const float sf2, const float sf3,
175 const float pixel_offset) {
176 using bicubic_interpolation::bicubic;
177
178 const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
179 const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
180
181 if (index >= dst_total_elements) {
182 return;
183 }
184
185 const int i10_dst = index % ne10_dst;
186 const int i11_dst = (index / ne10_dst) % ne11_dst;
187 const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
188 const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
189
190 const int i02_src = (int)(i12_dst / sf2);
191 const int i03_src = (int)(i13_dst / sf3);
192
193 const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
194 const int y0_src = (int)floorf(y_src_f);
195 const float dy = y_src_f - (float)y0_src;
196
197 const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
198 const int x0_src = (int)floorf(x_src_f);
199 const float dx = x_src_f - (float)x0_src;
200
201 const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
202
203 auto load = [=](int x_off, int y_off) -> float {
204 int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));
205 int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));
206 return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
207 };
208
209 const float result = bicubic(
210 bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),
211 bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),
212 bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),
213 bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);
214
215 dst[index] = result;
216}
217
218static void upscale_f32_cuda(const float * x, float * dst,
219 const int nb00, const int nb01, const int nb02, const int nb03,
220 const int ne10, const int ne11, const int ne12, const int ne13,
221 const float sf0, const float sf1, const float sf2, const float sf3,
222 cudaStream_t stream) {
223 const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
224 const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
225
226 upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
227}
228
229static void upscale_f32_bilinear_cuda(const float * x, float * dst,
230 const int nb00, const int nb01, const int nb02, const int nb03,
231 const int ne00_src, const int ne01_src,
232 const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
233 const float sf0, const float sf1, const float sf2, const float sf3,
234 const float pixel_offset, bool antialias, cudaStream_t stream) {
235 const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
236 const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
237
238 if (antialias) {
239 upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
240 } else {
241 upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
242 }
243}
244
245static void upscale_f32_bicubic_cuda(const float * x, float * dst,
246 const int nb00, const int nb01, const int nb02, const int nb03,
247 const int ne00_src, const int ne01_src,
248 const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
249 const float sf0, const float sf1, const float sf2, const float sf3,
250 const float pixel_offset, cudaStream_t stream) {
251 const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
252 const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
253
254 upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
255}
256
257void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
258 const ggml_tensor * src0 = dst->src[0];
259 const float * src0_d = (const float *)src0->data;
260 float * dst_d = (float *)dst->data;
261 cudaStream_t stream = ctx.stream();
262
263 GGML_ASSERT(src0->type == GGML_TYPE_F32);
264 GGML_ASSERT( dst->type == GGML_TYPE_F32);
265
266 const int mode_flags = dst->op_params[0];
267 const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
268
269 float sf0 = (float)dst->ne[0]/src0->ne[0];
270 float sf1 = (float)dst->ne[1]/src0->ne[1];
271 float sf2 = (float)dst->ne[2]/src0->ne[2];
272 const float sf3 = (float)dst->ne[3]/src0->ne[3];
273
274 float pixel_offset = 0.5f;
275 if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
276 sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
277 sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
278 pixel_offset = 0.0f;
279 }
280
281 if (mode == GGML_SCALE_MODE_NEAREST) {
282 upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
283 } else if (mode == GGML_SCALE_MODE_BILINEAR) {
284 const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
285 upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
286 src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
287 sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
288 } else if (mode == GGML_SCALE_MODE_BICUBIC) {
289 upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
290 src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
291 sf0, sf1, sf2, sf3, pixel_offset, stream);
292 }
293}