aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c')
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c342
1 files changed, 342 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c
new file mode 100644
index 0000000..ce879bf
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c
@@ -0,0 +1,342 @@
1#pragma clang diagnostic ignored "-Wunused-variable"
2#pragma clang diagnostic ignored "-Wunused-function"
3#pragma clang diagnostic ignored "-Wunused-but-set-variable"
4
5#include <HAP_farf.h>
6#include <HAP_perf.h>
7
8#include <math.h>
9#include <string.h>
10
11#include "hex-dma.h"
12#include "hvx-utils.h"
13
14#define GGML_COMMON_DECL_C
15#include "ggml-common.h"
16#include "htp-ctx.h"
17#include "htp-msg.h"
18#include "htp-ops.h"
19
20#define htp_unary_preamble \
21 const uint32_t ne00 = src->ne[0]; \
22 const uint32_t ne01 = src->ne[1]; \
23 const uint32_t ne02 = src->ne[2]; \
24 const uint32_t ne03 = src->ne[3]; \
25 \
26 const uint32_t ne0 = dst->ne[0]; \
27 const uint32_t ne1 = dst->ne[1]; \
28 const uint32_t ne2 = dst->ne[2]; \
29 const uint32_t ne3 = dst->ne[3]; \
30 \
31 const uint32_t nb00 = src->nb[0]; \
32 const uint32_t nb01 = src->nb[1]; \
33 const uint32_t nb02 = src->nb[2]; \
34 const uint32_t nb03 = src->nb[3]; \
35 \
36 const uint32_t nb0 = dst->nb[0]; \
37 const uint32_t nb1 = dst->nb[1]; \
38 const uint32_t nb2 = dst->nb[2]; \
39 const uint32_t nb3 = dst->nb[3];
40
41static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
42 uint8_t * restrict dst,
43 uint8_t * restrict pad,
44 const int num_elems,
45 float epsilon) {
46 const HVX_Vector * restrict v_src = (HVX_Vector *) src;
47 HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
48
49 HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
50 HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
51
52 int step_of_1 = num_elems >> 5;
53 #pragma unroll(4)
54 for (int i = 0; i < step_of_1; i++) {
55 HVX_Vector v1 = v_src[i];
56 HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
57 sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
58 }
59
60 HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
61 sum_v = hvx_vec_repl4(reduced_sum);
62
63 HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
64 HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
65 HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
66 HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
67
68 HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
69
70 #pragma unroll(4)
71 for (int i = 0; i < step_of_1; i++) {
72 HVX_Vector v1 = v_src[i];
73 HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
74 v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
75 }
76}
77
78static void scale_htp_f32(const float * restrict src,
79 float * restrict dst,
80 uint8_t * restrict spad,
81 const uint32_t num_rows,
82 const uint32_t row_elems,
83 const size_t row_size,
84 int32_t * op_params,
85 int opt_path) {
86 float scale = 0.f;
87 float bias = 0.f;
88 memcpy(&scale, &op_params[0], sizeof(float));
89 memcpy(&bias, &op_params[1], sizeof(float));
90
91 for (uint32_t ir = 0; ir < num_rows; ir++) {
92 const float * restrict src_local = src + (ir * row_elems);
93 float * restrict dst_local = dst + (ir * row_elems);
94
95 if (ir + 1 < num_rows) {
96 hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
97 }
98
99 hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
100 }
101}
102
103static void rms_norm_htp_f32(const float * restrict src,
104 float * restrict dst,
105 uint8_t * restrict spad,
106 const uint32_t num_rows,
107 const uint32_t row_elems,
108 const size_t row_size,
109 int32_t * op_params,
110 int opt_path) {
111 float epsilon = 0.f;
112 memcpy(&epsilon, op_params, sizeof(float));
113
114 for (uint32_t ir = 0; ir < num_rows; ir++) {
115 const float * restrict src_local = src + (ir * row_elems);
116 float * restrict dst_local = dst + (ir * row_elems);
117
118 if (ir + 1 < num_rows) {
119 hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
120 }
121
122 if (1 == opt_path) {
123 hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
124 } else {
125 float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
126
127 const float mean = sum / row_elems;
128 const float scale = 1.0f / sqrtf(mean + epsilon);
129
130 hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
131 }
132 }
133}
134
135static void sqr_htp_f32(const float * restrict src,
136 float * restrict dst,
137 uint8_t * restrict spad,
138 const uint32_t num_rows,
139 const uint32_t row_elems,
140 const size_t row_size,
141 int32_t * op_params,
142 int opt_path) {
143
144 for (uint32_t ir = 0; ir < num_rows; ir++) {
145 const float * restrict src_local = src + (ir * row_elems);
146 float * restrict dst_local = dst + (ir * row_elems);
147
148 if (ir + 1 < num_rows) {
149 hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
150 }
151
152 if (1 == opt_path) {
153 hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
154 } else {
155 hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
156 }
157 }
158}
159
160static void sqrt_htp_f32(const float * restrict src,
161 float * restrict dst,
162 uint8_t * restrict spad,
163 const uint32_t num_rows,
164 const uint32_t row_elems,
165 const size_t row_size,
166 int32_t * op_params,
167 int opt_path) {
168
169 for (uint32_t ir = 0; ir < num_rows; ir++) {
170 const float * restrict src_local = src + (ir * row_elems);
171 float * restrict dst_local = dst + (ir * row_elems);
172
173 if (ir + 1 < num_rows) {
174 hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
175 }
176
177 if (1 == opt_path) {
178 hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
179 } else {
180 hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
181 }
182 }
183}
184
185static void unary_job_f32_per_thread(const struct htp_tensor * src,
186 struct htp_tensor * dst,
187 uint8_t * spad,
188 int htp_op,
189 int32_t * op_params,
190 uint32_t nth,
191 uint32_t ith,
192 uint32_t src0_nrows_per_thread) {
193 htp_unary_preamble;
194
195 const size_t src0_row_size = nb01;
196 const size_t dst_row_size = nb1;
197
198 const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
199
200 const uint32_t src0_start_row = src0_nrows_per_thread * ith;
201 const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
202
203 // no work for this thread
204 if (src0_start_row >= src0_end_row) {
205 return;
206 }
207
208 uint64_t t1, t2;
209 t1 = HAP_perf_get_qtimer_count();
210
211 int is_aligned = 1;
212 int opt_path = 0;
213 if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
214 is_aligned = 0;
215 }
216 if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
217 opt_path = 1;
218 }
219
220 const uint8_t * restrict data_src = (const uint8_t *) src->data;
221 uint8_t * restrict data_dst = (uint8_t *) dst->data;
222
223 const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
224 float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
225 uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
226
227 switch (htp_op) {
228 case HTP_OP_RMS_NORM:
229 rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
230 break;
231 case HTP_OP_SCALE:
232 scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
233 break;
234 case HTP_OP_SQR:
235 sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
236 break;
237 case HTP_OP_SQRT:
238 sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
239 break;
240
241 default:
242 break;
243 }
244
245 t2 = HAP_perf_get_qtimer_count();
246
247 FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
248 src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
249 dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
250}
251
252static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
253 struct htp_ops_context * octx = (struct htp_ops_context *) data;
254
255 unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
256 octx->src0_nrows_per_thread);
257}
258
259static int execute_op_unary_f32(struct htp_ops_context * octx) {
260 int err = HTP_STATUS_OK;
261
262 const struct htp_tensor * src0 = &octx->src0;
263 struct htp_tensor * dst = &octx->dst;
264
265 worker_callback_t unary_op_func;
266 const char * op_type = NULL;
267
268 switch (octx->op) {
269 case HTP_OP_RMS_NORM:
270 unary_op_func = unary_job_dispatcher_f32;
271 op_type = "rmsnorm-f32";
272 break;
273 case HTP_OP_SCALE:
274 unary_op_func = unary_job_dispatcher_f32;
275 op_type = "scale-f32";
276 break;
277 case HTP_OP_SQR:
278 unary_op_func = unary_job_dispatcher_f32;
279 op_type = "sqr-f32";
280 break;
281 case HTP_OP_SQRT:
282 unary_op_func = unary_job_dispatcher_f32;
283 op_type = "sqrt-f32";
284 break;
285
286 default:
287 FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
288 return HTP_STATUS_NO_SUPPORT;
289 }
290
291 const int n_threads = octx->n_threads;
292 const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
293
294 const size_t src0_row_size = src0->nb[1];
295 const size_t dst_row_size = dst->nb[1];
296
297 // VTCM scratchpads for all tensors
298 octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
299 octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
300
301 size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
302
303 FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
304 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
305 octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
306
307 // Make sure the reserved vtcm size is sufficient
308 if (octx->ctx->vtcm_size < spad_size) {
309 FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
310 spad_size);
311 return HTP_STATUS_VTCM_TOO_SMALL;
312 }
313
314 octx->src0_spad.data = octx->ctx->vtcm_base;
315 octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
316
317 if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
318 uint32_t n_jobs = MIN(n_threads, src0_nrows);
319
320 octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
321
322 worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
323 }
324
325 return err;
326}
327
328int op_unary(struct htp_ops_context * octx) {
329 int err = HTP_STATUS_OK;
330
331 switch (octx->src0.type) {
332 case HTP_TYPE_F32:
333 err = execute_op_unary_f32(octx);
334 break;
335
336 default:
337 err = HTP_STATUS_NO_SUPPORT;
338 break;
339 }
340
341 return err;
342}