1//
  2// MIT license
  3// Copyright (C) 2025 Codeplay Software Ltd.
  4// Copyright (C) 2025 Intel Corporation
  5// SPDX-License-Identifier: MIT
  6//
  7
  8//
  9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 10// See https://llvm.org/LICENSE.txt for license information.
 11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 12//
 13
 14#ifndef GGML_SYCL_QUANTS_HPP
 15#define GGML_SYCL_QUANTS_HPP
 16
 17#include <utility>
 18
 19#include "ggml-common.h"
 20#include "ggml.h"
 21
 22namespace ggml_sycl_reordered {
 23
 24// The reordered block moves quants (qs) and  scales(d) to two
 25// uniform regions of memory that is contiguous in the same tensor.
 26// What this means is that instead of having:
 27// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN]
 28// We have:
 29// [qs0, qs1, qs2, ..., qsN]  [d0, d1, d2, ..., dN]
 30//
 31// Notes: out-of-bounds qs will run into d values
 32// Aligment relies on the allocated size of qs
 33
 34template <ggml_type type> struct block_q_t;
 35
 36// qk number of weights / quants in a block
 37// qr number of weights in a byte (described as 'before dequantization')
 38//    for quantization types that has low and high bits split, qr is calculated with
 39//    using the lower bits, e.g for Q6 quants QR6 is 2
 40// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field)
 41// See ggml-common.h to see how these are calculated
 42template <> struct block_q_t<GGML_TYPE_Q4_0> {
 43    struct traits {
 44        static constexpr uint32_t qk       = QK4_0;
 45        static constexpr uint32_t qi       = QI4_0;
 46        static constexpr uint32_t qr       = QR4_0;
 47        static constexpr uint32_t vdr_mmvq = 2;
 48    };
 49
 50    static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
 51        return { block_index * (QK4_0 / QR4_0), 0 };
 52    }
 53
 54    static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
 55        return { (ncols / QR4_0 * nrows) + block_index * sizeof(ggml_half), 0 };
 56    }
 57
 58    static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
 59};
 60
 61template <> struct block_q_t<GGML_TYPE_Q4_K> {
 62    struct traits {
 63        static constexpr uint32_t qk       = QK_K;
 64        static constexpr uint32_t qi       = QI4_K;
 65        static constexpr uint32_t qr       = QR4_K;
 66        static constexpr uint32_t vdr_mmvq = 2;
 67    };
 68
 69    static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
 70        return { block_index * (traits::qk / traits::qr), 0 };
 71    }
 72
 73    static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
 74        auto nblocks = (nrows * (ncols / QK_K));
 75        return { nblocks * (QK_K / 2) + (block_index * K_SCALE_SIZE),
 76                 (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
 77    }
 78
 79    static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
 80};
 81
 82template <> struct block_q_t<GGML_TYPE_Q6_K> {
 83    struct traits {
 84        static constexpr uint32_t qk       = QK_K;
 85        static constexpr uint32_t qi       = QI6_K;
 86        static constexpr uint32_t qr       = QR6_K;
 87        static constexpr uint32_t vdr_mmvq = 1;
 88    };
 89
 90    static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
 91        auto low_bits_index  = block_index * (QK_K / QR6_K);
 92        // the index of high bits it's after all low bits
 93        auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
 94        return { low_bits_index, high_bits_index };
 95    }
 96
 97    static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
 98        auto nblocks        = (nrows * (ncols / QK_K));
 99        auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
100        auto block_scales   = total_qs_bytes + block_index * (QK_K / 16);
101        auto sb_scale       = total_qs_bytes + nblocks * (QK_K / 16) + block_index * sizeof(ggml_half);
102        return { block_scales, sb_scale };
103    }
104
105    static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
106};
107
108}  // namespace ggml_sycl_reordered
109
110#endif  // GGML_SYCL_QUANTS_HPP