1#pragma once
  2
  3#include "string.h"
  4#include "utils.h"
  5
  6#include <algorithm>
  7#include <cmath>
  8#include <cstdint>
  9#include <functional>
 10#include <map>
 11#include <memory>
 12#include <set>
 13#include <sstream>
 14#include <string>
 15#include <unordered_map>
 16#include <vector>
 17
 18namespace jinja {
 19
 20struct value_t;
 21using value = std::shared_ptr<value_t>;
 22
 23
 24// Helper to check the type of a value
 25template<typename T>
 26struct extract_pointee {
 27    using type = T;
 28};
 29template<typename U>
 30struct extract_pointee<std::shared_ptr<U>> {
 31    using type = U;
 32};
 33template<typename T>
 34bool is_val(const value & ptr) {
 35    using PointeeType = typename extract_pointee<T>::type;
 36    return dynamic_cast<const PointeeType*>(ptr.get()) != nullptr;
 37}
 38template<typename T>
 39bool is_val(const value_t * ptr) {
 40    using PointeeType = typename extract_pointee<T>::type;
 41    return dynamic_cast<const PointeeType*>(ptr) != nullptr;
 42}
 43template<typename T, typename... Args>
 44std::shared_ptr<typename extract_pointee<T>::type> mk_val(Args&&... args) {
 45    using PointeeType = typename extract_pointee<T>::type;
 46    return std::make_shared<PointeeType>(std::forward<Args>(args)...);
 47}
 48template<typename T>
 49const typename extract_pointee<T>::type * cast_val(const value & ptr) {
 50    using PointeeType = typename extract_pointee<T>::type;
 51    return dynamic_cast<const PointeeType*>(ptr.get());
 52}
 53template<typename T>
 54typename extract_pointee<T>::type * cast_val(value & ptr) {
 55    using PointeeType = typename extract_pointee<T>::type;
 56    return dynamic_cast<PointeeType*>(ptr.get());
 57}
 58// End Helper
 59
 60
 61struct context; // forward declaration
 62
 63
 64// for converting from JSON to jinja values
 65// example input JSON:
 66// {
 67//   "messages": [
 68//     {"role": "user", "content": "Hello!"},
 69//     {"role": "assistant", "content": "Hi there!"}
 70//   ],
 71//   "bos_token": "<s>",
 72//   "eos_token": "</s>",
 73// }
 74//
 75// to mark strings as user input, wrap them in a special object:
 76// {
 77//   "messages": [
 78//     {
 79//       "role": "user",
 80//       "content": {"__input__": "Hello!"}  // this string is user input
 81//     },
 82//     ...
 83//   ],
 84// }
 85//
 86// marking input can be useful for tracking data provenance
 87// and preventing template injection attacks
 88//
 89// Note: T_JSON can be nlohmann::ordered_json
 90template<typename T_JSON>
 91void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input);
 92
 93//
 94// base value type
 95//
 96
 97struct func_args; // function argument values
 98
 99using func_hptr = value(const func_args &);
