diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 7086d08e5e5e9..7bdc9aab5995f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -48,6 +48,8 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-parser-combinator.cpp + chat-parser-combinator.h chat-parser.cpp chat-parser.h chat.cpp diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp new file mode 100644 index 0000000000000..331f00198ea56 --- /dev/null +++ b/common/chat-parser-combinator.cpp @@ -0,0 +1,1509 @@ +#include "chat-parser-combinator.h" +#include "json-schema-to-grammar.h" +#include "common.h" +#include "chat.h" +#include "log.h" + +#include + +#include +#include + +enum parser_type { + PARSER_LITERAL = 0, + PARSER_SEQUENCE = 1, + PARSER_CHOICE = 2, + PARSER_REPETITION = 3, + PARSER_OPTIONAL = 4, + PARSER_ZERO_OR_MORE = 5, + PARSER_ONE_OR_MORE = 6, + PARSER_NOT = 7, + PARSER_ANY = 8, + PARSER_CHARS = 9, + PARSER_RULE = 10, + PARSER_UNTIL = 11, + PARSER_SPACE = 12, + PARSER_SCHEMA = 13, + PARSER_ROOT = 14, + PARSER_JSON_STRING = 15, + PARSER_ACTION = 16, +}; + +class parser_visitor; + +class parser_base { + protected: + int id_; + + public: + parser_base(int id) : id_(id) {} + virtual ~parser_base() = default; + + int id() const { return id_; } + void set_id(int id) { id_ = id; } + + virtual parser_type type() const = 0; + + // Template Method: handles caching, delegates to parse_uncached() + virtual parser_result parse(parser_context & ctx, size_t start = 0) { + if (id_ == -1) { + // Don't cache parsers with ID -1 (from operators) + return parse_uncached(ctx, start); + } + + // Check cache + auto cached = ctx.memo.get(id_, start); + if (cached) { + return *cached; + } + + // Execute and cache + auto result = parse_uncached(ctx, start); + return ctx.memo.set(id_, start, result); + } + + // Actual parsing implementation (to be overridden by subclasses) + virtual parser_result parse_uncached(parser_context & ctx, size_t start = 0) = 0; + + virtual void assign_id(std::shared_ptr counter) { + if (id_ == -1) { + id_ = counter->next(); + } + } + + virtual std::string dump() const = 0; + virtual void accept(parser_visitor & visitor) = 0; +}; + +// We define our own space function because MSVC's std::isspace() +// crashes for non-printable characters in Debug builds. +static bool is_space(const char c) { + return (c == ' ' || c == '\t' || c == '\n'); +} + +static bool is_hex_digit(const char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); +} + +// Unescapes a JSON string (without the surrounding quotes) +// Uses nlohmann::json::parse to handle all JSON escape sequences +static std::string unescape_json_string(std::string_view str) { + try { + // Wrap in quotes and parse as JSON string + std::string quoted = "\"" + std::string(str) + "\""; + auto parsed = nlohmann::json::parse(quoted); + if (parsed.is_string()) { + return parsed.get(); + } + } catch (...) { + // If parsing fails, return original string + } + return std::string(str); +} + +// Matches an exact literal string. +// S -> "hello" +class literal_parser : public parser_base { + std::string literal_; + + public: + literal_parser(const std::string & literal, int id) : parser_base(id), literal_(literal) {} + + parser_type type() const override { return PARSER_LITERAL; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + for (auto i = 0u; i < literal_.size(); ++i) { + if (pos >= ctx.input.size()) { + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + if (i > 0) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + return parser_result(PARSER_RESULT_FAIL, start, pos); + } + if (ctx.input[pos] != literal_[i]) { + return parser_result(PARSER_RESULT_FAIL, start); + } + ++pos; + } + + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + } + + std::string dump() const override { + return "Literal(" + literal_ + ")"; + } + + void accept(parser_visitor & visitor) override; + + const std::string & literal() const { return literal_; } +}; + +// Matches a sequence of parsers in order, all must succeed. +// S -> A B C +class sequence_parser : public parser_base { + std::vector parsers_; + + public: + sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { + for (const auto & p : parsers) { + if (p->type() == PARSER_SEQUENCE) { + // Flatten sequences + auto seq = std::static_pointer_cast(p.ptr()); + for (const auto & embedded : seq->parsers()) { + parsers_.push_back(embedded); + } + } else { + parsers_.push_back(p); + } + } + } + + parser_type type() const override { return PARSER_SEQUENCE; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); + + if (result.is_fail()) { + if (result.end >= ctx.input.size() && !ctx.input_is_complete) { + // If we fail because we don't have enough input, then return success + return parser_result(PARSER_RESULT_SUCCESS, start, result.end); + } + return parser_result(PARSER_RESULT_FAIL, start, result.end); + } + + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end); + } + + pos = result.end; + } + + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + } + + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + for (auto & p : parsers_) { + p->assign_id(counter); + } + } + + std::string dump() const override { + std::vector parts; + parts.reserve(parsers_.size()); + for (const auto & p : parsers_) { + parts.push_back(p->dump()); + } + return "Sequence(" + string_join(parts, ", ") + ")"; + } + + void accept(parser_visitor & visitor) override; + + const std::vector & parsers() const { return parsers_; } +}; + +// Matches the first parser that succeeds from a list of alternatives. +// S -> A | B | C +class choice_parser : public parser_base { + std::vector parsers_; + + public: + choice_parser(std::initializer_list parsers, int id) : parser_base(id) { + for (const auto & p : parsers) { + if (p->type() == PARSER_CHOICE) { + // Flatten choices + auto choice = std::static_pointer_cast(p.ptr()); + for (const auto & embedded : choice->parsers()) { + parsers_.push_back(embedded); + } + } else { + parsers_.push_back(p); + } + } + } + + parser_type type() const override { return PARSER_CHOICE; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); + + if (result.is_success()) { + return result; + } + + if (result.is_need_more_input()) { + return result; + } + } + + return parser_result(PARSER_RESULT_FAIL, start); + } + + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + for (auto & p : parsers_) { + p->assign_id(counter); + } + } + + std::string dump() const override { + std::vector parts; + parts.reserve(parsers_.size()); + for (const auto & p : parsers_) { + parts.push_back(p->dump()); + } + return "Choice(" + string_join(parts, ", ") + ")"; + } + + void accept(parser_visitor & visitor) override; + + const std::vector & parsers() const { return parsers_; } +}; + +// Matches between min and max repetitions of a parser (inclusive). +// S -> A{m,n} +// Use -1 for max_count to represent unbounded repetition (equivalent to {m,}) +class repetition_parser : public parser_base { + parser parser_; + int min_count_; + int max_count_; + + public: + repetition_parser(const parser & parser, int min_count, int max_count, int id) + : parser_base(id), parser_(parser), min_count_(min_count), max_count_(max_count) {} + + parser_type type() const override { return PARSER_REPETITION; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + int match_count = 0; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (max_count_ == -1 || match_count < max_count_) { + auto result = parser_->parse(ctx, pos); + + if (result.is_success()) { + // Prevent infinite loop on empty matches + if (result.end == pos) { + break; + } + pos = result.end; + match_count++; + continue; + } + + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + + // Child failed - stop trying + break; + } + + // Check if we got enough matches + if (match_count < min_count_) { + return parser_result(PARSER_RESULT_FAIL, start, pos); + } + + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + } + + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + parser_->assign_id(counter); + } + + std::string dump() const override { + if (max_count_ == -1) { + return "Repetition(" + parser_->dump() + ", " + std::to_string(min_count_) + ", unbounded)"; + } + return "Repetition(" + parser_->dump() + ", " + std::to_string(min_count_) + ", " + std::to_string(max_count_) + ")"; + } + + void accept(parser_visitor & visitor) override; + + const parser & child() const { return parser_; } + + int min_count() const { return min_count_; } + + int max_count() const { return max_count_; } +}; + +// Matches one or more repetitions of a parser. +// S -> A+ +class one_or_more_parser : public repetition_parser { + public: + one_or_more_parser(const parser & p, int id) : repetition_parser(p, 1, -1, id) {} + + parser_type type() const override { return PARSER_ONE_OR_MORE; } + + std::string dump() const override { + return "OneOrMore(" + child()->dump() + ")"; + } + + void accept(parser_visitor & visitor) override; +}; + +// Matches zero or more repetitions of a parser, always succeeds. +// S -> A* +class zero_or_more_parser : public repetition_parser { + public: + zero_or_more_parser(const parser & p, int id) : repetition_parser(p, 0, -1, id) {} + + parser_type type() const override { return PARSER_ZERO_OR_MORE; } + + std::string dump() const override { + return "ZeroOrMore(" + child()->dump() + ")"; + } + + void accept(parser_visitor & visitor) override; +}; + +// Matches zero or one occurrence of a parser, always succeeds. +// S -> A? +class optional_parser : public repetition_parser { + public: + optional_parser(const parser & p, int id) : repetition_parser(p, 0, 1, id) {} + + parser_type type() const override { return PARSER_OPTIONAL; } + + std::string dump() const override { + return "Optional(" + child()->dump() + ")"; + } + + void accept(parser_visitor & visitor) override; +}; + +// Negative lookahead: succeeds if child parser fails, consumes no input. +// S -> !A +class not_parser : public parser_base { + parser parser_; + + public: + not_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_NOT; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto result = parser_->parse(ctx, start); + + if (result.is_success()) { + // Fail if the underlying parser matches + return parser_result(PARSER_RESULT_FAIL, start); + } + + if (result.is_need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } + + // Child failed, so negation succeeds + return parser_result(PARSER_RESULT_SUCCESS, start); + } + + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + parser_->assign_id(counter); + } + + std::string dump() const override { + return "Not(" + parser_->dump() + ")"; + } + + void accept(parser_visitor & visitor) override; + + const parser & child() const { return parser_; } +}; + +// Matches any single character. +// S -> . +class any_parser : public parser_base { + public: + any_parser(int id) : parser_base(id) {} + + parser_type type() const override { return PARSER_ANY; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_FAIL, start); + } + + return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); + } + + std::string dump() const override { + return "Any"; + } + + void accept(parser_visitor & visitor) override; +}; + +// Matches zero or more whitespace characters (space, tab, newline). +// S -> [ \t\n]* +class space_parser : public parser_base { + public: + space_parser(int id) : parser_base(id) {} + + parser_type type() const override { return PARSER_SPACE; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + if (is_space(c)) { + ++pos; + } else { + break; + } + } + + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + } + + std::string dump() const override { + return "Space"; + } + + void accept(parser_visitor & visitor) override; +}; + +// Matches between min and max repetitions of characters from a character class. +// S -> [a-z]{m,n} +class chars_parser : public parser_base { + struct char_range { + int start; + int end; + + bool contains(char c) const { return (int)c >= start && int(c) <= end; } + }; + + std::string pattern_; + std::vector ranges_; + bool negated_; + int min_count_; + int max_count_; + + public: + chars_parser(const std::string & classes, int min_count, int max_count, int id) + : parser_base(id), pattern_(classes), negated_(false), min_count_(min_count), max_count_(max_count) { + + std::string content = classes; + if (content.front() == '[') { + content = content.substr(1); + } + + if (content.back() == ']') { + content.pop_back(); + } + + // Check for negation + if (!content.empty() && content.front() == '^') { + negated_ = true; + content = content.substr(1); + } + + auto parse_char = [&](size_t pos) -> std::pair { + if (content[pos] == '\\' && pos + 1 < content.length()) { + char next = content[pos + 1]; + switch (next) { + case 'n': return {'\n', 2}; + case 't': return {'\t', 2}; + case 'r': return {'\r', 2}; + case '\\': return {'\\', 2}; + case ']': return {']', 2}; + case '-': return {'-', 2}; + case '[': return {'[', 2}; + default: return {next, 2}; // Treat as literal escaped character + } + } + return {content[pos], 1}; + }; + + size_t i = 0; + while (i < content.length()) { + auto [start, start_len] = parse_char(i); + i += start_len; + + if (i + 1 < content.length() && content[i] == '-') { + // Range detected + auto [end, end_len] = parse_char(i + 1); + ranges_.push_back(char_range{start, end}); + i += 1 + end_len; + } else { + ranges_.push_back(char_range{start, start}); + } + } + } + + parser_type type() const override { return PARSER_CHARS; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + int match_count = 0; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (max_count_ == -1 || match_count < max_count_) { + if (pos >= ctx.input.size()) { + break; + } + + bool matches = false; + for (const auto & range : ranges_) { + if (range.contains(ctx.input[pos])) { + matches = true; + break; + } + } + + // If negated, invert the match result + if (negated_) { + matches = !matches; + } + + if (matches) { + ++pos; + ++match_count; + } else { + break; + } + } + + // Check if we got enough matches + if (match_count < min_count_) { + return parser_result(PARSER_RESULT_FAIL, start); + } + + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + } + + std::string dump() const override { + if (max_count_ == -1) { + return "CharRepeat(" + pattern_ + ", " + std::to_string(min_count_) + ", unbounded)"; + } + return "CharRepeat(" + pattern_ + ", " + std::to_string(min_count_) + ", " + std::to_string(max_count_) + ")"; + } + + void accept(parser_visitor & visitor) override; + + const std::string & pattern() const { return pattern_; } + + int min_count() const { return min_count_; } + + int max_count() const { return max_count_; } +}; + +// Specialized parser for JSON string content (without quotes). +// Parses the content between quotes with single-pass streaming support. +// Stops before the closing quote (doesn't consume it). +// Handles escape sequences and emits NEED_MORE_INPUT for incomplete input. +// S -> (regular chars and escape sequences)* until closing " +class json_string_parser : public parser_base { + + public: + json_string_parser(int id) : parser_base(id) {} + + parser_type type() const override { return PARSER_JSON_STRING; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + + // Parse string content (without quotes) + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + + if (c == '"') { + // Found closing quote - success (don't consume it) + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + } + + if (c == '\\') { + // Handle escape sequence + ++pos; + if (pos >= ctx.input.size()) { + // Mid-escape sequence + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + + char escape = ctx.input[pos]; + switch (escape) { + case '"': + case '\\': + case '/': + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + // Valid escape + ++pos; + break; + + case 'u': + // Unicode escape: must be followed by 4 hex digits + ++pos; + for (int i = 0; i < 4; ++i) { + if (pos >= ctx.input.size()) { + // Incomplete unicode escape + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + if (!is_hex_digit(ctx.input[pos])) { + return parser_result(PARSER_RESULT_FAIL, start); + } + ++pos; + } + break; + + default: + // Invalid escape sequence + return parser_result(PARSER_RESULT_FAIL, start); + } + } else { + // Regular character + ++pos; + } + } + + // Reached end without finding closing quote + return parser_result(PARSER_RESULT_FAIL, start, pos); + } + + std::string dump() const override { + return "JsonString()"; + } + + void accept(parser_visitor & visitor) override; +}; + +// Matches all characters until a delimiter is found (delimiter not consumed). +// S -> (!delim .)* +class until_parser : public parser_base { + std::string delimiter_; + bool consume_spaces_; + + std::default_searcher searcher_; + + public: + until_parser(const std::string & delimiter, bool consume_spaces, int id) + : parser_base(id), delimiter_(delimiter), consume_spaces_(consume_spaces), searcher_(delimiter_.begin(), delimiter_.end()) { + } + + parser_type type() const override { return PARSER_UNTIL; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + parser_result result(PARSER_RESULT_SUCCESS, start, ctx.input.size()); + + // Search for the delimiter + const auto it = std::search(ctx.input.begin(), ctx.input.end(), searcher_); + + if (it != ctx.input.end()) { + result.type = PARSER_RESULT_SUCCESS; + result.end = std::distance(ctx.input.begin(), it); + } else { + // If not found, check if the input ends with a prefix of the delimiter + size_t max_overlap = std::min(ctx.input.size(), delimiter_.size() - 1); + for (size_t overlap = max_overlap; overlap > 0; --overlap) { + if (std::equal(ctx.input.end() - overlap, ctx.input.end(), delimiter_.begin())) { + result.type = (ctx.input_is_complete) ? PARSER_RESULT_FAIL : PARSER_RESULT_NEED_MORE_INPUT; + result.end = ctx.input.size() - overlap; + } + } + } + + if (consume_spaces_) { + // Remove trailing spaces + while (result.end > start && is_space(ctx.input[result.end - 1])) { + result.end--; + } + } + + return result; + } + + std::string dump() const override { + return "Until(" + delimiter_ + ")"; + } + + void accept(parser_visitor & visitor) override; + + const std::string & delimiter() const { return delimiter_; } +}; + +// Wraps a parser with JSON schema metadata for grammar generation. +// Used internally to convert JSON schemas to GBNF grammar rules. +class schema_parser : public parser_base { + parser parser_; + std::string name_; + nlohmann::ordered_json schema_; + + public: + schema_parser(const parser & parser, const std::string & name, const nlohmann::ordered_json & schema, int id) + : parser_base(id), parser_(parser), name_(name), schema_(schema) {} + + parser_type type() const override { return PARSER_SCHEMA; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + return parser_->parse(ctx, start); + } + + std::string dump() const override { + return "Schema(" + parser_->dump() + ", " + schema_.dump() + ")"; + } + + void accept(parser_visitor & visitor) override; + + const parser & child() const { return parser_; } + + const std::string & name() const { return name_; } + + const nlohmann::ordered_json & schema() const { return schema_; } +}; + +// References a named rule for recursive or reusable grammar definitions. +// expr -> term | expr "+" term +class rule_parser : public parser_base { + std::string name_; + std::weak_ptr> rules_; + + public: + rule_parser(const std::string & name, const std::shared_ptr> & rules, int id) + : parser_base(id), name_(name), rules_(rules) {} + + parser_type type() const override { return PARSER_RULE; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto rules = rules_.lock(); + if (!rules) { + LOG_ERR("rule_parser::parse called with expired rule registry\n"); + return parser_result(PARSER_RESULT_FAIL, start); + } + + auto it = rules->find(name_); + if (it == rules->end()) { + LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); + return parser_result(PARSER_RESULT_FAIL, start); + } + + return it->second->parse(ctx, start); + } + + std::string dump() const override { + return "Rule(" + name_ + ")"; + } + + void accept(parser_visitor & visitor) override; + + const std::string & name() const { return name_; } +}; + +// Container for the root parser and all named rules in the grammar. +// Manages ownership of rule registry to enable recursive grammar definitions. +class root_parser : public parser_base { + parser root_; + std::shared_ptr> rules_; + + friend class parser_visitor; + + public: + root_parser(const parser & root, std::shared_ptr> rules, int id) + : parser_base(id), root_(root), rules_(std::move(rules)) {} + + parser_type type() const override { return PARSER_ROOT; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + return root_->parse(ctx, start); + } + + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + root_->assign_id(counter); + } + + std::string dump() const override { + return root_->dump(); + } + + void accept(parser_visitor & visitor) override; + + const parser & root() const { return root_; } + + std::shared_ptr> rules() const { return rules_; } +}; + +// Wraps a parser with a semantic action callback. +class action_parser : public parser_base { + parser parser_; + std::function action_; + + public: + action_parser(const parser & parser, std::function action, int id) + : parser_base(id), parser_(parser), action_(std::move(action)) {} + + parser_type type() const override { return PARSER_ACTION; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto result = parser_->parse(ctx, start); + + // Invoke action callback on success if environment is available + if (result.is_success() && ctx.env && action_) { + std::string_view matched = ctx.input.substr(result.start, result.end - result.start); + action_(result, matched, *ctx.env); + } + + return result; + } + + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + parser_->assign_id(counter); + } + + std::string dump() const override { + return "Action(" + parser_->dump() + ")"; + } + + void accept(parser_visitor & visitor) override; + + const parser & child() const { return parser_; } +}; + +// Base visitor class for parser tree traversal +class parser_visitor { + public: + virtual ~parser_visitor() = default; + + virtual void visit(literal_parser & p) = 0; + virtual void visit(sequence_parser & p) = 0; + virtual void visit(choice_parser & p) = 0; + virtual void visit(one_or_more_parser & p) = 0; + virtual void visit(zero_or_more_parser & p) = 0; + virtual void visit(optional_parser & p) = 0; + virtual void visit(repetition_parser & p) = 0; + virtual void visit(until_parser & p) = 0; + virtual void visit(not_parser & p) = 0; + virtual void visit(any_parser & p) = 0; + virtual void visit(space_parser & p) = 0; + virtual void visit(chars_parser & p) = 0; + virtual void visit(json_string_parser & p) = 0; + virtual void visit(schema_parser & p) = 0; + virtual void visit(rule_parser & p) = 0; + virtual void visit(root_parser & p) = 0; + virtual void visit(action_parser & p) = 0; +}; + +class gbnf_visitor : public parser_visitor { + const common_grammar_builder & builder_; + std::unordered_map rule_name_mapping_; + std::string current_result_; + + public: + gbnf_visitor(const common_grammar_builder & builder) : builder_(builder) {} + + const std::string& result() const { return current_result_; } + + private: + // Escape special characters for GBNF literals + static std::string escape_literal(const std::string & s) { + std::string escaped; + for (char c : s) { + switch (c) { + case '\n': escaped += "\\n"; break; + case '\t': escaped += "\\t"; break; + case '\r': escaped += "\\r"; break; + case '\\': escaped += "\\\\"; break; + case '"': escaped += "\\\""; break; + default: escaped += c; break; + } + } + return escaped; + } + + // Escape a single character for use in character classes + static std::string escape_char_class(char c) { + switch (c) { + case '\n': return "\\n"; + case '\t': return "\\t"; + case '\r': return "\\r"; + case '\\': return "\\\\"; + case ']': return "\\]"; + case '-': return "\\-"; + case '^': return "\\^"; + default: return std::string(1, c); + } + } + + // Generate pattern for until() that matches prefixes but prevents full delimiter match + // For "" generates: ( [^<] | "<" [^/] | " alternatives; + + // First alternative: match any character that's not the start of the delimiter + alternatives.push_back("[^" + escape_char_class(delimiter[0]) + "]"); + + // For each prefix, match the prefix followed by a char that's not the next delimiter char + for (size_t i = 1; i < delimiter.length(); ++i) { + std::string prefix = "\"" + escape_literal(delimiter.substr(0, i)) + "\""; + std::string next_char_negated = "[^" + escape_char_class(delimiter[i]) + "]"; + alternatives.push_back(prefix + " " + next_char_negated); + } + + // Combine alternatives with | + std::string result = "("; + for (size_t i = 0; i < alternatives.size(); ++i) { + if (i > 0) { + result += " | "; + } + result += alternatives[i]; + } + result += ")"; + + return result; + } + + // Check if expression needs parentheses + static bool needs_parens(parser_type type) { + return type == PARSER_CHOICE || type == PARSER_SEQUENCE; + } + + public: + void visit(literal_parser & p) override { + current_result_ = "\"" + escape_literal(p.literal()) + "\""; + } + + void visit(sequence_parser & p) override { + std::string s; + for (const auto & child : p.parsers()) { + if (!s.empty()) { + s += " "; + } + child->accept(*this); + + // Parenthesize choices + if (needs_parens(child->type())) { + s += "(" + current_result_ + ")"; + } else { + s += current_result_; + } + } + current_result_ = s; + } + + void visit(choice_parser & p) override { + std::string s; + for (const auto & child : p.parsers()) { + if (!s.empty()) { + s += " | "; + } + + child->accept(*this); + + // Parenthesize choices + if (child->type() == PARSER_CHOICE) { + s += "(" + current_result_ + ")"; + } else { + s += current_result_; + } + } + current_result_ = s; + } + + void visit(one_or_more_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")+"; + } else { + current_result_ = current_result_ + "+"; + } + } + + void visit(zero_or_more_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")*"; + } else { + current_result_ = current_result_ + "*"; + } + } + + void visit(optional_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")?"; + } else { + current_result_ = current_result_ + "?"; + } + } + + void visit(repetition_parser & p) override { + p.child()->accept(*this); + std::string child_result = current_result_; + + if (needs_parens(p.child()->type())) { + child_result = "(" + child_result + ")"; + } + + if (p.max_count() == -1) { + // Unbounded: {n,} + current_result_ = child_result + "{" + std::to_string(p.min_count()) + ",}"; + } else { + // Bounded: {n,m} + current_result_ = child_result + "{" + std::to_string(p.min_count()) + "," + + std::to_string(p.max_count()) + "}"; + } + } + + void visit(until_parser & p) override { + // Generate pattern that matches prefixes but prevents full delimiter match + current_result_ = generate_until_pattern(p.delimiter()) + "*"; + } + + void visit(not_parser &) override { + // NOT is tricky in GBNF - for now, emit error + LOG_ERR("NOT operator not directly supported in GBNF generation\n"); + current_result_ = ""; + } + + void visit(any_parser &) override { + // Match any single character + current_result_ = "."; + } + + void visit(space_parser &) override { + // Reference the built-in space rule + current_result_ = "space"; + } + + void visit(chars_parser & p) override { + const std::string & pattern = p.pattern(); + + if (p.min_count() == 0 && p.max_count() == -1) { + // Zero or more: * + current_result_ = pattern + "*"; + } else if (p.min_count() == 1 && p.max_count() == -1) { + // One or more: + + current_result_ = pattern + "+"; + } else if (p.max_count() == -1) { + // Unbounded: {n,} + current_result_ = pattern + "{" + std::to_string(p.min_count()) + ",}"; + } else if (p.min_count() == p.max_count()) { + // Exact count: {n} or just pattern for n=1 + if (p.min_count() == 1) { + current_result_ = pattern; + } else { + current_result_ = pattern + "{" + std::to_string(p.min_count()) + "}"; + } + } else { + // Bounded: {n,m} + current_result_ = pattern + "{" + std::to_string(p.min_count()) + "," + + std::to_string(p.max_count()) + "}"; + } + } + + void visit(json_string_parser &) override { + // JSON string content (without quotes) + // Pattern: (any non-quote/backslash OR escape sequences)* until closing quote + current_result_ = R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; + } + + void visit(schema_parser & p) override { + current_result_ = builder_.add_schema(p.name(), p.schema()); + } + + void visit(rule_parser & p) override { + // Return canonical rule reference + auto it = rule_name_mapping_.find(p.name()); + if (it != rule_name_mapping_.end()) { + current_result_ = it->second; + } else { + // Fallback to original name if not in mapping (shouldn't happen in valid usage) + current_result_ = p.name(); + } + } + + void visit(root_parser & p) override { + // Generate named rules first + auto rules = p.rules(); + if (rules) { + for (const auto & [name, rule] : *rules) { + rule->accept(*this); + auto rule_body = current_result_; + auto canonical_name = builder_.add_rule(name, rule_body); + rule_name_mapping_[name] = canonical_name; + } + } + + // Return root body for composition + p.root()->accept(*this); + } + + void visit(action_parser & p) override { + // Actions are transparent for grammar generation - just visit child + p.child()->accept(*this); + } +}; + +// Implement accept() methods for all parser classes +void literal_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void sequence_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void choice_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void one_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void zero_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void optional_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void repetition_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void until_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void not_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void any_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void space_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void chars_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void json_string_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void schema_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void rule_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void root_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void action_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } + +parser_result parse_cache::set(int id, size_t start, parser_result result) { + if (id == -1) { + // Don't cache parsers with ID -1 (from operators and global factory functions) + return result; + } + results[parse_cache_key{id, start}] = result; + return result; +} + +std::optional parse_cache::get(int id, size_t start) { + if (id == -1) { + // Don't cache parsers with ID -1 (from operators and global factory functions) + return std::nullopt; + } + auto it = results.find(parse_cache_key{id, start}); + if (it != results.end()) { + return it->second; + } + return std::nullopt; +} + +void parse_cache::clear() { + results.clear(); +} + +parser::parser() {} + +parser::parser(std::shared_ptr parser) : ptr_(std::move(parser)) {} + +parser::parser(const std::string & literal) : ptr_(std::make_shared(literal, -1)) {} + +parser::parser(const char * literal) : ptr_(std::make_shared(literal, -1)) {} + +parser parser::operator~() const { + return parser(std::make_shared(*this, -1)); +} + +parser parser::operator+(const parser & other) const { + return parser(std::make_shared(std::initializer_list{*this, other}, -1)); +} + +parser parser::operator|(const parser & other) const { + return parser(std::make_shared(std::initializer_list{*this, other}, -1)); +} + +parser parser::operator<<(const parser & other) const { + auto ws = parser(std::make_shared(-1)); + return parser(std::make_shared(std::initializer_list{*this, ws, other}, -1)); +} + +parser operator+(const char * lhs, const parser & rhs) { return parser(lhs) + rhs; } +parser operator|(const char * lhs, const parser & rhs) { return parser(lhs) | rhs; } +parser operator<<(const char * lhs, const parser & rhs) { return parser(lhs) << rhs; } + +parser_base & parser::operator*() const { + return *ptr_; +} + +parser_base * parser::operator->() const { + return ptr_.get(); +} + +parser_result parser::parse(parser_context & ctx, size_t start) const { + return ptr_->parse(ctx, start); +} + +std::string parser::dump() const { + return ptr_->dump(); +} + +void parser::build_grammar(const common_grammar_builder & builder) const { + gbnf_visitor visitor(builder); + ptr_->accept(visitor); + auto result = visitor.result(); + if (!result.empty()) { + builder.add_rule("root", result); + } +} + +parser_builder::parser_builder() + : rules_(std::make_shared>()) + , counter_(std::make_shared(0)) {} + +parser_builder::parser_builder(std::shared_ptr counter) + : rules_(std::make_shared>()) + , counter_(std::move(counter)) {} + +parser parser_builder::literal(const std::string & literal) { + return parser(std::make_shared(literal, counter_->next())); +} + +parser parser_builder::sequence(std::initializer_list parsers) { + return parser(std::make_shared(parsers, counter_->next())); +} + +parser parser_builder::choice(std::initializer_list parsers) { + return parser(std::make_shared(parsers, counter_->next())); +} + +parser parser_builder::one_or_more(const parser & p) { + return parser(std::make_shared(p, counter_->next())); +} + +parser parser_builder::zero_or_more(const parser & p) { + return parser(std::make_shared(p, counter_->next())); +} + +parser parser_builder::optional(const parser & p) { + return parser(std::make_shared(p, counter_->next())); +} + +parser parser_builder::negate(const parser & p) { + return parser(std::make_shared(p, counter_->next())); +} + +parser parser_builder::any() { + return parser(std::make_shared(counter_->next())); +} + +parser parser_builder::chars(const std::string & classes, int min, int max) { + return parser(std::make_shared(classes, min, max, counter_->next())); +} + +parser parser_builder::one(const std::string & classes) { + return chars(classes, 1, 1); +} + +parser parser_builder::json_string() { + return parser(std::make_shared(counter_->next())); +} + +parser parser_builder::rule(const std::string & name) { + return parser(std::make_shared(name, rules_, counter_->next())); +} + +parser parser_builder::space() { + return parser(std::make_shared(counter_->next())); +} + +parser parser_builder::until(const std::string & delimiter, bool consume_spaces) { + return parser(std::make_shared(delimiter, consume_spaces, counter_->next())); +} + +parser parser_builder::repeat(const parser & p, int min, int max) { + return parser(std::make_shared(p, min, max, counter_->next())); +} + +parser parser_builder::repeat(const parser & p, int n) { + return repeat(p, n, n); +} + +parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { + return parser(std::make_shared(p, name, schema, counter_->next())); +} + +parser parser_builder::action(const parser & p, std::function fn) { + return parser(std::make_shared(p, std::move(fn), counter_->next())); +} + +parser parser_builder::append_reasoning(const parser & p) { + return action(p, [](const parser_result &, std::string_view matched, parser_environment & env) { + if (!env.reasoning_content.empty()) { + env.reasoning_content += "\n"; + } + env.reasoning_content += matched; + }); +} + +parser parser_builder::append_content(const parser & p) { + return action(p, [](const parser_result &, std::string_view matched, parser_environment & env) { + if (!env.content.empty()) { + env.content += "\n"; + } + env.content += matched; + }); +} + +parser parser_builder::capture(const parser & p, const std::string & key, bool unescape_json) { + return action(p, [key, unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + std::string value = unescape_json ? unescape_json_string(matched) : std::string(matched); + env.scratchpad[key] = std::move(value); + }); +} + +parser parser_builder::capture_tool_call_id(const parser & p, bool unescape_json) { + return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + env.tool_call_id = unescape_json ? unescape_json_string(matched) : std::string(matched); + }); +} + +parser parser_builder::capture_tool_call_name(const parser & p, bool unescape_json) { + return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + env.tool_call_name = unescape_json ? unescape_json_string(matched) : std::string(matched); + }); +} + +parser parser_builder::capture_tool_call_args(const parser & p, bool unescape_json) { + return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + env.tool_call_args = unescape_json ? unescape_json_string(matched) : std::string(matched); + }); +} + +parser parser_builder::add_tool_call(const parser & p) { + return action(p, [](const parser_result &, std::string_view, parser_environment & env) { + auto tool_call = common_chat_tool_call{ + env.tool_call_name, + env.tool_call_args, + env.tool_call_id + }; + env.tool_calls.push_back(tool_call); + + // Clear the fields to prevent bleeding to next tool call + env.tool_call_id.clear(); + env.tool_call_name.clear(); + env.tool_call_args.clear(); + }); +} + +parser parser_builder::json_key(const std::string & name, const parser & p) { + return literal("\"" + name + "\"") << literal(":") << p; +} + +parser parser_builder::json_string(const parser & p) { + auto quote = literal("\""); + return quote + p + quote; +} + +parser parser_builder::add_rule(const std::string & name, const parser & p) { + (*rules_)[name] = p; + return rule(name); +} + +void parser_builder::assign_ids(parser & p) { + if (p.ptr()) { + p.ptr()->assign_id(counter_); + } +} + +parser build_parser(const std::function & fn) { + parser_builder builder; + auto root = fn(builder); + builder.assign_ids(root); // Assign IDs to rules that were created with operators + + // Wrap the root parser in a root_parser to own the rules and break circular references + auto rules = builder.rules(); + if (rules && !rules->empty()) { + return parser(std::make_shared(root, rules, -1)); + } + return root; +} + +static parser json_parser(std::shared_ptr counter) { + parser_builder builder(std::move(counter)); + + // Whitespace: space, tab, newline, carriage return + auto ws = builder.space(); + + // Number components + auto digit1_9 = builder.chars("[1-9]", 1, 1); + auto digits = builder.chars("[0-9]"); + + // Integer part: 0 or non-zero digit followed by more digits + auto int_part = builder.literal("0") | (digit1_9 + builder.chars("[0-9]", 0, -1)); + + // Optional fractional part + auto frac = builder.literal(".") + digits; + + // Optional exponent part + auto exp = (builder.literal("e") | builder.literal("E")) + builder.optional(builder.chars("[+\\-]", 1, 1)) + digits; + + // Complete number + auto number = builder.optional(builder.literal("-")) + int_part + builder.optional(frac) + builder.optional(exp); + + builder.add_rule("json_number", number); + + // String: specialized single-pass parser (content only, wrapped with quotes) + auto string = builder.literal("\"") + builder.json_string() + builder.literal("\""); + + builder.add_rule("json_string", string); + + // Literals + auto true_lit = builder.literal("true"); + auto false_lit = builder.literal("false"); + auto null_lit = builder.literal("null"); + + // Object: { "key": value, ... } + auto member = builder.rule("json_string") + ws + builder.literal(":") + ws + builder.rule("json_value"); + auto members = member + builder.zero_or_more(ws + builder.literal(",") + ws + member); + + // Empty object or object with members + auto object = (builder.literal("{") + ws + builder.literal("}")) | + (builder.literal("{") + ws + members + ws + builder.literal("}")); + + builder.add_rule("json_object", object); + + // Array: [ value, ... ] + auto elements = builder.rule("json_value") + builder.zero_or_more(ws + builder.literal(",") + ws + builder.rule("json_value")); + + // Empty array or array with elements + auto array = (builder.literal("[") + ws + builder.literal("]")) | + (builder.literal("[") + ws + elements + ws + builder.literal("]")); + + builder.add_rule("json_array", array); + + // Value - uses forward references for recursive structures + auto root = builder.add_rule("json_value", + builder.rule("json_object") | + builder.rule("json_array") | + builder.rule("json_string") | + builder.rule("json_number") | + true_lit | + false_lit | + null_lit + ); + + // Wrap in root_parser to own the rules + return parser(std::make_shared(root, builder.rules(), -1)); +} + +parser parser_builder::json() { + return json_parser(counter_); +} diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h new file mode 100644 index 0000000000000..98ca83f7bfb02 --- /dev/null +++ b/common/chat-parser-combinator.h @@ -0,0 +1,283 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +struct common_chat_tool_call; +struct common_grammar_builder; + +struct parser_environment { + std::string content; + std::string reasoning_content; + std::vector tool_calls; + + // Tool call fields for building tool calls + std::string tool_call_id; + std::string tool_call_name; + std::string tool_call_args; + + // Scratch pad for any custom logic + std::unordered_map> scratchpad; +}; + +enum parser_result_type { + PARSER_RESULT_FAIL = 0, + PARSER_RESULT_NEED_MORE_INPUT = 1, + PARSER_RESULT_SUCCESS = 2, +}; + +struct parse_cache_key { + int id; + size_t start; + + bool operator==(const parse_cache_key & other) const { + return id == other.id && start == other.start; + } +}; + +template <> +struct std::hash { + std::size_t operator()(const parse_cache_key & k) const { + return std::hash{}(((size_t)k.id << 32) | k.start); + } +}; + +struct parser_result { + parser_result_type type = PARSER_RESULT_FAIL; + size_t start = 0; + size_t end = 0; + + parser_result() : type(PARSER_RESULT_FAIL) {} + + parser_result(parser_result_type type, size_t start) + : type(type), start(start), end(start) {} + + parser_result(parser_result_type type, size_t start, size_t end) + : type(type), start(start), end(end) {} + + bool is_fail() const { return type == PARSER_RESULT_FAIL; } + bool is_need_more_input() const { return type == PARSER_RESULT_NEED_MORE_INPUT; } + bool is_success() const { return type == PARSER_RESULT_SUCCESS; } +}; + +class parse_cache { + std::unordered_map results; + + public: + parser_result set(int id, size_t start, parser_result result); + std::optional get(int id, size_t start); + void clear(); +}; + +struct parser_context { + std::string_view input; + parse_cache memo; + bool input_is_complete; + parser_environment * env; + + parser_context() + : memo(), input_is_complete(true), env(nullptr) {} + + parser_context(std::string_view input) + : input(input), memo(), input_is_complete(true), env(nullptr) {} + + parser_context(std::string_view input, bool complete) + : input(input), memo(), input_is_complete(complete), env(nullptr) {} + + parser_context(std::string_view input, parse_cache memo, bool complete = true) + : input(input), memo(std::move(memo)), input_is_complete(complete), env(nullptr) {} + + parser_context(std::string_view input, parser_environment * environment) + : input(input), memo(), input_is_complete(true), env(environment) {} + + parser_context(std::string_view input, parser_environment * environment, bool complete) + : input(input), memo(), input_is_complete(complete), env(environment) {} + + parser_context(std::string_view input, parse_cache memo, parser_environment * environment, bool complete = true) + : input(input), memo(std::move(memo)), input_is_complete(complete), env(environment) {} +}; + +class parser_base; + +class parser { + std::shared_ptr ptr_; + + public: + parser(); + parser(std::shared_ptr parser); + parser(const parser & other) = default; + parser(const std::string & literal); + parser(const char * literal); + + parser & operator=(const parser & other) { + if (this != &other) { + ptr_ = other.ptr_; + } + return *this; + } + + parser operator~() const; + parser operator+(const parser & other) const; + parser operator|(const parser & other) const; + parser operator<<(const parser & other) const; + + parser_base & operator*() const; + parser_base * operator->() const; + + std::shared_ptr ptr() const { return ptr_; } + + parser_result parse(parser_context & ctx, size_t start = 0) const; + + std::string dump() const; + + void build_grammar(const common_grammar_builder & builder) const; +}; + +parser operator+(const char * lhs, const parser & rhs); +parser operator|(const char * lhs, const parser & rhs); +parser operator<<(const char * lhs, const parser & rhs); + +class parser_id_counter { + int next_id_; + public: + parser_id_counter(int start) : next_id_(start) {} + int next() { return next_id_++; } +}; + +class parser_builder { + std::shared_ptr> rules_; + std::shared_ptr counter_; + + public: + parser_builder(); + parser_builder(std::shared_ptr counter); + + // Matches an exact literal string. + // S -> "hello" + parser literal(const std::string & literal); + + // Matches a sequence of parsers in order, all must succeed. + // S -> A B C + parser sequence(std::initializer_list parsers); + + // Matches the first parser that succeeds from a list of alternatives. + // S -> A | B | C + parser choice(std::initializer_list parsers); + + // Matches one or more repetitions of a parser. + // S -> A+ + parser one_or_more(const parser & p); + + // Matches zero or more repetitions of a parser, always succeeds. + // S -> A* + parser zero_or_more(const parser & p); + + // Matches zero or one occurrence of a parser, always succeeds. + // S -> A? + parser optional(const parser & p); + + // Negative lookahead: succeeds if child parser fails, consumes no input. + // S -> !A + parser negate(const parser & p); + + // Matches any single character. + // S -> . + parser any(); + + // Matches between min and max repetitions of characters from a character class. + // S -> [a-z]{m,n} + // + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + parser chars(const std::string & classes, int min = 1, int max = -1); + + // Matches a single character from a character class or range. + // S -> [a-z] or S -> [^0-9] + // + // Equivalent to chars(classes, 1, 1) + parser one(const std::string & classes); + + // References a named rule for recursive or reusable grammar definitions. + // expr -> term | expr "+" term + parser rule(const std::string & name); + + // Matches zero or more whitespace characters (space, tab, newline). + // S -> [ \t\n]* + parser space(); + + // Matches all characters until a delimiter is found (delimiter not consumed). + // S -> (!delim .)* + parser until(const std::string & delimiter, bool consume_spaces = true); + + // Matches between min and max repetitions of a parser (inclusive). + // S -> A{m,n} + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + parser repeat(const parser & p, int min, int max); + + // Matches exactly n repetitions of a parser. + // S -> A{n} + parser repeat(const parser & p, int n); + + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. + // value -> object | array | string | number | true | false | null + parser json(); + + // Specialized single-pass JSON string parser with escape sequence handling + parser json_string(); + + // TODO: improve convenience functions to allow users to build specific JSON fields + parser json_key(const std::string & name, const parser & p); + parser json_string(const parser & p); + + // Wraps a parser with JSON schema metadata for grammar generation. + // Used internally to convert JSON schemas to GBNF grammar rules. + parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); + + // Wraps a parser with a semantic action callback. + // The callback is invoked on successful parse with the result, matched text, and environment. + // S -> A [action] + parser action(const parser & p, std::function fn); + + // Convenience action wrappers for common patterns + + // Appends matched text to env.reasoning_content + parser append_reasoning(const parser & p); + + // Appends matched text to env.content + parser append_content(const parser & p); + + // Captures matched text to env.scratchpad[key] + // If unescape_json is true, the matched text is unescaped as a JSON string + parser capture(const parser & p, const std::string & key, bool unescape_json = false); + + // Captures matched text to env.tool_call_id + // If unescape_json is true, the matched text is unescaped as a JSON string + parser capture_tool_call_id(const parser & p, bool unescape_json = false); + + // Captures matched text to env.tool_call_name + // If unescape_json is true, the matched text is unescaped as a JSON string + parser capture_tool_call_name(const parser & p, bool unescape_json = false); + + // Captures matched text to env.tool_call_args + // If unescape_json is true, the matched text is unescaped as a JSON string + parser capture_tool_call_args(const parser & p, bool unescape_json = false); + + // Adds a tool call to env.tool_calls using env.tool_call_{id,name,args} + // Clears the tool call fields after adding + parser add_tool_call(const parser & p); + + parser add_rule(const std::string & name, const parser & p); + + void assign_ids(parser & p); + + std::shared_ptr> rules() const { return rules_; } +}; + +parser build_parser(const std::function & fn); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f4ce..90badf62af667 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -180,6 +180,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) endif() llama_build_and_test(test-chat-parser.cpp) +llama_build_and_test(test-chat-parser-combinator.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp new file mode 100644 index 0000000000000..0bfa0b45b6364 --- /dev/null +++ b/tests/test-chat-parser-combinator.cpp @@ -0,0 +1,1021 @@ +#include +#include +#include + +#include "nlohmann/json.hpp" + +#include "chat.h" +#include "chat-parser.h" +#include "chat-parser-combinator.h" +#include "common.h" +#include "json-schema-to-grammar.h" + +template +static void assert_equals(const std::string_view label, const T & expected, const T & actual) { + if (expected != actual) { + std::cerr << label << "\n"; + std::cerr << "Expected: " << expected << "\n"; + std::cerr << "Actual: " << actual << "\n"; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +template +static void assert_equals(const T & expected, const T & actual) { + assert_equals("", expected, actual); +} + +static void assert_equals(const char * expected, const std::string & actual) { + assert_equals(expected, actual); +} + +static void test_partial_parsing() { + { + // Test literal + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context("hello"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + } + { + // Test char class + auto parser = build_parser([](parser_builder& p) { + return p.one("a-z"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context("a"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("A"); + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + + parser = build_parser([](parser_builder& p) { + return p.one("a-z-"); + }); + + ctx = parser_context("f"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("-"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("A"); + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test sequences and literals + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + p.literal(""); + }); + + // Partial matches + auto ctx = parser_context("", false); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("", true); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match, since it does not adhere to the grammar + ctx = parser_context("I am parser", false); + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test choices + auto parser = build_parser([](parser_builder& p) { + return p.literal("") | p.literal(""); + }); + + // Partial matches + auto ctx = parser_context("", true); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("", true); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match + ctx = parser_context("", true); + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test zero_or_more + auto parser = build_parser([](parser_builder& p) { + return p.zero_or_more(p.literal("ab")); + }); + + // Partial matches + auto ctx = parser_context("a", false); + auto result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + ctx = parser_context("aba", false); + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + // Full match + ctx = parser_context("ab", true); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + } + { + // Test one_or_more + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.literal("ab")); + }); + + // Partial matches + auto ctx = parser_context("a", false); + auto result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + ctx = parser_context("aba", false); + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + // Full match + ctx = parser_context("ab", true); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match + ctx = parser_context("cd", true); + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } +} + +static void test_one() { + { + // Test common escape sequences + auto parser = build_parser([](parser_builder& p) { + return p.one("[\\n\\t\\\\]"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context("\n"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("\t"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("\\"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context(" "); + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test escaped dash (literal dash, not a range) + auto parser = build_parser([](parser_builder& p) { + return p.one("[a\\-z]"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context("a"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("-"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context("z"); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Should NOT match 'b' since \- is a literal dash, not a range + ctx = parser_context("b"); + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } +} + +static void test_recursive_references() { + auto value_parser = build_parser([](parser_builder& p) { + p.add_rule("number", p.one_or_more(p.one("0-9"))); + p.add_rule("list", p.sequence({ + p.literal("["), + p.rule("value"), + p.literal("]") + })); + return p.add_rule("value", p.rule("number") | p.rule("list")); + }); + + parser_context ctx; + parser_result result; + + // Test simple number + ctx = parser_context("1", true); + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test simple list + ctx = parser_context("[1]", true); + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test nested list + ctx = parser_context("[[2]]", true); + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test deeply nested list + ctx = parser_context("[[[3]]]", true); + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test partial match + ctx = parser_context("[[", false); + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test no match + ctx = parser_context("[a]", true); + result = value_parser.parse(ctx); + assert_equals(true, result.is_fail()); +} + +static void test_optional() { + // Test optional with a match + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + // Full match with optional part present + auto ctx = parser_context("hello world"); + auto result = parser.parse(ctx); + assert_equals(true, result.is_success()); + assert_equals((size_t)11, result.end); + + // Full match with optional part absent + ctx = parser_context("hello", true); + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + assert_equals((size_t)5, result.end); + + // Partial match - waiting for more input to determine if optional matches + ctx = parser_context("hello ", false); + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); +} + +static void test_json_parser() { + auto json = build_parser([](parser_builder & p) { + return p.json(); + }); + + { + // Test parsing a simple JSON object + std::string input = R"({"name": "test", "value": 42, "flag": true})"; + parser_context ctx(input); + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test parsing a JSON array with mixed types + std::string input = R"([1, "hello", true, null, 3.14])"; + parser_context ctx(input); + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test parsing nested JSON with objects and arrays + std::string input = R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})"; + parser_context ctx(input); + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test partial parsing - incomplete object + std::string input = R"({"name": "test", "value": )"; + parser_context ctx(input, false); + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } + { + // Test partial parsing - incomplete array + std::string input = R"([1, 2, 3, )"; + parser_context ctx(input, false); + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } + { + // Test partial parsing - incomplete nested structure + std::string input = R"({"data": {"nested": )"; + parser_context ctx(input, false); + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } +} + +static void test_complete_example() { + // Parser for a fictitious model that outputs: + // + // + // ... reasoning content ... + // + // ... content ... + // + // tool_name + // { ... json args ... } + // + // + auto parser = build_parser([](parser_builder & p) { + auto reasoning = p.add_rule("reasoning", + "" << p.append_reasoning(p.until("")) << ""); + + auto content = p.add_rule("content", + p.append_content(p.until(""))); + + auto json = p.json(); + + auto tool_call_name = p.add_rule("tool-call-name", + "" << p.capture_tool_call_name(p.until("")) << ""); + + auto schema = nlohmann::ordered_json::parse(R"({"type": "object"})"); + + auto tool_call_args = p.add_rule("tool-call-args", + "" << p.capture_tool_call_args(p.schema(json, "get_weather", schema)) << ""); + + auto tool_call = p.add_rule("tool-call", + "" << p.add_tool_call(tool_call_name << tool_call_args) << ""); + + return reasoning << p.optional(content) << p.optional(tool_call); + }); + + // Test complete input + { + std::string input = R"(I need to call get_weather with city = New Yorkget_weather{"city": "New York"})"; + parser_environment env; + parser_context ctx(input, &env); + + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + assert_equals("I need to call get_weather with city = New York", env.reasoning_content); + assert_equals((size_t)1, env.tool_calls.size()); + assert_equals("", env.tool_calls[0].id); + assert_equals("get_weather", env.tool_calls[0].name); + assert_equals(R"({"city": "New York"})", env.tool_calls[0].arguments); + } + + // Test partial input + { + std::string input = R"(I need to call get_weather )"; + parser_environment env = parser_environment(); + parser_context ctx = parser_context(input, &env, /* .is_input_complete = */ false); + + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("I need to call get_weather", env.reasoning_content); + } + { + std::string input = R"(I need to call I need to call get_weatherI need to call get_weatherget_weather)"; + parser_environment env = parser_environment(); + parser_context ctx = parser_context(input, &env, /* .is_input_complete = */ false); + + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("I need to call get_weather", env.reasoning_content); + } + { + std::string input = R"(I need to call get_weatherget_weatherI need to call get_weatherget_weather{"cit)"; + parser_environment env = parser_environment(); + parser_context ctx = parser_context(input, &env, /* .is_input_complete = */ false); + + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("I need to call get_weather", env.reasoning_content); + assert_equals("get_weather", env.tool_calls[0].name); + assert_equals(R"({"cit)", env.tool_calls[0].arguments); + } + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + std::cout << "Grammar:\n" << gbnf << "\n"; +} + +static void test_actions() { + { + // Test simple action - append matched text to content + auto parser = build_parser([](parser_builder& p) { + auto word = p.chars("[a-z]+"); + return p.action(word, [](const parser_result &, std::string_view matched, parser_environment & env) { + env.content += std::string(matched); + }); + }); + + parser_environment env; + parser_context ctx("hello", &env); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello", env.content); + } + { + // Test multiple sequential actions - build a sentence + auto parser = build_parser([](parser_builder& p) { + auto greeting = p.action(p.literal("hello"), [](const parser_result &, std::string_view matched, parser_environment & env) { + env.content += std::string(matched) + " "; + }); + + auto name = p.action(p.chars("[A-Z][a-z]+"), [](const parser_result &, std::string_view matched, parser_environment & env) { + env.content += std::string(matched); + env.scratchpad["name"] = std::string(matched); + }); + + return greeting + p.literal(" ") + name; + }); + + parser_environment env; + parser_context ctx("hello Alice", &env); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello Alice", env.content); + assert_equals("Alice", std::get(env.scratchpad["name"])); + } + { + // Test using scratchpad for intermediate calculations + auto parser = build_parser([](parser_builder& p) { + auto digit = p.action(p.one("[0-9]"), [](const parser_result &, std::string_view matched, parser_environment & env) { + auto it = env.scratchpad.find("sum"); + int current_sum = it != env.scratchpad.end() ? std::get(it->second) : 0; + current_sum += (matched[0] - '0'); + env.scratchpad["sum"] = current_sum; + }); + + return p.one_or_more(digit + p.optional(p.literal("+"))); + }); + + parser_environment env; + parser_context ctx("1+2+3+4", &env); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(10, std::get(env.scratchpad["sum"])); // 1+2+3+4 = 10 + } + { + // Test actions don't run when parse fails + auto parser = build_parser([](parser_builder& p) { + return p.action(p.literal("success"), [](const parser_result &, std::string_view, parser_environment & env) { + env.content = "action_ran"; + }); + }); + + parser_environment env; + parser_context ctx("failure", &env); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_fail()); + assert_equals("", env.content); // Action should not have run + } + { + // Test Actions work with partial parsing + auto parser = build_parser([](parser_builder& p) { + auto content = p.action(p.until(""), [](const parser_result &, std::string_view matched, parser_environment & env) { + env.content += std::string(matched); + }); + return "" << content << ""; + }); + + { + parser_environment env; + parser_context ctx("hello ", &env, false); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello", env.content); + } + { + parser_environment env; + parser_context ctx("hello world", &env, false); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello world", env.content); + } + { + parser_environment env; + parser_context ctx("hello world", &env, true); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello world", env.content); + } + } +} + +static void test_gbnf_generation() { + { + // Test literal + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); + assert_equals(true, gbnf.find("space ::=") != std::string::npos); + } + { + // Test char class + auto parser = build_parser([](parser_builder& p) { + return p.one("[a-z]"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= [a-z]") != std::string::npos); + } + { + // Test sequence + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.literal(" ") + p.literal("world"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= \"hello\" \" \" \"world\"") != std::string::npos); + } + { + // Test choice + auto parser = build_parser([](parser_builder& p) { + return p.literal("cat") | p.literal("dog"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= \"cat\" | \"dog\"") != std::string::npos); + } + { + // Test one_or_more + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.one("[0-9]")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= [0-9]+") != std::string::npos); + } + { + // Test zero_or_more + auto parser = build_parser([](parser_builder& p) { + return p.zero_or_more(p.one("[a-z]")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= [a-z]*") != std::string::npos); + } + { + // Test optional + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= \"hello\" \" world\"?") != std::string::npos); + } + { + // Test until + auto parser = build_parser([](parser_builder& p) { + return p.until(""); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + // Should generate pattern that prevents matching the full delimiter + assert_equals(true, gbnf.find("root ::= ([^<] | \"<\" [^/] | \"])*") != std::string::npos); + } + { + // Test complex expression with parentheses + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.literal("a") | p.literal("b")); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= (\"a\" | \"b\")+") != std::string::npos); + } + { + // Test rule references + auto parser = build_parser([](parser_builder& p) { + auto digit = p.add_rule("digit", p.one("[0-9]")); + return p.one_or_more(digit); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + // Should have digit rule defined and referenced + assert_equals(true, gbnf.find("digit ::= [0-9]") != std::string::npos); + assert_equals(true, gbnf.find("root ::= digit+") != std::string::npos); + } + { + // Test escaping in literals + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello\nworld\t!"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_equals(true, gbnf.find("root ::= \"hello\\nworld\\t!\"") != std::string::npos); + } + { + // Test operator<< (whitespace insertion) + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") << p.literal("world"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + // Should inline the whitespace pattern + assert_equals(true, gbnf.find("\"hello\"") != std::string::npos); + assert_equals(true, gbnf.find("\"world\"") != std::string::npos); + } +} + +static parser create_command_r7b_parser() { + auto parser = build_parser([](parser_builder & p) { + auto thinking = p.add_rule("thinking", + "<|START_THINKING|>" << p.append_reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>"); + + auto response = p.add_rule("response", + "<|START_RESPONSE|>" << p.append_content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>"); + + auto json = p.add_rule("json", p.json()); + + auto tool_call_id = p.add_rule("tool-call-id", + p.json_key("tool_call_id", "\"" + p.capture_tool_call_id(p.json_string(), /* unescape_json = */ true) + "\"")); + + auto tool_call_name = p.add_rule("tool-name", + p.json_key("tool_name", "\"" + p.capture_tool_call_name(p.json_string(), /* unescape_json = */ true) + "\"")); + + auto tool_call_args = p.add_rule("tool-args", p.json_key("parameters", p.capture_tool_call_args(json))); + + auto tool_call_fields = p.add_rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); + + auto tool_call = p.add_rule("tool-call", + "{" << p.add_tool_call(tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields)) << "}"); + + auto tool_calls = p.add_rule("tool-calls", + "<|START_ACTION|>" + << ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]") + << "<|END_ACTION|>"); + + return p.optional(thinking) << p.add_rule("content", tool_calls | response); + }); + + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + std::cout << "=== Grammar ===\n\n" << grammar << "\n\n"; + + return parser; +} + +static void test_command_r7b_parser(const parser & p, const std::string & input, bool partial, bool print_results = false) { + parser_environment env; + parser_context ctx(input, &env, !partial); + p.parse(ctx); + + if (print_results) { + std::cout << "== Parsed (new) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << env.reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << env.content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : env.tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } +} + +static void test_command_r7b_legacy_parser(const std::string & input, bool partial, bool print_results = false) { + // Original parser taken from chat.cpp + common_chat_msg_parser builder(input, + /* is_partial= */ partial, { + /* .format = */ COMMON_CHAT_FORMAT_GENERIC, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + }); + + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } + + if (print_results) { + std::cout << "== Parsed (legacy) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << builder.result().reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << builder.result().content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : builder.result().tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } +} + +struct bench_tool_call { + std::string id; + std::string name; + nlohmann::ordered_json args; +}; + +// Simple tokenize function that splits by space +static std::vector simple_tokenize(const std::string & input) { + std::vector result; + std::string current; + + for (size_t i = 0; i < input.size(); i++) { + if (input[i] == ' ') { + if (!current.empty()) { + result.push_back(current); + current.clear(); + } + current += ' '; + } else { + current += input[i]; + } + } + + if (!current.empty()) { + result.push_back(current); + } + + return result; +} + +static void benchmark_compare( + const std::string & reasoning, + const std::string & content, + const std::vector & tool_calls, + int iterations) { + + // Build response + std::vector tokens; // Since we don't have a command r7b tokenizer, we're going to "simulate" them. + + if (!reasoning.empty()) { + auto tokenized = simple_tokenize(reasoning); + tokens.emplace_back("<|START_THINKING|>"); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + tokens.emplace_back("<|END_THINKING|>"); + } + + if (!content.empty()) { + auto tokenized = simple_tokenize(content); + tokens.emplace_back("<|START_RESPONSE|>"); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + tokens.emplace_back("<|END_RESPONSE|>"); + } + + if (!tool_calls.empty()) { + tokens.emplace_back("<|START_ACTION|>"); + + auto json = nlohmann::json::array(); + for (const auto & tc : tool_calls) { + auto tc_json = nlohmann::json::object(); + tc_json["tool_call_id"] = tc.id; + tc_json["tool_name"] = tc.name; + tc_json["parameters"] = tc.args; + json.push_back(tc_json); + } + + auto tokenized = simple_tokenize(json.dump(-1, ' ', true)); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + + tokens.emplace_back("<|END_ACTION|>"); + } + + auto run = [&](const std::function & fn) { + std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string()); + + std::chrono::microseconds duration(0); + for (int i = 0; i < iterations; i++) { + auto start = std::chrono::high_resolution_clock::now(); + fn(input, false, i == 0); + auto end = std::chrono::high_resolution_clock::now(); + duration += std::chrono::duration_cast(end - start); + } + return duration.count() / iterations; + }; + + auto parser = create_command_r7b_parser(); + + auto duration_new = run([&](const std::string & input, bool partial, bool print_content) { + test_command_r7b_parser(parser, input, partial, print_content); + }); + + auto duration_legacy = run([&](const std::string & input, bool partial, bool print_content) { + try { + test_command_r7b_legacy_parser(input, partial, print_content); + } catch (const common_chat_msg_partial_exception &) { } + }); + + std::cout << " New parser avg: " << duration_new << " us\n"; + std::cout << "Legacy parser avg: " << duration_legacy << " us\n"; +} + +int main() { + test_partial_parsing(); + test_one(); + test_recursive_references(); + test_optional(); + test_json_parser(); + test_complete_example(); + test_actions(); + test_gbnf_generation(); + std::cout << "All tests passed!\n"; + + std::cout << "\n== Benchmarks ==\n"; + std::string example_reasoning = + "To plan an effective trip to Japan that includes both historical sites and modern attractions within a budget of $4000 for a two-week stay, we need to:\n\n" + "1. Identify key historical sites and modern attractions in Japan.\n" + "2. Find affordable accommodation options that provide a balance between comfort and cost.\n" + "3. Determine the best modes of transportation for getting around Japan.\n" + "4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without overspending.\n" + "5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees to attractions."; + + std::string example_content = + "For a two-week trip to Japan with a $4,000 budget, I recommend planning an itinerary that balances historical sites with modern attractions. The destination will be Japan, with a duration of 14 days.\n\n" + "Given your interests in both historical sites and modern attractions, you'll want to focus on cities like Kyoto for its temples and traditional culture, Tokyo for its cutting-edge technology and entertainment districts, and possibly Hiroshima or Nara for additional historical significance.\n\n" + "For accommodation, I suggest looking for affordable options such as budget hotels, hostels, or guesthouses that offer good value without sacrificing too much comfort. Japan has excellent mid-range accommodation options that can keep your lodging costs manageable.\n\n" + "Transportation should prioritize efficiency—consider getting a JR Rail Pass for intercity travel, which allows unlimited rides on most JR trains including the Shinkansen (bullet train). Within cities, use local trains and subways, which are both affordable and highly reliable.\n\n" + "For meals, embrace local cuisine by eating at neighborhood restaurants, ramen shops, and izakayas rather than touristy establishments. This will give you an authentic experience while keeping costs reasonable—you can enjoy excellent meals for $10-20 per person at local spots.\n\n"; + + std::vector example_tool_calls = {{ + "call_0", + "plan_trip", + nlohmann::json::parse(R"({ + "destination": "Japan", + "duration": 14, + "budget": 4000, + "interests": ["historical sites", "modern attractions"], + "accommodation_preferences": "affordable", + "transportation_preferences": "efficient", + "meal_preferences": "local cuisine" + })") + }}; + + std::cout << "\nReasoning + Content:\n"; + benchmark_compare(example_reasoning, example_content, std::vector(), 100); + + std::cout << "\nReasoning + Tool Call:\n"; + benchmark_compare(example_reasoning, "", example_tool_calls, 100); + return 0; +}