1#include <string.h>
  2#include <stdlib.h>
  3#include <math.h>
  4#include <HAP_farf.h>
  5#include <HAP_perf.h>
  6
  7#define GGML_COMMON_DECL_C
  8#include "ggml-common.h"
  9#include "ggml.h"
 10
 11#include "hvx-utils.h"
 12#include "hex-dma.h"
 13
 14#include "htp-ctx.h"
 15#include "htp-msg.h"
 16#include "htp-ops.h"
 17
 18#ifndef MIN
 19#define MIN(a, b) ((a) < (b) ? (a) : (b))
 20#endif
 21
 22struct htp_argsort_context {
 23    struct htp_ops_context * octx;
 24    uint32_t                 nrows_per_thread;
 25};
 26
 27static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
 28{
 29    const HVX_Vector one  = Q6_V_vsplat_R(1);
 30    const HVX_Vector zero = Q6_V_vzero();
 31
 32    HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
 33    HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
 34    HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
 35    return hvx_vec_get_i32(sum) == 32;
 36}
 37
 38// Sorts values and mirrors swaps to indices.
 39static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
 40    if (left >= right) return;
 41
 42    int pivot_idx = (left + right) / 2;
 43    float pivot = values[pivot_idx];
 44    int i = left;
 45    int j = right;
 46
 47    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
 48    while (i <= j) {
 49        // Vectorized scan for i
 50        while (i <= j) {
 51            // Check if we have at least one full vector
 52            if (i + 32 <= j) {
 53                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
 54                if (all_greater_f32(pivot_vec, vals_vec)) {
 55                    // If all elements are < pivot, we can skip this whole block
 56                    i += 32;
 57                    continue;
 58                }
 59            }
 60
 61            // Scalar fallback / cleanup
 62            if (values[i] < pivot) {
 63                i++;
 64            } else {
 65                break;
 66            }
 67        }
 68
 69        // Vectorized scan for j
 70        while (i <= j) {
 71            if (j - 32 >= i) {
 72                // Load 32 elements ending at j.
 73                // Since we want `values[j] > pivot`, let's load from j-31 to j.
 74                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
 75                if (all_greater_f32(vals_vec, pivot_vec)) {
 76                    j -= 32;
 77                    continue;
 78                }
 79            }
 80
 81            if (values[j] > pivot) {
 82                j--;
 83            } else {
 84                break;
 85            }
 86        }
 87
 88        if (i <= j) {
 89            float tmp_val = values[i];
 90            values[i] = values[j];
 91            values[j] = tmp_val;
 92
 93            int32_t tmp_idx = indices[i];
 94            indices[i] = indices[j];
 95            indices[j] = tmp_idx;
 96            i++;
 97            j--;
 98        }
 99    }
