summaryrefslogtreecommitdiff
path: root/llama.cpp/tests/testing.h
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/tests/testing.h')
-rw-r--r--llama.cpp/tests/testing.h243
1 files changed, 243 insertions, 0 deletions
diff --git a/llama.cpp/tests/testing.h b/llama.cpp/tests/testing.h
new file mode 100644
index 0000000..7949483
--- /dev/null
+++ b/llama.cpp/tests/testing.h
@@ -0,0 +1,243 @@
+#pragma once
+
+#include "common.h"
+
+#include <chrono>
+#include <exception>
+#include <iostream>
+#include <string>
+#include <regex>
+#include <vector>
+
+struct testing {
+ std::ostream &out;
+ std::vector<std::string> stack;
+ std::regex filter;
+ bool filter_tests = false;
+ bool throw_exception = false;
+ bool verbose = false;
+ int tests = 0;
+ int assertions = 0;
+ int failures = 0;
+ int unnamed = 0;
+ int exceptions = 0;
+
+ static constexpr std::size_t status_column = 80;
+
+ explicit testing(std::ostream &os = std::cout) : out(os) {}
+
+ std::string indent() const {
+ if (stack.empty()) {
+ return "";
+ }
+ return std::string((stack.size() - 1) * 2, ' ');
+ }
+
+ std::string full_name() const {
+ return string_join(stack, ".");
+ }
+
+ void log(const std::string & msg) {
+ if (verbose) {
+ out << indent() << " " << msg << "\n";
+ }
+ }
+
+ void set_filter(const std::string & re) {
+ filter = std::regex(re);
+ filter_tests = true;
+ }
+
+ bool should_run() const {
+ if (filter_tests) {
+ if (!std::regex_match(full_name(), filter)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ template <typename F>
+ void run_with_exceptions(F &&f, const char *ctx) {
+ try {
+ f();
+ } catch (const std::exception &e) {
+ ++failures;
+ ++exceptions;
+ out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n";
+ if (throw_exception) {
+ throw;
+ }
+ } catch (...) {
+ ++failures;
+ ++exceptions;
+ out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n";
+ if (throw_exception) {
+ throw;
+ }
+ }
+ }
+
+ void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const {
+ std::string line = indent() + label;
+
+ std::string details;
+ if (new_assertions > 0) {
+ if (new_failures == 0) {
+ details = std::to_string(new_assertions) + " assertion(s)";
+ } else {
+ details = std::to_string(new_failures) + " of " +
+ std::to_string(new_assertions) + " assertion(s) failed";
+ }
+ }
+ if (!extra.empty()) {
+ if (!details.empty()) {
+ details += ", ";
+ }
+ details += extra;
+ }
+
+ if (!details.empty()) {
+ line += " (" + details + ")";
+ }
+
+ std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]";
+
+ if (line.size() + 1 < status_column) {
+ line.append(status_column - line.size(), ' ');
+ } else {
+ line.push_back(' ');
+ }
+
+ out << line << status << "\n";
+ }
+
+ template <typename F>
+ void test(const std::string &name, F f) {
+ stack.push_back(name);
+ if (!should_run()) {
+ stack.pop_back();
+ return;
+ }
+
+ ++tests;
+ out << indent() << name << "\n";
+
+ int before_failures = failures;
+ int before_assertions = assertions;
+
+ run_with_exceptions([&] { f(*this); }, "test");
+
+ int new_failures = failures - before_failures;
+ int new_assertions = assertions - before_assertions;
+
+ print_result(name, new_failures, new_assertions);
+
+ stack.pop_back();
+ }
+
+ template <typename F>
+ void test(F f) {
+ test("test #" + std::to_string(++unnamed), f);
+ }
+
+ template <typename F>
+ void bench(const std::string &name, F f, int iterations = 100) {
+ stack.push_back(name);
+ if (!should_run()) {
+ stack.pop_back();
+ return;
+ }
+
+ ++tests;
+ out << indent() << "[bench] " << name << "\n";
+
+ int before_failures = failures;
+ int before_assertions = assertions;
+
+ using clock = std::chrono::high_resolution_clock;
+
+ std::chrono::microseconds duration(0);
+
+ run_with_exceptions([&] {
+ for (auto i = 0; i < iterations; i++) {
+ auto start = clock::now();
+ f();
+ duration += std::chrono::duration_cast<std::chrono::microseconds>(clock::now() - start);
+ }
+ }, "bench");
+
+ auto avg_elapsed = duration.count() / iterations;
+ auto avg_elapsed_s = std::chrono::duration_cast<std::chrono::duration<double>>(duration).count() / iterations;
+ auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0;
+
+ int new_failures = failures - before_failures;
+ int new_assertions = assertions - before_assertions;
+
+ std::string extra =
+ "n=" + std::to_string(iterations) +
+ " avg=" + std::to_string(avg_elapsed) + "us" +
+ " rate=" + std::to_string(int(rate)) + "/s";
+
+ print_result("[bench] " + name, new_failures, new_assertions, extra);
+
+ stack.pop_back();
+ }
+
+ template <typename F>
+ void bench(F f, int iterations = 100) {
+ bench("bench #" + std::to_string(++unnamed), f, iterations);
+ }
+
+ // Assertions
+ bool assert_true(bool cond) {
+ return assert_true("", cond);
+ }
+
+ bool assert_true(const std::string &msg, bool cond) {
+ ++assertions;
+ if (!cond) {
+ ++failures;
+ out << indent() << "ASSERTION FAILED";
+ if (!msg.empty()) {
+ out << " : " << msg;
+ }
+ out << "\n";
+ return false;
+ }
+ return true;
+ }
+
+ template <typename A, typename B>
+ bool assert_equal(const A &expected, const B &actual) {
+ return assert_equal("", expected, actual);
+ }
+
+ template <typename A, typename B>
+ bool assert_equal(const std::string &msg, const A &expected, const B &actual) {
+ ++assertions;
+ if (!(actual == expected)) {
+ ++failures;
+ out << indent() << "ASSERT EQUAL FAILED";
+ if (!msg.empty()) {
+ out << " : " << msg;
+ }
+ out << "\n";
+
+ out << indent() << " expected: " << expected << "\n";
+ out << indent() << " actual : " << actual << "\n";
+ return false;
+ }
+ return true;
+ }
+
+ // Print summary and return an exit code
+ int summary() const {
+ out << "\n";
+ out << "tests : " << tests << "\n";
+ out << "assertions : " << assertions << "\n";
+ out << "failures : " << failures << "\n";
+ out << "exceptions : " << exceptions << "\n";
+ return failures == 0 ? 0 : 1;
+ }
+};