1#pragma once
  2
  3#include "common.h"
  4
  5#include <chrono>
  6#include <exception>
  7#include <iostream>
  8#include <string>
  9#include <regex>
 10#include <vector>
 11
 12struct testing {
 13    std::ostream &out;
 14    std::vector<std::string> stack;
 15    std::regex filter;
 16    bool filter_tests = false;
 17    bool throw_exception = false;
 18    bool verbose = false;
 19    int tests = 0;
 20    int assertions = 0;
 21    int failures = 0;
 22    int unnamed = 0;
 23    int exceptions = 0;
 24
 25    static constexpr std::size_t status_column = 80;
 26
 27    explicit testing(std::ostream &os = std::cout) : out(os) {}
 28
 29    std::string indent() const {
 30        if (stack.empty()) {
 31            return "";
 32        }
 33        return std::string((stack.size() - 1) * 2, ' ');
 34    }
 35
 36    std::string full_name() const {
 37        return string_join(stack, ".");
 38    }
 39
 40    void log(const std::string & msg) {
 41        if (verbose) {
 42            out << indent() << "  " << msg << "\n";
 43        }
 44    }
 45
 46    void set_filter(const std::string & re) {
 47        filter = std::regex(re);
 48        filter_tests = true;
 49    }
 50
 51    bool should_run() const {
 52        if (filter_tests) {
 53            if (!std::regex_match(full_name(), filter)) {
 54                return false;
 55            }
 56        }
 57        return true;
 58    }
 59
 60    template <typename F>
 61    void run_with_exceptions(F &&f, const char *ctx) {
 62        try {
 63            f();
 64        } catch (const std::exception &e) {
 65            ++failures;
 66            ++exceptions;
 67            out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n";
 68            if (throw_exception) {
 69                throw;
 70            }
 71        } catch (...) {
 72            ++failures;
 73            ++exceptions;
 74            out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n";
 75            if (throw_exception) {
 76                throw;
 77            }
 78        }
 79    }
 80
 81    void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const {
 82        std::string line = indent() + label;
 83
 84        std::string details;
 85        if (new_assertions > 0) {
 86            if (new_failures == 0) {
 87                details = std::to_string(new_assertions) + " assertion(s)";
 88            } else {
 89                details = std::to_string(new_failures) + " of " +
 90                          std::to_string(new_assertions) + " assertion(s) failed";
 91            }
 92        }
 93        if (!extra.empty()) {
 94            if (!details.empty()) {
 95                details += ", ";
 96            }
 97            details += extra;
 98        }
 99
100        if (!details.empty()) {
101            line += " (" + details + ")";
102        }
103
104        std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]";
105
106        if (line.size() + 1 < status_column) {
107            line.append(status_column - line.size(), ' ');
108        } else {
109            line.push_back(' ');
110        }
111
112        out << line << status << "\n";
113    }
114
115    template <typename F>
116    void test(const std::string &name, F f) {
117        stack.push_back(name);
118        if (!should_run()) {
119            stack.pop_back();
120            return;
121        }
122
123        ++tests;
124        out << indent() << name << "\n";
125
126        int before_failures   = failures;
127        int before_assertions = assertions;
128
129        run_with_exceptions([&] { f(*this); }, "test");
130
131        int new_failures   = failures   - before_failures;
132        int new_assertions = assertions - before_assertions;
133
134        print_result(name, new_failures, new_assertions);
135
136        stack.pop_back();
137    }
138
139    template <typename F>
140    void test(F f) {
141        test("test #" + std::to_string(++unnamed), f);
142    }
143
144    template <typename F>
145    void bench(const std::string &name, F f, int iterations = 100) {
146        stack.push_back(name);
147        if (!should_run()) {
148            stack.pop_back();
149            return;
150        }
151
152        ++tests;
153        out << indent() << "[bench] " << name << "\n";
154
155        int before_failures   = failures;
156        int before_assertions = assertions;
157
158        using clock = std::chrono::high_resolution_clock;
159
160        std::chrono::microseconds duration(0);
161
162        run_with_exceptions([&] {
163            for (auto i = 0; i < iterations; i++) {
164                auto start = clock::now();
165                f();
166                duration += std::chrono::duration_cast<std::chrono::microseconds>(clock::now() - start);
167            }
168        }, "bench");
169
170        auto avg_elapsed   = duration.count() / iterations;
171        auto avg_elapsed_s = std::chrono::duration_cast<std::chrono::duration<double>>(duration).count() / iterations;
172        auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0;
173
174        int new_failures   = failures   - before_failures;
175        int new_assertions = assertions - before_assertions;
176
177        std::string extra =
178            "n=" + std::to_string(iterations) +
179            " avg=" + std::to_string(avg_elapsed) + "us" +
180            " rate=" + std::to_string(int(rate)) + "/s";
181
182        print_result("[bench] " + name, new_failures, new_assertions, extra);
183
184        stack.pop_back();
185    }
186
187    template <typename F>
188    void bench(F f, int iterations = 100) {
189        bench("bench #" + std::to_string(++unnamed), f, iterations);
190    }
191
192    // Assertions
193    bool assert_true(bool cond) {
194        return assert_true("", cond);
195    }
196
197    bool assert_true(const std::string &msg, bool cond) {
198        ++assertions;
199        if (!cond) {
200            ++failures;
201            out << indent() << "ASSERTION FAILED";
202            if (!msg.empty()) {
203                out << " : " << msg;
204            }
205            out << "\n";
206            return false;
207        }
208        return true;
209    }
210
211    template <typename A, typename B>
212    bool assert_equal(const A &expected, const B &actual) {
213        return assert_equal("", expected, actual);
214    }
215
216    template <typename A, typename B>
217    bool assert_equal(const std::string &msg, const A &expected, const B &actual) {
218        ++assertions;
219        if (!(actual == expected)) {
220            ++failures;
221            out << indent() << "ASSERT EQUAL FAILED";
222            if (!msg.empty()) {
223                out << " : " << msg;
224            }
225            out << "\n";
226
227            out << indent() << "  expected: " << expected << "\n";
228            out << indent() << "  actual  : " << actual << "\n";
229            return false;
230        }
231        return true;
232    }
233
234    // Print summary and return an exit code
235    int summary() const {
236        out << "\n";
237        out << "tests      : " << tests << "\n";
238        out << "assertions : " << assertions << "\n";
239        out << "failures   : " << failures << "\n";
240        out << "exceptions : " << exceptions << "\n";
241        return failures == 0 ? 0 : 1;
242    }
243};