1#include <string.h>
2#include <stdlib.h>
3#include <math.h>
4#include <HAP_farf.h>
5#include <HAP_perf.h>
6
7#define GGML_COMMON_DECL_C
8#include "ggml-common.h"
9#include "ggml.h"
10
11#include "hvx-utils.h"
12#include "hex-dma.h"
13
14#include "htp-ctx.h"
15#include "htp-msg.h"
16#include "htp-ops.h"
17
18#ifndef MIN
19#define MIN(a, b) ((a) < (b) ? (a) : (b))
20#endif
21
22struct htp_argsort_context {
23 struct htp_ops_context * octx;
24 uint32_t nrows_per_thread;
25};
26
27static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
28{
29 const HVX_Vector one = Q6_V_vsplat_R(1);
30 const HVX_Vector zero = Q6_V_vzero();
31
32 HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
33 HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
34 HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
35 return hvx_vec_get_i32(sum) == 32;
36}
37
38// Sorts values and mirrors swaps to indices.
39static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
40 if (left >= right) return;
41
42 int pivot_idx = (left + right) / 2;
43 float pivot = values[pivot_idx];
44 int i = left;
45 int j = right;
46
47 HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
48 while (i <= j) {
49 // Vectorized scan for i
50 while (i <= j) {
51 // Check if we have at least one full vector
52 if (i + 32 <= j) {
53 HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
54 if (all_greater_f32(pivot_vec, vals_vec)) {
55 // If all elements are < pivot, we can skip this whole block
56 i += 32;
57 continue;
58 }
59 }
60
61 // Scalar fallback / cleanup
62 if (values[i] < pivot) {
63 i++;
64 } else {
65 break;
66 }
67 }
68
69 // Vectorized scan for j
70 while (i <= j) {
71 if (j - 32 >= i) {
72 // Load 32 elements ending at j.
73 // Since we want `values[j] > pivot`, let's load from j-31 to j.
74 HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
75 if (all_greater_f32(vals_vec, pivot_vec)) {
76 j -= 32;
77 continue;
78 }
79 }
80
81 if (values[j] > pivot) {
82 j--;
83 } else {
84 break;
85 }
86 }
87
88 if (i <= j) {
89 float tmp_val = values[i];
90 values[i] = values[j];
91 values[j] = tmp_val;
92
93 int32_t tmp_idx = indices[i];
94 indices[i] = indices[j];
95 indices[j] = tmp_idx;
96 i++;
97 j--;
98 }
99 }
100
101 if (left < j) quicksort_values_indices_asc(values, indices, left, j);
102 if (i < right) quicksort_values_indices_asc(values, indices, i, right);
103}
104
105static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
106 if (left >= right) return;
107
108 int pivot_idx = (left + right) / 2;
109 float pivot = values[pivot_idx];
110 int i = left;
111 int j = right;
112
113 HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
114
115 while (i <= j) {
116 // Vectorized scan for i (values[i] > pivot)
117 while (i <= j) {
118 if (i + 32 <= j) {
119 HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
120 if (all_greater_f32(vals_vec, pivot_vec)) {
121 i += 32;
122 continue;
123 }
124 }
125
126 if (values[i] > pivot) {
127 i++;
128 } else {
129 break;
130 }
131 }
132
133 // Vectorized scan for j (values[j] < pivot)
134 while (i <= j) {
135 if (j - 32 >= i) {
136 HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
137 if (all_greater_f32(pivot_vec, vals_vec)) {
138 j -= 32;
139 continue;
140 }
141 }
142
143 if (values[j] < pivot) {
144 j--;
145 } else {
146 break;
147 }
148 }
149
150 if (i <= j) {
151 float tmp_val = values[i];
152 values[i] = values[j];
153 values[j] = tmp_val;
154
155 int32_t tmp_idx = indices[i];
156 indices[i] = indices[j];
157 indices[j] = tmp_idx;
158 i++;
159 j--;
160 }
161 }
162
163 if (left < j) quicksort_values_indices_desc(values, indices, left, j);
164 if (i < right) quicksort_values_indices_desc(values, indices, i, right);
165}
166
167static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
168 struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
169 struct htp_ops_context * octx = actx->octx;
170
171 // Unpack context
172 const struct htp_tensor * src0 = &octx->src0;
173 const struct htp_tensor * dst = &octx->dst;
174
175 // Scratchpad memory
176 uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
177
178 // Dimensions
179 uint32_t ne00 = src0->ne[0];
180 uint32_t ne01 = src0->ne[1];
181 uint32_t ne02 = src0->ne[2];
182 uint32_t ne03 = src0->ne[3];
183
184 uint32_t nb01 = src0->nb[1];
185 //uint32_t nb02 = src0->nb[2];
186 //uint32_t nb03 = src0->nb[3];
187
188 uint32_t nb1 = dst->nb[1];
189 //uint32_t nb2 = dst->nb[2];
190 //uint32_t nb3 = dst->nb[3];
191
192 // Sort order
193 enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
194
195 // Rows to process
196 uint32_t total_rows = ne01 * ne02 * ne03;
197 uint32_t rows_per_thread = actx->nrows_per_thread;
198 uint32_t start_row = rows_per_thread * i;
199 uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
200
201 // Scratchpad layout:
202 // We need space for one row of float data (values) and one row of int32 indices.
203 // values: ne00 * sizeof(float)
204 // indices: ne00 * sizeof(int32_t)
205 // Padded to 128 bytes.
206
207 size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
208 float * values_buf = (float *) spad;
209 int32_t * indices_buf = (int32_t *) (spad + values_size);
210
211 for (uint32_t r = start_row; r < end_row; r++) {
212 uint32_t src_offset = r * nb01;
213 uint32_t dst_offset = r * nb1;
214
215 uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
216 uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
217
218 hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
219 hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
220
221 // Initialize indices
222 for (uint32_t j = 0; j < ne00; j++) {
223 indices_buf[j] = j;
224 }
225
226 // Sort values and mirror swaps to indices
227 if (order == GGML_SORT_ORDER_ASC) {
228 quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
229 } else {
230 quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
231 }
232
233 // Copy indices back to DDR
234 hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
235 }
236}
237
238int op_argsort(struct htp_ops_context * octx) {
239 // Check supported types
240 if (octx->src0.type != HTP_TYPE_F32) {
241 return HTP_STATUS_NO_SUPPORT;
242 }
243
244 // Allocate scratchpad
245 // We need 1 row of float + 1 row of int32 per thread.
246 uint32_t ne00 = octx->src0.ne[0];
247 size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
248 size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
249 size_t spad_per_thread = values_size + indices_size;
250
251 // Make sure we round up to 256 for alignment requirements
252 spad_per_thread = hex_round_up(spad_per_thread, 256);
253
254 size_t total_spad_size = spad_per_thread * octx->n_threads;
255
256 if (octx->ctx->vtcm_size < total_spad_size) {
257 FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
258 return HTP_STATUS_VTCM_TOO_SMALL;
259 }
260
261 octx->src0_spad.data = octx->ctx->vtcm_base;
262 octx->src0_spad.size = total_spad_size;
263 octx->src0_spad.size_per_thread = spad_per_thread;
264
265 FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
266 octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
267 octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
268 octx->src0.data, octx->dst.data);
269
270 uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
271 uint32_t n_jobs = MIN(total_rows, octx->n_threads);
272
273 struct htp_argsort_context actx;
274 actx.octx = octx;
275 actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
276
277 // Run jobs
278 worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
279
280 return HTP_STATUS_OK;
281}