summaryrefslogtreecommitdiff
path: root/llama.cpp/common/jinja/caps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/common/jinja/caps.cpp')
-rw-r--r--llama.cpp/common/jinja/caps.cpp285
1 files changed, 285 insertions, 0 deletions
diff --git a/llama.cpp/common/jinja/caps.cpp b/llama.cpp/common/jinja/caps.cpp
new file mode 100644
index 0000000..dbaaed5
--- /dev/null
+++ b/llama.cpp/common/jinja/caps.cpp
@@ -0,0 +1,285 @@
+#include "value.h"
+#include "runtime.h"
+#include "caps.h"
+
+// note: the json dependency is only for defining input in a convenient way
+// we can remove it in the future when we figure out a better way to define inputs using jinja::value
+#include <nlohmann/json.hpp>
+
+#include <functional>
+#include <sstream>
+
+#define FILENAME "jinja-caps"
+
+using json = nlohmann::ordered_json;
+
+namespace jinja {
+
+using caps_json_fn = std::function<json()>;
+using caps_analyze_fn = std::function<void(bool, value &, value &)>;
+
+static void caps_try_execute(jinja::program & prog,
+ const caps_json_fn & messages_fn,
+ const caps_json_fn & tools_fn,
+ const caps_analyze_fn & analyze_fn) {
+ context ctx;
+ ctx.is_get_stats = true;
+ jinja::global_from_json(ctx, json{
+ {"messages", messages_fn()},
+ {"tools", tools_fn()},
+ {"bos_token", ""},
+ {"eos_token", ""},
+ {"add_generation_prompt", true}
+ }, true);
+
+ auto messages = ctx.get_val("messages");
+ auto tools = ctx.get_val("tools");
+
+ bool success = false;
+ try {
+ jinja::runtime runtime(ctx);
+ runtime.execute(prog);
+ success = true;
+ } catch (const std::exception & e) {
+ JJ_DEBUG("Exception during execution: %s", e.what());
+ // ignore exceptions during capability analysis
+ }
+
+ analyze_fn(success, messages, tools);
+}
+
+// for debugging only
+static void caps_print_stats(value & v, const std::string & path) {
+ std::string ops;
+ for (const auto & name : v->stats.ops) {
+ ops += name + " ";
+ }
+ JJ_DEBUG("Value %s, type: %s %s, ops: %s",
+ path.c_str(),
+ v->type().c_str(),
+ v->stats.used ? "(used)" : "",
+ ops.c_str());
+}
+
+std::map<std::string, bool> caps::to_map() const {
+ return {
+ {"supports_string_content", supports_string_content},
+ {"supports_typed_content", supports_typed_content},
+ {"supports_tools", supports_tools},
+ {"supports_tool_calls", supports_tool_calls},
+ {"supports_parallel_tool_calls", supports_parallel_tool_calls},
+ {"supports_system_role", supports_system_role},
+ {"supports_preserve_reasoning", supports_preserve_reasoning},
+ };
+}
+
+std::string caps::to_string() const {
+ std::ostringstream ss;
+ ss << "Caps(\n";
+ for (const auto & [key, value] : to_map()) {
+ ss << " " << key << "=" << (value ? "true" : "false") << "\n";
+ }
+ ss << ")";
+ return ss.str();
+}
+
+caps caps_get(jinja::program & prog) {
+ caps result;
+
+ static const auto has_op = [](value & v, const std::string & op_name) {
+ return v->stats.ops.find(op_name) != v->stats.ops.end();
+ };
+
+ // case: typed content support
+ caps_try_execute(
+ prog,
+ [&]() {
+ // messages
+ return json::array({
+ {
+ {"role", "user"},
+ {"content", "content"}
+ }
+ });
+ },
+ [&]() {
+ // tools
+ return json{nullptr};
+ },
+ [&](bool success, value & messages, value &) {
+ auto & content = messages->at(0)->at("content");
+ caps_print_stats(content, "messages[0].content");
+ if (has_op(content, "selectattr") || has_op(content, "array_access")) {
+ // accessed as an array
+ result.supports_typed_content = true;
+ }
+ if (!success) {
+ // failed to execute with content as string
+ result.supports_string_content = false;
+ }
+ }
+ );
+
+
+ // case: system prompt support
+ caps_try_execute(
+ prog,
+ [&]() {
+ // messages
+ return json::array({
+ {
+ {"role", "system"},
+ {"content", "System message"}
+ },
+ {
+ {"role", "user"},
+ {"content", "User message"}
+ },
+ });
+ },
+ [&]() {
+ // tools
+ return json::array();
+ },
+ [&](bool, value & messages, value &) {
+ auto & content = messages->at(0)->at("content");
+ caps_print_stats(content, "messages[0].content");
+ if (!content->stats.used) {
+ result.supports_system_role = false;
+ }
+ }
+ );
+
+ // case: tools support
+ caps_try_execute(
+ prog,
+ [&]() {
+ // messages
+ return json::array({
+ {
+ {"role", "user"},
+ {"content", "User message"},
+ },
+ {
+ {"role", "assistant"},
+ {"content", "Assistant message"},
+ {"tool_calls", json::array({
+ {
+ {"id", "call1"},
+ {"type", "function"},
+ {"function", {
+ {"name", "tool1"},
+ {"arguments", {
+ {"arg", "value"}
+ }}
+ }}
+ },
+ {
+ {"id", "call2"},
+ {"type", "function"},
+ {"function", {
+ {"name", "tool2"},
+ {"arguments", {
+ {"arg", "value"}
+ }}
+ }}
+ }
+ })}
+ },
+ {
+ {"role", "user"},
+ {"content", "User message"},
+ },
+ });
+ },
+ [&]() {
+ // tools
+ return json::array({
+ {
+ {"name", "tool"},
+ {"type", "function"},
+ {"function", {
+ {"name", "tool"},
+ {"description", "Tool description"},
+ {"parameters", {
+ {"type", "object"},
+ {"properties", {
+ {"arg", {
+ {"type", "string"},
+ {"description", "Arg description"},
+ }},
+ }},
+ {"required", json::array({ "arg" })},
+ }},
+ }},
+ },
+ });
+ },
+ [&](bool success, value & messages, value & tools) {
+ if (!success) {
+ result.supports_tool_calls = false;
+ result.supports_tools = false;
+ return;
+ }
+
+ auto & tool_name = tools->at(0)->at("function")->at("name");
+ caps_print_stats(tool_name, "tools[0].function.name");
+ if (!tool_name->stats.used) {
+ result.supports_tools = false;
+ }
+
+ auto & tool_calls = messages->at(1)->at("tool_calls");;
+ caps_print_stats(tool_calls, "messages[1].tool_calls");
+ if (!tool_calls->stats.used) {
+ result.supports_tool_calls = false;
+ }
+
+ // check for second tool call usage
+ auto & tool_call_1 = tool_calls->at(1)->at("function");
+ caps_print_stats(tool_call_1, "messages[1].tool_calls[1].function");
+ if (!tool_call_1->stats.used) {
+ result.supports_parallel_tool_calls = false;
+ }
+ }
+ );
+
+ // case: preserve reasoning content in chat history
+ caps_try_execute(
+ prog,
+ [&]() {
+ // messages
+ return json::array({
+ {
+ {"role", "user"},
+ {"content", "User message"}
+ },
+ {
+ {"role", "assistant"},
+ {"content", "Assistant message"},
+ {"reasoning_content", "Reasoning content"}
+ },
+ {
+ {"role", "user"},
+ {"content", "User message"}
+ },
+ });
+ },
+ [&]() {
+ // tools
+ return json::array();
+ },
+ [&](bool, value & messages, value &) {
+ auto & content = messages->at(1)->at("reasoning_content");
+ caps_print_stats(content, "messages[1].reasoning_content");
+ if (content->stats.used) {
+ result.supports_preserve_reasoning = true;
+ }
+ }
+ );
+
+ JJ_DEBUG("%s\n", result.to_string().c_str());
+
+ return result;
+}
+
+} // namespace jinja