100
101    if (left < j) quicksort_values_indices_asc(values, indices, left, j);
102    if (i < right) quicksort_values_indices_asc(values, indices, i, right);
103}
104
105static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
106    if (left >= right) return;
107
108    int pivot_idx = (left + right) / 2;
109    float pivot = values[pivot_idx];
110    int i = left;
111    int j = right;
112
113    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
114
115    while (i <= j) {
116        // Vectorized scan for i (values[i] > pivot)
117        while (i <= j) {
118            if (i + 32 <= j) {
119                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
120                if (all_greater_f32(vals_vec, pivot_vec)) {
121                    i += 32;
122                    continue;
123                }
124            }
125
126            if (values[i] > pivot) {
127                i++;
128            } else {
129                break;
130            }
131        }
132
133        // Vectorized scan for j (values[j] < pivot)
134        while (i <= j) {
135            if (j - 32 >= i) {
136                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
137                if (all_greater_f32(pivot_vec, vals_vec)) {
138                    j -= 32;
139                    continue;
140                }
141            }
142
143            if (values[j] < pivot) {
144                j--;
145            } else {
146                break;
147            }
148        }
149
150        if (i <= j) {
151            float tmp_val = values[i];
152            values[i] = values[j];
153            values[j] = tmp_val;
154
155            int32_t tmp_idx = indices[i];
156            indices[i] = indices[j];
157            indices[j] = tmp_idx;
158            i++;
159            j--;
160        }
161    }
162
163    if (left < j) quicksort_values_indices_desc(values, indices, left, j);
164    if (i < right) quicksort_values_indices_desc(values, indices, i, right);
165}
166
167static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
168    struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
169    struct htp_ops_context * octx = actx->octx;
170
171    // Unpack context
172    const struct htp_tensor * src0 = &octx->src0;
173    const struct htp_tensor * dst = &octx->dst;
174
175    // Scratchpad memory
176    uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
177
178    // Dimensions
179    uint32_t ne00 = src0->ne[0];
180    uint32_t ne01 = src0->ne[1];
181    uint32_t ne02 = src0->ne[2];
182    uint32_t ne03 = src0->ne[3];
183
184    uint32_t nb01 = src0->nb[1];
185    //uint32_t nb02 = src0->nb[2];
186    //uint32_t nb03 = src0->nb[3];
187
188    uint32_t nb1 = dst->nb[1];
189    //uint32_t nb2 = dst->nb[2];
190    //uint32_t nb3 = dst->nb[3];
191
192    // Sort order
193    enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
194
195    // Rows to process
196    uint32_t total_rows = ne01 * ne02 * ne03;
197    uint32_t rows_per_thread = actx->nrows_per_thread;
198    uint32_t start_row = rows_per_thread * i;
199    uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
200
201    // Scratchpad layout:
202    // We need space for one row of float data (values) and one row of int32 indices.
203    // values: ne00 * sizeof(float)
204    // indices: ne00 * sizeof(int32_t)
205    // Padded to 128 bytes.
206
207    size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
208    float * values_buf = (float *) spad;
209    int32_t * indices_buf = (int32_t *) (spad + values_size);
210
211    for (uint32_t r = start_row; r < end_row; r++) {
212        uint32_t src_offset = r * nb01;
213        uint32_t dst_offset = r * nb1;
214
215        uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
216        uint8_t * dst_ptr = (uint8_t *) dst->data  + dst_offset;
217
218        hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
219        hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
220
221        // Initialize indices
222        for (uint32_t j = 0; j < ne00; j++) {
223            indices_buf[j] = j;
224        }
225
226        // Sort values and mirror swaps to indices
227        if (order == GGML_SORT_ORDER_ASC) {
228            quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
229        } else {
230            quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
231        }
232
233        // Copy indices back to DDR
234        hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
235    }
236}
237
238int op_argsort(struct htp_ops_context * octx) {
239    // Check supported types
240    if (octx->src0.type != HTP_TYPE_F32) {
241        return HTP_STATUS_NO_SUPPORT;
242    }
243
244    // Allocate scratchpad
245    // We need 1 row of float + 1 row of int32 per thread.
246    uint32_t ne00 = octx->src0.ne[0];
247    size_t values_size  = hex_round_up(ne00 * sizeof(float), 128);
248    size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
249    size_t spad_per_thread = values_size + indices_size;
250
251    // Make sure we round up to 256 for alignment requirements
252    spad_per_thread = hex_round_up(spad_per_thread, 256);
253
254    size_t total_spad_size = spad_per_thread * octx->n_threads;
255
256    if (octx->ctx->vtcm_size < total_spad_size) {
257        FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
258        return HTP_STATUS_VTCM_TOO_SMALL;
259    }
260
261    octx->src0_spad.data = octx->ctx->vtcm_base;
262    octx->src0_spad.size = total_spad_size;
263    octx->src0_spad.size_per_thread = spad_per_thread;
264
265    FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
266         octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
267         octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
268         octx->src0.data, octx->dst.data);
269
270    uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
271    uint32_t n_jobs = MIN(total_rows, octx->n_threads);
272
273    struct htp_argsort_context actx;
274    actx.octx = octx;
275    actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
276
277    // Run jobs
278    worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
279
280    return HTP_STATUS_OK;
281}