1#pragma once
 2#include "common.h"
 3#include <string>
 4#include <vector>
 5#include <regex>
 6
 7// common debug functions and structs
 8
 9// Print a tensor's detailed data
10// data - the tensor's data in byte format
11// type - the tensor's quantization type
12// ne   - the tensor dimensions array
13// nb   - the tensor strides array
14// n    - the number of rows/columns to fully print
15template <bool abort_on_nan> void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n);
16
17// Intended to use as callback for ggml_backend_sched_eval_callback
18// prints tensors that are processed in the computation graph
19// by default prints all tensors, but can be configured by creating a `base_callback_data` instance with
20// non-empty filter_patterns. See examples/debug.ccp for possible usage patterns
21// The template parameter determins whether an error should be thrown whenever a NaN is encountered
22// in a tensor (useful for stopping debug sessions on first erroneous tensor)
23// The callback data will be passed as the third parameter (user_data)
24template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data);
25struct base_callback_data {
26    std::vector<uint8_t>    data;
27    std::vector<std::regex> tensor_filters;
28
29    base_callback_data() = default;
30
31    base_callback_data(common_params & params, const std::vector<std::string> & filter_patterns) {
32        for (const auto & pattern : filter_patterns) {
33            try {
34                std::string anchored_pattern = "^" + pattern;
35                tensor_filters.emplace_back(anchored_pattern, std::regex::optimize);
36            } catch (const std::regex_error & e) {
37                throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
38            }
39        }
40        params.cb_eval           = common_debug_cb_eval<false>;
41        params.cb_eval_user_data = this;
42    }
43};