100using func_handler = std::function<func_hptr>;
101using func_builtins = std::map<std::string, func_handler>;
102
103enum value_compare_op { eq, ge, gt, lt, ne };
104bool value_compare(const value & a, const value & b, value_compare_op op);
105
106struct value_t {
107    int64_t val_int;
108    double val_flt;
109    string val_str;
110
111    std::vector<value> val_arr;
112    std::vector<std::pair<value, value>> val_obj;
113
114    func_handler val_func;
115
116    // only used if ctx.is_get_stats = true
117    struct stats_t {
118        bool used = false;
119        // ops can be builtin calls or operators: "array_access", "object_access"
120        std::set<std::string> ops;
121    } stats;
122
123    value_t() = default;
124    value_t(const value_t &) = default;
125    virtual ~value_t() = default;
126
127    // Note: only for debugging and error reporting purposes
128    virtual std::string type() const { return ""; }
129
130    virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
131    virtual double as_float() const { throw std::runtime_error(type() + " is not a float value"); }
132    virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
133    virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
134    virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
135    virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
136    virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
137    virtual bool is_none() const { return false; }
138    virtual bool is_undefined() const { return false; }
139    virtual const func_builtins & get_builtins() const {
140        throw std::runtime_error("No builtins available for type " + type());
141    }
142
143    virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
144    virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
145    virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
146    virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
147    virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
148    virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
149    virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
150    virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
151
152    virtual bool is_numeric() const { return false; }
153    virtual bool is_hashable() const { return false; }
154    virtual bool is_immutable() const { return true; }
155    virtual hasher unique_hash() const noexcept = 0;
156    // TODO: C++20 <=> operator
157    // NOTE: We are treating == as equivalent (for normal comparisons) and != as strict nonequal (for strict (is) comparisons)
158    virtual bool operator==(const value_t & other) const { return equivalent(other); }
159    virtual bool operator!=(const value_t & other) const { return nonequal(other); }
160
161    // Note: only for debugging purposes
162    virtual std::string as_repr() const { return as_string().str(); }
163
164protected:
165    virtual bool equivalent(const value_t &) const = 0;
166    virtual bool nonequal(const value_t & other) const { return !equivalent(other); }
167};
168
169//
170// utils
171//
172
173const func_builtins & global_builtins();
174
175std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
176
177// Note: only used for debugging purposes
178std::string value_to_string_repr(const value & val);
179
180struct not_implemented_exception : public std::runtime_error {
181    not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
182};
183
184struct value_hasher {
185    size_t operator()(const value & val) const noexcept {
186        return val->unique_hash().digest();
187    }
188};
189
190struct value_equivalence {
191    bool operator()(const value & lhs, const value & rhs) const {
192        return *lhs == *rhs;
193    }
194    bool operator()(const std::pair<value, value> & lhs, const std::pair<value, value> & rhs) const {
195        return *(lhs.first) == *(rhs.first) && *(lhs.second) == *(rhs.second);
196    }
197};
198
199struct value_equality {
200    bool operator()(const value & lhs, const value & rhs) const {
201        return !(*lhs != *rhs);
202    }
203};
204
205//
206// primitive value types
207//
208
209struct value_int_t : public value_t {
210    value_int_t(int64_t v) {
211        val_int = v;
212        val_flt = static_cast<double>(v);
213        if (static_cast<int64_t>(val_flt) != v) {
214            val_flt = v < 0 ? -INFINITY : INFINITY;
215        }
216    }
217    virtual std::string type() const override { return "Integer"; }
218    virtual int64_t as_int() const override { return val_int; }
219    virtual double as_float() const override { return val_flt; }
220    virtual string as_string() const override { return std::to_string(val_int); }
221    virtual bool as_bool() const override {
222        return val_int != 0;
223    }
224    virtual const func_builtins & get_builtins() const override;
225    virtual bool is_numeric() const override { return true; }
226    virtual bool is_hashable() const override { return true; }
227    virtual hasher unique_hash() const noexcept override {
228        return hasher(typeid(*this))
229            .update(&val_int, sizeof(val_int))
230            .update(&val_flt, sizeof(val_flt));
231    }
232protected:
233    virtual bool equivalent(const value_t & other) const override {
234        return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
235    }
236    virtual bool nonequal(const value_t & other) const override {
237        return !(typeid(*this) == typeid(other) && val_int == other.val_int);
238    }
239};
240using value_int = std::shared_ptr<value_int_t>;
241
242
243struct value_float_t : public value_t {
244    value val;
245    value_float_t(double v) {
246        val_flt = v;
247        val_int = std::isfinite(v) ? static_cast<int64_t>(v) : 0;
248        val = mk_val<value_int>(val_int);
249    }
250    virtual std::string type() const override { return "Float"; }
251    virtual double as_float() const override { return val_flt; }
252    virtual int64_t as_int() const override { return val_int; }
253    virtual string as_string() const override {
254        std::string out = std::to_string(val_flt);
255        out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros
256        if (out.back() == '.') out.push_back('0'); // leave one zero if no decimals
257        return out;
258    }
259    virtual bool as_bool() const override {
260        return val_flt != 0.0;
261    }
262    virtual const func_builtins & get_builtins() const override;
263    virtual bool is_numeric() const override { return true; }
264    virtual bool is_hashable() const override { return true; }
265    virtual hasher unique_hash() const noexcept override {
266        if (static_cast<double>(val_int) == val_flt) {
267            return val->unique_hash();
268        } else {
269            return hasher(typeid(*this))
270                .update(&val_int, sizeof(val_int))
271                .update(&val_flt, sizeof(val_flt));
272        }
273    }
274protected:
275    virtual bool equivalent(const value_t & other) const override {
276        return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
277    }
278    virtual bool nonequal(const value_t & other) const override {
279        return !(typeid(*this) == typeid(other) && val_flt == other.val_flt);
280    }
281};
282using value_float = std::shared_ptr<value_float_t>;
283
284
285struct value_string_t : public value_t {
286    value_string_t() { val_str = string(); }
287    value_string_t(const std::string & v) { val_str = string(v); }
288    value_string_t(const string & v) { val_str = v; }
289    virtual std::string type() const override { return "String"; }
290    virtual string as_string() const override { return val_str; }
291    virtual std::string as_repr() const override {
292        std::ostringstream ss;
293        for (const auto & part : val_str.parts) {
294            ss << (part.is_input ? "INPUT: " : "TMPL:  ") << part.val << "\n";
295        }
296        return ss.str();
297    }
298    virtual bool as_bool() const override {
299        return val_str.length() > 0;
300    }
301    virtual const func_builtins & get_builtins() const override;
302    virtual bool is_hashable() const override { return true; }
303    virtual hasher unique_hash() const noexcept override {
304        const auto type_hash = typeid(*this).hash_code();
305        auto hash = hasher();
306        hash.update(&type_hash, sizeof(type_hash));
307        val_str.hash_update(hash);
308        return hash;
309    }
310    void mark_input() {
311        val_str.mark_input();
312    }
313protected:
314    virtual bool equivalent(const value_t & other) const override {
315        return typeid(*this) == typeid(other) && val_str.str() == other.val_str.str();
316    }
317};
318using value_string = std::shared_ptr<value_string_t>;
319
320
321struct value_bool_t : public value_t {
322    value val;
323    value_bool_t(bool v) {
324        val_int = static_cast<int64_t>(v);
325        val_flt = static_cast<double>(v);
326        val = mk_val<value_int>(val_int);
327    }
328    virtual std::string type() const override { return "Boolean"; }
329    virtual int64_t as_int() const override { return val_int; }
330    virtual bool as_bool() const override { return val_int; }
331    virtual string as_string() const override { return std::string(val_int ? "True" : "False"); }
332    virtual const func_builtins & get_builtins() const override;
333    virtual bool is_numeric() const override { return true; }
334    virtual bool is_hashable() const override { return true; }
335    virtual hasher unique_hash() const noexcept override {
336        return val->unique_hash();
337    }
338protected:
339    virtual bool equivalent(const value_t & other) const override {
340        return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
341    }
342    virtual bool nonequal(const value_t & other) const override {
343        return !(typeid(*this) == typeid(other) && val_int == other.val_int);
344    }
345};
346using value_bool = std::shared_ptr<value_bool_t>;
347
348
349struct value_array_t : public value_t {
350    value_array_t() = default;
351    value_array_t(value & v) {
352        val_arr = v->val_arr;
353    }
354    value_array_t(std::vector<value> && arr) {
355        val_arr = arr;
356    }
357    value_array_t(const std::vector<value> & arr) {
358        val_arr = arr;
359    }
360    void reverse() {
361        if (is_immutable()) {
362            throw std::runtime_error("Attempting to modify immutable type");
363        }
364        std::reverse(val_arr.begin(), val_arr.end());
365    }
366    void push_back(const value & val) {
367        if (is_immutable()) {
368            throw std::runtime_error("Attempting to modify immutable type");
369        }
370        val_arr.push_back(val);
371    }
372    void push_back(value && val) {
373        if (is_immutable()) {
374            throw std::runtime_error("Attempting to modify immutable type");
375        }
376        val_arr.push_back(std::move(val));
377    }
378    value pop_at(int64_t index) {
379        if (is_immutable()) {
380            throw std::runtime_error("Attempting to modify immutable type");
381        }
382        if (index < 0) {
383            index = static_cast<int64_t>(val_arr.size()) + index;
384        }
385        if (index < 0 || index >= static_cast<int64_t>(val_arr.size())) {
386            throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
387        }
388        value val = val_arr.at(static_cast<size_t>(index));
389        val_arr.erase(val_arr.begin() + index);
390        return val;
391    }
392    virtual std::string type() const override { return "Array"; }
393    virtual bool is_immutable() const override { return false; }
394    virtual const std::vector<value> & as_array() const override { return val_arr; }
395    virtual string as_string() const override {
396        const bool immutable = is_immutable();
397        std::ostringstream ss;
398        ss << (immutable ? "(" : "[");
399        for (size_t i = 0; i < val_arr.size(); i++) {
400            if (i > 0) ss << ", ";
401            value val = val_arr.at(i);
402            ss << value_to_string_repr(val);
403        }
404        if (immutable && val_arr.size() == 1) {
405            ss << ",";
406        }
407        ss << (immutable ? ")" : "]");
408        return ss.str();
409    }
410    virtual bool as_bool() const override {
411        return !val_arr.empty();
412    }
413    virtual value & at(int64_t index, value & default_val) override {
414        if (index < 0) {
415            index += val_arr.size();
416        }
417        if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
418            return default_val;
419        }
420        return val_arr[index];
421    }
422    virtual value & at(int64_t index) override {
423        if (index < 0) {
424            index += val_arr.size();
425        }
426        if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
427            throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
428        }
429        return val_arr[index];
430    }
431    virtual const func_builtins & get_builtins() const override;
432    virtual bool is_hashable() const override {
433        if (std::all_of(val_arr.begin(), val_arr.end(), [&](auto & val) -> bool {
434            return val->is_immutable() && val->is_hashable();
435        })) {
436            return true;
437        }
438        return false;
439    }
440    virtual hasher unique_hash() const noexcept override {
441        auto hash = hasher(typeid(*this));
442        for (const auto & val : val_arr) {
443            // must use digest to prevent problems from "concatenation" property of hasher
444            // for ex. hash of [ "ab", "c" ] should be different from [ "a", "bc" ]
445            const size_t val_hash = val->unique_hash().digest();
446            hash.update(&val_hash, sizeof(size_t));
447        }
448        return hash;
449    }
450protected:
451    virtual bool equivalent(const value_t & other) const override {
452        return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
453    }
454};
455using value_array = std::shared_ptr<value_array_t>;
456
457
458struct value_tuple_t : public value_array_t {
459    value_tuple_t(value & v) {
460        val_arr = v->val_arr;
461    }
462    value_tuple_t(std::vector<value> && arr) {
463        val_arr = arr;
464    }
465    value_tuple_t(const std::vector<value> & arr) {
466        val_arr = arr;
467    }
468    value_tuple_t(const std::pair<value, value> & pair) {
469        val_arr.push_back(pair.first);
470        val_arr.push_back(pair.second);
471    }
472    virtual std::string type() const override { return "Tuple"; }
473    virtual bool is_immutable() const override { return true; }
474};
475using value_tuple = std::shared_ptr<value_tuple_t>;
476
477
478struct value_object_t : public value_t {
479    std::unordered_map<value, value, value_hasher, value_equivalence> unordered;
480    bool has_builtins = true; // context and loop objects do not have builtins
481    value_object_t() = default;
482    value_object_t(value & v) {
483        val_obj = v->val_obj;
484        for (const auto & pair : val_obj) {
485            unordered[pair.first] = pair.second;
486        }
487    }
488    value_object_t(const std::map<value, value> & obj) {
489        for (const auto & pair : obj) {
490            insert(pair.first, pair.second);
491        }
492    }
493    value_object_t(const std::vector<std::pair<value, value>> & obj) {
494        for (const auto & pair : obj) {
495            insert(pair.first, pair.second);
496        }
497    }
498    void insert(const std::string & key, const value & val) {
499        insert(mk_val<value_string>(key), val);
500    }
501    virtual std::string type() const override { return "Object"; }
502    virtual bool is_immutable() const override { return false; }
503    virtual const std::vector<std::pair<value, value>> & as_ordered_object() const override { return val_obj; }
504    virtual string as_string() const override {
505        std::ostringstream ss;
506        ss << "{";
507        for (size_t i = 0; i < val_obj.size(); i++) {
508            if (i > 0) ss << ", ";
509            auto & [key, val] = val_obj.at(i);
510            ss << value_to_string_repr(key) << ": " << value_to_string_repr(val);
511        }
512        ss << "}";
513        return ss.str();
514    }
515    virtual bool as_bool() const override {
516        return !unordered.empty();
517    }
518    virtual bool has_key(const value & key) override {
519        if (!key->is_immutable() || !key->is_hashable()) {
520            throw std::runtime_error("Object key of unhashable type: " + key->type());
521        }
522        return unordered.find(key) != unordered.end();
523    }
524    virtual void insert(const value & key, const value & val) override {
525        bool replaced = false;
526        if (is_immutable()) {
527            throw std::runtime_error("Attempting to modify immutable type");
528        }
529        if (has_key(key)) {
530            // if key exists, replace value in ordered list instead of appending
531            for (auto & pair : val_obj) {
532                if (*(pair.first) == *key) {
533                    pair.second = val;
534                    replaced = true;
535                    break;
536                }
537            }
538        }
539        unordered[key] = val;
540        if (!replaced) {
541            val_obj.push_back({key, val});
542        }
543    }
544    virtual value & at(const value & key, value & default_val) override {
545        if (!has_key(key)) {
546            return default_val;
547        }
548        return unordered.at(key);
549    }
550    virtual value & at(const value & key) override {
551        if (!has_key(key)) {
552            throw std::runtime_error("Key '" + key->as_string().str() + "' not found in value of type " + type());
553        }
554        return unordered.at(key);
555    }
556    virtual value & at(const std::string & key, value & default_val) override {
557        value key_val = mk_val<value_string>(key);
558        return at(key_val, default_val);
559    }
560    virtual value & at(const std::string & key) override {
561        value key_val = mk_val<value_string>(key);
562        return at(key_val);
563    }
564    virtual const func_builtins & get_builtins() const override;
565    virtual bool is_hashable() const override {
566        if (std::all_of(val_obj.begin(), val_obj.end(), [&](auto & pair) -> bool {
567            const auto & val = pair.second;
568            return val->is_immutable() && val->is_hashable();
569        })) {
570            return true;
571        }
572        return false;
573    }
574    virtual hasher unique_hash() const noexcept override {
575        auto hash = hasher(typeid(*this));
576        for (const auto & [key, val] : val_obj) {
577            // must use digest to prevent problems from "concatenation" property of hasher
578            // for ex. hash of key="ab", value="c" should be different from key="a", value="bc"
579            const size_t key_hash = key->unique_hash().digest();
580            const size_t val_hash = val->unique_hash().digest();
581            hash.update(&key_hash, sizeof(key_hash));
582            hash.update(&val_hash, sizeof(val_hash));
583        }
584        return hash;
585    }
586protected:
587    virtual bool equivalent(const value_t & other) const override {
588        return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
589    }
590};
591using value_object = std::shared_ptr<value_object_t>;
592
593//
594// none and undefined types
595//
596
597struct value_none_t : public value_t {
598    virtual std::string type() const override { return "None"; }
599    virtual bool is_none() const override { return true; }
600    virtual bool as_bool() const override { return false; }
601    virtual string as_string() const override { return string(type()); }
602    virtual std::string as_repr() const override { return type(); }
603    virtual const func_builtins & get_builtins() const override;
604    virtual bool is_hashable() const override { return true; }
605    virtual hasher unique_hash() const noexcept override {
606        return hasher(typeid(*this));
607    }
608protected:
609    virtual bool equivalent(const value_t & other) const override {
610        return typeid(*this) == typeid(other);
611    }
612};
613using value_none = std::shared_ptr<value_none_t>;
614
615struct value_undefined_t : public value_t {
616    std::string hint; // for debugging, to indicate where undefined came from
617    value_undefined_t(const std::string & h = "") : hint(h) {}
618    virtual std::string type() const override { return hint.empty() ? "Undefined" : "Undefined (hint: '" + hint + "')"; }
619    virtual bool is_undefined() const override { return true; }
620    virtual bool as_bool() const override { return false; }
621    virtual std::string as_repr() const override { return type(); }
622    virtual const func_builtins & get_builtins() const override;
623    virtual hasher unique_hash() const noexcept override {
624        return hasher(typeid(*this));
625    }
626protected:
627    virtual bool equivalent(const value_t & other) const override {
628        return is_undefined() == other.is_undefined();
629    }
630};
631using value_undefined = std::shared_ptr<value_undefined_t>;
632
633//
634// function type
635//
636
637struct func_args {
638public:
639    std::string func_name; // for error messages
640    context & ctx;
641    func_args(context & ctx) : ctx(ctx) {}
642    value get_kwarg(const std::string & key, value default_val) const;
643    value get_kwarg_or_pos(const std::string & key, size_t pos) const;
644    value get_pos(size_t pos) const;
645    value get_pos(size_t pos, value default_val) const;
646    const std::vector<value> & get_args() const;
647    size_t count() const { return args.size(); }
648    void push_back(const value & val);
649    void push_front(const value & val);
650    void ensure_count(size_t min, size_t max = 999) const {
651        size_t n = args.size();
652        if (n < min || n > max) {
653            throw std::runtime_error("Function '" + func_name + "' expected between " + std::to_string(min) + " and " + std::to_string(max) + " arguments, got " + std::to_string(n));
654        }
655    }
656    template<typename T> void ensure_val(const value & ptr) const {
657        if (!is_val<T>(ptr)) {
658            throw std::runtime_error("Function '" + func_name + "' expected value of type " + std::string(typeid(T).name()) + ", got " + ptr->type());
659        }
660    }
661    void ensure_count(bool require0, bool require1, bool require2, bool require3) const {
662        static auto bool_to_int = [](bool b) { return b ? 1 : 0; };
663        size_t required = bool_to_int(require0) + bool_to_int(require1) + bool_to_int(require2) + bool_to_int(require3);
664        ensure_count(required);
665    }
666    template<typename T0> void ensure_vals(bool required0 = true) const {
667        ensure_count(required0, false, false, false);
668        if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
669    }
670    template<typename T0, typename T1> void ensure_vals(bool required0 = true, bool required1 = true) const {
671        ensure_count(required0, required1, false, false);
672        if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
673        if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
674    }
675    template<typename T0, typename T1, typename T2> void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true) const {
676        ensure_count(required0, required1, required2, false);
677        if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
678        if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
679        if (required2 && args.size() > 2) ensure_val<T2>(args[2]);
680    }
681    template<typename T0, typename T1, typename T2, typename T3> void ensure_vals(bool required0 = true, bool required1 = true, bool required2 = true, bool required3 = true) const {
682        ensure_count(required0, required1, required2, required3);
683        if (required0 && args.size() > 0) ensure_val<T0>(args[0]);
684        if (required1 && args.size() > 1) ensure_val<T1>(args[1]);
685        if (required2 && args.size() > 2) ensure_val<T2>(args[2]);
686        if (required3 && args.size() > 3) ensure_val<T3>(args[3]);
687    }
688private:
689    std::vector<value> args;
690};
691
692struct value_func_t : public value_t {
693    std::string name;
694    value arg0; // bound "this" argument, if any
695    value_func_t(const std::string & name, const func_handler & func) : name(name) {
696        val_func = func;
697    }
698    value_func_t(const std::string & name, const func_handler & func, const value & arg_this) : name(name), arg0(arg_this) {
699        val_func = func;
700    }
701    virtual value invoke(const func_args & args) const override {
702        func_args new_args(args); // copy
703        new_args.func_name = name;
704        if (arg0) {
705            new_args.push_front(arg0);
706        }
707        return val_func(new_args);
708    }
709    virtual std::string type() const override { return "Function"; }
710    virtual std::string as_repr() const override { return type() + "<" + name + ">(" + (arg0 ? arg0->as_repr() : "") + ")"; }
711    virtual bool is_hashable() const override { return false; }
712    virtual hasher unique_hash() const noexcept override {
713        // Note: this is unused for now, we don't support function as object keys
714        // use function pointer as unique identifier
715        const auto target = val_func.target<func_hptr>();
716        return hasher(typeid(*this)).update(&target, sizeof(target));
717    }
718protected:
719    virtual bool equivalent(const value_t & other) const override {
720        // Note: this is unused for now, we don't support function as object keys
721        // compare function pointers
722        // (val_func == other.val_func does not work as std::function::operator== is only used for nullptr check)
723        const auto target_this  = this->val_func.target<func_hptr>();
724        const auto target_other = other.val_func.target<func_hptr>();
725        return typeid(*this) == typeid(other) && target_this == target_other;
726    }
727};
728using value_func = std::shared_ptr<value_func_t>;
729
730// special value for kwarg
731struct value_kwarg_t : public value_t {
732    std::string key;
733    value val;
734    value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {}
735    virtual std::string type() const override { return "KwArg"; }
736    virtual std::string as_repr() const override { return type(); }
737    virtual bool is_hashable() const override { return true; }
738    virtual hasher unique_hash() const noexcept override {
739        const auto type_hash = typeid(*this).hash_code();
740        auto hash = val->unique_hash();
741        hash.update(&type_hash, sizeof(type_hash))
742            .update(key.data(), key.size());
743        return hash;
744    }
745protected:
746    virtual bool equivalent(const value_t & other) const override {
747        const value_kwarg_t & other_val = static_cast<const value_kwarg_t &>(other);
748        return typeid(*this) == typeid(other) && key == other_val.key && val == other_val.val;
749    }
750};
751using value_kwarg = std::shared_ptr<value_kwarg_t>;
752
753
754} // namespace jinja