1#include "value.h"
  2#include "runtime.h"
  3#include "caps.h"
  4
  5// note: the json dependency is only for defining input in a convenient way
  6// we can remove it in the future when we figure out a better way to define inputs using jinja::value
  7#include <nlohmann/json.hpp>
  8
  9#include <functional>
 10#include <sstream>
 11
 12#define FILENAME "jinja-caps"
 13
 14using json = nlohmann::ordered_json;
 15
 16namespace jinja {
 17
 18using caps_json_fn = std::function<json()>;
 19using caps_analyze_fn = std::function<void(bool, value &, value &)>;
 20
 21static void caps_try_execute(jinja::program & prog,
 22                             const caps_json_fn & messages_fn,
 23                             const caps_json_fn & tools_fn,
 24                             const caps_analyze_fn & analyze_fn) {
 25    context ctx;
 26    ctx.is_get_stats = true;
 27    jinja::global_from_json(ctx, json{
 28        {"messages", messages_fn()},
 29        {"tools", tools_fn()},
 30        {"bos_token", ""},
 31        {"eos_token", ""},
 32        {"add_generation_prompt", true}
 33    }, true);
 34
 35    auto messages = ctx.get_val("messages");
 36    auto tools = ctx.get_val("tools");
 37
 38    bool success = false;
 39    try {
 40        jinja::runtime runtime(ctx);
 41        runtime.execute(prog);
 42        success = true;
 43    } catch (const std::exception & e) {
 44        JJ_DEBUG("Exception during execution: %s", e.what());
 45        // ignore exceptions during capability analysis
 46    }
 47
 48    analyze_fn(success, messages, tools);
 49}
 50
 51// for debugging only
 52static void caps_print_stats(value & v, const std::string & path) {
 53    std::string ops;
 54    for (const auto & name : v->stats.ops) {
 55        ops += name + " ";
 56    }
 57    JJ_DEBUG("Value %s, type: %s %s, ops: %s",
 58                path.c_str(),
 59                v->type().c_str(),
 60                v->stats.used ? "(used)" : "",
 61                ops.c_str());
 62}
 63
 64std::map<std::string, bool> caps::to_map() const {
 65    return {
 66        {"supports_string_content", supports_string_content},
 67        {"supports_typed_content", supports_typed_content},
 68        {"supports_tools", supports_tools},
 69        {"supports_tool_calls", supports_tool_calls},
 70        {"supports_parallel_tool_calls", supports_parallel_tool_calls},
 71        {"supports_system_role", supports_system_role},
 72        {"supports_preserve_reasoning", supports_preserve_reasoning},
 73    };
 74}
 75
 76std::string caps::to_string() const {
 77    std::ostringstream ss;
 78    ss << "Caps(\n";
 79    for (const auto & [key, value] : to_map()) {
 80        ss << "  " << key << "=" << (value ? "true" : "false") << "\n";
 81    }
 82    ss << ")";
 83    return ss.str();
 84}
 85
 86caps caps_get(jinja::program & prog) {
 87    caps result;
 88
 89    static const auto has_op = [](value & v, const std::string & op_name) {
 90        return v->stats.ops.find(op_name) != v->stats.ops.end();
 91    };
 92
 93    // case: typed content support
 94    caps_try_execute(
 95        prog,
 96        [&]() {
 97            // messages
 98            return json::array({
 99                {
100                    {"role", "user"},
101                    {"content", "content"}
102                }
103            });
104        },
105        [&]() {
106            // tools
107            return json{nullptr};
108        },
109        [&](bool success, value & messages, value &) {
110            auto & content = messages->at(0)->at("content");
111            caps_print_stats(content, "messages[0].content");
112            if (has_op(content, "selectattr") || has_op(content, "array_access")) {
113                // accessed as an array
114                result.supports_typed_content = true;
115            }
116            if (!success) {
117                // failed to execute with content as string
118                result.supports_string_content = false;
119            }
120        }
121    );
122
123
124    // case: system prompt support
125    caps_try_execute(
126        prog,
127        [&]() {
128            // messages
129            return json::array({
130                {
131                    {"role", "system"},
132                    {"content", "System message"}
133                },
134                {
135                    {"role", "user"},
136                    {"content", "User message"}
137                },
138            });
139        },
140        [&]() {
141            // tools
142            return json::array();
143        },
144        [&](bool, value & messages, value &) {
145            auto & content = messages->at(0)->at("content");
146            caps_print_stats(content, "messages[0].content");
147            if (!content->stats.used) {
148                result.supports_system_role = false;
149            }
150        }
151    );
152
153    // case: tools support
154    caps_try_execute(
155        prog,
156        [&]() {
157            // messages
158            return json::array({
159                {
160                    {"role", "user"},
161                    {"content", "User message"},
162                },
163                {
164                    {"role", "assistant"},
165                    {"content", "Assistant message"},
166                    {"tool_calls", json::array({
167                        {
168                            {"id", "call1"},
169                            {"type", "function"},
170                            {"function", {
171                                {"name", "tool1"},
172                                {"arguments", {
173                                    {"arg", "value"}
174                                }}
175                            }}
176                        },
177                        {
178                            {"id", "call2"},
179                            {"type", "function"},
180                            {"function", {
181                                {"name", "tool2"},
182                                {"arguments", {
183                                    {"arg", "value"}
184                                }}
185                            }}
186                        }
187                    })}
188                },
189                {
190                    {"role", "user"},
191                    {"content", "User message"},
192                },
193            });
194        },
195        [&]() {
196            // tools
197            return json::array({
198                {
199                    {"name", "tool"},
200                    {"type", "function"},
201                    {"function", {
202                        {"name", "tool"},
203                        {"description", "Tool description"},
204                        {"parameters", {
205                            {"type", "object"},
206                            {"properties", {
207                                {"arg", {
208                                    {"type", "string"},
209                                    {"description", "Arg description"},
210                                }},
211                            }},
212                            {"required", json::array({ "arg" })},
213                        }},
214                    }},
215                },
216            });
217        },
218        [&](bool success, value & messages, value & tools) {
219            if (!success) {
220                result.supports_tool_calls = false;
221                result.supports_tools = false;
222                return;
223            }
224
225            auto & tool_name = tools->at(0)->at("function")->at("name");
226            caps_print_stats(tool_name, "tools[0].function.name");
227            if (!tool_name->stats.used) {
228                result.supports_tools = false;
229            }
230
231            auto & tool_calls = messages->at(1)->at("tool_calls");;
232            caps_print_stats(tool_calls, "messages[1].tool_calls");
233            if (!tool_calls->stats.used) {
234                result.supports_tool_calls = false;
235            }
236
237            // check for second tool call usage
238            auto & tool_call_1 = tool_calls->at(1)->at("function");
239            caps_print_stats(tool_call_1, "messages[1].tool_calls[1].function");
240            if (!tool_call_1->stats.used) {
241                result.supports_parallel_tool_calls = false;
242            }
243        }
244    );
245
246    // case: preserve reasoning content in chat history
247    caps_try_execute(
248        prog,
249        [&]() {
250            // messages
251            return json::array({
252                {
253                    {"role", "user"},
254                    {"content", "User message"}
255                },
256                {
257                    {"role", "assistant"},
258                    {"content", "Assistant message"},
259                    {"reasoning_content", "Reasoning content"}
260                },
261                {
262                    {"role", "user"},
263                    {"content", "User message"}
264                },
265            });
266        },
267        [&]() {
268            // tools
269            return json::array();
270        },
271        [&](bool, value & messages, value &) {
272            auto & content = messages->at(1)->at("reasoning_content");
273            caps_print_stats(content, "messages[1].reasoning_content");
274            if (content->stats.used) {
275                result.supports_preserve_reasoning = true;
276            }
277        }
278    );
279
280    JJ_DEBUG("%s\n", result.to_string().c_str());
281
282    return result;
283}
284
285} // namespace jinja