diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/common/ngram-map.cpp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/common/ngram-map.cpp')
| -rw-r--r-- | llama.cpp/common/ngram-map.cpp | 530 |
1 files changed, 530 insertions, 0 deletions
diff --git a/llama.cpp/common/ngram-map.cpp b/llama.cpp/common/ngram-map.cpp new file mode 100644 index 0000000..ebf771a --- /dev/null +++ b/llama.cpp/common/ngram-map.cpp | |||
| @@ -0,0 +1,530 @@ | |||
| 1 | #include "common.h" | ||
| 2 | #include "log.h" | ||
| 3 | #include "ngram-map.h" | ||
| 4 | |||
| 5 | #include <cinttypes> | ||
| 6 | #include <cstdint> | ||
| 7 | #include <cstdio> | ||
| 8 | #include <sstream> | ||
| 9 | |||
| 10 | // prime number used for LCG hash function (32 bit), it is near (sqrt(5) - 1)/2 * 2^32. | ||
| 11 | #define LCG_FACTOR 2654435761UL | ||
| 12 | |||
| 13 | // Compute the LCG hash of a n-gram of size len at offset start. | ||
| 14 | static uint32_t common_ngram_map_hash(const llama_tokens & tokens, size_t start, size_t len) { | ||
| 15 | uint32_t hash = 0; | ||
| 16 | for (size_t i = 0; i < len; ++i) { | ||
| 17 | hash = hash * LCG_FACTOR + tokens[start + i]; | ||
| 18 | } | ||
| 19 | return hash; | ||
| 20 | } | ||
| 21 | |||
| 22 | // Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...]. | ||
| 23 | static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) { | ||
| 24 | std::ostringstream oss; | ||
| 25 | oss << '['; | ||
| 26 | for (size_t i = 0; i < length; ++i) { | ||
| 27 | if (i > 0) { | ||
| 28 | oss << ", "; | ||
| 29 | } | ||
| 30 | oss << inp[start + i]; | ||
| 31 | } | ||
| 32 | oss << ']'; | ||
| 33 | return oss.str(); | ||
| 34 | } | ||
| 35 | |||
| 36 | |||
| 37 | // n-gram simple | ||
| 38 | // | ||
| 39 | |||
| 40 | /** | ||
| 41 | * Perform speculative generation using the model's own token history. | ||
| 42 | * Searches for a matching pattern in the token history and returns draft tokens. | ||
| 43 | * | ||
| 44 | * @param state Current state of this implementation | ||
| 45 | * @param tokens Token history to search in | ||
| 46 | * @param sampled Last sampled token | ||
| 47 | * @return Vector of draft tokens, empty if no matching pattern is found | ||
| 48 | */ | ||
| 49 | llama_tokens common_ngram_simple_draft( | ||
| 50 | const common_ngram_simple_config & config, | ||
| 51 | const llama_tokens & tokens, llama_token sampled) { | ||
| 52 | |||
| 53 | // Simple implementation of self-speculative decoding without a draft model. | ||
| 54 | // | ||
| 55 | const size_t cur_len = tokens.size(); | ||
| 56 | |||
| 57 | const size_t n_draft_min = config.size_ngram; // size of n-gram to lookup in token history | ||
| 58 | const size_t n_draft_max = config.size_mgram; // the m-gram following the found n-gram is used for draft | ||
| 59 | |||
| 60 | // vector for tokens we want to verify. | ||
| 61 | // return empty vector if there is no match. | ||
| 62 | llama_tokens draft_tokens; | ||
| 63 | |||
| 64 | // We need at least n_draft_min + n_draft_max + 1 tokens. | ||
| 65 | if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) { | ||
| 66 | return draft_tokens; | ||
| 67 | } | ||
| 68 | |||
| 69 | // pattern search | ||
| 70 | llama_tokens pattern; | ||
| 71 | pattern.reserve(n_draft_min); | ||
| 72 | for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) { | ||
| 73 | pattern.push_back(tokens[j]); | ||
| 74 | } | ||
| 75 | pattern.push_back(sampled); // add the last token to the pattern | ||
| 76 | |||
| 77 | size_t match_pos = 0; // we ignore position 0, position 0 == no match | ||
| 78 | // search backwards, but skip the current match (we are currently there) | ||
| 79 | for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) { | ||
| 80 | bool match = true; | ||
| 81 | for (size_t k = 0; k < pattern.size(); ++k) { | ||
| 82 | if (tokens[j + k] != pattern[k]) { | ||
| 83 | match = false; | ||
| 84 | break; | ||
| 85 | } | ||
| 86 | } | ||
| 87 | if (match) { | ||
| 88 | match_pos = j; | ||
| 89 | break; | ||
| 90 | } | ||
| 91 | } | ||
| 92 | if (match_pos == 0) { | ||
| 93 | return draft_tokens; | ||
| 94 | } | ||
| 95 | |||
| 96 | const size_t copy_max = std::min( | ||
| 97 | n_draft_max, | ||
| 98 | cur_len - (match_pos + n_draft_min) | ||
| 99 | ); | ||
| 100 | if (copy_max < n_draft_min) { | ||
| 101 | return draft_tokens; | ||
| 102 | } | ||
| 103 | LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n", | ||
| 104 | __func__, cur_len, | ||
| 105 | match_pos, pattern.size(), copy_max); | ||
| 106 | |||
| 107 | draft_tokens.reserve(copy_max); | ||
| 108 | for (size_t j = 0; j < copy_max; ++j) { | ||
| 109 | draft_tokens.push_back(tokens[match_pos + n_draft_min + j]); | ||
| 110 | } | ||
| 111 | return draft_tokens; | ||
| 112 | } | ||
| 113 | |||
| 114 | |||
| 115 | // n-gram map | ||
| 116 | // | ||
| 117 | |||
| 118 | // maximum number of counted values of a ngram map value. | ||
| 119 | #define COMMON_NGRAM_MAX_VALUE_COUNT 16380 | ||
| 120 | |||
| 121 | void common_ngram_map_begin( | ||
| 122 | common_ngram_map & map, const llama_tokens & tokens) { | ||
| 123 | size_t size_begin = tokens.size(); | ||
| 124 | |||
| 125 | LOG_DBG("%s: begin, idx_last_draft=%zu, new begin=%zu, #keys=%zu\n", __func__, | ||
| 126 | map.idx_last_check, size_begin, map.keys.size()); | ||
| 127 | |||
| 128 | size_t count_map_entries_upd = 0; | ||
| 129 | if (!map.key_map.empty() && size_begin < map.idx_last_check) { | ||
| 130 | if (map.show_key_map_stats) { | ||
| 131 | // Print statistics of hash map map_key. | ||
| 132 | size_t count_nonzero = 0; | ||
| 133 | uint32_t min_idx = UINT32_MAX; | ||
| 134 | uint32_t max_idx = 0; | ||
| 135 | for (size_t i = 0; i < map.key_map.size(); ++i) { | ||
| 136 | uint32_t key_idx = map.key_map[i]; | ||
| 137 | if (key_idx != 0) { | ||
| 138 | ++count_nonzero; | ||
| 139 | if (key_idx < min_idx) min_idx = key_idx; | ||
| 140 | if (key_idx > max_idx) max_idx = key_idx; | ||
| 141 | } | ||
| 142 | } | ||
| 143 | if (count_nonzero == 0) { | ||
| 144 | min_idx = 0; | ||
| 145 | } | ||
| 146 | LOG_INF("%s: key_map stats: entries=%zu, min_idx=%u, max_idx=%u, key_map_last_idx=%u\n", | ||
| 147 | __func__, count_nonzero, min_idx, max_idx, map.key_map_last_idx); | ||
| 148 | } | ||
| 149 | |||
| 150 | // Update the map from hash to key index (clear outdated entries). | ||
| 151 | for (size_t i = 0; i < map.key_map.size(); ++i) { | ||
| 152 | uint32_t key_idx = map.key_map[i]; | ||
| 153 | if (key_idx >= map.size_last_begin) { | ||
| 154 | map.key_map[i] = 0; | ||
| 155 | count_map_entries_upd++; | ||
| 156 | } | ||
| 157 | } | ||
| 158 | map.key_map_last_idx = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; | ||
| 159 | } | ||
| 160 | |||
| 161 | if (size_begin < map.idx_last_check && !map.keys.empty()) { | ||
| 162 | // The next token generation will start at index size_begin. | ||
| 163 | // The tokens between map.size_last_begin and size_begin are no longer valid. | ||
| 164 | // | ||
| 165 | // Refresh map: Remove all entries with index >= map.size_last_begin. | ||
| 166 | size_t count_keys = map.keys.size(); | ||
| 167 | size_t count_keys_del = 0; | ||
| 168 | size_t count_values_del = 0; | ||
| 169 | for (int32_t i = map.keys.size() - 1; i >= 0; --i) { | ||
| 170 | common_ngram_map_key & key = map.keys[i]; | ||
| 171 | if (key.key_idx >= map.size_last_begin) { | ||
| 172 | // Delete the key. | ||
| 173 | LOG_DBG("%s: delete key %d at index %zu (>= size_last_begin=%zu)\n", __func__, i, key.key_idx, map.size_last_begin); | ||
| 174 | map.keys.erase(map.keys.begin() + i); | ||
| 175 | count_keys_del++; | ||
| 176 | continue; | ||
| 177 | } | ||
| 178 | if (map.key_only) { | ||
| 179 | continue; | ||
| 180 | } | ||
| 181 | |||
| 182 | // Check the indices of the values. | ||
| 183 | for (int16_t j = COMMON_NGRAM_MAX_VALUES - 1; j >= 0; --j) { | ||
| 184 | common_ngram_map_value & value = key.values[j]; | ||
| 185 | if (value.value_idx >= map.size_last_begin) { | ||
| 186 | // Delete the value. | ||
| 187 | count_values_del++; | ||
| 188 | |||
| 189 | // Move all values after this value to the left. | ||
| 190 | for (uint16_t k = j; k < COMMON_NGRAM_MAX_VALUES - 1; ++k) { | ||
| 191 | key.values[k] = key.values[k + 1]; | ||
| 192 | } | ||
| 193 | // Clear the last value. | ||
| 194 | key.values[COMMON_NGRAM_MAX_VALUES - 1].value_idx = 0; | ||
| 195 | key.values[COMMON_NGRAM_MAX_VALUES - 1].value_num = 0; | ||
| 196 | } | ||
| 197 | } | ||
| 198 | if (key.values[0].value_idx == 0) { | ||
| 199 | // No values left, delete the key. | ||
| 200 | LOG_DBG("%s: delete key %d at index %zu (no values left)\n", __func__, i, key.key_idx); | ||
| 201 | map.keys.erase(map.keys.begin() + i); | ||
| 202 | count_keys_del++; | ||
| 203 | } | ||
| 204 | } | ||
| 205 | |||
| 206 | LOG_INF("%s: refresh map: idx_last_draft=%zu, new begin=%zu, #keys_checked=%zu, #keys_del=%zu, #values_del=%zu, #hashes_upd=%zu\n", __func__, | ||
| 207 | map.idx_last_check, size_begin, | ||
| 208 | count_keys, count_keys_del, count_values_del, count_map_entries_upd); | ||
| 209 | } | ||
| 210 | |||
| 211 | map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; | ||
| 212 | map.size_last_begin = size_begin; | ||
| 213 | } | ||
| 214 | |||
| 215 | void common_ngram_map_draft(common_ngram_map & map, | ||
| 216 | const llama_tokens & inp, llama_token sampled, | ||
| 217 | llama_tokens & draft) { | ||
| 218 | // reset last key and value. | ||
| 219 | map.last_draft_created = false; | ||
| 220 | map.last_draft_key_idx = 0; | ||
| 221 | map.last_draft_value_idx = 0; | ||
| 222 | |||
| 223 | const size_t cur_len = inp.size(); | ||
| 224 | const uint16_t n = map.size_key; | ||
| 225 | const uint16_t m = map.size_value; | ||
| 226 | if (cur_len < static_cast<size_t>(2 * n + m)) { | ||
| 227 | return; | ||
| 228 | } | ||
| 229 | if (cur_len >= static_cast<size_t>(UINT32_MAX)) { | ||
| 230 | // key_map uses uint32_t instead of size_t. | ||
| 231 | GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len); | ||
| 232 | } | ||
| 233 | |||
| 234 | if (map.idx_last_check > cur_len) { | ||
| 235 | // Should not happen because of common_ngram_map_begin(). | ||
| 236 | GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len); | ||
| 237 | } | ||
| 238 | map.idx_last_check = cur_len; | ||
| 239 | |||
| 240 | // search pattern, the key n-gram | ||
| 241 | std::vector<llama_token> key_tokens; | ||
| 242 | key_tokens.reserve(n); | ||
| 243 | for (size_t j = cur_len - n + 1; j < cur_len; ++j) { | ||
| 244 | key_tokens.push_back(inp[j]); | ||
| 245 | } | ||
| 246 | key_tokens.push_back(sampled); | ||
| 247 | |||
| 248 | // search for the key in the map | ||
| 249 | size_t match_pos = 0; | ||
| 250 | if (map.size_last_begin > cur_len) { | ||
| 251 | GGML_ABORT("%s: map.size_last_begin > cur_len: %zu > %zu", __func__, map.size_last_begin, cur_len); | ||
| 252 | } | ||
| 253 | if (!map.key_map.empty()) { | ||
| 254 | // Search for the key in the map key_map from hash of ngrams to index of ngram. | ||
| 255 | uint32_t idx_hash = (common_ngram_map_hash(key_tokens, 0, n) % map.key_map.size()); | ||
| 256 | uint32_t idx_key = map.key_map[idx_hash]; | ||
| 257 | if (idx_key != 0 && idx_key < cur_len - n - m - 1) { | ||
| 258 | // Check if the key matches the key at idx_key (because of possible collisions). | ||
| 259 | bool match = true; | ||
| 260 | for (size_t k = 0; k < n; ++k) { | ||
| 261 | if (inp[idx_key + k] != key_tokens[k]) { | ||
| 262 | match = false; | ||
| 263 | break; | ||
| 264 | } | ||
| 265 | } | ||
| 266 | LOG_DBG("%s: key hash %x -> idx_key %d: match %d\n", __func__, idx_hash, idx_key, match ? 1 : 0); | ||
| 267 | if (match) { | ||
| 268 | match_pos = idx_key; | ||
| 269 | } | ||
| 270 | } | ||
| 271 | } | ||
| 272 | if (match_pos == 0 && map.size_last_begin > (size_t) (n + m + 1)) { | ||
| 273 | // Search for the key in [1, map.size_last_begin - n - m -1], descending. | ||
| 274 | for (size_t j = map.size_last_begin - n - m - 1; j > map.key_map_last_idx; --j) { | ||
| 275 | // Check if the key matches the key. | ||
| 276 | bool match = true; | ||
| 277 | for (size_t k = 0; k < n; ++k) { | ||
| 278 | if (inp[j + k] != key_tokens[k]) { | ||
| 279 | match = false; | ||
| 280 | break; | ||
| 281 | } | ||
| 282 | } | ||
| 283 | if (match) { | ||
| 284 | match_pos = j; | ||
| 285 | break; | ||
| 286 | } | ||
| 287 | } | ||
| 288 | } | ||
| 289 | if (match_pos == 0) { | ||
| 290 | // In case of a reasoning chat, the part after size_last_begin may be deleted/reordered later. | ||
| 291 | // | ||
| 292 | // Search in [size_last_begin, cur_len - n - m - 1], descending. | ||
| 293 | for (size_t j = cur_len - n - m - 1; j > map.size_last_begin && j > map.key_map_last_idx; --j) { | ||
| 294 | bool match = true; | ||
| 295 | for (size_t k = 0; k < n; ++k) { | ||
| 296 | if (inp[j + k] != key_tokens[k]) { | ||
| 297 | match = false; | ||
| 298 | break; | ||
| 299 | } | ||
| 300 | } | ||
| 301 | if (match) { | ||
| 302 | match_pos = j; | ||
| 303 | break; | ||
| 304 | } | ||
| 305 | } | ||
| 306 | } | ||
| 307 | if (match_pos > 0) { | ||
| 308 | LOG_DBG("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__, | ||
| 309 | cur_len, n, m, key_tokens.size(), sampled, match_pos); | ||
| 310 | } | ||
| 311 | |||
| 312 | if (!map.key_map.empty()) { | ||
| 313 | // Add hashes of new ngrams in key_map. | ||
| 314 | // | ||
| 315 | // Use the same order as above. | ||
| 316 | if (map.size_last_begin > (size_t) (n + m + 1)) { | ||
| 317 | for (size_t j = map.size_last_begin - n - m - 1; j > map.key_map_last_idx; --j) { | ||
| 318 | // compute hash and store index of ngram at idx j in the map. | ||
| 319 | uint32_t idx_hash = (common_ngram_map_hash(inp, j, n) % map.key_map.size()); | ||
| 320 | if (map.key_map[idx_hash] == 0) { | ||
| 321 | map.key_map[idx_hash] = j; // collisions may occur | ||
| 322 | } | ||
| 323 | } | ||
| 324 | } | ||
| 325 | |||
| 326 | for (size_t j = cur_len - n - m - 1; j > map.size_last_begin && j > map.key_map_last_idx; --j) { | ||
| 327 | // compute hash and store index of ngram at idx j in the map. | ||
| 328 | uint32_t idx_hash = (common_ngram_map_hash(inp, j, n) % map.key_map.size()); | ||
| 329 | if (map.key_map[idx_hash] == 0) { | ||
| 330 | map.key_map[idx_hash] = j; | ||
| 331 | } | ||
| 332 | } | ||
| 333 | map.key_map_last_idx = std::max(static_cast<uint32_t>(cur_len - n - m - 1), map.key_map_last_idx); | ||
| 334 | } | ||
| 335 | |||
| 336 | if (match_pos == 0) { | ||
| 337 | return; | ||
| 338 | } | ||
| 339 | |||
| 340 | // We have a match, now we look for the statistics of the key. | ||
| 341 | size_t key_offset = map.keys.size(); // offset in the map | ||
| 342 | // We iterate through the std::vector<common_ngram_map_key> map->keys. | ||
| 343 | for (size_t i = 0; i < map.keys.size(); ++i) { | ||
| 344 | bool match = true; | ||
| 345 | for (size_t j = 0; j < n; ++j) { | ||
| 346 | if (inp[map.keys[i].key_idx + j] != key_tokens[j]) { | ||
| 347 | match = false; | ||
| 348 | break; | ||
| 349 | } | ||
| 350 | } | ||
| 351 | if (match) { | ||
| 352 | key_offset = i; | ||
| 353 | break; | ||
| 354 | } | ||
| 355 | } | ||
| 356 | if (key_offset == map.keys.size()) { | ||
| 357 | // We create a new key-entry, it will get offset key_offset. | ||
| 358 | common_ngram_map_key new_key; | ||
| 359 | new_key.key_idx = match_pos; | ||
| 360 | new_key.stat_idx = 0; | ||
| 361 | new_key.key_num = 0; | ||
| 362 | for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) { | ||
| 363 | new_key.values[i].value_num = 0; | ||
| 364 | new_key.values[i].n_accepted = m; | ||
| 365 | } | ||
| 366 | map.keys.push_back(new_key); | ||
| 367 | } | ||
| 368 | |||
| 369 | // our key n-gram: | ||
| 370 | common_ngram_map_key & curr_key = map.keys[key_offset]; | ||
| 371 | |||
| 372 | // update number of key hits | ||
| 373 | curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1, | ||
| 374 | (int) COMMON_NGRAM_MAX_VALUE_COUNT); | ||
| 375 | |||
| 376 | if (map.key_only) { | ||
| 377 | // simple mode: | ||
| 378 | // Fill in the draft with the m tokens following the key. | ||
| 379 | // We work with value values[0] only. | ||
| 380 | int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted); | ||
| 381 | |||
| 382 | for (int i = 0; i < n_draft_tokens; ++i) { | ||
| 383 | draft.push_back(inp[match_pos + n + i]); | ||
| 384 | } | ||
| 385 | |||
| 386 | LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, | ||
| 387 | curr_key.key_idx, key_offset, curr_key.key_num, draft.size()); | ||
| 388 | |||
| 389 | map.last_draft_created = false; | ||
| 390 | map.last_draft_key_idx = key_offset; | ||
| 391 | map.last_draft_value_idx = 0; // value 0 is used for simple mode | ||
| 392 | return; | ||
| 393 | } | ||
| 394 | |||
| 395 | if (curr_key.key_num < map.min_hits) { | ||
| 396 | // not enough hits to consider this a good draft | ||
| 397 | LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__, | ||
| 398 | key_offset, curr_key.key_num, map.min_hits); | ||
| 399 | return; | ||
| 400 | } | ||
| 401 | |||
| 402 | // complex mode: examine the different m-grams after this key n-gram. | ||
| 403 | // | ||
| 404 | |||
| 405 | // determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram. | ||
| 406 | for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) { | ||
| 407 | // begins the key n-gram at index i? | ||
| 408 | bool match_key = true; | ||
| 409 | for (size_t k = 0; k < n; ++k) { | ||
| 410 | if (inp[i + k] != key_tokens[k]) { | ||
| 411 | match_key = false; | ||
| 412 | break; | ||
| 413 | } | ||
| 414 | } | ||
| 415 | if (!match_key) { | ||
| 416 | continue; | ||
| 417 | } | ||
| 418 | |||
| 419 | // Do we haven a existing value m-gram or a new one after the key at index i? | ||
| 420 | size_t idx_begin_value_key = i + n; | ||
| 421 | int idx_value = -1; | ||
| 422 | for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { | ||
| 423 | size_t idx_begin_value_v = curr_key.values[v].value_idx; | ||
| 424 | if (idx_begin_value_v == 0) { | ||
| 425 | // We found an empty value slot => we found a new value m-gram after the key n-gram. | ||
| 426 | curr_key.values[v].value_idx = idx_begin_value_key; | ||
| 427 | curr_key.values[v].value_num = 0; | ||
| 428 | curr_key.values[v].n_accepted = m; | ||
| 429 | idx_value = v; | ||
| 430 | break; | ||
| 431 | } | ||
| 432 | bool match = true; | ||
| 433 | for (size_t j = 0; j < m; ++j) { | ||
| 434 | if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) { | ||
| 435 | match = false; | ||
| 436 | break; | ||
| 437 | } | ||
| 438 | } | ||
| 439 | if (match) { | ||
| 440 | // We found an existing value m-gram after the key n-gram. | ||
| 441 | idx_value = v; | ||
| 442 | break; | ||
| 443 | } | ||
| 444 | } | ||
| 445 | if (idx_value >= 0) { | ||
| 446 | // We found a value m-gram of the key n-gram. | ||
| 447 | curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1, | ||
| 448 | (int) COMMON_NGRAM_MAX_VALUE_COUNT); | ||
| 449 | } | ||
| 450 | } | ||
| 451 | // the statistics are updated up to match_pos. | ||
| 452 | curr_key.stat_idx = match_pos; | ||
| 453 | |||
| 454 | // Do we have a value we could use for the draft? | ||
| 455 | uint16_t max_occur = 0; | ||
| 456 | int slot_max = 0; | ||
| 457 | for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { | ||
| 458 | uint16_t curr_occur = curr_key.values[v].value_num; | ||
| 459 | if (curr_occur > max_occur) { | ||
| 460 | max_occur = curr_occur; | ||
| 461 | slot_max = v; | ||
| 462 | } | ||
| 463 | } | ||
| 464 | // What is sum of the other occurrences? | ||
| 465 | uint32_t sum_occur = 0; | ||
| 466 | for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { | ||
| 467 | if (v == slot_max) { | ||
| 468 | continue; | ||
| 469 | } | ||
| 470 | uint16_t curr_occur = curr_key.values[v].value_num; | ||
| 471 | sum_occur += curr_occur; | ||
| 472 | } | ||
| 473 | |||
| 474 | LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__, | ||
| 475 | key_offset, | ||
| 476 | max_occur, sum_occur, slot_max, | ||
| 477 | curr_key.values[0].value_idx, curr_key.values[0].value_num, | ||
| 478 | curr_key.values[1].value_idx, curr_key.values[1].value_num, | ||
| 479 | curr_key.values[2].value_idx, curr_key.values[2].value_num, | ||
| 480 | curr_key.values[3].value_idx, curr_key.values[3].value_num | ||
| 481 | ); | ||
| 482 | // Print the tokens of the four values (if idx != 0), use LOG_INF | ||
| 483 | for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { | ||
| 484 | if (curr_key.values[v].value_idx != 0) { | ||
| 485 | LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str()); | ||
| 486 | } | ||
| 487 | } | ||
| 488 | |||
| 489 | if (sum_occur > 0 && max_occur < 2 * sum_occur) { | ||
| 490 | // The most frequent value is not much more frequent than the other values. | ||
| 491 | // We do not use the draft. | ||
| 492 | return; | ||
| 493 | } | ||
| 494 | |||
| 495 | // We use the most frequent value values[slot_max] for the draft. | ||
| 496 | // Fill in the draft with the m tokens following the key. | ||
| 497 | int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted); | ||
| 498 | |||
| 499 | for (int i = 0; i < n_draft_tokens; ++i) { | ||
| 500 | draft.push_back(inp[match_pos + n + i]); | ||
| 501 | } | ||
| 502 | |||
| 503 | LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__, | ||
| 504 | key_offset, slot_max, | ||
| 505 | curr_key.key_num, draft.size()); | ||
| 506 | |||
| 507 | map.last_draft_created = true; | ||
| 508 | map.last_draft_key_idx = key_offset; | ||
| 509 | map.last_draft_value_idx = slot_max; // value used for draft generation. | ||
| 510 | } | ||
| 511 | |||
| 512 | void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) { | ||
| 513 | if (!map.last_draft_created) { | ||
| 514 | return; | ||
| 515 | } | ||
| 516 | |||
| 517 | // find the key and its chosen value. | ||
| 518 | const size_t key_idx = map.last_draft_key_idx; | ||
| 519 | const size_t val_idx = map.last_draft_value_idx; | ||
| 520 | |||
| 521 | // find key corresponding to key_idx. | ||
| 522 | common_ngram_map_key & curr_key = map.keys[key_idx]; | ||
| 523 | // find value corresponding to val_idx. | ||
| 524 | struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation. | ||
| 525 | |||
| 526 | // update the value statistics | ||
| 527 | LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", | ||
| 528 | n_accepted, curr_value.n_accepted); | ||
| 529 | curr_value.n_accepted = n_accepted; | ||
| 530 | } | ||
