1//
2// MIT license
3// Copyright (C) 2024 Intel Corporation
4// SPDX-License-Identifier: MIT
5//
6
7//
8// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9// See https://llvm.org/LICENSE.txt for license information.
10// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11//
12
13#include "conv.hpp"
14
15static void conv_transpose_1d_kernel(
16 const int s0, const int output_size,
17 const int src0_ne0, const int src0_ne1, const int src0_ne2,
18 const int src1_ne0, const int dst_ne0,
19 const float * src0, const float * src1, float * dst,
20 const sycl::nd_item<3> &item_ct1) {
21 int global_index = item_ct1.get_local_id(2) +
22 item_ct1.get_group(2) * item_ct1.get_local_range(2);
23 if (global_index >= output_size) {
24 return;
25 }
26
27 int out_index = global_index / dst_ne0;
28
29 float accumulator = 0;
30
31 for (int c = 0; c < src0_ne2; c++) {
32 int idx = global_index % dst_ne0;
33
34 int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
35 int input_offset = src1_ne0 * c;
36
37 for (int i = 0; i < src1_ne0; i++) {
38 if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
39 continue;
40 }
41 int weight_idx = idx - i*s0;
42
43 float kernel_weight = src0[kernel_offset + weight_idx];
44 float input_value = src1[input_offset+i];
45
46 accumulator += kernel_weight * input_value;
47 }
48 }
49 dst[global_index] = accumulator;
50}
51
52static void conv_transpose_1d_f32_f32_sycl(
53 const int s0, const int output_size,
54 const int src0_ne0, const int src0_ne1, const int src0_ne2,
55 const int src1_ne0, const int dst_ne0,
56 const float *src0, const float *src1, float *dst,
57 const queue_ptr& stream) {
58
59 const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
60 const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
61 const sycl::range<3> block_nums(1, 1, num_blocks);
62 stream->parallel_for(
63 sycl::nd_range<3>(
64 block_nums * block_dims, block_dims),
65 [=](sycl::nd_item<3> item_ct1) {
66 conv_transpose_1d_kernel(
67 s0, output_size,
68 src0_ne0, src0_ne1, src0_ne2,
69 src1_ne0, dst_ne0,
70 src0, src1, dst, item_ct1);
71 });
72}
73
74void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
75 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
76 const ggml_tensor *src0 = dst->src[0];
77 const ggml_tensor *src1 = dst->src[1];
78 const float * src0_d = (const float *)src0->data;
79 const float * src1_d = (const float *)src1->data;
80
81 float * dst_d = (float *)dst->data;
82 dpct::queue_ptr stream = ctx.stream();
83
84 GGML_ASSERT(src0->type == GGML_TYPE_F32);
85 GGML_ASSERT( dst->type == GGML_TYPE_F32);
86
87 GGML_ASSERT(ggml_is_contiguous(src0));
88 GGML_ASSERT(ggml_is_contiguous(src1));
89
90 const int32_t * opts = (const int32_t *)dst->op_params;
91
92 const int s0 = opts[0];
93
94 const int64_t output_size = ggml_nelements(dst);
95
96 conv_transpose_1d_f32_f32_sycl(s0, output_size,
97 src0->ne[0], src0->ne[1], src0->ne[2],
98 src1->ne[0], dst->ne[0],
99 src0_d, src1_d, dst_d, stream);
100}
101