1#include "llama-impl.h"
  2
  3#include "gguf.h"
  4#include "llama.h"
  5
  6#include <cinttypes>
  7#include <climits>
  8#include <cstdarg>
  9#include <cstring>
 10#include <vector>
 11#include <sstream>
 12
 13struct llama_logger_state {
 14    ggml_log_callback log_callback = llama_log_callback_default;
 15    void * log_callback_user_data = nullptr;
 16};
 17
 18static llama_logger_state g_logger_state;
 19
 20time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
 21
 22time_meas::~time_meas() {
 23    if (t_start_us >= 0) {
 24        t_acc += ggml_time_us() - t_start_us;
 25    }
 26}
 27
 28void llama_log_get(ggml_log_callback * log_callback, void ** user_data) {
 29    ggml_log_get(log_callback, user_data);
 30}
 31
 32void llama_log_set(ggml_log_callback log_callback, void * user_data) {
 33    ggml_log_set(log_callback, user_data);
 34    g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
 35    g_logger_state.log_callback_user_data = user_data;
 36}
 37
 38static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
 39    va_list args_copy;
 40    va_copy(args_copy, args);
 41    char buffer[128];
 42    int len = vsnprintf(buffer, 128, format, args);
 43    if (len < 128) {
 44        g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
 45    } else {
 46        char * buffer2 = new char[len + 1];
 47        vsnprintf(buffer2, len + 1, format, args_copy);
 48        buffer2[len] = 0;
 49        g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
 50        delete[] buffer2;
 51    }
 52    va_end(args_copy);
 53}
 54
 55void llama_log_internal(ggml_log_level level, const char * format, ...) {
 56    va_list args;
 57    va_start(args, format);
 58    llama_log_internal_v(level, format, args);
 59    va_end(args);
 60}
 61
 62void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
 63    (void) level;
 64    (void) user_data;
 65    fputs(text, stderr);
 66    fflush(stderr);
 67}
 68
 69void replace_all(std::string & s, const std::string & search, const std::string & replace) {
 70    if (search.empty()) {
 71        return;
 72    }
 73    std::string builder;
 74    builder.reserve(s.length());
 75    size_t pos = 0;
 76    size_t last_pos = 0;
 77    while ((pos = s.find(search, last_pos)) != std::string::npos) {
 78        builder.append(s, last_pos, pos - last_pos);
 79        builder.append(replace);
 80        last_pos = pos + search.length();
 81    }
 82    builder.append(s, last_pos, std::string::npos);
 83    s = std::move(builder);
 84}
 85
 86std::string format(const char * fmt, ...) {
 87    va_list ap;
 88    va_list ap2;
 89    va_start(ap, fmt);
 90    va_copy(ap2, ap);
 91    int size = vsnprintf(NULL, 0, fmt, ap);
 92    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
 93    std::vector<char> buf(size + 1);
 94    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
 95    GGML_ASSERT(size2 == size);
 96    va_end(ap2);
 97    va_end(ap);
 98    return std::string(buf.data(), size);
 99}
100
101std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) {
102    char buf[256];
103    snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
104    for (size_t i = 1; i < ne.size(); i++) {
105        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
106    }
107    return buf;
108}
109
110std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
111    char buf[256];
112    snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
113    for (int i = 1; i < GGML_MAX_DIMS; i++) {
114        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
115    }
116    return buf;
117}
118
119static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
120    switch (type) {
121        case GGUF_TYPE_UINT8:   return std::to_string(((const uint8_t  *)data)[i]);
122        case GGUF_TYPE_INT8:    return std::to_string(((const int8_t   *)data)[i]);
123        case GGUF_TYPE_UINT16:  return std::to_string(((const uint16_t *)data)[i]);
124        case GGUF_TYPE_INT16:   return std::to_string(((const int16_t  *)data)[i]);
125        case GGUF_TYPE_UINT32:  return std::to_string(((const uint32_t *)data)[i]);
126        case GGUF_TYPE_INT32:   return std::to_string(((const int32_t  *)data)[i]);
127        case GGUF_TYPE_UINT64:  return std::to_string(((const uint64_t *)data)[i]);
128        case GGUF_TYPE_INT64:   return std::to_string(((const int64_t  *)data)[i]);
129        case GGUF_TYPE_FLOAT32: return std::to_string(((const float    *)data)[i]);
130        case GGUF_TYPE_FLOAT64: return std::to_string(((const double   *)data)[i]);
131        case GGUF_TYPE_BOOL:    return ((const bool *)data)[i] ? "true" : "false";
132        default:                return format("unknown type %d", type);
133    }
134}
135
136std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
137    const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
138
139    switch (type) {
140        case GGUF_TYPE_STRING:
141            return gguf_get_val_str(ctx_gguf, i);
142        case GGUF_TYPE_ARRAY:
143            {
144                const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
145                int arr_n = gguf_get_arr_n(ctx_gguf, i);
146                const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i);
147                std::stringstream ss;
148                ss << "[";
149                for (int j = 0; j < arr_n; j++) {
150                    if (arr_type == GGUF_TYPE_STRING) {
151                        std::string val = gguf_get_arr_str(ctx_gguf, i, j);
152                        // escape quotes
153                        replace_all(val, "\\", "\\\\");
154                        replace_all(val, "\"", "\\\"");
155                        ss << '"' << val << '"';
156                    } else if (arr_type == GGUF_TYPE_ARRAY) {
157                        ss << "???";
158                    } else {
159                        ss << gguf_data_to_str(arr_type, data, j);
160                    }
161                    if (j < arr_n - 1) {
162                        ss << ", ";
163                    }
164                }
165                ss << "]";
166                return ss.str();
167            }
168        default:
169            return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
170    }
171}