1/***************************************************************************
  2 *
  3 *  Copyright (C) 2025 Codeplay Software Ltd.
  4 *  Copyright (C) 2025 Intel Corporation
  5 *
  6 *  MIT License
  7 *
  8 *  Unless required by applicable law or agreed to in writing, software
  9 *  distributed under the License is distributed on an "AS IS" BASIS,
 10 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 11 *  See the License for the specific language governing permissions and
 12 *  limitations under the License.
 13 *
 14 *  quantize.hpp
 15 *
 16 *  Description:
 17 *     Sycl backend specific quantization functions
 18 **************************************************************************/
 19
 20#pragma once
 21
 22#include <sycl/nd_item.hpp>
 23
 24#include "ggml-sycl/dpct/helper.hpp"
 25
 26template <int ElementsPerWI>
 27__dpct_inline__ static void quantize_q8_1_impl(const float * __restrict__ x,
 28                                               sycl::vec<int8_t, ElementsPerWI> & quantized_values, float & d,
 29                                               float & sum, const sycl::nd_item<1> & it) {
 30    auto subgroup_id = it.get_group(0);
 31    auto wi_id       = it.get_local_id(0);
 32
 33    sycl::vec<float, ElementsPerWI> wi_f32_vals;
 34
 35    auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
 36    wi_f32_vals           = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
 37
 38    float amax = 0.0f;
 39
 40#pragma unroll(ElementsPerWI)
 41    for (int i = 0; i < ElementsPerWI; i++) {
 42        sum += wi_f32_vals[i];
 43        amax                = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
 44        quantized_values[i] = 0;
 45    }
 46    sum  = sycl::reduce_over_group(it.get_sub_group(), sum, sycl::plus<float>());
 47    amax = sycl::reduce_over_group(it.get_sub_group(), amax, sycl::maximum<float>());
 48    d    = amax == 0 ? 1 : amax / 127;
 49
 50#pragma unroll(ElementsPerWI)
 51    for (int i = 0; i < ElementsPerWI; i++) {
 52        quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
 53    }
 54
 55    d = amax == 0 ? 0 : d;
 56}
 57
 58// No op to control codepath in ggml_sycl_op_mul_mat
 59template <int ElementsPerWI> struct no_quantize_q8_1 {
 60    void operator()(const float *, void *, int, int, const sycl::nd_item<1> &) const {}
 61};
 62
 63template <int ElementsPerWI> struct quantize_and_reorder_q8_1_soa {
 64    __dpct_inline__ void operator()(const float * __restrict__ x, void * reordered_q8_tensor, const int kx,
 65                                    const int kx_padded, const sycl::nd_item<1> & it) const {
 66        /*
 67        Quantizes and reorders the resultant q8 tensor in a per row fashion
 68        Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
 69    */
 70        auto subgroup_id = it.get_group(0);
 71        auto wi_id       = it.get_local_id(0);
 72
 73        sycl::vec<int8_t, ElementsPerWI> quantized_values;
 74        float                            d   = 0.0f;
 75        float                            sum = 0.0f;
 76        quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);
 77
 78        const int num_blocks_per_row = kx / QK8_1;
 79        auto      row                = subgroup_id / num_blocks_per_row;
 80        auto      col                = subgroup_id % num_blocks_per_row;
 81        auto      row_offset         = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
 82        auto      col_offset         = QK8_1 * col + wi_id * ElementsPerWI;
 83
 84        auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
 85        *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
 86
 87        auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
 88        if (wi_id == 0) {
 89            *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
 90        }
 91    }
 92};
 93
 94template <int ElementsPerWI> struct quantize_q8_1 {
 95    __dpct_inline__ void operator()(const float * __restrict__ x, void * q8_tensor, const int kx, const int kx_padded,
 96                                    const sycl::nd_item<1> & it) const {
 97        auto subgroup_id = it.get_group(0);
 98        auto wi_id       = it.get_local_id(0);
 99
100        const int num_blocks_per_row = kx / QK8_1;
101        auto      row                = subgroup_id / num_blocks_per_row;
102        const int pitch              = kx_padded / QK8_1;
103
104        sycl::vec<int8_t, ElementsPerWI> quantized_values;
105        float                            d   = 0.0f;
106        float                            sum = 0.0f;
107        quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);
108
109        block_q8_1 * quant_ptr = (block_q8_1 *) q8_tensor;
110        auto         block_id  = subgroup_id % num_blocks_per_row + row * pitch;
111
112        int8_t * qs                                               = &(quant_ptr[block_id].qs[wi_id * ElementsPerWI]);
113        *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(qs) = quantized_values;
114        if (wi_id == 0) {
115            quant_ptr[block_id].ds = sycl::half2(sycl::half(d), sycl::half(sum));
116        }
117    }
118};
119
120template <template <int> typename quantize_f>
121void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
122                            dpct::queue_ptr stream) {
123    static_assert(QK8_1 % WARP_SIZE == 0);
124    auto local_range      = std::size_t(WARP_SIZE);
125    auto num_quant_blocks = ky * (kx / QK8_1);
126    auto global_range     = num_quant_blocks * local_range;
127    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
128
129    stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
130                         [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
131                             quantize_f<QK8_1 / WARP_SIZE>()(x, vy, kx, kx_padded, it);
132                         });
133}