1#include <cstdio>
  2#include <vector>
  3#include <random>
  4#include <chrono>
  5#include <cstdlib>
  6#include <cmath>
  7#include <cassert>
  8#include <cstring>
  9#include <array>
 10
 11#include <ggml.h>
 12#include <ggml-cpu.h>
 13
 14#if defined(_MSC_VER)
 15#pragma warning(disable: 4244 4267) // possible loss of data
 16#endif
 17
 18constexpr int kVecSize = 1 << 18;
 19
 20static float drawFromGaussianPdf(std::mt19937& rndm) {
 21    constexpr double kScale = 1./(1. + std::mt19937::max());
 22    constexpr double kTwoPiTimesScale = 6.28318530717958647692*kScale;
 23    static float lastX;
 24    static bool haveX = false;
 25    if (haveX) { haveX = false; return lastX; }
 26    auto r = sqrt(-2*log(1 - kScale*rndm()));
 27    auto phi = kTwoPiTimesScale * rndm();
 28    lastX = r*sin(phi);
 29    haveX = true;
 30    return r*cos(phi);
 31}
 32
 33static void fillRandomGaussianFloats(std::vector<float>& values, std::mt19937& rndm, float mean = 0) {
 34    for (auto& v : values) v = mean + drawFromGaussianPdf(rndm);
 35}
 36
 37// Copy-pasted from ggml.c
 38#define QK4_0 32
 39typedef struct {
 40    float   d;          // delta
 41    uint8_t qs[QK4_0 / 2];  // nibbles / quants
 42} block_q4_0;
 43static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
 44
 45#define QK4_1 32
 46typedef struct {
 47    float   d;          // delta
 48    float   m;          // min
 49    uint8_t qs[QK4_1 / 2];  // nibbles / quants
 50} block_q4_1;
 51static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
 52
 53// Copy-pasted from ggml.c
 54#define QK8_0 32
 55typedef struct {
 56    float   d;          // delta
 57    int8_t  qs[QK8_0];  // quants
 58} block_q8_0;
 59static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
 60
 61// "Scalar" dot product between the quantized vector x and float vector y
 62inline double dot(int n, const block_q4_0* x, const float* y) {
 63    const static float kValues[16] = {-8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f};
 64    constexpr uint32_t kMask1 = 0x0f0f0f0f;
 65    uint32_t u1, u2;
 66    auto q1 = (const uint8_t*)&u1;
 67    auto q2 = (const uint8_t*)&u2;
 68    double sum = 0;
 69    for (int i=0; i<n; ++i) {
 70        float d = x->d;
 71        auto u = (const uint32_t*)x->qs;
 72        float s = 0;
 73        for (int k=0; k<4; ++k) {
 74            u1 = u[k] & kMask1;
 75            u2 = (u[k] >> 4) & kMask1;
 76            s += y[0]*kValues[q1[0]] + y[1]*kValues[q2[0]] +
 77                 y[2]*kValues[q1[1]] + y[3]*kValues[q2[1]] +
 78                 y[4]*kValues[q1[2]] + y[5]*kValues[q2[2]] +
 79                 y[6]*kValues[q1[3]] + y[7]*kValues[q2[3]];
 80            y += 8;
 81        }
 82        sum += s*d;
 83        ++x;
 84    }
 85    return sum;
 86}
 87// Alternative version of the above. Faster on my Mac (~45 us vs ~55 us per dot product),
 88// but about the same on X86_64 (Ryzen 7950X CPU).
 89inline double dot3(int n, const block_q4_0* x, const float* y) {
 90    const static std::pair<float,float> kValues[256] = {
 91        {-8.f, -8.f}, {-7.f, -8.f}, {-6.f, -8.f}, {-5.f, -8.f}, {-4.f, -8.f}, {-3.f, -8.f}, {-2.f, -8.f}, {-1.f, -8.f},
 92        { 0.f, -8.f}, { 1.f, -8.f}, { 2.f, -8.f}, { 3.f, -8.f}, { 4.f, -8.f}, { 5.f, -8.f}, { 6.f, -8.f}, { 7.f, -8.f},
 93        {-8.f, -7.f}, {-7.f, -7.f}, {-6.f, -7.f}, {-5.f, -7.f}, {-4.f, -7.f}, {-3.f, -7.f}, {-2.f, -7.f}, {-1.f, -7.f},
 94        { 0.f, -7.f}, { 1.f, -7.f}, { 2.f, -7.f}, { 3.f, -7.f}, { 4.f, -7.f}, { 5.f, -7.f}, { 6.f, -7.f}, { 7.f, -7.f},
 95        {-8.f, -6.f}, {-7.f, -6.f}, {-6.f, -6.f}, {-5.f, -6.f}, {-4.f, -6.f}, {-3.f, -6.f}, {-2.f, -6.f}, {-1.f, -6.f},
 96        { 0.f, -6.f}, { 1.f, -6.f}, { 2.f, -6.f}, { 3.f, -6.f}, { 4.f, -6.f}, { 5.f, -6.f}, { 6.f, -6.f}, { 7.f, -6.f},
 97        {-8.f, -5.f}, {-7.f, -5.f}, {-6.f, -5.f}, {-5.f, -5.f}, {-4.f, -5.f}, {-3.f, -5.f}, {-2.f, -5.f}, {-1.f, -5.f},
 98        { 0.f, -5.f}, { 1.f, -5.f}, { 2.f, -5.f}, { 3.f, -5.f}, { 4.f, -5.f}, { 5.f, -5.f}, { 6.f, -5.f}, { 7.f, -5.f},
 99        {-8.f, -4.f}, {-7.f, -4.f}, {-6.f, -4.f}, {-5.f, -4.f}, {-4.f, -4.f}, {-3.f, -4.f}, {-2.f, -4.f}, {-1.f, -4.f},
100        { 0.f, -4.f}, { 1.f, -4.f}, { 2.f, -4.f}, { 3.f, -4.f}, { 4.f, -4.f}, { 5.f, -4.f}, { 6.f, -4.f}, { 7.f, -4.f},
101        {-8.f, -3.f}, {-7.f, -3.f}, {-6.f, -3.f}, {-5.f, -3.f}, {-4.f, -3.f}, {-3.f, -3.f}, {-2.f, -3.f}, {-1.f, -3.f},
102        { 0.f, -3.f}, { 1.f, -3.f}, { 2.f, -3.f}, { 3.f, -3.f}, { 4.f, -3.f}, { 5.f, -3.f}, { 6.f, -3.f}, { 7.f, -3.f},
103        {-8.f, -2.f}, {-7.f, -2.f}, {-6.f, -2.f}, {-5.f, -2.f}, {-4.f, -2.f}, {-3.f, -2.f}, {-2.f, -2.f}, {-1.f, -2.f},
104        { 0.f, -2.f}, { 1.f, -2.f}, { 2.f, -2.f}, { 3.f, -2.f}, { 4.f, -2.f}, { 5.f, -2.f}, { 6.f, -2.f}, { 7.f, -2.f},
105        {-8.f, -1.f}, {-7.f, -1.f}, {-6.f, -1.f}, {-5.f, -1.f}, {-4.f, -1.f}, {-3.f, -1.f}, {-2.f, -1.f}, {-1.f, -1.f},
106        { 0.f, -1.f}, { 1.f, -1.f}, { 2.f, -1.f}, { 3.f, -1.f}, { 4.f, -1.f}, { 5.f, -1.f}, { 6.f, -1.f}, { 7.f, -1.f},
107        {-8.f,  0.f}, {-7.f,  0.f}, {-6.f,  0.f}, {-5.f,  0.f}, {-4.f,  0.f}, {-3.f,  0.f}, {-2.f,  0.f}, {-1.f,  0.f},
108        { 0.f,  0.f}, { 1.f,  0.f}, { 2.f,  0.f}, { 3.f,  0.f}, { 4.f,  0.f}, { 5.f,  0.f}, { 6.f,  0.f}, { 7.f,  0.f},
109        {-8.f,  1.f}, {-7.f,  1.f}, {-6.f,  1.f}, {-5.f,  1.f}, {-4.f,  1.f}, {-3.f,  1.f}, {-2.f,  1.f}, {-1.f,  1.f},
110        { 0.f,  1.f}, { 1.f,  1.f}, { 2.f,  1.f}, { 3.f,  1.f}, { 4.f,  1.f}, { 5.f,  1.f}, { 6.f,  1.f}, { 7.f,  1.f},
111        {-8.f,  2.f}, {-7.f,  2.f}, {-6.f,  2.f}, {-5.f,  2.f}, {-4.f,  2.f}, {-3.f,  2.f}, {-2.f,  2.f}, {-1.f,  2.f},
112        { 0.f,  2.f}, { 1.f,  2.f}, { 2.f,  2.f}, { 3.f,  2.f}, { 4.f,  2.f}, { 5.f,  2.f}, { 6.f,  2.f}, { 7.f,  2.f},
113        {-8.f,  3.f}, {-7.f,  3.f}, {-6.f,  3.f}, {-5.f,  3.f}, {-4.f,  3.f}, {-3.f,  3.f}, {-2.f,  3.f}, {-1.f,  3.f},
114        { 0.f,  3.f}, { 1.f,  3.f}, { 2.f,  3.f}, { 3.f,  3.f}, { 4.f,  3.f}, { 5.f,  3.f}, { 6.f,  3.f}, { 7.f,  3.f},
115        {-8.f,  4.f}, {-7.f,  4.f}, {-6.f,  4.f}, {-5.f,  4.f}, {-4.f,  4.f}, {-3.f,  4.f}, {-2.f,  4.f}, {-1.f,  4.f},
116        { 0.f,  4.f}, { 1.f,  4.f}, { 2.f,  4.f}, { 3.f,  4.f}, { 4.f,  4.f}, { 5.f,  4.f}, { 6.f,  4.f}, { 7.f,  4.f},
117        {-8.f,  5.f}, {-7.f,  5.f}, {-6.f,  5.f}, {-5.f,  5.f}, {-4.f,  5.f}, {-3.f,  5.f}, {-2.f,  5.f}, {-1.f,  5.f},
118        { 0.f,  5.f}, { 1.f,  5.f}, { 2.f,  5.f}, { 3.f,  5.f}, { 4.f,  5.f}, { 5.f,  5.f}, { 6.f,  5.f}, { 7.f,  5.f},
119        {-8.f,  6.f}, {-7.f,  6.f}, {-6.f,  6.f}, {-5.f,  6.f}, {-4.f,  6.f}, {-3.f,  6.f}, {-2.f,  6.f}, {-1.f,  6.f},
120        { 0.f,  6.f}, { 1.f,  6.f}, { 2.f,  6.f}, { 3.f,  6.f}, { 4.f,  6.f}, { 5.f,  6.f}, { 6.f,  6.f}, { 7.f,  6.f},
121        {-8.f,  7.f}, {-7.f,  7.f}, {-6.f,  7.f}, {-5.f,  7.f}, {-4.f,  7.f}, {-3.f,  7.f}, {-2.f,  7.f}, {-1.f,  7.f},
122        { 0.f,  7.f}, { 1.f,  7.f}, { 2.f,  7.f}, { 3.f,  7.f}, { 4.f,  7.f}, { 5.f,  7.f}, { 6.f,  7.f}, { 7.f,  7.f}
123    };
124    double sum = 0;
125    for (int i=0; i<n; ++i) {
126        float d = x->d;
127        auto q = x->qs;
128        float s = 0;
129        for (int k=0; k<4; ++k) {
130            s += y[0]*kValues[q[0]].first + y[1]*kValues[q[0]].second +
131                 y[2]*kValues[q[1]].first + y[3]*kValues[q[1]].second +
132                 y[4]*kValues[q[2]].first + y[5]*kValues[q[2]].second +
133                 y[6]*kValues[q[3]].first + y[7]*kValues[q[3]].second;
134            y += 8; q += 4;
135        }
136        sum += s*d;
137        ++x;
138    }
139    return sum;
140}
141
142inline double dot41(int n, const block_q4_1* x, const float* y) {
143    const static float kValues[16] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f};
144    constexpr uint32_t kMask1 = 0x0f0f0f0f;
145    uint32_t u1, u2;
146    auto q1 = (const uint8_t*)&u1;
147    auto q2 = (const uint8_t*)&u2;
148    double sum = 0;
149    for (int i=0; i<n; ++i) {
150        auto u = (const uint32_t*)x->qs;
151        float s = 0, s1 = 0;
152        for (int k=0; k<4; ++k) {
153            u1 = u[k] & kMask1;
154            u2 = (u[k] >> 4) & kMask1;
155            s += y[0]*kValues[q1[0]] + y[1]*kValues[q2[0]] +
156                 y[2]*kValues[q1[1]] + y[3]*kValues[q2[1]] +
157                 y[4]*kValues[q1[2]] + y[5]*kValues[q2[2]] +
158                 y[6]*kValues[q1[3]] + y[7]*kValues[q2[3]];
159            s1 += y[0] + y[1] + y[2] + y[3] + y[4] + y[5] + y[6] + y[7];
160            y += 8;
161        }
162        sum += s*x->d + s1*x->m;
163        ++x;
164    }
165    return sum;
166}
167
168// Copy-pasted from ggml.c
169static void quantize_row_q8_0_reference(const float *x, block_q8_0 *y, int k) {
170    assert(k % QK8_0 == 0);
171    const int nb = k / QK8_0;
172
173    for (int i = 0; i < nb; i++) {
174        float amax = 0.0f; // absolute max
175
176        for (int l = 0; l < QK8_0; l++) {
177            const float v = x[i*QK8_0 + l];
178            amax = std::max(amax, fabsf(v));
179        }
180
181        const float d = amax / ((1 << 7) - 1);
182        const float id = d ? 1.0f/d : 0.0f;
183
184        y[i].d = d;
185
186        for (int l = 0; l < QK8_0; ++l) {
187            const float   v  = x[i*QK8_0 + l]*id;
188            y[i].qs[l] = roundf(v);
189        }
190    }
191}
192
193// Copy-pasted from ggml.c
194static void dot_q4_q8(const int n, float* s, const void* vx, const void* vy) {
195    const int nb = n / QK8_0;
196    const block_q4_0* x = (const block_q4_0*)vx;
197    const block_q8_0* y = (const block_q8_0*)vy;
198    float sumf = 0;
199    for (int i = 0; i < nb; i++) {
200        const float d0 = x[i].d;
201        const float d1 = y[i].d;
202
203        const uint8_t * p0 = x[i].qs;
204        const  int8_t * p1 = y[i].qs;
205
206        int sumi = 0;
207        for (int j = 0; j < QK8_0/2; j++) {
208            const uint8_t v0 = p0[j];
209
210            const int i0 = (int8_t) (v0 & 0xf) - 8;
211            const int i1 = (int8_t) (v0 >> 4)  - 8;
212
213            const int i2 = p1[2*j + 0];
214            const int i3 = p1[2*j + 1];
215
216            sumi += i0*i2 + i1*i3;
217        }
218        sumf += d0*d1*sumi;
219    }
220    *s = sumf;
221}
222
223int main(int argc, char** argv) {
224
225    int nloop = argc > 1 ? atoi(argv[1]) : 10;
226    bool scalar = argc > 2 ? atoi(argv[2]) : false;
227    bool useQ4_1 = argc > 3 ? atoi(argv[3]) : false;
228
229    if (scalar && useQ4_1) {
230        printf("It is not possible to use Q4_1 quantization and scalar implementations\n");
231        return 1;
232    }
233
234    std::mt19937 rndm(1234);
235
236    std::vector<float> x1(kVecSize), y1(kVecSize);
237    int n4 = useQ4_1 ? kVecSize / QK4_1 : kVecSize / QK4_0; n4 = 64*((n4 + 63)/64);
238    int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64);
239
240    const auto * funcs_cpu = ggml_get_type_traits_cpu(useQ4_1 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q4_0);
241
242    std::vector<block_q4_0> q40;
243    std::vector<block_q4_1> q41;
244    if (useQ4_1) q41.resize(n4);
245    else q40.resize(n4);
246    std::vector<block_q8_0> q8(n8);
247    double sumt = 0, sumt2 = 0, maxt = 0;
248    double sumqt = 0, sumqt2 = 0, maxqt = 0;
249    double sum = 0, sumq = 0, exactSum = 0;
250    for (int iloop=0; iloop<nloop; ++iloop) {
251
252        // Fill vector x with random numbers
253        fillRandomGaussianFloats(x1, rndm);
254
255        // Fill vector y with random numbers
256        fillRandomGaussianFloats(y1, rndm);
257
258        // Compute the exact dot product
259        for (int k=0; k<kVecSize; ++k) exactSum += x1[k]*y1[k];
260
261        // quantize x.
262        // Note, we do not include this in the timing as in practical application
263        // we already have the quantized model weights.
264        if (useQ4_1) {
265            funcs_cpu->from_float(x1.data(), q41.data(), kVecSize);
266        } else {
267            funcs_cpu->from_float(x1.data(), q40.data(), kVecSize);
268        }
269
270        // Now measure time the dot product needs using the "scalar" version above
271        auto t1 = std::chrono::high_resolution_clock::now();
272        if (useQ4_1) sum += dot41(kVecSize / QK4_1, q41.data(), y1.data());
273        else sum += dot(kVecSize / QK4_0, q40.data(), y1.data());
274        auto t2 = std::chrono::high_resolution_clock::now();
275        auto t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
276        sumt += t; sumt2 += t*t; maxt = std::max(maxt, t);
277
278        // And now measure the time needed to quantize y and perform the dot product with the quantized y
279        t1 = std::chrono::high_resolution_clock::now();
280        float result;
281        if (scalar) {
282            quantize_row_q8_0_reference(y1.data(), q8.data(), kVecSize);
283            dot_q4_q8(kVecSize, &result, q40.data(), q8.data());
284        }
285        else {
286            const auto * vdot = ggml_get_type_traits_cpu(funcs_cpu->vec_dot_type);
287            vdot->from_float(y1.data(), q8.data(), kVecSize);
288            if (useQ4_1) funcs_cpu->vec_dot(kVecSize, &result, 0, q41.data(), 0, q8.data(), 0, 1);
289            else funcs_cpu->vec_dot(kVecSize, &result, 0, q40.data(), 0, q8.data(), 0, 1);
290        }
291        sumq += result;
292        t2 = std::chrono::high_resolution_clock::now();
293        t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
294        sumqt += t; sumqt2 += t*t; maxqt = std::max(maxqt, t);
295
296    }
297
298    // Report the time (and the average of the dot products so the compiler does not come up with the idea
299    // of optimizing away the function calls after figuring that the result is not used).
300    sum /= nloop; sumq /= nloop;
301    exactSum /= nloop;
302    printf("Exact result: <dot> = %g\n",exactSum);
303    printf("<dot> = %g, %g\n",sum,sumq);
304    sumt /= nloop; sumt2 /= nloop; sumt2 -= sumt*sumt;
305    if (sumt2 > 0) sumt2 = sqrt(sumt2);
306    printf("time = %g +/- %g us. maxt = %g us\n",sumt,sumt2,maxt);
307    sumqt /= nloop; sumqt2 /= nloop; sumqt2 -= sumqt*sumqt;
308    if (sumqt2 > 0) sumqt2 = sqrt(sumqt2);
309    printf("timeq = %g +/- %g us. maxt = %g us\n",sumqt,sumqt2,maxqt);
310    return 0;
311}