diff options
Diffstat (limited to 'llama.cpp/common/common.cpp')
| -rw-r--r-- | llama.cpp/common/common.cpp | 1786 |
1 files changed, 1786 insertions, 0 deletions
diff --git a/llama.cpp/common/common.cpp b/llama.cpp/common/common.cpp new file mode 100644 index 0000000..ec15804 --- /dev/null +++ b/llama.cpp/common/common.cpp | |||
| @@ -0,0 +1,1786 @@ | |||
| 1 | #include "ggml.h" | ||
| 2 | #include "gguf.h" | ||
| 3 | |||
| 4 | #include "common.h" | ||
| 5 | #include "log.h" | ||
| 6 | #include "llama.h" | ||
| 7 | #include "sampling.h" | ||
| 8 | #include "unicode.h" | ||
| 9 | |||
| 10 | #include <algorithm> | ||
| 11 | #include <cinttypes> | ||
| 12 | #include <climits> | ||
| 13 | #include <cmath> | ||
| 14 | #include <chrono> | ||
| 15 | #include <cstdarg> | ||
| 16 | #include <cstring> | ||
| 17 | #include <ctime> | ||
| 18 | #include <filesystem> | ||
| 19 | #include <fstream> | ||
| 20 | #include <iostream> | ||
| 21 | #include <iterator> | ||
| 22 | #include <regex> | ||
| 23 | #include <sstream> | ||
| 24 | #include <string> | ||
| 25 | #include <thread> | ||
| 26 | #include <unordered_set> | ||
| 27 | #include <vector> | ||
| 28 | |||
| 29 | #if defined(__APPLE__) && defined(__MACH__) | ||
| 30 | #include <sys/types.h> | ||
| 31 | #include <sys/sysctl.h> | ||
| 32 | #endif | ||
| 33 | |||
| 34 | #if defined(_WIN32) | ||
| 35 | #define WIN32_LEAN_AND_MEAN | ||
| 36 | #ifndef NOMINMAX | ||
| 37 | # define NOMINMAX | ||
| 38 | #endif | ||
| 39 | #include <locale> | ||
| 40 | #include <windows.h> | ||
| 41 | #include <string.h> | ||
| 42 | #include <fcntl.h> | ||
| 43 | #include <io.h> | ||
| 44 | #else | ||
| 45 | #include <sys/ioctl.h> | ||
| 46 | #include <sys/stat.h> | ||
| 47 | #include <unistd.h> | ||
| 48 | #endif | ||
| 49 | |||
| 50 | #if defined(__linux__) | ||
| 51 | #include <sys/types.h> | ||
| 52 | #include <pwd.h> | ||
| 53 | #endif | ||
| 54 | |||
| 55 | #if defined(_MSC_VER) | ||
| 56 | #pragma warning(disable: 4244 4267) // possible loss of data | ||
| 57 | #endif | ||
| 58 | |||
| 59 | common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} | ||
| 60 | |||
| 61 | common_time_meas::~common_time_meas() { | ||
| 62 | if (t_start_us >= 0) { | ||
| 63 | t_acc += ggml_time_us() - t_start_us; | ||
| 64 | } | ||
| 65 | } | ||
| 66 | |||
| 67 | // | ||
| 68 | // CPU utils | ||
| 69 | // | ||
| 70 | |||
| 71 | int32_t cpu_get_num_physical_cores() { | ||
| 72 | #ifdef __linux__ | ||
| 73 | // enumerate the set of thread siblings, num entries is num cores | ||
| 74 | std::unordered_set<std::string> siblings; | ||
| 75 | for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { | ||
| 76 | std::ifstream thread_siblings("/sys/devices/system/cpu/cpu" | ||
| 77 | + std::to_string(cpu) + "/topology/thread_siblings"); | ||
| 78 | if (!thread_siblings.is_open()) { | ||
| 79 | break; // no more cpus | ||
| 80 | } | ||
| 81 | std::string line; | ||
| 82 | if (std::getline(thread_siblings, line)) { | ||
| 83 | siblings.insert(line); | ||
| 84 | } | ||
| 85 | } | ||
| 86 | if (!siblings.empty()) { | ||
| 87 | return static_cast<int32_t>(siblings.size()); | ||
| 88 | } | ||
| 89 | #elif defined(__APPLE__) && defined(__MACH__) | ||
| 90 | int32_t num_physical_cores; | ||
| 91 | size_t len = sizeof(num_physical_cores); | ||
| 92 | int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0); | ||
| 93 | if (result == 0) { | ||
| 94 | return num_physical_cores; | ||
| 95 | } | ||
| 96 | result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0); | ||
| 97 | if (result == 0) { | ||
| 98 | return num_physical_cores; | ||
| 99 | } | ||
| 100 | #elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later | ||
| 101 | // TODO: windows + arm64 + mingw64 | ||
| 102 | unsigned int n_threads_win = std::thread::hardware_concurrency(); | ||
| 103 | unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4; | ||
| 104 | |||
| 105 | DWORD buffer_size = 0; | ||
| 106 | if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) { | ||
| 107 | if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) { | ||
| 108 | return default_threads; | ||
| 109 | } | ||
| 110 | } | ||
| 111 | |||
| 112 | std::vector<char> buffer(buffer_size); | ||
| 113 | if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()), &buffer_size)) { | ||
| 114 | return default_threads; | ||
| 115 | } | ||
| 116 | |||
| 117 | int32_t num_physical_cores = 0; | ||
| 118 | PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()); | ||
| 119 | while (buffer_size > 0) { | ||
| 120 | if (info->Relationship == RelationProcessorCore) { | ||
| 121 | num_physical_cores += info->Processor.GroupCount; | ||
| 122 | } | ||
| 123 | buffer_size -= info->Size; | ||
| 124 | info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(reinterpret_cast<char*>(info) + info->Size); | ||
| 125 | } | ||
| 126 | |||
| 127 | return num_physical_cores > 0 ? num_physical_cores : default_threads; | ||
| 128 | #endif | ||
| 129 | unsigned int n_threads = std::thread::hardware_concurrency(); | ||
| 130 | return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; | ||
| 131 | } | ||
| 132 | |||
| 133 | #if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) | ||
| 134 | #include <pthread.h> | ||
| 135 | |||
| 136 | static void cpuid(unsigned leaf, unsigned subleaf, | ||
| 137 | unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) { | ||
| 138 | __asm__("movq\t%%rbx,%%rsi\n\t" | ||
| 139 | "cpuid\n\t" | ||
| 140 | "xchgq\t%%rbx,%%rsi" | ||
| 141 | : "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx) | ||
| 142 | : "0"(leaf), "2"(subleaf)); | ||
| 143 | } | ||
| 144 | |||
| 145 | static int pin_cpu(int cpu) { | ||
| 146 | cpu_set_t mask; | ||
| 147 | CPU_ZERO(&mask); | ||
| 148 | CPU_SET(cpu, &mask); | ||
| 149 | return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask); | ||
| 150 | } | ||
| 151 | |||
| 152 | static bool is_hybrid_cpu(void) { | ||
| 153 | unsigned eax, ebx, ecx, edx; | ||
| 154 | cpuid(7, 0, &eax, &ebx, &ecx, &edx); | ||
| 155 | return !!(edx & (1u << 15)); | ||
| 156 | } | ||
| 157 | |||
| 158 | static bool is_running_on_efficiency_core(void) { | ||
| 159 | unsigned eax, ebx, ecx, edx; | ||
| 160 | cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx); | ||
| 161 | int intel_atom = 0x20; | ||
| 162 | int core_type = (eax & 0xff000000u) >> 24; | ||
| 163 | return core_type == intel_atom; | ||
| 164 | } | ||
| 165 | |||
| 166 | static int cpu_count_math_cpus(int n_cpu) { | ||
| 167 | int result = 0; | ||
| 168 | for (int cpu = 0; cpu < n_cpu; ++cpu) { | ||
| 169 | if (pin_cpu(cpu)) { | ||
| 170 | return -1; | ||
| 171 | } | ||
| 172 | if (is_running_on_efficiency_core()) { | ||
| 173 | continue; // efficiency cores harm lockstep threading | ||
| 174 | } | ||
| 175 | ++cpu; // hyperthreading isn't useful for linear algebra | ||
| 176 | ++result; | ||
| 177 | } | ||
| 178 | return result; | ||
| 179 | } | ||
| 180 | |||
| 181 | #endif // __x86_64__ && __linux__ | ||
| 182 | |||
| 183 | /** | ||
| 184 | * Returns number of CPUs on system that are useful for math. | ||
| 185 | */ | ||
| 186 | int32_t cpu_get_num_math() { | ||
| 187 | #if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) | ||
| 188 | int n_cpu = sysconf(_SC_NPROCESSORS_ONLN); | ||
| 189 | if (n_cpu < 1) { | ||
| 190 | return cpu_get_num_physical_cores(); | ||
| 191 | } | ||
| 192 | if (is_hybrid_cpu()) { | ||
| 193 | cpu_set_t affinity; | ||
| 194 | if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) { | ||
| 195 | int result = cpu_count_math_cpus(n_cpu); | ||
| 196 | pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity); | ||
| 197 | if (result > 0) { | ||
| 198 | return result; | ||
| 199 | } | ||
| 200 | } | ||
| 201 | } | ||
| 202 | #endif | ||
| 203 | return cpu_get_num_physical_cores(); | ||
| 204 | } | ||
| 205 | |||
| 206 | // Helper for setting process priority | ||
| 207 | |||
| 208 | #if defined(_WIN32) | ||
| 209 | |||
| 210 | bool set_process_priority(enum ggml_sched_priority prio) { | ||
| 211 | if (prio == GGML_SCHED_PRIO_NORMAL) { | ||
| 212 | return true; | ||
| 213 | } | ||
| 214 | |||
| 215 | DWORD p = NORMAL_PRIORITY_CLASS; | ||
| 216 | switch (prio) { | ||
| 217 | case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break; | ||
| 218 | case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break; | ||
| 219 | case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break; | ||
| 220 | case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break; | ||
| 221 | case GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS; break; | ||
| 222 | } | ||
| 223 | |||
| 224 | if (!SetPriorityClass(GetCurrentProcess(), p)) { | ||
| 225 | LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError()); | ||
| 226 | return false; | ||
| 227 | } | ||
| 228 | |||
| 229 | return true; | ||
| 230 | } | ||
| 231 | |||
| 232 | #else // MacOS and POSIX | ||
| 233 | #include <sys/types.h> | ||
| 234 | #include <sys/resource.h> | ||
| 235 | |||
| 236 | bool set_process_priority(enum ggml_sched_priority prio) { | ||
| 237 | if (prio == GGML_SCHED_PRIO_NORMAL) { | ||
| 238 | return true; | ||
| 239 | } | ||
| 240 | |||
| 241 | int p = 0; | ||
| 242 | switch (prio) { | ||
| 243 | case GGML_SCHED_PRIO_LOW: p = 5; break; | ||
| 244 | case GGML_SCHED_PRIO_NORMAL: p = 0; break; | ||
| 245 | case GGML_SCHED_PRIO_MEDIUM: p = -5; break; | ||
| 246 | case GGML_SCHED_PRIO_HIGH: p = -10; break; | ||
| 247 | case GGML_SCHED_PRIO_REALTIME: p = -20; break; | ||
| 248 | } | ||
| 249 | |||
| 250 | if (setpriority(PRIO_PROCESS, 0, p) != 0) { | ||
| 251 | LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno); | ||
| 252 | return false; | ||
| 253 | } | ||
| 254 | return true; | ||
| 255 | } | ||
| 256 | |||
| 257 | #endif | ||
| 258 | |||
| 259 | // | ||
| 260 | // CLI argument parsing | ||
| 261 | // | ||
| 262 | |||
| 263 | |||
| 264 | void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) { | ||
| 265 | int32_t n_set = 0; | ||
| 266 | |||
| 267 | if (cpuparams.n_threads < 0) { | ||
| 268 | // Assuming everything about cpuparams is invalid | ||
| 269 | if (role_model != nullptr) { | ||
| 270 | cpuparams = *role_model; | ||
| 271 | } else { | ||
| 272 | cpuparams.n_threads = cpu_get_num_math(); | ||
| 273 | } | ||
| 274 | } | ||
| 275 | |||
| 276 | for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) { | ||
| 277 | if (cpuparams.cpumask[i]) { | ||
| 278 | n_set++; | ||
| 279 | } | ||
| 280 | } | ||
| 281 | |||
| 282 | if (n_set && n_set < cpuparams.n_threads) { | ||
| 283 | // Not enough set bits, may experience performance issues. | ||
| 284 | LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads); | ||
| 285 | } | ||
| 286 | } | ||
| 287 | |||
| 288 | bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) { | ||
| 289 | size_t dash_loc = range.find('-'); | ||
| 290 | if (dash_loc == std::string::npos) { | ||
| 291 | LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n"); | ||
| 292 | return false; | ||
| 293 | } | ||
| 294 | |||
| 295 | size_t start_i; | ||
| 296 | size_t end_i; | ||
| 297 | |||
| 298 | if (dash_loc == 0) { | ||
| 299 | start_i = 0; | ||
| 300 | } else { | ||
| 301 | start_i = std::stoull(range.substr(0, dash_loc)); | ||
| 302 | if (start_i >= GGML_MAX_N_THREADS) { | ||
| 303 | LOG_ERR("Start index out of bounds!\n"); | ||
| 304 | return false; | ||
| 305 | } | ||
| 306 | } | ||
| 307 | |||
| 308 | if (dash_loc == range.length() - 1) { | ||
| 309 | end_i = GGML_MAX_N_THREADS - 1; | ||
| 310 | } else { | ||
| 311 | end_i = std::stoull(range.substr(dash_loc + 1)); | ||
| 312 | if (end_i >= GGML_MAX_N_THREADS) { | ||
| 313 | LOG_ERR("End index out of bounds!\n"); | ||
| 314 | return false; | ||
| 315 | } | ||
| 316 | } | ||
| 317 | |||
| 318 | for (size_t i = start_i; i <= end_i; i++) { | ||
| 319 | boolmask[i] = true; | ||
| 320 | } | ||
| 321 | |||
| 322 | return true; | ||
| 323 | } | ||
| 324 | |||
| 325 | bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREADS]) { | ||
| 326 | // Discard potential 0x prefix | ||
| 327 | size_t start_i = 0; | ||
| 328 | if (mask.length() >= 2 && mask.substr(0, 2) == "0x") { | ||
| 329 | start_i = 2; | ||
| 330 | } | ||
| 331 | |||
| 332 | size_t num_digits = mask.length() - start_i; | ||
| 333 | if (num_digits > 128) num_digits = 128; | ||
| 334 | |||
| 335 | size_t end_i = num_digits + start_i; | ||
| 336 | |||
| 337 | for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) { | ||
| 338 | char c = mask.at(i); | ||
| 339 | int8_t id = c; | ||
| 340 | |||
| 341 | if ((c >= '0' && c <= '9')) { | ||
| 342 | id -= '0'; | ||
| 343 | } else if (c >= 'a' && c <= 'f') { | ||
| 344 | id -= 'a' - 10; | ||
| 345 | } else if (c >= 'A' && c <= 'F') { | ||
| 346 | id -= 'A' - 10; | ||
| 347 | } else { | ||
| 348 | LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i)); | ||
| 349 | return false; | ||
| 350 | } | ||
| 351 | |||
| 352 | boolmask[ n ] = boolmask[ n ] || ((id & 8) != 0); | ||
| 353 | boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0); | ||
| 354 | boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0); | ||
| 355 | boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0); | ||
| 356 | } | ||
| 357 | |||
| 358 | return true; | ||
| 359 | } | ||
| 360 | |||
| 361 | void common_init() { | ||
| 362 | llama_log_set(common_log_default_callback, NULL); | ||
| 363 | |||
| 364 | #ifdef NDEBUG | ||
| 365 | const char * build_type = ""; | ||
| 366 | #else | ||
| 367 | const char * build_type = " (debug)"; | ||
| 368 | #endif | ||
| 369 | |||
| 370 | LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type); | ||
| 371 | } | ||
| 372 | |||
| 373 | std::string common_params_get_system_info(const common_params & params) { | ||
| 374 | std::ostringstream os; | ||
| 375 | |||
| 376 | os << "system_info: n_threads = " << params.cpuparams.n_threads; | ||
| 377 | if (params.cpuparams_batch.n_threads != -1) { | ||
| 378 | os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")"; | ||
| 379 | } | ||
| 380 | #if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later | ||
| 381 | // TODO: windows + arm64 + mingw64 | ||
| 382 | DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS); | ||
| 383 | os << " / " << logicalProcessorCount << " | " << llama_print_system_info(); | ||
| 384 | #else | ||
| 385 | os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); | ||
| 386 | #endif | ||
| 387 | |||
| 388 | return os.str(); | ||
| 389 | } | ||
| 390 | |||
| 391 | // | ||
| 392 | // String utils | ||
| 393 | // | ||
| 394 | |||
| 395 | std::string string_format(const char * fmt, ...) { | ||
| 396 | va_list ap; | ||
| 397 | va_list ap2; | ||
| 398 | va_start(ap, fmt); | ||
| 399 | va_copy(ap2, ap); | ||
| 400 | int size = vsnprintf(NULL, 0, fmt, ap); | ||
| 401 | GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT | ||
| 402 | std::vector<char> buf(size + 1); | ||
| 403 | int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); | ||
| 404 | GGML_ASSERT(size2 == size); | ||
| 405 | va_end(ap2); | ||
| 406 | va_end(ap); | ||
| 407 | return std::string(buf.data(), size); | ||
| 408 | } | ||
| 409 | |||
| 410 | std::string string_strip(const std::string & str) { | ||
| 411 | size_t start = 0; | ||
| 412 | size_t end = str.size(); | ||
| 413 | while (start < end && std::isspace(str[start])) { | ||
| 414 | start++; | ||
| 415 | } | ||
| 416 | while (end > start && std::isspace(str[end - 1])) { | ||
| 417 | end--; | ||
| 418 | } | ||
| 419 | return str.substr(start, end - start); | ||
| 420 | } | ||
| 421 | |||
| 422 | std::string string_get_sortable_timestamp() { | ||
| 423 | using clock = std::chrono::system_clock; | ||
| 424 | |||
| 425 | const clock::time_point current_time = clock::now(); | ||
| 426 | const time_t as_time_t = clock::to_time_t(current_time); | ||
| 427 | char timestamp_no_ns[100]; | ||
| 428 | std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); | ||
| 429 | |||
| 430 | const int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>( | ||
| 431 | current_time.time_since_epoch() % 1000000000).count(); | ||
| 432 | char timestamp_ns[11]; | ||
| 433 | snprintf(timestamp_ns, 11, "%09" PRId64, ns); | ||
| 434 | |||
| 435 | return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); | ||
| 436 | } | ||
| 437 | |||
| 438 | void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { | ||
| 439 | if (search.empty()) { | ||
| 440 | return; | ||
| 441 | } | ||
| 442 | std::string builder; | ||
| 443 | builder.reserve(s.length()); | ||
| 444 | size_t pos = 0; | ||
| 445 | size_t last_pos = 0; | ||
| 446 | while ((pos = s.find(search, last_pos)) != std::string::npos) { | ||
| 447 | builder.append(s, last_pos, pos - last_pos); | ||
| 448 | builder.append(replace); | ||
| 449 | last_pos = pos + search.length(); | ||
| 450 | } | ||
| 451 | builder.append(s, last_pos, std::string::npos); | ||
| 452 | s = std::move(builder); | ||
| 453 | } | ||
| 454 | |||
| 455 | bool string_ends_with(const std::string_view & str, const std::string_view & suffix) { | ||
| 456 | return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; | ||
| 457 | } | ||
| 458 | |||
| 459 | bool string_remove_suffix(std::string & str, const std::string_view & suffix) { | ||
| 460 | bool has_suffix = string_ends_with(str, suffix); | ||
| 461 | if (has_suffix) { | ||
| 462 | str = str.substr(0, str.size() - suffix.size()); | ||
| 463 | } | ||
| 464 | return has_suffix; | ||
| 465 | } | ||
| 466 | |||
| 467 | size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) { | ||
| 468 | if (!str.empty() && !stop.empty()) { | ||
| 469 | const char text_last_char = str.back(); | ||
| 470 | for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { | ||
| 471 | if (stop[char_index] == text_last_char) { | ||
| 472 | const auto current_partial = stop.substr(0, char_index + 1); | ||
| 473 | if (string_ends_with(str, current_partial)) { | ||
| 474 | return str.size() - char_index - 1; | ||
| 475 | } | ||
| 476 | } | ||
| 477 | } | ||
| 478 | } | ||
| 479 | |||
| 480 | return std::string::npos; | ||
| 481 | } | ||
| 482 | |||
| 483 | std::string regex_escape(const std::string & s) { | ||
| 484 | static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); | ||
| 485 | return std::regex_replace(s, special_chars, "\\$&"); | ||
| 486 | } | ||
| 487 | |||
| 488 | std::string string_join(const std::vector<std::string> & values, const std::string & separator) { | ||
| 489 | std::ostringstream result; | ||
| 490 | for (size_t i = 0; i < values.size(); ++i) { | ||
| 491 | if (i > 0) { | ||
| 492 | result << separator; | ||
| 493 | } | ||
| 494 | result << values[i]; | ||
| 495 | } | ||
| 496 | return result.str(); | ||
| 497 | } | ||
| 498 | |||
| 499 | std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) { | ||
| 500 | std::vector<std::string> parts; | ||
| 501 | size_t start = 0; | ||
| 502 | size_t end = str.find(delimiter); | ||
| 503 | |||
| 504 | while (end != std::string::npos) { | ||
| 505 | parts.push_back(str.substr(start, end - start)); | ||
| 506 | start = end + delimiter.length(); | ||
| 507 | end = str.find(delimiter, start); | ||
| 508 | } | ||
| 509 | |||
| 510 | parts.push_back(str.substr(start)); | ||
| 511 | |||
| 512 | return parts; | ||
| 513 | } | ||
| 514 | |||
| 515 | std::string string_repeat(const std::string & str, size_t n) { | ||
| 516 | if (n == 0) { | ||
| 517 | return ""; | ||
| 518 | } | ||
| 519 | |||
| 520 | std::string result; | ||
| 521 | result.reserve(str.length() * n); | ||
| 522 | |||
| 523 | for (size_t i = 0; i < n; ++i) { | ||
| 524 | result += str; | ||
| 525 | } | ||
| 526 | |||
| 527 | return result; | ||
| 528 | } | ||
| 529 | |||
| 530 | std::string string_from(bool value) { | ||
| 531 | return value ? "true" : "false"; | ||
| 532 | } | ||
| 533 | |||
| 534 | std::string string_from(const std::vector<int> & values) { | ||
| 535 | std::stringstream buf; | ||
| 536 | |||
| 537 | buf << "[ "; | ||
| 538 | bool first = true; | ||
| 539 | for (auto e : values) { | ||
| 540 | if (first) { | ||
| 541 | first = false; | ||
| 542 | } else { | ||
| 543 | buf << ", "; | ||
| 544 | } | ||
| 545 | buf << std::to_string(e); | ||
| 546 | } | ||
| 547 | buf << " ]"; | ||
| 548 | |||
| 549 | return buf.str(); | ||
| 550 | } | ||
| 551 | |||
| 552 | std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens) { | ||
| 553 | std::stringstream buf; | ||
| 554 | |||
| 555 | buf << "[ "; | ||
| 556 | |||
| 557 | bool first = true; | ||
| 558 | for (const auto & token : tokens) { | ||
| 559 | if (!first) { | ||
| 560 | buf << ", "; | ||
| 561 | } else { | ||
| 562 | first = false; | ||
| 563 | } | ||
| 564 | |||
| 565 | auto detokenized = common_token_to_piece(ctx, token); | ||
| 566 | |||
| 567 | buf << "'" << detokenized << "'" | ||
| 568 | << ":" << std::to_string(token); | ||
| 569 | } | ||
| 570 | |||
| 571 | buf << " ]"; | ||
| 572 | |||
| 573 | return buf.str(); | ||
| 574 | } | ||
| 575 | |||
| 576 | std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) { | ||
| 577 | std::stringstream buf; | ||
| 578 | |||
| 579 | buf << "[ "; | ||
| 580 | |||
| 581 | bool first = true; | ||
| 582 | for (int i = 0; i < batch.n_tokens; ++i) { | ||
| 583 | if (!first) { | ||
| 584 | buf << ", "; | ||
| 585 | } else { | ||
| 586 | first = false; | ||
| 587 | } | ||
| 588 | |||
| 589 | auto detokenized = common_token_to_piece(ctx, batch.token[i]); | ||
| 590 | |||
| 591 | buf << "\n" << std::to_string(i) | ||
| 592 | << ", token '" << detokenized << "'" | ||
| 593 | << ", pos " << std::to_string(batch.pos[i]) | ||
| 594 | << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) | ||
| 595 | << ", seq_id " << std::to_string(batch.seq_id[i][0]) | ||
| 596 | << ", logits " << std::to_string(batch.logits[i]); | ||
| 597 | } | ||
| 598 | |||
| 599 | buf << " ]"; | ||
| 600 | |||
| 601 | return buf.str(); | ||
| 602 | } | ||
| 603 | |||
| 604 | void string_process_escapes(std::string & input) { | ||
| 605 | std::size_t input_len = input.length(); | ||
| 606 | std::size_t output_idx = 0; | ||
| 607 | |||
| 608 | for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { | ||
| 609 | if (input[input_idx] == '\\' && input_idx + 1 < input_len) { | ||
| 610 | switch (input[++input_idx]) { | ||
| 611 | case 'n': input[output_idx++] = '\n'; break; | ||
| 612 | case 'r': input[output_idx++] = '\r'; break; | ||
| 613 | case 't': input[output_idx++] = '\t'; break; | ||
| 614 | case '\'': input[output_idx++] = '\''; break; | ||
| 615 | case '\"': input[output_idx++] = '\"'; break; | ||
| 616 | case '\\': input[output_idx++] = '\\'; break; | ||
| 617 | case 'x': | ||
| 618 | // Handle \x12, etc | ||
| 619 | if (input_idx + 2 < input_len) { | ||
| 620 | const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 }; | ||
| 621 | char *err_p = nullptr; | ||
| 622 | const long val = std::strtol(x, &err_p, 16); | ||
| 623 | if (err_p == x + 2) { | ||
| 624 | input_idx += 2; | ||
| 625 | input[output_idx++] = char(val); | ||
| 626 | break; | ||
| 627 | } | ||
| 628 | } | ||
| 629 | // fall through | ||
| 630 | default: input[output_idx++] = '\\'; | ||
| 631 | input[output_idx++] = input[input_idx]; break; | ||
| 632 | } | ||
| 633 | } else { | ||
| 634 | input[output_idx++] = input[input_idx]; | ||
| 635 | } | ||
| 636 | } | ||
| 637 | |||
| 638 | input.resize(output_idx); | ||
| 639 | } | ||
| 640 | |||
| 641 | bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) { | ||
| 642 | const char * sep = strchr(data, '='); | ||
| 643 | if (sep == nullptr || sep - data >= 128) { | ||
| 644 | LOG_ERR("%s: malformed KV override '%s'\n", __func__, data); | ||
| 645 | return false; | ||
| 646 | } | ||
| 647 | llama_model_kv_override kvo; | ||
| 648 | std::strncpy(kvo.key, data, sep - data); | ||
| 649 | kvo.key[sep - data] = 0; | ||
| 650 | sep++; | ||
| 651 | if (strncmp(sep, "int:", 4) == 0) { | ||
| 652 | sep += 4; | ||
| 653 | kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; | ||
| 654 | kvo.val_i64 = std::atol(sep); | ||
| 655 | } else if (strncmp(sep, "float:", 6) == 0) { | ||
| 656 | sep += 6; | ||
| 657 | kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; | ||
| 658 | kvo.val_f64 = std::atof(sep); | ||
| 659 | } else if (strncmp(sep, "bool:", 5) == 0) { | ||
| 660 | sep += 5; | ||
| 661 | kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; | ||
| 662 | if (std::strcmp(sep, "true") == 0) { | ||
| 663 | kvo.val_bool = true; | ||
| 664 | } else if (std::strcmp(sep, "false") == 0) { | ||
| 665 | kvo.val_bool = false; | ||
| 666 | } else { | ||
| 667 | LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data); | ||
| 668 | return false; | ||
| 669 | } | ||
| 670 | } else if (strncmp(sep, "str:", 4) == 0) { | ||
| 671 | sep += 4; | ||
| 672 | kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; | ||
| 673 | if (strlen(sep) > 127) { | ||
| 674 | LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); | ||
| 675 | return false; | ||
| 676 | } | ||
| 677 | strncpy(kvo.val_str, sep, 127); | ||
| 678 | kvo.val_str[127] = '\0'; | ||
| 679 | } else { | ||
| 680 | LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data); | ||
| 681 | return false; | ||
| 682 | } | ||
| 683 | overrides.emplace_back(std::move(kvo)); | ||
| 684 | return true; | ||
| 685 | } | ||
| 686 | |||
| 687 | // | ||
| 688 | // Filesystem utils | ||
| 689 | // | ||
| 690 | |||
| 691 | // Validate if a filename is safe to use | ||
| 692 | // To validate a full path, split the path by the OS-specific path separator, and validate each part with this function | ||
| 693 | bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { | ||
| 694 | if (!filename.length()) { | ||
| 695 | // Empty filename invalid | ||
| 696 | return false; | ||
| 697 | } | ||
| 698 | if (filename.length() > 255) { | ||
| 699 | // Limit at common largest possible filename on Linux filesystems | ||
| 700 | // to avoid unnecessary further validation | ||
| 701 | // (On systems with smaller limits it will be caught by the OS) | ||
| 702 | return false; | ||
| 703 | } | ||
| 704 | |||
| 705 | size_t offset = 0; | ||
| 706 | while (offset < filename.size()) { | ||
| 707 | utf8_parse_result result = parse_utf8_codepoint(filename, offset); | ||
| 708 | |||
| 709 | if (result.status != utf8_parse_result::SUCCESS) { | ||
| 710 | return false; | ||
| 711 | } | ||
| 712 | uint32_t c = result.codepoint; | ||
| 713 | |||
| 714 | if ((result.bytes_consumed == 2 && c < 0x80) || | ||
| 715 | (result.bytes_consumed == 3 && c < 0x800) || | ||
| 716 | (result.bytes_consumed == 4 && c < 0x10000)) { | ||
| 717 | return false; | ||
| 718 | } | ||
| 719 | |||
| 720 | // Check for forbidden codepoints: | ||
| 721 | // - Control characters | ||
| 722 | // - Unicode equivalents of illegal characters | ||
| 723 | // - UTF-16 surrogate pairs | ||
| 724 | // - UTF-8 replacement character | ||
| 725 | // - Byte order mark (BOM) | ||
| 726 | // - Illegal characters: / \ : * ? " < > | | ||
| 727 | if (c <= 0x1F // Control characters (C0) | ||
| 728 | || c == 0x7F // Control characters (DEL) | ||
| 729 | || (c >= 0x80 && c <= 0x9F) // Control characters (C1) | ||
| 730 | || c == 0xFF0E // Fullwidth Full Stop (period equivalent) | ||
| 731 | || c == 0x2215 // Division Slash (forward slash equivalent) | ||
| 732 | || c == 0x2216 // Set Minus (backslash equivalent) | ||
| 733 | || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs | ||
| 734 | || c > 0x10FFFF // Max Unicode limit | ||
| 735 | || c == 0xFFFD // Replacement Character (UTF-8) | ||
| 736 | || c == 0xFEFF // Byte Order Mark (BOM) | ||
| 737 | || c == ':' || c == '*' // Illegal characters | ||
| 738 | || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { | ||
| 739 | return false; | ||
| 740 | } | ||
| 741 | if (!allow_subdirs && (c == '/' || c == '\\')) { | ||
| 742 | // Subdirectories not allowed, reject path separators | ||
| 743 | return false; | ||
| 744 | } | ||
| 745 | offset += result.bytes_consumed; | ||
| 746 | } | ||
| 747 | |||
| 748 | // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename | ||
| 749 | // Unicode and other whitespace is not affected, only 0x20 space | ||
| 750 | if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') { | ||
| 751 | return false; | ||
| 752 | } | ||
| 753 | |||
| 754 | // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead) | ||
| 755 | if (filename.find("..") != std::string::npos) { | ||
| 756 | return false; | ||
| 757 | } | ||
| 758 | |||
| 759 | // Reject "." | ||
| 760 | if (filename == ".") { | ||
| 761 | return false; | ||
| 762 | } | ||
| 763 | |||
| 764 | return true; | ||
| 765 | } | ||
| 766 | |||
| 767 | #include <iostream> | ||
| 768 | |||
| 769 | |||
| 770 | #ifdef _WIN32 | ||
| 771 | static std::wstring utf8_to_wstring(const std::string & str) { | ||
| 772 | if (str.empty()) { | ||
| 773 | return std::wstring(); | ||
| 774 | } | ||
| 775 | |||
| 776 | int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); | ||
| 777 | |||
| 778 | if (size <= 0) { | ||
| 779 | return std::wstring(); | ||
| 780 | } | ||
| 781 | |||
| 782 | std::wstring wstr(size, 0); | ||
| 783 | MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); | ||
| 784 | |||
| 785 | return wstr; | ||
| 786 | } | ||
| 787 | #endif | ||
| 788 | |||
| 789 | // returns true if successful, false otherwise | ||
| 790 | bool fs_create_directory_with_parents(const std::string & path) { | ||
| 791 | #ifdef _WIN32 | ||
| 792 | std::wstring wpath = utf8_to_wstring(path); | ||
| 793 | |||
| 794 | // if the path already exists, check whether it's a directory | ||
| 795 | const DWORD attributes = GetFileAttributesW(wpath.c_str()); | ||
| 796 | if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { | ||
| 797 | return true; | ||
| 798 | } | ||
| 799 | |||
| 800 | size_t pos_slash = 0; | ||
| 801 | |||
| 802 | // process path from front to back, procedurally creating directories | ||
| 803 | while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { | ||
| 804 | const std::wstring subpath = wpath.substr(0, pos_slash); | ||
| 805 | |||
| 806 | pos_slash += 1; | ||
| 807 | |||
| 808 | // skip the drive letter, in some systems it can return an access denied error | ||
| 809 | if (subpath.length() == 2 && subpath[1] == ':') { | ||
| 810 | continue; | ||
| 811 | } | ||
| 812 | |||
| 813 | const bool success = CreateDirectoryW(subpath.c_str(), NULL); | ||
| 814 | |||
| 815 | if (!success) { | ||
| 816 | const DWORD error = GetLastError(); | ||
| 817 | |||
| 818 | // if the path already exists, ensure that it's a directory | ||
| 819 | if (error == ERROR_ALREADY_EXISTS) { | ||
| 820 | const DWORD attributes = GetFileAttributesW(subpath.c_str()); | ||
| 821 | if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { | ||
| 822 | return false; | ||
| 823 | } | ||
| 824 | } else { | ||
| 825 | return false; | ||
| 826 | } | ||
| 827 | } | ||
| 828 | } | ||
| 829 | |||
| 830 | return true; | ||
| 831 | #else | ||
| 832 | // if the path already exists, check whether it's a directory | ||
| 833 | struct stat info; | ||
| 834 | if (stat(path.c_str(), &info) == 0) { | ||
| 835 | return S_ISDIR(info.st_mode); | ||
| 836 | } | ||
| 837 | |||
| 838 | size_t pos_slash = 1; // skip leading slashes for directory creation | ||
| 839 | |||
| 840 | // process path from front to back, procedurally creating directories | ||
| 841 | while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { | ||
| 842 | const std::string subpath = path.substr(0, pos_slash); | ||
| 843 | struct stat info; | ||
| 844 | |||
| 845 | // if the path already exists, ensure that it's a directory | ||
| 846 | if (stat(subpath.c_str(), &info) == 0) { | ||
| 847 | if (!S_ISDIR(info.st_mode)) { | ||
| 848 | return false; | ||
| 849 | } | ||
| 850 | } else { | ||
| 851 | // create parent directories | ||
| 852 | const int ret = mkdir(subpath.c_str(), 0755); | ||
| 853 | if (ret != 0) { | ||
| 854 | return false; | ||
| 855 | } | ||
| 856 | } | ||
| 857 | |||
| 858 | pos_slash += 1; | ||
| 859 | } | ||
| 860 | |||
| 861 | return true; | ||
| 862 | #endif // _WIN32 | ||
| 863 | } | ||
| 864 | |||
| 865 | bool fs_is_directory(const std::string & path) { | ||
| 866 | std::filesystem::path dir(path); | ||
| 867 | return std::filesystem::exists(dir) && std::filesystem::is_directory(dir); | ||
| 868 | } | ||
| 869 | |||
| 870 | std::string fs_get_cache_directory() { | ||
| 871 | std::string cache_directory = ""; | ||
| 872 | auto ensure_trailing_slash = [](std::string p) { | ||
| 873 | // Make sure to add trailing slash | ||
| 874 | if (p.back() != DIRECTORY_SEPARATOR) { | ||
| 875 | p += DIRECTORY_SEPARATOR; | ||
| 876 | } | ||
| 877 | return p; | ||
| 878 | }; | ||
| 879 | if (getenv("LLAMA_CACHE")) { | ||
| 880 | cache_directory = std::getenv("LLAMA_CACHE"); | ||
| 881 | } else { | ||
| 882 | #if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) | ||
| 883 | if (std::getenv("XDG_CACHE_HOME")) { | ||
| 884 | cache_directory = std::getenv("XDG_CACHE_HOME"); | ||
| 885 | } else if (std::getenv("HOME")) { | ||
| 886 | cache_directory = std::getenv("HOME") + std::string("/.cache/"); | ||
| 887 | } else { | ||
| 888 | #if defined(__linux__) | ||
| 889 | /* no $HOME is defined, fallback to getpwuid */ | ||
| 890 | struct passwd *pw = getpwuid(getuid()); | ||
| 891 | if ((!pw) || (!pw->pw_dir)) { | ||
| 892 | throw std::runtime_error("Failed to find $HOME directory"); | ||
| 893 | } | ||
| 894 | |||
| 895 | cache_directory = std::string(pw->pw_dir) + std::string("/.cache/"); | ||
| 896 | #else /* defined(__linux__) */ | ||
| 897 | throw std::runtime_error("Failed to find $HOME directory"); | ||
| 898 | #endif /* defined(__linux__) */ | ||
| 899 | } | ||
| 900 | #elif defined(__APPLE__) | ||
| 901 | cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); | ||
| 902 | #elif defined(_WIN32) | ||
| 903 | cache_directory = std::getenv("LOCALAPPDATA"); | ||
| 904 | #elif defined(__EMSCRIPTEN__) | ||
| 905 | GGML_ABORT("not implemented on this platform"); | ||
| 906 | #else | ||
| 907 | # error Unknown architecture | ||
| 908 | #endif | ||
| 909 | cache_directory = ensure_trailing_slash(cache_directory); | ||
| 910 | cache_directory += "llama.cpp"; | ||
| 911 | } | ||
| 912 | return ensure_trailing_slash(cache_directory); | ||
| 913 | } | ||
| 914 | |||
| 915 | std::string fs_get_cache_file(const std::string & filename) { | ||
| 916 | GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos); | ||
| 917 | std::string cache_directory = fs_get_cache_directory(); | ||
| 918 | const bool success = fs_create_directory_with_parents(cache_directory); | ||
| 919 | if (!success) { | ||
| 920 | throw std::runtime_error("failed to create cache directory: " + cache_directory); | ||
| 921 | } | ||
| 922 | return cache_directory + filename; | ||
| 923 | } | ||
| 924 | |||
| 925 | std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) { | ||
| 926 | std::vector<common_file_info> files; | ||
| 927 | if (path.empty()) return files; | ||
| 928 | |||
| 929 | std::filesystem::path dir(path); | ||
| 930 | if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) { | ||
| 931 | return files; | ||
| 932 | } | ||
| 933 | |||
| 934 | for (const auto & entry : std::filesystem::directory_iterator(dir)) { | ||
| 935 | try { | ||
| 936 | // Only include regular files (skip directories) | ||
| 937 | const auto & p = entry.path(); | ||
| 938 | if (std::filesystem::is_regular_file(p)) { | ||
| 939 | common_file_info info; | ||
| 940 | info.path = p.string(); | ||
| 941 | info.name = p.filename().string(); | ||
| 942 | info.is_dir = false; | ||
| 943 | try { | ||
| 944 | info.size = static_cast<size_t>(std::filesystem::file_size(p)); | ||
| 945 | } catch (const std::filesystem::filesystem_error &) { | ||
| 946 | info.size = 0; | ||
| 947 | } | ||
| 948 | files.push_back(std::move(info)); | ||
| 949 | } else if (include_directories && std::filesystem::is_directory(p)) { | ||
| 950 | common_file_info info; | ||
| 951 | info.path = p.string(); | ||
| 952 | info.name = p.filename().string(); | ||
| 953 | info.size = 0; // Directories have no size | ||
| 954 | info.is_dir = true; | ||
| 955 | files.push_back(std::move(info)); | ||
| 956 | } | ||
| 957 | } catch (const std::filesystem::filesystem_error &) { | ||
| 958 | // skip entries we cannot inspect | ||
| 959 | continue; | ||
| 960 | } | ||
| 961 | } | ||
| 962 | |||
| 963 | return files; | ||
| 964 | } | ||
| 965 | |||
| 966 | // | ||
| 967 | // TTY utils | ||
| 968 | // | ||
| 969 | |||
| 970 | bool tty_can_use_colors() { | ||
| 971 | // Check NO_COLOR environment variable (https://no-color.org/) | ||
| 972 | if (const char * no_color = std::getenv("NO_COLOR")) { | ||
| 973 | if (no_color[0] != '\0') { | ||
| 974 | return false; | ||
| 975 | } | ||
| 976 | } | ||
| 977 | |||
| 978 | // Check TERM environment variable | ||
| 979 | if (const char * term = std::getenv("TERM")) { | ||
| 980 | if (std::strcmp(term, "dumb") == 0) { | ||
| 981 | return false; | ||
| 982 | } | ||
| 983 | } | ||
| 984 | |||
| 985 | // Check if stdout and stderr are connected to a terminal | ||
| 986 | // We check both because log messages can go to either | ||
| 987 | bool stdout_is_tty = isatty(fileno(stdout)); | ||
| 988 | bool stderr_is_tty = isatty(fileno(stderr)); | ||
| 989 | |||
| 990 | return stdout_is_tty || stderr_is_tty; | ||
| 991 | } | ||
| 992 | |||
| 993 | // | ||
| 994 | // Model utils | ||
| 995 | // | ||
| 996 | |||
| 997 | // TODO: move to common/sampling | ||
| 998 | static void common_init_sampler_from_model( | ||
| 999 | const llama_model * model, | ||
| 1000 | common_params_sampling & sparams) { | ||
| 1001 | |||
| 1002 | const uint64_t config = sparams.user_sampling_config; | ||
| 1003 | |||
| 1004 | auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { | ||
| 1005 | if (config & user_config) { | ||
| 1006 | return; | ||
| 1007 | } | ||
| 1008 | |||
| 1009 | char buf[64] = {0}; | ||
| 1010 | if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { | ||
| 1011 | char * end = nullptr; | ||
| 1012 | int32_t v = strtol(buf, &end, 10); | ||
| 1013 | if (end && end != buf) { | ||
| 1014 | dst = v; | ||
| 1015 | } | ||
| 1016 | } | ||
| 1017 | }; | ||
| 1018 | |||
| 1019 | auto get_float = [&](const char * key, float & dst, uint64_t user_config) { | ||
| 1020 | if (config & user_config) { | ||
| 1021 | return; | ||
| 1022 | } | ||
| 1023 | |||
| 1024 | char buf[128] = {0}; | ||
| 1025 | if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { | ||
| 1026 | char * end = nullptr; | ||
| 1027 | float v = strtof(buf, &end); | ||
| 1028 | if (end && end != buf) { | ||
| 1029 | dst = v; | ||
| 1030 | } | ||
| 1031 | } | ||
| 1032 | }; | ||
| 1033 | |||
| 1034 | // Sampling sequence | ||
| 1035 | if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) { | ||
| 1036 | char buf[512] = {0}; | ||
| 1037 | if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) { | ||
| 1038 | const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';'); | ||
| 1039 | if (!sampler_names.empty()) { | ||
| 1040 | sparams.samplers = common_sampler_types_from_names(sampler_names, true); | ||
| 1041 | } | ||
| 1042 | } | ||
| 1043 | } | ||
| 1044 | |||
| 1045 | get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K); | ||
| 1046 | get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P); | ||
| 1047 | get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P); | ||
| 1048 | get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY); | ||
| 1049 | get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD); | ||
| 1050 | get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP); | ||
| 1051 | get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N); | ||
| 1052 | get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT); | ||
| 1053 | get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT); | ||
| 1054 | get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU); | ||
| 1055 | get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); | ||
| 1056 | } | ||
| 1057 | |||
| 1058 | struct common_init_result::impl { | ||
| 1059 | impl() = default; | ||
| 1060 | ~impl() = default; | ||
| 1061 | |||
| 1062 | // note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top | ||
| 1063 | |||
| 1064 | llama_model_ptr model; | ||
| 1065 | llama_context_ptr context; | ||
| 1066 | |||
| 1067 | std::vector<llama_adapter_lora_ptr> lora; | ||
| 1068 | |||
| 1069 | std::vector<common_sampler_ptr> samplers; | ||
| 1070 | std::vector<llama_sampler_seq_config> samplers_seq_config; | ||
| 1071 | }; | ||
| 1072 | |||
| 1073 | common_init_result::common_init_result(common_params & params) : | ||
| 1074 | pimpl(new impl{}) { | ||
| 1075 | auto mparams = common_model_params_to_llama(params); | ||
| 1076 | auto cparams = common_context_params_to_llama(params); | ||
| 1077 | |||
| 1078 | if (params.fit_params) { | ||
| 1079 | LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__); | ||
| 1080 | llama_params_fit(params.model.path.c_str(), &mparams, &cparams, | ||
| 1081 | params.tensor_split, | ||
| 1082 | params.tensor_buft_overrides.data(), | ||
| 1083 | params.fit_params_target.data(), | ||
| 1084 | params.fit_params_min_ctx, | ||
| 1085 | params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); | ||
| 1086 | } | ||
| 1087 | |||
| 1088 | llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); | ||
| 1089 | if (model == NULL) { | ||
| 1090 | return; | ||
| 1091 | } | ||
| 1092 | |||
| 1093 | pimpl->model.reset(model); | ||
| 1094 | |||
| 1095 | const llama_vocab * vocab = llama_model_get_vocab(model); | ||
| 1096 | |||
| 1097 | // load and optionally apply lora adapters (must be loaded before context creation) | ||
| 1098 | for (auto & la : params.lora_adapters) { | ||
| 1099 | llama_adapter_lora_ptr lora; | ||
| 1100 | lora.reset(llama_adapter_lora_init(model, la.path.c_str())); | ||
| 1101 | if (lora == nullptr) { | ||
| 1102 | LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str()); | ||
| 1103 | pimpl->model.reset(model); | ||
| 1104 | return; | ||
| 1105 | } | ||
| 1106 | |||
| 1107 | char buf[1024]; | ||
| 1108 | la.ptr = lora.get(); | ||
| 1109 | llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); | ||
| 1110 | la.task_name = buf; | ||
| 1111 | llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); | ||
| 1112 | la.prompt_prefix = buf; | ||
| 1113 | pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters | ||
| 1114 | } | ||
| 1115 | |||
| 1116 | // updates params.sampling | ||
| 1117 | // TODO: fix naming | ||
| 1118 | common_init_sampler_from_model(model, params.sampling); | ||
| 1119 | |||
| 1120 | if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { | ||
| 1121 | LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); | ||
| 1122 | params.sampling.ignore_eos = false; | ||
| 1123 | } | ||
| 1124 | |||
| 1125 | // initialize once | ||
| 1126 | for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { | ||
| 1127 | if (llama_vocab_is_eog(vocab, i)) { | ||
| 1128 | LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY); | ||
| 1129 | params.sampling.logit_bias_eog.push_back({i, -INFINITY}); | ||
| 1130 | } | ||
| 1131 | } | ||
| 1132 | |||
| 1133 | if (params.sampling.ignore_eos) { | ||
| 1134 | // add EOG biases to the active set of logit biases | ||
| 1135 | params.sampling.logit_bias.insert( | ||
| 1136 | params.sampling.logit_bias.end(), | ||
| 1137 | params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); | ||
| 1138 | } | ||
| 1139 | |||
| 1140 | //if (params.sampling.penalty_last_n == -1) { | ||
| 1141 | // LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); | ||
| 1142 | // params.sampling.penalty_last_n = llama_n_ctx(lctx); | ||
| 1143 | //} | ||
| 1144 | |||
| 1145 | //if (params.sampling.dry_penalty_last_n == -1) { | ||
| 1146 | // LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); | ||
| 1147 | // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); | ||
| 1148 | //} | ||
| 1149 | |||
| 1150 | // init the backend samplers as part of the context creation | ||
| 1151 | pimpl->samplers.resize(cparams.n_seq_max); | ||
| 1152 | pimpl->samplers_seq_config.resize(cparams.n_seq_max); | ||
| 1153 | |||
| 1154 | for (int i = 0; i < (int) cparams.n_seq_max; ++i) { | ||
| 1155 | pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); | ||
| 1156 | pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; | ||
| 1157 | } | ||
| 1158 | |||
| 1159 | if (params.sampling.backend_sampling) { | ||
| 1160 | cparams.samplers = pimpl->samplers_seq_config.data(); | ||
| 1161 | cparams.n_samplers = pimpl->samplers_seq_config.size(); | ||
| 1162 | } | ||
| 1163 | |||
| 1164 | llama_context * lctx = llama_init_from_model(model, cparams); | ||
| 1165 | if (lctx == NULL) { | ||
| 1166 | LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); | ||
| 1167 | return; | ||
| 1168 | } | ||
| 1169 | |||
| 1170 | pimpl->context.reset(lctx); | ||
| 1171 | } | ||
| 1172 | |||
| 1173 | llama_model * common_init_result::model() { | ||
| 1174 | return pimpl->model.get(); | ||
| 1175 | } | ||
| 1176 | |||
| 1177 | llama_context * common_init_result::context() { | ||
| 1178 | return pimpl->context.get(); | ||
| 1179 | } | ||
| 1180 | |||
| 1181 | common_sampler * common_init_result::sampler(llama_seq_id seq_id) { | ||
| 1182 | return pimpl->samplers[seq_id].get(); | ||
| 1183 | } | ||
| 1184 | |||
| 1185 | void common_init_result::reset_samplers() { | ||
| 1186 | for (int i = 0; i < (int) pimpl->samplers.size(); ++i) { | ||
| 1187 | llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get())); | ||
| 1188 | } | ||
| 1189 | } | ||
| 1190 | |||
| 1191 | std::vector<llama_adapter_lora_ptr> & common_init_result::lora() { | ||
| 1192 | return pimpl->lora; | ||
| 1193 | } | ||
| 1194 | |||
| 1195 | common_init_result_ptr common_init_from_params(common_params & params) { | ||
| 1196 | common_init_result_ptr res(new common_init_result(params)); | ||
| 1197 | |||
| 1198 | llama_model * model = res->model(); | ||
| 1199 | if (model == NULL) { | ||
| 1200 | LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); | ||
| 1201 | return res; | ||
| 1202 | } | ||
| 1203 | |||
| 1204 | llama_context * lctx = res->context(); | ||
| 1205 | if (lctx == NULL) { | ||
| 1206 | LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); | ||
| 1207 | return res; | ||
| 1208 | } | ||
| 1209 | |||
| 1210 | const llama_vocab * vocab = llama_model_get_vocab(model); | ||
| 1211 | |||
| 1212 | if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { | ||
| 1213 | LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); | ||
| 1214 | params.ctx_shift = false; | ||
| 1215 | } | ||
| 1216 | |||
| 1217 | if (!params.control_vectors.empty()) { | ||
| 1218 | if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; | ||
| 1219 | if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model); | ||
| 1220 | |||
| 1221 | const auto cvec = common_control_vector_load(params.control_vectors); | ||
| 1222 | if (cvec.n_embd == -1) { | ||
| 1223 | return res; | ||
| 1224 | } | ||
| 1225 | |||
| 1226 | int err = llama_apply_adapter_cvec( | ||
| 1227 | lctx, | ||
| 1228 | cvec.data.data(), | ||
| 1229 | cvec.data.size(), | ||
| 1230 | cvec.n_embd, | ||
| 1231 | params.control_vector_layer_start, | ||
| 1232 | params.control_vector_layer_end); | ||
| 1233 | if (err) { | ||
| 1234 | return res; | ||
| 1235 | } | ||
| 1236 | } | ||
| 1237 | |||
| 1238 | if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) { | ||
| 1239 | bool ok = true; | ||
| 1240 | |||
| 1241 | if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { | ||
| 1242 | LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); | ||
| 1243 | ok = false; | ||
| 1244 | } | ||
| 1245 | |||
| 1246 | bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; | ||
| 1247 | bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL; | ||
| 1248 | bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL; | ||
| 1249 | |||
| 1250 | if (!has_eos && !has_sep && !has_rerank_prompt) { | ||
| 1251 | LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__); | ||
| 1252 | ok = false; | ||
| 1253 | } else if (!has_eos) { | ||
| 1254 | LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__); | ||
| 1255 | } | ||
| 1256 | |||
| 1257 | if (!ok) { | ||
| 1258 | return res; | ||
| 1259 | } | ||
| 1260 | } | ||
| 1261 | |||
| 1262 | if (!params.lora_init_without_apply) { | ||
| 1263 | common_set_adapter_lora(lctx, params.lora_adapters); | ||
| 1264 | } | ||
| 1265 | |||
| 1266 | if (params.warmup) { | ||
| 1267 | LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); | ||
| 1268 | |||
| 1269 | llama_set_warmup(lctx, true); | ||
| 1270 | |||
| 1271 | std::vector<llama_token> tmp; | ||
| 1272 | llama_token bos = llama_vocab_bos(vocab); | ||
| 1273 | llama_token eos = llama_vocab_eos(vocab); | ||
| 1274 | |||
| 1275 | // some models (e.g. T5) don't have a BOS token | ||
| 1276 | if (bos != LLAMA_TOKEN_NULL) { | ||
| 1277 | tmp.push_back(bos); | ||
| 1278 | } | ||
| 1279 | if (eos != LLAMA_TOKEN_NULL) { | ||
| 1280 | tmp.push_back(eos); | ||
| 1281 | } | ||
| 1282 | if (tmp.empty()) { | ||
| 1283 | tmp.push_back(0); | ||
| 1284 | } | ||
| 1285 | |||
| 1286 | if (llama_model_has_encoder(model)) { | ||
| 1287 | llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); | ||
| 1288 | llama_token decoder_start_token_id = llama_model_decoder_start_token(model); | ||
| 1289 | if (decoder_start_token_id == LLAMA_TOKEN_NULL) { | ||
| 1290 | decoder_start_token_id = bos; | ||
| 1291 | } | ||
| 1292 | tmp.clear(); | ||
| 1293 | tmp.push_back(decoder_start_token_id); | ||
| 1294 | } | ||
| 1295 | if (llama_model_has_decoder(model)) { | ||
| 1296 | llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); | ||
| 1297 | } | ||
| 1298 | llama_memory_clear(llama_get_memory(lctx), true); | ||
| 1299 | llama_synchronize(lctx); | ||
| 1300 | llama_perf_context_reset(lctx); | ||
| 1301 | llama_set_warmup(lctx, false); | ||
| 1302 | |||
| 1303 | // reset samplers to reset RNG state after warmup to the seeded state | ||
| 1304 | res->reset_samplers(); | ||
| 1305 | } | ||
| 1306 | |||
| 1307 | return res; | ||
| 1308 | } | ||
| 1309 | |||
| 1310 | common_init_result::~common_init_result() = default; | ||
| 1311 | |||
| 1312 | std::string get_model_endpoint() { | ||
| 1313 | const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); | ||
| 1314 | // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. | ||
| 1315 | const char * hf_endpoint_env = getenv("HF_ENDPOINT"); | ||
| 1316 | const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env; | ||
| 1317 | std::string model_endpoint = "https://huggingface.co/"; | ||
| 1318 | if (endpoint_env) { | ||
| 1319 | model_endpoint = endpoint_env; | ||
| 1320 | if (model_endpoint.back() != '/') { | ||
| 1321 | model_endpoint += '/'; | ||
| 1322 | } | ||
| 1323 | } | ||
| 1324 | return model_endpoint; | ||
| 1325 | } | ||
| 1326 | |||
| 1327 | void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) { | ||
| 1328 | llama_clear_adapter_lora(ctx); | ||
| 1329 | for (auto & la : lora) { | ||
| 1330 | if (la.scale != 0.0f) { | ||
| 1331 | llama_set_adapter_lora(ctx, la.ptr, la.scale); | ||
| 1332 | } | ||
| 1333 | } | ||
| 1334 | } | ||
| 1335 | |||
| 1336 | struct llama_model_params common_model_params_to_llama(common_params & params) { | ||
| 1337 | auto mparams = llama_model_default_params(); | ||
| 1338 | |||
| 1339 | if (!params.devices.empty()) { | ||
| 1340 | mparams.devices = params.devices.data(); | ||
| 1341 | } | ||
| 1342 | |||
| 1343 | mparams.n_gpu_layers = params.n_gpu_layers; | ||
| 1344 | mparams.main_gpu = params.main_gpu; | ||
| 1345 | mparams.split_mode = params.split_mode; | ||
| 1346 | mparams.tensor_split = params.tensor_split; | ||
| 1347 | mparams.use_mmap = params.use_mmap; | ||
| 1348 | mparams.use_direct_io = params.use_direct_io; | ||
| 1349 | mparams.use_mlock = params.use_mlock; | ||
| 1350 | mparams.check_tensors = params.check_tensors; | ||
| 1351 | mparams.use_extra_bufts = !params.no_extra_bufts; | ||
| 1352 | mparams.no_host = params.no_host; | ||
| 1353 | |||
| 1354 | if (params.kv_overrides.empty()) { | ||
| 1355 | mparams.kv_overrides = NULL; | ||
| 1356 | } else { | ||
| 1357 | GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key"); | ||
| 1358 | mparams.kv_overrides = params.kv_overrides.data(); | ||
| 1359 | } | ||
| 1360 | |||
| 1361 | if (params.tensor_buft_overrides.empty()) { | ||
| 1362 | mparams.tensor_buft_overrides = NULL; | ||
| 1363 | } else { | ||
| 1364 | GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern"); | ||
| 1365 | mparams.tensor_buft_overrides = params.tensor_buft_overrides.data(); | ||
| 1366 | } | ||
| 1367 | |||
| 1368 | mparams.progress_callback = params.load_progress_callback; | ||
| 1369 | mparams.progress_callback_user_data = params.load_progress_callback_user_data; | ||
| 1370 | |||
| 1371 | return mparams; | ||
| 1372 | } | ||
| 1373 | |||
| 1374 | struct llama_context_params common_context_params_to_llama(const common_params & params) { | ||
| 1375 | auto cparams = llama_context_default_params(); | ||
| 1376 | |||
| 1377 | cparams.n_ctx = params.n_ctx; | ||
| 1378 | cparams.n_seq_max = params.n_parallel; | ||
| 1379 | cparams.n_batch = params.n_batch; | ||
| 1380 | cparams.n_ubatch = params.n_ubatch; | ||
| 1381 | cparams.n_threads = params.cpuparams.n_threads; | ||
| 1382 | cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? | ||
| 1383 | params.cpuparams.n_threads : params.cpuparams_batch.n_threads; | ||
| 1384 | cparams.embeddings = params.embedding; | ||
| 1385 | cparams.rope_scaling_type = params.rope_scaling_type; | ||
| 1386 | cparams.rope_freq_base = params.rope_freq_base; | ||
| 1387 | cparams.rope_freq_scale = params.rope_freq_scale; | ||
| 1388 | cparams.yarn_ext_factor = params.yarn_ext_factor; | ||
| 1389 | cparams.yarn_attn_factor = params.yarn_attn_factor; | ||
| 1390 | cparams.yarn_beta_fast = params.yarn_beta_fast; | ||
| 1391 | cparams.yarn_beta_slow = params.yarn_beta_slow; | ||
| 1392 | cparams.yarn_orig_ctx = params.yarn_orig_ctx; | ||
| 1393 | cparams.pooling_type = params.pooling_type; | ||
| 1394 | cparams.attention_type = params.attention_type; | ||
| 1395 | cparams.flash_attn_type = params.flash_attn_type; | ||
| 1396 | cparams.cb_eval = params.cb_eval; | ||
| 1397 | cparams.cb_eval_user_data = params.cb_eval_user_data; | ||
| 1398 | cparams.offload_kqv = !params.no_kv_offload; | ||
| 1399 | cparams.no_perf = params.no_perf; | ||
| 1400 | cparams.op_offload = !params.no_op_offload; | ||
| 1401 | cparams.swa_full = params.swa_full; | ||
| 1402 | cparams.kv_unified = params.kv_unified; | ||
| 1403 | |||
| 1404 | cparams.type_k = params.cache_type_k; | ||
| 1405 | cparams.type_v = params.cache_type_v; | ||
| 1406 | |||
| 1407 | return cparams; | ||
| 1408 | } | ||
| 1409 | |||
| 1410 | struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) { | ||
| 1411 | struct ggml_threadpool_params tpp; | ||
| 1412 | |||
| 1413 | ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults | ||
| 1414 | |||
| 1415 | if (params.mask_valid) { | ||
| 1416 | std::memcpy(&tpp.cpumask, ¶ms.cpumask, GGML_MAX_N_THREADS); | ||
| 1417 | } | ||
| 1418 | |||
| 1419 | tpp.prio = params.priority; | ||
| 1420 | tpp.poll = params.poll; | ||
| 1421 | tpp.strict_cpu = params.strict_cpu; | ||
| 1422 | |||
| 1423 | return tpp; | ||
| 1424 | } | ||
| 1425 | |||
| 1426 | // | ||
| 1427 | // Batch utils | ||
| 1428 | // | ||
| 1429 | |||
| 1430 | void common_batch_clear(struct llama_batch & batch) { | ||
| 1431 | batch.n_tokens = 0; | ||
| 1432 | } | ||
| 1433 | |||
| 1434 | void common_batch_add( | ||
| 1435 | struct llama_batch & batch, | ||
| 1436 | llama_token id, | ||
| 1437 | llama_pos pos, | ||
| 1438 | const std::vector<llama_seq_id> & seq_ids, | ||
| 1439 | bool logits) { | ||
| 1440 | GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); | ||
| 1441 | |||
| 1442 | batch.token [batch.n_tokens] = id; | ||
| 1443 | batch.pos [batch.n_tokens] = pos; | ||
| 1444 | batch.n_seq_id[batch.n_tokens] = seq_ids.size(); | ||
| 1445 | for (size_t i = 0; i < seq_ids.size(); ++i) { | ||
| 1446 | batch.seq_id[batch.n_tokens][i] = seq_ids[i]; | ||
| 1447 | } | ||
| 1448 | batch.logits [batch.n_tokens] = logits; | ||
| 1449 | |||
| 1450 | batch.n_tokens++; | ||
| 1451 | } | ||
| 1452 | |||
| 1453 | // | ||
| 1454 | // Vocab utils | ||
| 1455 | // | ||
| 1456 | |||
| 1457 | std::vector<llama_token> common_tokenize( | ||
| 1458 | const struct llama_context * ctx, | ||
| 1459 | const std::string & text, | ||
| 1460 | bool add_special, | ||
| 1461 | bool parse_special) { | ||
| 1462 | const llama_model * model = llama_get_model(ctx); | ||
| 1463 | const llama_vocab * vocab = llama_model_get_vocab(model); | ||
| 1464 | return common_tokenize(vocab, text, add_special, parse_special); | ||
| 1465 | } | ||
| 1466 | |||
| 1467 | std::vector<llama_token> common_tokenize( | ||
| 1468 | const struct llama_vocab * vocab, | ||
| 1469 | const std::string & text, | ||
| 1470 | bool add_special, | ||
| 1471 | bool parse_special) { | ||
| 1472 | // upper limit for the number of tokens | ||
| 1473 | int n_tokens = text.length() + 2 * add_special; | ||
| 1474 | std::vector<llama_token> result(n_tokens); | ||
| 1475 | n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); | ||
| 1476 | if (n_tokens == std::numeric_limits<int32_t>::min()) { | ||
| 1477 | throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit"); | ||
| 1478 | } | ||
| 1479 | if (n_tokens < 0) { | ||
| 1480 | result.resize(-n_tokens); | ||
| 1481 | int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); | ||
| 1482 | GGML_ASSERT(check == -n_tokens); | ||
| 1483 | } else { | ||
| 1484 | result.resize(n_tokens); | ||
| 1485 | } | ||
| 1486 | return result; | ||
| 1487 | } | ||
| 1488 | |||
| 1489 | std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { | ||
| 1490 | const llama_model * model = llama_get_model(ctx); | ||
| 1491 | const llama_vocab * vocab = llama_model_get_vocab(model); | ||
| 1492 | return common_token_to_piece(vocab, token, special); | ||
| 1493 | } | ||
| 1494 | |||
| 1495 | std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) { | ||
| 1496 | std::string piece; | ||
| 1497 | piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' | ||
| 1498 | const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); | ||
| 1499 | if (n_chars < 0) { | ||
| 1500 | piece.resize(-n_chars); | ||
| 1501 | int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); | ||
| 1502 | GGML_ASSERT(check == -n_chars); | ||
| 1503 | } | ||
| 1504 | else { | ||
| 1505 | piece.resize(n_chars); | ||
| 1506 | } | ||
| 1507 | |||
| 1508 | return piece; | ||
| 1509 | } | ||
| 1510 | |||
| 1511 | std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) { | ||
| 1512 | const llama_model * model = llama_get_model(ctx); | ||
| 1513 | const llama_vocab * vocab = llama_model_get_vocab(model); | ||
| 1514 | return common_detokenize(vocab, tokens, special); | ||
| 1515 | } | ||
| 1516 | |||
| 1517 | std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) { | ||
| 1518 | std::string text; | ||
| 1519 | text.resize(std::max(text.capacity(), tokens.size())); | ||
| 1520 | int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); | ||
| 1521 | if (n_chars < 0) { | ||
| 1522 | text.resize(-n_chars); | ||
| 1523 | n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); | ||
| 1524 | GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization | ||
| 1525 | } | ||
| 1526 | |||
| 1527 | text.resize(n_chars); | ||
| 1528 | |||
| 1529 | // NOTE: the original tokenizer decodes bytes after collecting the pieces. | ||
| 1530 | return text; | ||
| 1531 | } | ||
| 1532 | |||
| 1533 | // | ||
| 1534 | // Embedding utils | ||
| 1535 | // | ||
| 1536 | |||
| 1537 | void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) { | ||
| 1538 | double sum = 0.0; | ||
| 1539 | |||
| 1540 | switch (embd_norm) { | ||
| 1541 | case -1: // no normalisation | ||
| 1542 | sum = 1.0; | ||
| 1543 | break; | ||
| 1544 | case 0: // max absolute | ||
| 1545 | for (int i = 0; i < n; i++) { | ||
| 1546 | if (sum < std::abs(inp[i])) { | ||
| 1547 | sum = std::abs(inp[i]); | ||
| 1548 | } | ||
| 1549 | } | ||
| 1550 | sum /= 32760.0; // make an int16 range | ||
| 1551 | break; | ||
| 1552 | case 2: // euclidean | ||
| 1553 | for (int i = 0; i < n; i++) { | ||
| 1554 | sum += inp[i] * inp[i]; | ||
| 1555 | } | ||
| 1556 | sum = std::sqrt(sum); | ||
| 1557 | break; | ||
| 1558 | default: // p-norm (euclidean is p-norm p=2) | ||
| 1559 | for (int i = 0; i < n; i++) { | ||
| 1560 | sum += std::pow(std::abs(inp[i]), embd_norm); | ||
| 1561 | } | ||
| 1562 | sum = std::pow(sum, 1.0 / embd_norm); | ||
| 1563 | break; | ||
| 1564 | } | ||
| 1565 | |||
| 1566 | const float norm = sum > 0.0 ? 1.0 / sum : 0.0f; | ||
| 1567 | |||
| 1568 | for (int i = 0; i < n; i++) { | ||
| 1569 | out[i] = inp[i] * norm; | ||
| 1570 | } | ||
| 1571 | } | ||
| 1572 | |||
| 1573 | float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){ | ||
| 1574 | double sum = 0.0; | ||
| 1575 | double sum1 = 0.0; | ||
| 1576 | double sum2 = 0.0; | ||
| 1577 | |||
| 1578 | for (int i = 0; i < n; i++) { | ||
| 1579 | sum += embd1[i] * embd2[i]; | ||
| 1580 | sum1 += embd1[i] * embd1[i]; | ||
| 1581 | sum2 += embd2[i] * embd2[i]; | ||
| 1582 | } | ||
| 1583 | |||
| 1584 | // Handle the case where one or both vectors are zero vectors | ||
| 1585 | if (sum1 == 0.0 || sum2 == 0.0) { | ||
| 1586 | if (sum1 == 0.0 && sum2 == 0.0) { | ||
| 1587 | return 1.0f; // two zero vectors are similar | ||
| 1588 | } | ||
| 1589 | return 0.0f; | ||
| 1590 | } | ||
| 1591 | |||
| 1592 | return sum / (sqrt(sum1) * sqrt(sum2)); | ||
| 1593 | } | ||
| 1594 | |||
| 1595 | // | ||
| 1596 | // Control vector utils | ||
| 1597 | // | ||
| 1598 | |||
| 1599 | static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) { | ||
| 1600 | common_control_vector_data result = { -1, {} }; | ||
| 1601 | |||
| 1602 | ggml_context * ctx = nullptr; | ||
| 1603 | struct gguf_init_params meta_gguf_params = { | ||
| 1604 | /* .no_alloc = */ false, | ||
| 1605 | /* .ctx = */ &ctx, | ||
| 1606 | }; | ||
| 1607 | struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params); | ||
| 1608 | if (!ctx_gguf) { | ||
| 1609 | LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str()); | ||
| 1610 | return result; | ||
| 1611 | } | ||
| 1612 | |||
| 1613 | int32_t n_tensors = gguf_get_n_tensors(ctx_gguf); | ||
| 1614 | if (n_tensors == 0) { | ||
| 1615 | LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); | ||
| 1616 | } | ||
| 1617 | |||
| 1618 | for (int i = 0; i < n_tensors; i++) { | ||
| 1619 | std::string name = gguf_get_tensor_name(ctx_gguf, i); | ||
| 1620 | |||
| 1621 | int layer_idx = -1; | ||
| 1622 | |||
| 1623 | // split on '.' | ||
| 1624 | size_t dotpos = name.find('.'); | ||
| 1625 | if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") { | ||
| 1626 | try { | ||
| 1627 | layer_idx = std::stoi(name.substr(dotpos + 1)); | ||
| 1628 | } catch (...) { | ||
| 1629 | layer_idx = -1; | ||
| 1630 | } | ||
| 1631 | } | ||
| 1632 | if (layer_idx < 0) { | ||
| 1633 | LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); | ||
| 1634 | result.n_embd = -1; | ||
| 1635 | break; | ||
| 1636 | } else if (layer_idx == 0) { | ||
| 1637 | LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); | ||
| 1638 | result.n_embd = -1; | ||
| 1639 | break; | ||
| 1640 | } | ||
| 1641 | |||
| 1642 | struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str()); | ||
| 1643 | if (tensor->type != GGML_TYPE_F32) { | ||
| 1644 | LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str()); | ||
| 1645 | result.n_embd = -1; | ||
| 1646 | break; | ||
| 1647 | } | ||
| 1648 | if (ggml_n_dims(tensor) != 1) { | ||
| 1649 | LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str()); | ||
| 1650 | result.n_embd = -1; | ||
| 1651 | break; | ||
| 1652 | } | ||
| 1653 | |||
| 1654 | if (result.n_embd == -1) { | ||
| 1655 | result.n_embd = ggml_nelements(tensor); | ||
| 1656 | } else if (ggml_nelements(tensor) != result.n_embd) { | ||
| 1657 | LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str()); | ||
| 1658 | result.n_embd = -1; | ||
| 1659 | break; | ||
| 1660 | } | ||
| 1661 | |||
| 1662 | // extend if necessary - do not store data for layer 0 (it's not used) | ||
| 1663 | result.data.resize(std::max(result.data.size(), static_cast<size_t>(result.n_embd * layer_idx)), 0.0f); | ||
| 1664 | |||
| 1665 | const float * src = (const float *) tensor->data; | ||
| 1666 | float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0] | ||
| 1667 | for (int j = 0; j < result.n_embd; j++) { | ||
| 1668 | dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file | ||
| 1669 | } | ||
| 1670 | |||
| 1671 | } | ||
| 1672 | |||
| 1673 | if (result.n_embd == -1) { | ||
| 1674 | LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str()); | ||
| 1675 | result.data.clear(); | ||
| 1676 | } | ||
| 1677 | |||
| 1678 | gguf_free(ctx_gguf); | ||
| 1679 | ggml_free(ctx); | ||
| 1680 | |||
| 1681 | return result; | ||
| 1682 | } | ||
| 1683 | |||
| 1684 | common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos) { | ||
| 1685 | common_control_vector_data result = { -1, {} }; | ||
| 1686 | |||
| 1687 | for (const auto & info : load_infos) { | ||
| 1688 | auto cur = common_control_vector_load_one(info); | ||
| 1689 | |||
| 1690 | if (cur.n_embd == -1) { | ||
| 1691 | result.n_embd = -1; | ||
| 1692 | break; | ||
| 1693 | } | ||
| 1694 | if (result.n_embd != -1 && result.n_embd != cur.n_embd) { | ||
| 1695 | LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str()); | ||
| 1696 | result.n_embd = -1; | ||
| 1697 | break; | ||
| 1698 | } | ||
| 1699 | |||
| 1700 | if (result.n_embd == -1) { | ||
| 1701 | result = std::move(cur); | ||
| 1702 | } else { | ||
| 1703 | result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f); // extend if necessary | ||
| 1704 | for (size_t i = 0; i < cur.data.size(); i++) { | ||
| 1705 | result.data[i] += cur.data[i]; | ||
| 1706 | } | ||
| 1707 | } | ||
| 1708 | } | ||
| 1709 | |||
| 1710 | if (result.n_embd == -1) { | ||
| 1711 | LOG_ERR("%s: no valid control vector files passed\n", __func__); | ||
| 1712 | result.data.clear(); | ||
| 1713 | } | ||
| 1714 | |||
| 1715 | return result; | ||
| 1716 | } | ||
| 1717 | |||
| 1718 | ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) { | ||
| 1719 | const int64_t ne_datapoint = llama_n_ctx(ctx); | ||
| 1720 | const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride; | ||
| 1721 | ggml_opt_dataset_t result = ggml_opt_dataset_init( | ||
| 1722 | GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1); | ||
| 1723 | |||
| 1724 | llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data; | ||
| 1725 | llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data; | ||
| 1726 | |||
| 1727 | for (int64_t idata = 0; idata < ndata; ++idata) { | ||
| 1728 | memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token)); | ||
| 1729 | memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token)); | ||
| 1730 | } | ||
| 1731 | |||
| 1732 | return result; | ||
| 1733 | } | ||
| 1734 | |||
| 1735 | ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) { | ||
| 1736 | ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr); | ||
| 1737 | const lr_opt & d = *(lr_opt *) userdata; | ||
| 1738 | result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch); | ||
| 1739 | result.sgd.wd = result.adamw.wd = d.wd; | ||
| 1740 | return result; | ||
| 1741 | } | ||
| 1742 | |||
| 1743 | // TODO make all command line args case-insensitive | ||
| 1744 | static inline bool eq_case_insensitive(char const* a, char const* b) { | ||
| 1745 | return ! | ||
| 1746 | #if defined(_MSC_VER) | ||
| 1747 | _stricmp | ||
| 1748 | #else | ||
| 1749 | strcasecmp | ||
| 1750 | #endif // defined(_MSC_VER) | ||
| 1751 | (a, b); | ||
| 1752 | } | ||
| 1753 | |||
| 1754 | enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) { | ||
| 1755 | if (eq_case_insensitive("adamw", n)) { | ||
| 1756 | return GGML_OPT_OPTIMIZER_TYPE_ADAMW; | ||
| 1757 | } | ||
| 1758 | if (eq_case_insensitive("sgd", n)) { | ||
| 1759 | return GGML_OPT_OPTIMIZER_TYPE_SGD; | ||
| 1760 | } | ||
| 1761 | return GGML_OPT_OPTIMIZER_TYPE_COUNT; | ||
| 1762 | } | ||
| 1763 | |||
| 1764 | // TODO simplify to use just log and exp | ||
| 1765 | static float const k_log_2 = std::log(2.f); | ||
| 1766 | |||
| 1767 | void lr_opt::init() { | ||
| 1768 | if (lr_min > 0 && lr_min < lr0) { | ||
| 1769 | float nhalf = std::log(lr0 / lr_min) / k_log_2; | ||
| 1770 | float e = epochs; | ||
| 1771 | if (decay_epochs > 0 && decay_epochs < e) { | ||
| 1772 | e = decay_epochs; | ||
| 1773 | } else { | ||
| 1774 | decay_epochs = e; | ||
| 1775 | } | ||
| 1776 | scale_epoch = nhalf / e; | ||
| 1777 | } | ||
| 1778 | } | ||
| 1779 | |||
| 1780 | float lr_opt::get_lr(float epoch) const { | ||
| 1781 | float r = lr_min <= 0 ? lr0 : | ||
| 1782 | epoch >= decay_epochs ? lr_min : | ||
| 1783 | lr0 * std::pow(0.5f, epoch * scale_epoch); | ||
| 1784 | LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); | ||
| 1785 | return r; | ||
| 1786 | } | ||
