diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c | 342 |
1 files changed, 342 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c new file mode 100644 index 0000000..ce879bf --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c | |||
| @@ -0,0 +1,342 @@ | |||
| 1 | #pragma clang diagnostic ignored "-Wunused-variable" | ||
| 2 | #pragma clang diagnostic ignored "-Wunused-function" | ||
| 3 | #pragma clang diagnostic ignored "-Wunused-but-set-variable" | ||
| 4 | |||
| 5 | #include <HAP_farf.h> | ||
| 6 | #include <HAP_perf.h> | ||
| 7 | |||
| 8 | #include <math.h> | ||
| 9 | #include <string.h> | ||
| 10 | |||
| 11 | #include "hex-dma.h" | ||
| 12 | #include "hvx-utils.h" | ||
| 13 | |||
| 14 | #define GGML_COMMON_DECL_C | ||
| 15 | #include "ggml-common.h" | ||
| 16 | #include "htp-ctx.h" | ||
| 17 | #include "htp-msg.h" | ||
| 18 | #include "htp-ops.h" | ||
| 19 | |||
| 20 | #define htp_unary_preamble \ | ||
| 21 | const uint32_t ne00 = src->ne[0]; \ | ||
| 22 | const uint32_t ne01 = src->ne[1]; \ | ||
| 23 | const uint32_t ne02 = src->ne[2]; \ | ||
| 24 | const uint32_t ne03 = src->ne[3]; \ | ||
| 25 | \ | ||
| 26 | const uint32_t ne0 = dst->ne[0]; \ | ||
| 27 | const uint32_t ne1 = dst->ne[1]; \ | ||
| 28 | const uint32_t ne2 = dst->ne[2]; \ | ||
| 29 | const uint32_t ne3 = dst->ne[3]; \ | ||
| 30 | \ | ||
| 31 | const uint32_t nb00 = src->nb[0]; \ | ||
| 32 | const uint32_t nb01 = src->nb[1]; \ | ||
| 33 | const uint32_t nb02 = src->nb[2]; \ | ||
| 34 | const uint32_t nb03 = src->nb[3]; \ | ||
| 35 | \ | ||
| 36 | const uint32_t nb0 = dst->nb[0]; \ | ||
| 37 | const uint32_t nb1 = dst->nb[1]; \ | ||
| 38 | const uint32_t nb2 = dst->nb[2]; \ | ||
| 39 | const uint32_t nb3 = dst->nb[3]; | ||
| 40 | |||
| 41 | static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, | ||
| 42 | uint8_t * restrict dst, | ||
| 43 | uint8_t * restrict pad, | ||
| 44 | const int num_elems, | ||
| 45 | float epsilon) { | ||
| 46 | const HVX_Vector * restrict v_src = (HVX_Vector *) src; | ||
| 47 | HVX_Vector * restrict v_dst = (HVX_Vector *) dst; | ||
| 48 | |||
| 49 | HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); | ||
| 50 | HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); | ||
| 51 | |||
| 52 | int step_of_1 = num_elems >> 5; | ||
| 53 | #pragma unroll(4) | ||
| 54 | for (int i = 0; i < step_of_1; i++) { | ||
| 55 | HVX_Vector v1 = v_src[i]; | ||
| 56 | HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); | ||
| 57 | sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); | ||
| 58 | } | ||
| 59 | |||
| 60 | HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); | ||
| 61 | sum_v = hvx_vec_repl4(reduced_sum); | ||
| 62 | |||
| 63 | HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); | ||
| 64 | HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); | ||
| 65 | HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); | ||
| 66 | HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); | ||
| 67 | |||
| 68 | HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); | ||
| 69 | |||
| 70 | #pragma unroll(4) | ||
| 71 | for (int i = 0; i < step_of_1; i++) { | ||
| 72 | HVX_Vector v1 = v_src[i]; | ||
| 73 | HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); | ||
| 74 | v_dst[i] = Q6_Vsf_equals_Vqf32(v2); | ||
| 75 | } | ||
| 76 | } | ||
| 77 | |||
| 78 | static void scale_htp_f32(const float * restrict src, | ||
| 79 | float * restrict dst, | ||
| 80 | uint8_t * restrict spad, | ||
| 81 | const uint32_t num_rows, | ||
| 82 | const uint32_t row_elems, | ||
| 83 | const size_t row_size, | ||
| 84 | int32_t * op_params, | ||
| 85 | int opt_path) { | ||
| 86 | float scale = 0.f; | ||
| 87 | float bias = 0.f; | ||
| 88 | memcpy(&scale, &op_params[0], sizeof(float)); | ||
| 89 | memcpy(&bias, &op_params[1], sizeof(float)); | ||
| 90 | |||
| 91 | for (uint32_t ir = 0; ir < num_rows; ir++) { | ||
| 92 | const float * restrict src_local = src + (ir * row_elems); | ||
| 93 | float * restrict dst_local = dst + (ir * row_elems); | ||
| 94 | |||
| 95 | if (ir + 1 < num_rows) { | ||
| 96 | hex_l2fetch(src_local + row_elems, row_size, row_size, 1); | ||
| 97 | } | ||
| 98 | |||
| 99 | hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); | ||
| 100 | } | ||
| 101 | } | ||
| 102 | |||
| 103 | static void rms_norm_htp_f32(const float * restrict src, | ||
| 104 | float * restrict dst, | ||
| 105 | uint8_t * restrict spad, | ||
| 106 | const uint32_t num_rows, | ||
| 107 | const uint32_t row_elems, | ||
| 108 | const size_t row_size, | ||
| 109 | int32_t * op_params, | ||
| 110 | int opt_path) { | ||
| 111 | float epsilon = 0.f; | ||
| 112 | memcpy(&epsilon, op_params, sizeof(float)); | ||
| 113 | |||
| 114 | for (uint32_t ir = 0; ir < num_rows; ir++) { | ||
| 115 | const float * restrict src_local = src + (ir * row_elems); | ||
| 116 | float * restrict dst_local = dst + (ir * row_elems); | ||
| 117 | |||
| 118 | if (ir + 1 < num_rows) { | ||
| 119 | hex_l2fetch(src_local + row_elems, row_size, row_size, 1); | ||
| 120 | } | ||
| 121 | |||
| 122 | if (1 == opt_path) { | ||
| 123 | hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); | ||
| 124 | } else { | ||
| 125 | float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems); | ||
| 126 | |||
| 127 | const float mean = sum / row_elems; | ||
| 128 | const float scale = 1.0f / sqrtf(mean + epsilon); | ||
| 129 | |||
| 130 | hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); | ||
| 131 | } | ||
| 132 | } | ||
| 133 | } | ||
| 134 | |||
| 135 | static void sqr_htp_f32(const float * restrict src, | ||
| 136 | float * restrict dst, | ||
| 137 | uint8_t * restrict spad, | ||
| 138 | const uint32_t num_rows, | ||
| 139 | const uint32_t row_elems, | ||
| 140 | const size_t row_size, | ||
| 141 | int32_t * op_params, | ||
| 142 | int opt_path) { | ||
| 143 | |||
| 144 | for (uint32_t ir = 0; ir < num_rows; ir++) { | ||
| 145 | const float * restrict src_local = src + (ir * row_elems); | ||
| 146 | float * restrict dst_local = dst + (ir * row_elems); | ||
| 147 | |||
| 148 | if (ir + 1 < num_rows) { | ||
| 149 | hex_l2fetch(src_local + row_elems, row_size, row_size, 1); | ||
| 150 | } | ||
| 151 | |||
| 152 | if (1 == opt_path) { | ||
| 153 | hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); | ||
| 154 | } else { | ||
| 155 | hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); | ||
| 156 | } | ||
| 157 | } | ||
| 158 | } | ||
| 159 | |||
| 160 | static void sqrt_htp_f32(const float * restrict src, | ||
| 161 | float * restrict dst, | ||
| 162 | uint8_t * restrict spad, | ||
| 163 | const uint32_t num_rows, | ||
| 164 | const uint32_t row_elems, | ||
| 165 | const size_t row_size, | ||
| 166 | int32_t * op_params, | ||
| 167 | int opt_path) { | ||
| 168 | |||
| 169 | for (uint32_t ir = 0; ir < num_rows; ir++) { | ||
| 170 | const float * restrict src_local = src + (ir * row_elems); | ||
| 171 | float * restrict dst_local = dst + (ir * row_elems); | ||
| 172 | |||
| 173 | if (ir + 1 < num_rows) { | ||
| 174 | hex_l2fetch(src_local + row_elems, row_size, row_size, 1); | ||
| 175 | } | ||
| 176 | |||
| 177 | if (1 == opt_path) { | ||
| 178 | hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); | ||
| 179 | } else { | ||
| 180 | hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); | ||
| 181 | } | ||
| 182 | } | ||
| 183 | } | ||
| 184 | |||
| 185 | static void unary_job_f32_per_thread(const struct htp_tensor * src, | ||
| 186 | struct htp_tensor * dst, | ||
| 187 | uint8_t * spad, | ||
| 188 | int htp_op, | ||
| 189 | int32_t * op_params, | ||
| 190 | uint32_t nth, | ||
| 191 | uint32_t ith, | ||
| 192 | uint32_t src0_nrows_per_thread) { | ||
| 193 | htp_unary_preamble; | ||
| 194 | |||
| 195 | const size_t src0_row_size = nb01; | ||
| 196 | const size_t dst_row_size = nb1; | ||
| 197 | |||
| 198 | const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows | ||
| 199 | |||
| 200 | const uint32_t src0_start_row = src0_nrows_per_thread * ith; | ||
| 201 | const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); | ||
| 202 | |||
| 203 | // no work for this thread | ||
| 204 | if (src0_start_row >= src0_end_row) { | ||
| 205 | return; | ||
| 206 | } | ||
| 207 | |||
| 208 | uint64_t t1, t2; | ||
| 209 | t1 = HAP_perf_get_qtimer_count(); | ||
| 210 | |||
| 211 | int is_aligned = 1; | ||
| 212 | int opt_path = 0; | ||
| 213 | if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) { | ||
| 214 | is_aligned = 0; | ||
| 215 | } | ||
| 216 | if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { | ||
| 217 | opt_path = 1; | ||
| 218 | } | ||
| 219 | |||
| 220 | const uint8_t * restrict data_src = (const uint8_t *) src->data; | ||
| 221 | uint8_t * restrict data_dst = (uint8_t *) dst->data; | ||
| 222 | |||
| 223 | const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); | ||
| 224 | float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); | ||
| 225 | uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01); | ||
| 226 | |||
| 227 | switch (htp_op) { | ||
| 228 | case HTP_OP_RMS_NORM: | ||
| 229 | rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); | ||
| 230 | break; | ||
| 231 | case HTP_OP_SCALE: | ||
| 232 | scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); | ||
| 233 | break; | ||
| 234 | case HTP_OP_SQR: | ||
| 235 | sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); | ||
| 236 | break; | ||
| 237 | case HTP_OP_SQRT: | ||
| 238 | sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); | ||
| 239 | break; | ||
| 240 | |||
| 241 | default: | ||
| 242 | break; | ||
| 243 | } | ||
| 244 | |||
| 245 | t2 = HAP_perf_get_qtimer_count(); | ||
| 246 | |||
| 247 | FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0], | ||
| 248 | src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2], | ||
| 249 | dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); | ||
| 250 | } | ||
| 251 | |||
| 252 | static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { | ||
| 253 | struct htp_ops_context * octx = (struct htp_ops_context *) data; | ||
| 254 | |||
| 255 | unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i, | ||
| 256 | octx->src0_nrows_per_thread); | ||
| 257 | } | ||
| 258 | |||
| 259 | static int execute_op_unary_f32(struct htp_ops_context * octx) { | ||
| 260 | int err = HTP_STATUS_OK; | ||
| 261 | |||
| 262 | const struct htp_tensor * src0 = &octx->src0; | ||
| 263 | struct htp_tensor * dst = &octx->dst; | ||
| 264 | |||
| 265 | worker_callback_t unary_op_func; | ||
| 266 | const char * op_type = NULL; | ||
| 267 | |||
| 268 | switch (octx->op) { | ||
| 269 | case HTP_OP_RMS_NORM: | ||
| 270 | unary_op_func = unary_job_dispatcher_f32; | ||
| 271 | op_type = "rmsnorm-f32"; | ||
| 272 | break; | ||
| 273 | case HTP_OP_SCALE: | ||
| 274 | unary_op_func = unary_job_dispatcher_f32; | ||
| 275 | op_type = "scale-f32"; | ||
| 276 | break; | ||
| 277 | case HTP_OP_SQR: | ||
| 278 | unary_op_func = unary_job_dispatcher_f32; | ||
| 279 | op_type = "sqr-f32"; | ||
| 280 | break; | ||
| 281 | case HTP_OP_SQRT: | ||
| 282 | unary_op_func = unary_job_dispatcher_f32; | ||
| 283 | op_type = "sqrt-f32"; | ||
| 284 | break; | ||
| 285 | |||
| 286 | default: | ||
| 287 | FARF(ERROR, "Unsupported unary Op %u\n", octx->op); | ||
| 288 | return HTP_STATUS_NO_SUPPORT; | ||
| 289 | } | ||
| 290 | |||
| 291 | const int n_threads = octx->n_threads; | ||
| 292 | const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; | ||
| 293 | |||
| 294 | const size_t src0_row_size = src0->nb[1]; | ||
| 295 | const size_t dst_row_size = dst->nb[1]; | ||
| 296 | |||
| 297 | // VTCM scratchpads for all tensors | ||
| 298 | octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; | ||
| 299 | octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; | ||
| 300 | |||
| 301 | size_t spad_size = octx->src0_spad.size + octx->dst_spad.size; | ||
| 302 | |||
| 303 | FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, | ||
| 304 | src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], | ||
| 305 | octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); | ||
| 306 | |||
| 307 | // Make sure the reserved vtcm size is sufficient | ||
| 308 | if (octx->ctx->vtcm_size < spad_size) { | ||
| 309 | FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, | ||
| 310 | spad_size); | ||
| 311 | return HTP_STATUS_VTCM_TOO_SMALL; | ||
| 312 | } | ||
| 313 | |||
| 314 | octx->src0_spad.data = octx->ctx->vtcm_base; | ||
| 315 | octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; | ||
| 316 | |||
| 317 | if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { | ||
| 318 | uint32_t n_jobs = MIN(n_threads, src0_nrows); | ||
| 319 | |||
| 320 | octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; | ||
| 321 | |||
| 322 | worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs); | ||
| 323 | } | ||
| 324 | |||
| 325 | return err; | ||
| 326 | } | ||
| 327 | |||
| 328 | int op_unary(struct htp_ops_context * octx) { | ||
| 329 | int err = HTP_STATUS_OK; | ||
| 330 | |||
| 331 | switch (octx->src0.type) { | ||
| 332 | case HTP_TYPE_F32: | ||
| 333 | err = execute_op_unary_f32(octx); | ||
| 334 | break; | ||
| 335 | |||
| 336 | default: | ||
| 337 | err = HTP_STATUS_NO_SUPPORT; | ||
| 338 | break; | ||
| 339 | } | ||
| 340 | |||
| 341 | return err; | ||
| 342 | } | ||
