1#include "im2col.cuh"
  2
  3#define MAX_GRIDDIM_Z 65535
  4
  5template <typename T>
  6static  __global__ void im2col_kernel(
  7        const float * x, T * dst,
  8        int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
  9        int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,
 10        int s0, int s1, int p0, int p1, int d0, int d1) {
 11    const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
 12    if (i >= IC_KH_KW) {
 13        return;
 14    }
 15
 16    const int64_t iic = i / (KH_KW);
 17    const int64_t rem = i - iic * KH_KW;
 18    const int64_t ikh = rem / KW;
 19    const int64_t ikw = rem - ikh * KW;
 20
 21    const int64_t  iow = blockIdx.y;
 22    for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
 23        const int64_t  in = iz / OH;
 24        const int64_t  ioh = iz - in * OH;
 25
 26        const int64_t iiw = iow * s0 + ikw * d0 - p0;
 27        const int64_t iih = ioh * s1 + ikh * d1 - p1;
 28
 29        const int64_t offset_dst =
 30            ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
 31
 32        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
 33            dst[offset_dst] = 0.0f;
 34        } else {
 35            const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
 36            dst[offset_dst] = x[offset_src + iih * IW + iiw];
 37        }
 38    }
 39
 40    GGML_UNUSED(IC);
 41    GGML_UNUSED(KH);
 42}
 43
 44// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
 45template <typename T>
 46static void im2col_cuda(const float * x, T* dst,
 47    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
 48    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
 49    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
 50    const int64_t IC_KH_KW = IC * KH * KW;
 51    const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
 52    const int64_t N_OH = N * OH;
 53    const int64_t KH_KW = KW*KH;
 54    dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
 55    im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
 56                                                                                     IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
 57                                                                                     s0, s1, p0, p1, d0, d1);
 58}
 59
 60static void im2col_cuda_f16(const float * x, half * dst,
 61    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
 62    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
 63    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
 64
 65    im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
 66}
 67
 68static void im2col_cuda_f32(const float * x, float * dst,
 69    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
 70    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
 71    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
 72
 73    im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
 74}
 75
 76void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 77    const ggml_tensor * src0 = dst->src[0];
 78    const ggml_tensor * src1 = dst->src[1];
 79    const float * src1_d = (const float *)src1->data;
 80    float * dst_d = (float *)dst->data;
 81    cudaStream_t stream = ctx.stream();
 82
 83    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 84    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
 85
 86    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
 87    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
 88    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
 89    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
 90    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
 91    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
 92
 93    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
 94
 95    const int64_t IC = src1->ne[is_2D ? 2 : 1];
 96    const int64_t IH = is_2D ? src1->ne[1] : 1;
 97    const int64_t IW =         src1->ne[0];
 98
 99    const int64_t KH = is_2D ? src0->ne[1] : 1;
100    const int64_t KW =         src0->ne[0];
101
102    const int64_t OH = is_2D ? dst->ne[2] : 1;
103    const int64_t OW =         dst->ne[1];
104
105    const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
106    const int64_t N        = src1->ne[is_2D ? 3 : 2];
107    const int64_t IH_IW    = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
108
109    if(dst->type == GGML_TYPE_F16) {
110        im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
111    } else {
112        im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
113    }
114}
115
116// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
117template <typename T>
118static  __global__ void im2col_3d_kernel(
119        const float * src, T * dst,
120        int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
121        int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
122        int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
123        int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
124        int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
125        int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
126        int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
127    const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
128    if (i >= IC_KD_KH_KW) {
129        return;
130    }
131    GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);
132    GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);
133
134    const int64_t iic = i / KD_KH_KW;
135    const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
136    const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
137    const int64_t ikw = i % KW;
138
139    const int64_t  iow = blockIdx.y;
140    for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
141        const int64_t in  = iz / OD_OH;
142        const int64_t iod = (iz - in*OD_OH) / OH;
143        const int64_t ioh = iz % OH;
144
145        const int64_t iiw = iow * s0 + ikw * d0 - p0;
146        const int64_t iih = ioh * s1 + ikh * d1 - p1;
147        const int64_t iid = iod * s2 + ikd * d2 - p2;
148
149        const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
150
151        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
152            dst[offset_dst] = 0.0f;
153        } else {
154            const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
155            dst[offset_dst] = src[offset_src];
156        }
157    }
158}
159
160// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
161template <typename T>
162static void im2col_3d_cuda(const float * src, T* dst,
163    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
164    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
165    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
166    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
167    const int64_t OH_OW = OH*OW;
168    const int64_t KD_KH_KW = KD*KH*KW;
169    const int64_t ID_IH_IW = ID*IH*IW;
170    const int64_t KH_KW = KH*KW;
171    const int64_t IH_IW = IH*IW;
172    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
173    const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
174    const int64_t N_OD_OH = N*OD*OH;
175    const int64_t OD_OH = OD*OH;
176    const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
177    const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
178    const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
179    const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
180    const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
181    dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
182    im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
183                                                                                           OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
184                                                                                           IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
185                                                                                           OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
186                                                                                           stride_q, stride_z, stride_y, stride_x,
187                                                                                           s0, s1, s2, p0, p1, p2, d0, d1, d2);
188}
189
190static void im2col_3d_cuda_f16(const float * src, half * dst,
191    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
192    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
193    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
194    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
195
196    im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
197                         stride_q, stride_z, stride_y, stride_x,
198                         s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
199}
200
201static void im2col_3d_cuda_f32(const float * src, float * dst,
202    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
203    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
204    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
205    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
206
207    im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
208                          stride_q, stride_z, stride_y, stride_x,
209                          s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
210}
211
212void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213    const ggml_tensor * src0 = dst->src[0];
214    const ggml_tensor * src1 = dst->src[1];
215    const float * src1_d = (const float *)src1->data;
216    float * dst_d = (float *)dst->data;
217    cudaStream_t stream = ctx.stream();
218
219    GGML_ASSERT(src1->type == GGML_TYPE_F32);
220    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
221
222    GGML_TENSOR_BINARY_OP_LOCALS
223
224    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
225    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
226    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
227    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
228    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
229    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
230    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
231    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
232    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
233    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
234
235    const int64_t N  = ne13 / IC;
236    const int64_t ID = ne12;
237    const int64_t IH = ne11;
238    const int64_t IW = ne10;
239
240    const int64_t OC = ne03 / IC;
241    const int64_t KD = ne02;
242    const int64_t KH = ne01;
243    const int64_t KW = ne00;
244
245    const int64_t OD = ne3 / N;
246    const int64_t OH = ne2;
247    const int64_t OW = ne1;
248
249    const size_t  es       = ggml_element_size(src1);
250    const int64_t stride_x = src1->nb[0] / es;
251    const int64_t stride_y = src1->nb[1] / es;
252    const int64_t stride_z = src1->nb[2] / es;
253    const int64_t stride_q = src1->nb[3] / es;
254
255    if(dst->type == GGML_TYPE_F16) {
256        im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
257                           stride_q, stride_z, stride_y, stride_x,
258                           s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
259    } else {
260        im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
261                           stride_q, stride_z, stride_y, stride_x,
262                           s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
263    }
264}