1#pragma once
2
3#include "common.h"
4
5#include <chrono>
6#include <exception>
7#include <iostream>
8#include <string>
9#include <regex>
10#include <vector>
11
12struct testing {
13 std::ostream &out;
14 std::vector<std::string> stack;
15 std::regex filter;
16 bool filter_tests = false;
17 bool throw_exception = false;
18 bool verbose = false;
19 int tests = 0;
20 int assertions = 0;
21 int failures = 0;
22 int unnamed = 0;
23 int exceptions = 0;
24
25 static constexpr std::size_t status_column = 80;
26
27 explicit testing(std::ostream &os = std::cout) : out(os) {}
28
29 std::string indent() const {
30 if (stack.empty()) {
31 return "";
32 }
33 return std::string((stack.size() - 1) * 2, ' ');
34 }
35
36 std::string full_name() const {
37 return string_join(stack, ".");
38 }
39
40 void log(const std::string & msg) {
41 if (verbose) {
42 out << indent() << " " << msg << "\n";
43 }
44 }
45
46 void set_filter(const std::string & re) {
47 filter = std::regex(re);
48 filter_tests = true;
49 }
50
51 bool should_run() const {
52 if (filter_tests) {
53 if (!std::regex_match(full_name(), filter)) {
54 return false;
55 }
56 }
57 return true;
58 }
59
60 template <typename F>
61 void run_with_exceptions(F &&f, const char *ctx) {
62 try {
63 f();
64 } catch (const std::exception &e) {
65 ++failures;
66 ++exceptions;
67 out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n";
68 if (throw_exception) {
69 throw;
70 }
71 } catch (...) {
72 ++failures;
73 ++exceptions;
74 out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n";
75 if (throw_exception) {
76 throw;
77 }
78 }
79 }
80
81 void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const {
82 std::string line = indent() + label;
83
84 std::string details;
85 if (new_assertions > 0) {
86 if (new_failures == 0) {
87 details = std::to_string(new_assertions) + " assertion(s)";
88 } else {
89 details = std::to_string(new_failures) + " of " +
90 std::to_string(new_assertions) + " assertion(s) failed";
91 }
92 }
93 if (!extra.empty()) {
94 if (!details.empty()) {
95 details += ", ";
96 }
97 details += extra;
98 }
99
100 if (!details.empty()) {
101 line += " (" + details + ")";
102 }
103
104 std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]";
105
106 if (line.size() + 1 < status_column) {
107 line.append(status_column - line.size(), ' ');
108 } else {
109 line.push_back(' ');
110 }
111
112 out << line << status << "\n";
113 }
114
115 template <typename F>
116 void test(const std::string &name, F f) {
117 stack.push_back(name);
118 if (!should_run()) {
119 stack.pop_back();
120 return;
121 }
122
123 ++tests;
124 out << indent() << name << "\n";
125
126 int before_failures = failures;
127 int before_assertions = assertions;
128
129 run_with_exceptions([&] { f(*this); }, "test");
130
131 int new_failures = failures - before_failures;
132 int new_assertions = assertions - before_assertions;
133
134 print_result(name, new_failures, new_assertions);
135
136 stack.pop_back();
137 }
138
139 template <typename F>
140 void test(F f) {
141 test("test #" + std::to_string(++unnamed), f);
142 }
143
144 template <typename F>
145 void bench(const std::string &name, F f, int iterations = 100) {
146 stack.push_back(name);
147 if (!should_run()) {
148 stack.pop_back();
149 return;
150 }
151
152 ++tests;
153 out << indent() << "[bench] " << name << "\n";
154
155 int before_failures = failures;
156 int before_assertions = assertions;
157
158 using clock = std::chrono::high_resolution_clock;
159
160 std::chrono::microseconds duration(0);
161
162 run_with_exceptions([&] {
163 for (auto i = 0; i < iterations; i++) {
164 auto start = clock::now();
165 f();
166 duration += std::chrono::duration_cast<std::chrono::microseconds>(clock::now() - start);
167 }
168 }, "bench");
169
170 auto avg_elapsed = duration.count() / iterations;
171 auto avg_elapsed_s = std::chrono::duration_cast<std::chrono::duration<double>>(duration).count() / iterations;
172 auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0;
173
174 int new_failures = failures - before_failures;
175 int new_assertions = assertions - before_assertions;
176
177 std::string extra =
178 "n=" + std::to_string(iterations) +
179 " avg=" + std::to_string(avg_elapsed) + "us" +
180 " rate=" + std::to_string(int(rate)) + "/s";
181
182 print_result("[bench] " + name, new_failures, new_assertions, extra);
183
184 stack.pop_back();
185 }
186
187 template <typename F>
188 void bench(F f, int iterations = 100) {
189 bench("bench #" + std::to_string(++unnamed), f, iterations);
190 }
191
192 // Assertions
193 bool assert_true(bool cond) {
194 return assert_true("", cond);
195 }
196
197 bool assert_true(const std::string &msg, bool cond) {
198 ++assertions;
199 if (!cond) {
200 ++failures;
201 out << indent() << "ASSERTION FAILED";
202 if (!msg.empty()) {
203 out << " : " << msg;
204 }
205 out << "\n";
206 return false;
207 }
208 return true;
209 }
210
211 template <typename A, typename B>
212 bool assert_equal(const A &expected, const B &actual) {
213 return assert_equal("", expected, actual);
214 }
215
216 template <typename A, typename B>
217 bool assert_equal(const std::string &msg, const A &expected, const B &actual) {
218 ++assertions;
219 if (!(actual == expected)) {
220 ++failures;
221 out << indent() << "ASSERT EQUAL FAILED";
222 if (!msg.empty()) {
223 out << " : " << msg;
224 }
225 out << "\n";
226
227 out << indent() << " expected: " << expected << "\n";
228 out << indent() << " actual : " << actual << "\n";
229 return false;
230 }
231 return true;
232 }
233
234 // Print summary and return an exit code
235 int summary() const {
236 out << "\n";
237 out << "tests : " << tests << "\n";
238 out << "assertions : " << assertions << "\n";
239 out << "failures : " << failures << "\n";
240 out << "exceptions : " << exceptions << "\n";
241 return failures == 0 ? 0 : 1;
242 }
243};