From c822e73c4d5484a03ea0c367ded40ab3f832a56d Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sun, 9 Nov 2025 20:07:12 -0600 Subject: [PATCH 01/34] common : implement parser combinators to simplify chat parsing --- common/CMakeLists.txt | 2 + common/chat-parser-combinator.cpp | 819 ++++++++++++++++++++++++++ common/chat-parser-combinator.h | 158 +++++ tests/CMakeLists.txt | 1 + tests/test-chat-parser-combinator.cpp | 472 +++++++++++++++ 5 files changed, 1452 insertions(+) create mode 100644 common/chat-parser-combinator.cpp create mode 100644 common/chat-parser-combinator.h create mode 100644 tests/test-chat-parser-combinator.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 7086d08e5e5e9..7bdc9aab5995f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -48,6 +48,8 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-parser-combinator.cpp + chat-parser-combinator.h chat-parser.cpp chat-parser.h chat.cpp diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp new file mode 100644 index 0000000000000..f2182980b3bac --- /dev/null +++ b/common/chat-parser-combinator.cpp @@ -0,0 +1,819 @@ +#include "chat-parser-combinator.h" +#include "common.h" +#include "log.h" + +#include +#include + +class parser_base { + protected: + int id_; + + void set_id(int id) { id_ = id; } + + public: + parser_base(int id) : id_(id) {} + + virtual parser_type type() const = 0; + virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; + virtual std::string dump() const = 0; + virtual void assign_ids_internal(int& next_id) { + if (id_ == -1) { + id_ = next_id++; + } + } +}; + +class literal_parser : public parser_base { + std::string literal_; + + public: + literal_parser(const std::string & literal, int id) : parser_base(id), literal_(literal) {} + + parser_type type() const override { return PARSER_LITERAL; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto pos = start; + for (auto i = 0u; i < literal_.size(); ++i) { + if (pos >= ctx.input.size()) { + if (ctx.input_is_complete) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + if (i > 0) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + return parser_result(PARSER_RESULT_FAIL, start); + } + if (ctx.input[pos] != literal_[i]) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + ++pos; + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos)); + } + + std::string dump() const override { + return "Literal(" + literal_ + ")"; + } +}; + +class sequence_parser : public parser_base { + std::vector parsers_; + + public: + sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { + for (const auto & p : parsers) { + if (p.is_sequence()) { + // Flatten sequences + for (const auto & embedded : p.to_sequence()->parsers()) { + parsers_.push_back(embedded); + } + } else { + parsers_.push_back(p); + } + } + } + + parser_type type() const override { return PARSER_SEQUENCE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + std::unordered_map groups; + + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); + + // Copy groups + groups.insert(result.groups.begin(), result.groups.end()); + + if (result.is_fail()) { + if (result.end >= ctx.input.size() && !ctx.input_is_complete) { + // If we fail because we don't have enough input, then return success + return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); + } + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start, result.end, groups)); + } + + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); + } + + pos = result.end; + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + } + + std::string dump() const override { + std::vector parts; + parts.reserve(parsers_.size()); + for (const auto & p : parsers_) { + parts.push_back(p->dump()); + } + return "Sequence(" + string_join(parts, ", ") + ")"; + } + + const std::vector & parsers() const { return parsers_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + for (auto & p : parsers_) { + p->assign_ids_internal(next_id); + } + } +}; + +class choice_parser : public parser_base { + std::vector parsers_; + + public: + choice_parser(std::initializer_list parsers, int id) : parser_base(id) { + for (const auto & p : parsers) { + if (p.is_choice()) { + // Flatten choices + for (const auto & embedded : p.to_choice()->parsers()) { + parsers_.push_back(embedded); + } + } else { + parsers_.push_back(p); + } + } + } + + parser_type type() const override { return PARSER_CHOICE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); + + if (result.is_success()) { + return ctx.memo.set(id_, start, result); + } + + if (result.is_need_more_input()) { + return result; + } + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + std::string dump() const override { + std::vector parts; + parts.reserve(parsers_.size()); + for (const auto & p : parsers_) { + parts.push_back(p->dump()); + } + return "Choice(" + string_join(parts, ", ") + ")"; + } + + const std::vector & parsers() const { return parsers_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + for (auto & p : parsers_) { + p->assign_ids_internal(next_id); + } + } +}; + +class one_or_more_parser : public parser_base { + parser parser_; + + public: + one_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_ONE_OR_MORE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + std::unordered_map groups; + + // We can't return back the cached result, since there may be more + // repetitions since the last parsing attempt. Instead, resume parsing from + // the last successful repetition found. + auto pos = start; + if (cached != std::nullopt) { + pos = cached->end; + groups.insert(cached->groups.begin(), cached->groups.end()); + } + + if (pos == start) { + auto first_result = parser_->parse(ctx, pos); + if (!first_result.is_success()) { + return first_result; + } + + pos = first_result.end; + groups.insert(first_result.groups.begin(), first_result.groups.end()); + } + + for (;;) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); + + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + } + + if (result.is_fail()) { + // Done with repetitions + break; + } + + if (result.end == pos) { + break; // Prevent an infinite loop + } + + pos = result.end; + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + } + + std::string dump() const override { + return "OneOrMore(" + parser_->dump() + ")"; + } + + const parser & child() const { return parser_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class zero_or_more_parser : public parser_base { + parser parser_; + + public: + zero_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_ZERO_OR_MORE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + std::unordered_map groups; + + // We can't return back the cached result, since there may be more + // repetitions since the last parsing attempt. Instead, resume parsing from + // the last successful repetition found. + auto pos = start; + if (cached != std::nullopt) { + pos = cached->end; + groups.insert(cached->groups.begin(), cached->groups.end()); + } + + for (;;) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); + + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + } + + if (result.is_fail()) { + // Done with repetitions (zero or more is always valid) + break; + } + + if (result.end == pos) { + break; // Prevent an infinite loop + } + + pos = result.end; + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + } + + std::string dump() const override { + return "ZeroOrMore(" + parser_->dump() + ")"; + } + + const parser & child() const { return parser_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class optional_parser : public parser_base { + parser parser_; + + public: + optional_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_OPTIONAL; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto result = parser_->parse(ctx, start); + + if (result.is_success()) { + // Matched successfully + return ctx.memo.set(id_, start, result); + } + + if (result.is_need_more_input()) { + // Propagate - need more input to determine if optional matches + return result; + } + + // No match, but optional always succeeds with zero matches + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start)); + } + + std::string dump() const override { + return "Optional(" + parser_->dump() + ")"; + } + + const parser & child() const { return parser_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class not_parser : public parser_base { + parser parser_; + + public: + not_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + + parser_type type() const override { return PARSER_NOT; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto result = parser_->parse(ctx, start); + + if (result.is_success()) { + // Fail if the underlying parser matches + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + if (result.is_need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } + + // Child failed, so negation succeeds + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start)); + } + + std::string dump() const override { + return "Not(" + parser_->dump() + ")"; + } + + const parser & child() const { return parser_; } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class any_parser : public parser_base { + public: + any_parser(int id) : parser_base(id) {} + + parser_type type() const override { return PARSER_ANY; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + return parser_result(PARSER_RESULT_FAIL, start); + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + } + + std::string dump() const override { + return "Any"; + } +}; + +class char_class_parser : public parser_base { + struct char_range { + int start; + int end; + + bool contains(char c) const { return (int)c >= start && int(c) <= end; } + }; + + std::string pattern_; + std::vector ranges_; + + public: + char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes) { + std::string content = classes; + if (content.front() == '[') { + content = content.substr(1); + } + + if (content.back() == ']') { + content.pop_back(); + } + + auto parse_char = [&](size_t pos) -> std::pair { + if (content[pos] == '\\' && pos + 1 < content.length()) { + char next = content[pos + 1]; + switch (next) { + case 'n': return {'\n', 2}; + case 't': return {'\t', 2}; + case 'r': return {'\r', 2}; + case '\\': return {'\\', 2}; + case ']': return {']', 2}; + case '-': return {'-', 2}; + case '[': return {'[', 2}; + default: return {next, 2}; // Treat as literal escaped character + } + } + return {content[pos], 1}; + }; + + size_t i = 0; + while (i < content.length()) { + auto [start, start_len] = parse_char(i); + i += start_len; + + if (i + 1 < content.length() && content[i] == '-') { + // Range detected + auto [end, end_len] = parse_char(i + 1); + ranges_.push_back(char_range{start, end}); + i += 1 + end_len; + } else { + ranges_.push_back(char_range{start, start}); + } + } + } + + parser_type type() const override { return PARSER_CHAR_CLASS; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + return parser_result(PARSER_RESULT_FAIL, start); + } + + for (const auto & range : ranges_) { + if (range.contains(ctx.input[start])) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + } + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + std::string dump() const override { + return "Char(" + pattern_ + ")"; + } +}; + +class group_parser : public parser_base { + std::string name_; + parser parser_; + + public: + group_parser(const std::string & name, const parser & parser, int id) : parser_base(id), name_(name), parser_(parser) {} + + parser_type type() const override { return PARSER_GROUP; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto result = parser_->parse(ctx, start); + + // Store result + result.groups[name_] = parser_match_location{result.start, result.end}; + return ctx.memo.set(id_, start, result); + } + + std::string dump() const override { + return "Group(" + name_ + ", " + parser_->dump() + ")"; + } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + +class rule_parser : public parser_base { + std::string rule_name_; + std::shared_ptr> rules_; + + public: + rule_parser(const std::string & name, std::shared_ptr> rules, int id) + : parser_base(id), rule_name_(name), rules_(std::move(rules)) {} + + parser_type type() const override { return PARSER_RULE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + if (!rules_) { + LOG_ERR("rule_parser::parse called without rule registry\n"); + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + auto it = rules_->find(rule_name_); + if (it == rules_->end()) { + LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", rule_name_.c_str()); + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + } + + auto result = it->second->parse(ctx, start); + return ctx.memo.set(id_, start, result); + } + + std::string dump() const override { + return "Rule(" + rule_name_ + ")"; + } +}; + +std::optional parser_result::group(const std::string & name, std::string_view input) const { + auto it = groups.find(name); + if (it == groups.end()) { + return std::nullopt; + } + + return std::string(it->second.view(input)); +} + +parser_result parse_cache::set(int id, size_t start, parser_result result) { + if (id == -1) { + // Don't cache parsers with ID -1 (from operators and global factory functions) + return result; + } + results[parse_cache_key{id, start}] = result; + return result; +} + +std::optional parse_cache::get(int id, size_t start) { + if (id == -1) { + // Don't cache parsers with ID -1 (from operators and global factory functions) + return std::nullopt; + } + auto it = results.find(parse_cache_key{id, start}); + if (it != results.end()) { + return it->second; + } + return std::nullopt; +} + +void parse_cache::clear() { + results.clear(); +} + +parser::parser() {} + +parser::parser(std::shared_ptr parser) : ptr(std::move(parser)) {} + +parser parser::operator~() const { + return parser(std::make_shared(*this, -1)); +} + +parser parser::operator+(const parser & other) const { + return parser(std::shared_ptr(new sequence_parser({*this, other}, -1))); +} + +parser parser::operator|(const parser & other) const { + return parser(std::shared_ptr(new choice_parser({*this, other}, -1))); +} + +parser_base & parser::operator*() const { + return *ptr; +} + +parser_base * parser::operator->() const { + return ptr.get(); +} + +bool parser::is_sequence() const { + return ptr->type() == PARSER_SEQUENCE; +} + +std::shared_ptr parser::to_sequence() const { + return std::dynamic_pointer_cast(ptr); +} + +bool parser::is_choice() const { + return ptr->type() == PARSER_CHOICE; +} + +std::shared_ptr parser::to_choice() const { + return std::dynamic_pointer_cast(ptr); +} + +parser_type parser::type() const { + return ptr->type(); +} + +parser_result parser::parse(parser_context & ctx, size_t start) const { + return ptr->parse(ctx, start); +} + +std::string parser::dump() const { + return ptr->dump(); +} + +parser_builder::parser_builder() + : rules_(std::make_shared>()) + , next_id_(0) {} + +parser parser_builder::literal(const std::string & literal) { + return parser(std::make_shared(literal, next_id_++)); +} + +parser parser_builder::sequence(std::initializer_list parsers) { + return parser(std::shared_ptr(new sequence_parser(parsers, next_id_++))); +} + +parser parser_builder::choice(std::initializer_list parsers) { + return parser(std::shared_ptr(new choice_parser(parsers, next_id_++))); +} + +parser parser_builder::one_or_more(const parser & p) { + return parser(std::make_shared(p, next_id_++)); +} + +parser parser_builder::zero_or_more(const parser & p) { + return parser(std::make_shared(p, next_id_++)); +} + +parser parser_builder::optional(const parser & p) { + return parser(std::make_shared(p, next_id_++)); +} + +parser parser_builder::negate(const parser & p) { + return parser(std::make_shared(p, next_id_++)); +} + +parser parser_builder::any() { + return parser(std::make_shared(next_id_++)); +} + +parser parser_builder::char_class(const std::string & classes) { + return parser(std::make_shared(classes, next_id_++)); +} + +parser parser_builder::group(const std::string & name, const parser & p) { + return parser(std::make_shared(name, p, next_id_++)); +} + +parser parser_builder::rule(const std::string & name) { + return parser(std::make_shared(name, rules_, next_id_++)); +} + +parser parser_builder::space() { + return zero_or_more(char_class("[ \\t\\n\\r]")); +} + +parser parser_builder::add_rule(const std::string & name, const parser & p) { + (*rules_)[name] = p; + return rule(name); +} + +void parser_builder::assign_ids(parser & p) { + if (p.ptr) { + p.ptr->assign_ids_internal(next_id_); + } +} + +parser parser_builder::add_json_rule(const std::string & name) { + // Whitespace: space, tab, newline, carriage return + auto ws = zero_or_more(char_class("[ \\t\\n\\r]")); + + // Number components + auto digit = char_class("[0-9]"); + auto digit1_9 = char_class("[1-9]"); + auto digits = one_or_more(digit); + + // Integer part: 0 or non-zero digit followed by more digits + auto int_part = literal("0") | (digit1_9 + zero_or_more(digit)); + + // Optional fractional part + auto frac = literal(".") + digits; + + // Optional exponent part + auto exp = (literal("e") | literal("E")) + optional(char_class("[+\\-]")) + digits; + + // Complete number + auto number = optional(literal("-")) + int_part + optional(frac) + optional(exp); + + add_rule("json_number", number); + + // String components + auto hex = char_class("[0-9a-fA-F]"); + auto unicode_escape = literal("\\u") + hex + hex + hex + hex; + auto simple_escape = literal("\\") + char_class("[\"\\\\bfnrt/]"); + auto escape = simple_escape | unicode_escape; + + // String character: escape sequence or any char except quote and backslash + auto string_char = escape | (~char_class("[\"\\\\]") + any()); + auto string = literal("\"") + zero_or_more(string_char) + literal("\""); + + add_rule("json_string", string); + + // Literals + auto true_lit = literal("true"); + auto false_lit = literal("false"); + auto null_lit = literal("null"); + + // Value - uses forward references for recursive structures + add_rule("json_value", + rule("json_object") | + rule("json_array") | + rule("json_string") | + rule("json_number") | + true_lit | + false_lit | + null_lit + ); + + // Object: { "key": value, ... } + auto member = rule("json_string") + ws + literal(":") + ws + rule("json_value"); + auto members = member + zero_or_more(ws + literal(",") + ws + member); + + // Empty object or object with members + auto object = (literal("{") + ws + literal("}")) | + (literal("{") + ws + members + ws + literal("}")); + + add_rule("json_object", object); + + // Array: [ value, ... ] + auto elements = rule("json_value") + zero_or_more(ws + literal(",") + ws + rule("json_value")); + + // Empty array or array with elements + auto array = (literal("[") + ws + literal("]")) | + (literal("[") + ws + elements + ws + literal("]")); + + add_rule("json_array", array); + + // Register the main rule with the provided name + return add_rule(name, rule("json_value")); +} + +parser build_parser(const std::function & fn) { + parser_builder builder; + auto root = fn(builder); + builder.assign_ids(root); // Assign IDs to rules that were created with operators + return root; +} diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h new file mode 100644 index 0000000000000..72adf523c489a --- /dev/null +++ b/common/chat-parser-combinator.h @@ -0,0 +1,158 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +enum parser_type { + PARSER_LITERAL = 0, + PARSER_SEQUENCE = 1, + PARSER_CHOICE = 2, + PARSER_ZERO_OR_MORE = 3, + PARSER_ONE_OR_MORE = 4, + PARSER_NOT = 5, + PARSER_ANY = 6, + PARSER_CHAR_CLASS = 7, + PARSER_GROUP = 8, + PARSER_RULE = 9, + PARSER_OPTIONAL = 10, +}; + +enum parser_result_type { + PARSER_RESULT_FAIL = 0, + PARSER_RESULT_NEED_MORE_INPUT = 1, + PARSER_RESULT_SUCCESS = 2, +}; + +struct parse_cache_key { + int id; + size_t start; + + bool operator==(const parse_cache_key & other) const { + return id == other.id && start == other.start; + } +}; + +template <> +struct std::hash { + std::size_t operator()(const parse_cache_key & k) const { + return std::hash{}(((size_t)k.id << 32) | k.start); + } +}; + +struct parser_match_location { + size_t start; + size_t end; + + size_t length() const { return end - start; } + + std::string_view view(std::string_view sv) const { + return sv.substr(start, length()); + } +}; + +struct parser_result { + parser_result_type type = PARSER_RESULT_FAIL; + size_t start = 0; + size_t end = 0; + + std::unordered_map groups; + + parser_result() : type(PARSER_RESULT_FAIL) {} + parser_result(parser_result_type type, size_t start) : type(type), start(start), end(start) {} + parser_result(parser_result_type type, size_t start, size_t end) : type(type), start(start), end(end) {} + parser_result(parser_result_type type, size_t start, size_t end, const std::unordered_map & groups) : type(type), start(start), end(end), groups(groups) {} + + bool is_fail() const { return type == PARSER_RESULT_FAIL; } + bool is_need_more_input() const { return type == PARSER_RESULT_NEED_MORE_INPUT; } + bool is_success() const { return type == PARSER_RESULT_SUCCESS; } + + std::optional group(const std::string & name, std::string_view input) const; +}; + +class parse_cache { + std::unordered_map results; + + public: + parser_result set(int id, size_t start, parser_result result); + std::optional get(int id, size_t start); + void clear(); +}; + +class parser; + +struct parser_context { + std::string_view input; + parse_cache memo; + bool input_is_complete = true; +}; + +class parser_base; +class sequence_parser; +class choice_parser; +class parser_builder; + +class parser { + std::shared_ptr ptr; + + friend class parser_builder; + + public: + parser(); + parser(std::shared_ptr parser); + parser(const parser & other) = default; + parser & operator=(const parser & other) { + if (this != &other) { + ptr = other.ptr; + } + return *this; + } + + parser operator~() const; + parser operator+(const parser & other) const; + parser operator|(const parser & other) const; + + parser_base & operator*() const; + parser_base * operator->() const; + + bool is_sequence() const; + std::shared_ptr to_sequence() const; + + bool is_choice() const; + std::shared_ptr to_choice() const; + + parser_type type() const; + parser_result parse(parser_context & ctx, size_t start = 0) const; + std::string dump() const; +}; + +class parser_builder { + std::shared_ptr> rules_; + int next_id_; + + public: + parser_builder(); + + parser literal(const std::string & literal); + parser sequence(std::initializer_list parsers); + parser choice(std::initializer_list parsers); + parser one_or_more(const parser & p); + parser zero_or_more(const parser & p); + parser optional(const parser & p); + parser negate(const parser & p); + parser any(); + parser char_class(const std::string & classes); + parser group(const std::string & name, const parser & p); + parser rule(const std::string & name); + parser space(); + + parser add_rule(const std::string & name, const parser & p); + parser add_json_rule(const std::string & name); + + void assign_ids(parser & p); +}; + +parser build_parser(const std::function & fn); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f4ce..90badf62af667 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -180,6 +180,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) endif() llama_build_and_test(test-chat-parser.cpp) +llama_build_and_test(test-chat-parser-combinator.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp new file mode 100644 index 0000000000000..55a443aed3e38 --- /dev/null +++ b/tests/test-chat-parser-combinator.cpp @@ -0,0 +1,472 @@ +#include +#include + +#include "chat-parser-combinator.h" + +template +static void assert_equals(const std::string_view label, const T & expected, const T & actual) { + if (expected != actual) { + std::cerr << label << "\n"; + std::cerr << "Expected: " << expected << "\n"; + std::cerr << "Actual: " << actual << "\n"; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +template +static void assert_equals(const T & expected, const T & actual) { + assert_equals("", expected, actual); +} + +static void assert_equals(const char * expected, const std::string & actual) { + assert_equals(expected, actual); +} + +static void test_partial_parsing() { + { + // Test literal + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context{"hello", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + } + { + // Test char class + auto parser = build_parser([](parser_builder& p) { + return p.char_class("a-z"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context{"a", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"A", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + + parser = build_parser([](parser_builder& p) { + return p.char_class("a-z-"); + }); + + ctx = parser_context{"f", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"-", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"A", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test sequences and literals + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + p.literal(""); + }); + + // Partial matches + auto ctx = parser_context{"", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match, since it does not adhere to the grammar + ctx = parser_context{"I am parser", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test choices + auto parser = build_parser([](parser_builder& p) { + return p.literal("") | p.literal(""); + }); + + // Partial matches + auto ctx = parser_context{"", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match + ctx = parser_context{"", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test zero_or_more + auto parser = build_parser([](parser_builder& p) { + return p.zero_or_more(p.literal("ab")); + }); + + // Partial matches + auto ctx = parser_context{"a", parse_cache(), false}; + auto result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + ctx = parser_context{"aba", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + // Full match + ctx = parser_context{"ab", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + } + { + // Test one_or_more + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.literal("ab")); + }); + + // Partial matches + auto ctx = parser_context{"a", parse_cache(), false}; + auto result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + ctx = parser_context{"aba", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + + // Full match + ctx = parser_context{"ab", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // No match + ctx = parser_context{"cd", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } +} + +static void test_capture_groups() { + { + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + + p.group("reasoning_content", + p.zero_or_more(~p.literal("") + p.any()) + ) + + p.literal(""); + }); + + std::string input = "I have a thought"; + auto ctx = parser_context{input, parse_cache()}; + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + + auto it = result.groups.find("reasoning_content"); + assert_equals(true, it != result.groups.end()); + assert_equals("I have a thought", std::string(it->second.view(input))); + } + { + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + + p.group("reasoning_content", + p.zero_or_more(~p.literal("") + p.any()) + ) + + p.literal(""); + }); + + std::string input = "I have a "; + auto ctx = parser_context{input, parse_cache(), false}; + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + + auto it = result.groups.find("reasoning_content"); + assert_equals(true, it != result.groups.end()); + assert_equals("I have a ", std::string(it->second.view(input))); + } + { + auto parser = build_parser([](parser_builder& p) { + return p.literal("") + + p.group("reasoning_content", + p.zero_or_more(~p.literal("") + p.any()) + ) + + p.literal("") + + p.group("content", p.zero_or_more(p.any())); + }); + + std::string input = "The user said hello.Hello!"; + auto ctx = parser_context{input, parse_cache(), true}; + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + + auto it = result.groups.find("reasoning_content"); + assert_equals(true, it != result.groups.end()); + assert_equals("The user said hello.", std::string(it->second.view(input))); + + it = result.groups.find("content"); + assert_equals(true, it != result.groups.end()); + assert_equals("Hello!", std::string(it->second.view(input))); + } +} + +static void test_char_class() { + { + // Test common escape sequences + auto parser = build_parser([](parser_builder& p) { + return p.char_class("[\\n\\t\\\\]"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context{"\n", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"\t", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"\\", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{" ", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } + { + // Test escaped dash (literal dash, not a range) + auto parser = build_parser([](parser_builder& p) { + return p.char_class("[a\\-z]"); + }); + + parser_context ctx; + parser_result result; + + ctx = parser_context{"a", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"-", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + ctx = parser_context{"z", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Should NOT match 'b' since \- is a literal dash, not a range + ctx = parser_context{"b", parse_cache()}; + result = parser.parse(ctx); + assert_equals(true, result.is_fail()); + } +} + +static void test_recursive_references() { + auto value_parser = build_parser([](parser_builder& p) { + p.add_rule("number", p.one_or_more(p.char_class("0-9"))); + p.add_rule("list", p.sequence({ + p.literal("["), + p.rule("value"), + p.literal("]") + })); + return p.add_rule("value", p.rule("number") | p.rule("list")); + }); + + parser_context ctx; + parser_result result; + + // Test simple number + ctx = parser_context{"1", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test simple list + ctx = parser_context{"[1]", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test nested list + ctx = parser_context{"[[2]]", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test deeply nested list + ctx = parser_context{"[[[3]]]", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test partial match + ctx = parser_context{"[[", parse_cache(), false}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_success()); + + // Test no match + ctx = parser_context{"[a]", parse_cache(), true}; + result = value_parser.parse(ctx); + assert_equals(true, result.is_fail()); +} + +static void test_optional() { + // Test optional with a match + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + // Full match with optional part present + auto ctx = parser_context{"hello world", parse_cache()}; + auto result = parser.parse(ctx); + assert_equals(true, result.is_success()); + assert_equals((size_t)11, result.end); + + // Full match with optional part absent + ctx = parser_context{"hello", parse_cache(), true}; + result = parser.parse(ctx); + assert_equals(true, result.is_success()); + assert_equals((size_t)5, result.end); + + // Partial match - waiting for more input to determine if optional matches + ctx = parser_context{"hello ", parse_cache(), false}; + result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); +} + +static void test_json_parser() { + auto json = build_parser([](parser_builder & p) { + return p.add_json_rule("json"); + }); + + // Test parsing a simple JSON object + std::string input = R"({"name": "test", "value": 42, "flag": true})"; + parser_context ctx{input, parse_cache()}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); +} + +static void test_complete_example() { + auto parser = build_parser([](parser_builder & p) { + auto space = p.add_rule("space", p.space()); + + auto reasoning = p.add_rule("reasoning", + p.literal("") + space + + p.group("reasoning-content", + p.zero_or_more(~(space + p.literal("")) + p.any())) + + space + p.literal("")); + + auto content = p.add_rule("content", + p.group("content", + p.zero_or_more(~(space + p.literal("")) + p.any()))); + + auto ident_chars = p.add_rule("ident-chars", p.char_class("[a-zA-Z\\-_]")); + auto json = p.add_json_rule("json"); + + auto tool_call_name = p.add_rule("tool-call-name", + p.literal("") + space + + p.group("tool-name", p.one_or_more(~p.literal("") + ident_chars)) + + space + p.literal("")); + + auto tool_call_args = p.add_rule("tool-call-args", + p.literal("") + space + + p.group("tool-args", json) + + space + p.literal("")); + + auto tool_call = p.add_rule("tool-call", + p.literal("") + space + + tool_call_name + space + + tool_call_args + space + + p.literal("")); + + return p.add_rule("root", reasoning + p.optional(content) + p.optional(tool_call)); + }); + + // Test complete input + std::string input = R"(I need to call get_weather with city = New Yorkget_weather{"city": "New York"})"; + parser_context ctx{input, parse_cache()}; + + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + assert_equals(std::string("I need to call get_weather with city = New York"), *result.group("reasoning-content", ctx.input)); + assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); + assert_equals(std::string(R"({"city": "New York"})"), *result.group("tool-args", ctx.input)); + + // Test partial input + input = R"(I need to call get_weather )"; + ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + + input = R"(I need to call get_weatherget_weather)"; + ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + + input = R"(I need to call get_weatherget_weatherI need to call get_weatherget_weather{"cit)"; + ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); + assert_equals(std::string(R"({"cit)"), *result.group("tool-args", ctx.input)); +} + +int main() { + test_partial_parsing(); + test_char_class(); + test_capture_groups(); + test_recursive_references(); + test_optional(); + test_json_parser(); + test_complete_example(); + std::cout << "All tests passed!\n"; + return 0; +} From e6153bb14a0728ed7fcce7dadf54c3c7f670dca0 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sun, 9 Nov 2025 22:34:24 -0600 Subject: [PATCH 02/34] add virtual destructor to parser_base --- common/chat-parser-combinator.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index f2182980b3bac..a8d8b10fe4fcd 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -13,6 +13,7 @@ class parser_base { public: parser_base(int id) : id_(id) {} + virtual ~parser_base() = default; virtual parser_type type() const = 0; virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; From 4ced9996e65817c0da27acca05d64b3acb6a6a08 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sun, 9 Nov 2025 22:52:20 -0600 Subject: [PATCH 03/34] fix memory leak from circular references of rules --- common/chat-parser-combinator.cpp | 53 +++++++++++++++++++++++++------ common/chat-parser-combinator.h | 2 ++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index a8d8b10fe4fcd..cd915afcd7b88 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -555,11 +555,11 @@ class group_parser : public parser_base { class rule_parser : public parser_base { std::string rule_name_; - std::shared_ptr> rules_; + std::weak_ptr> rules_; public: rule_parser(const std::string & name, std::shared_ptr> rules, int id) - : parser_base(id), rule_name_(name), rules_(std::move(rules)) {} + : parser_base(id), rule_name_(name), rules_(rules) {} parser_type type() const override { return PARSER_RULE; } @@ -569,13 +569,14 @@ class rule_parser : public parser_base { return *cached; } - if (!rules_) { - LOG_ERR("rule_parser::parse called without rule registry\n"); + auto rules = rules_.lock(); + if (!rules) { + LOG_ERR("rule_parser::parse called with expired rule registry\n"); return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } - auto it = rules_->find(rule_name_); - if (it == rules_->end()) { + auto it = rules->find(rule_name_); + if (it == rules->end()) { LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", rule_name_.c_str()); return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } @@ -589,6 +590,32 @@ class rule_parser : public parser_base { } }; +class root_parser : public parser_base { + parser root_; + std::shared_ptr> rules_; + + public: + root_parser(const parser & root, std::shared_ptr> rules, int id) + : parser_base(id), root_(root), rules_(std::move(rules)) {} + + parser_type type() const override { return root_->type(); } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + return root_->parse(ctx, start); + } + + std::string dump() const override { + return root_->dump(); + } + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + root_->assign_ids_internal(next_id); + } +}; + std::optional parser_result::group(const std::string & name, std::string_view input) const { auto it = groups.find(name); if (it == groups.end()) { @@ -632,11 +659,11 @@ parser parser::operator~() const { } parser parser::operator+(const parser & other) const { - return parser(std::shared_ptr(new sequence_parser({*this, other}, -1))); + return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } parser parser::operator|(const parser & other) const { - return parser(std::shared_ptr(new choice_parser({*this, other}, -1))); + return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } parser_base & parser::operator*() const { @@ -684,11 +711,11 @@ parser parser_builder::literal(const std::string & literal) { } parser parser_builder::sequence(std::initializer_list parsers) { - return parser(std::shared_ptr(new sequence_parser(parsers, next_id_++))); + return parser(std::make_shared(parsers, next_id_++)); } parser parser_builder::choice(std::initializer_list parsers) { - return parser(std::shared_ptr(new choice_parser(parsers, next_id_++))); + return parser(std::make_shared(parsers, next_id_++)); } parser parser_builder::one_or_more(const parser & p) { @@ -816,5 +843,11 @@ parser build_parser(const std::function & fn) { parser_builder builder; auto root = fn(builder); builder.assign_ids(root); // Assign IDs to rules that were created with operators + + // Wrap the root parser in a root_parser to own the rules and break circular references + auto rules = builder.rules(); + if (rules && !rules->empty()) { + return parser(std::make_shared(root, rules, -1)); + } return root; } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 72adf523c489a..edebd0bef75db 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -153,6 +153,8 @@ class parser_builder { parser add_json_rule(const std::string & name); void assign_ids(parser & p); + + std::shared_ptr> rules() const { return rules_; } }; parser build_parser(const std::function & fn); From 2a9a13de753dafa7e3caa64b9cd5dafcdf8bda49 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 03:44:21 -0600 Subject: [PATCH 04/34] implement gbnf grammar building --- common/chat-parser-combinator.cpp | 570 +++++++++++++++++++++++--- common/chat-parser-combinator.h | 14 +- tests/test-chat-parser-combinator.cpp | 276 +++++++++++-- 3 files changed, 782 insertions(+), 78 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index cd915afcd7b88..897c4f6f75b4b 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -1,10 +1,17 @@ #include "chat-parser-combinator.h" +#include "json-schema-to-grammar.h" #include "common.h" #include "log.h" +#include + #include #include +class gbnf_visitor; + +static parser json_parser(); + class parser_base { protected: int id_; @@ -18,6 +25,7 @@ class parser_base { virtual parser_type type() const = 0; virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; virtual std::string dump() const = 0; + virtual std::string accept(gbnf_visitor & visitor) const = 0; virtual void assign_ids_internal(int& next_id) { if (id_ == -1) { id_ = next_id++; @@ -28,6 +36,8 @@ class parser_base { class literal_parser : public parser_base { std::string literal_; + friend class gbnf_visitor; + public: literal_parser(const std::string & literal, int id) : parser_base(id), literal_(literal) {} @@ -62,11 +72,15 @@ class literal_parser : public parser_base { std::string dump() const override { return "Literal(" + literal_ + ")"; } + + std::string accept(gbnf_visitor & visitor) const override; }; class sequence_parser : public parser_base { std::vector parsers_; + friend class gbnf_visitor; + public: sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { @@ -125,6 +139,8 @@ class sequence_parser : public parser_base { return "Sequence(" + string_join(parts, ", ") + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const std::vector & parsers() const { return parsers_; } void assign_ids_internal(int& next_id) override { @@ -140,6 +156,8 @@ class sequence_parser : public parser_base { class choice_parser : public parser_base { std::vector parsers_; + friend class gbnf_visitor; + public: choice_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { @@ -187,6 +205,8 @@ class choice_parser : public parser_base { return "Choice(" + string_join(parts, ", ") + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const std::vector & parsers() const { return parsers_; } void assign_ids_internal(int& next_id) override { @@ -202,6 +222,8 @@ class choice_parser : public parser_base { class one_or_more_parser : public parser_base { parser parser_; + friend class gbnf_visitor; + public: one_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -257,6 +279,8 @@ class one_or_more_parser : public parser_base { return "OneOrMore(" + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const parser & child() const { return parser_; } void assign_ids_internal(int& next_id) override { @@ -270,6 +294,8 @@ class one_or_more_parser : public parser_base { class zero_or_more_parser : public parser_base { parser parser_; + friend class gbnf_visitor; + public: zero_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -315,6 +341,8 @@ class zero_or_more_parser : public parser_base { return "ZeroOrMore(" + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const parser & child() const { return parser_; } void assign_ids_internal(int& next_id) override { @@ -328,6 +356,8 @@ class zero_or_more_parser : public parser_base { class optional_parser : public parser_base { parser parser_; + friend class gbnf_visitor; + public: optional_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -359,6 +389,8 @@ class optional_parser : public parser_base { return "Optional(" + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const parser & child() const { return parser_; } void assign_ids_internal(int& next_id) override { @@ -369,9 +401,55 @@ class optional_parser : public parser_base { } }; +class until_parser : public parser_base { + std::string delimiter_; + bool include_spaces_; + parser parser_; + + friend class gbnf_visitor; + + public: + until_parser(const std::string & delimiter, bool include_spaces, int id, parser_builder & builder) + : parser_base(id), delimiter_(delimiter), include_spaces_(include_spaces) { + if (include_spaces) { + auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); + parser_ = builder.zero_or_more(builder.negate(ws + builder.literal(delimiter)) + builder.any()); + } else { + parser_ = builder.zero_or_more(builder.negate(builder.literal(delimiter)) + builder.any()); + } + } + + parser_type type() const override { return PARSER_UNTIL; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto result = parser_->parse(ctx, start); + return ctx.memo.set(id_, start, result); + } + + std::string dump() const override { + return "Until(" + delimiter_ + ")"; + } + + std::string accept(gbnf_visitor & visitor) const override; + + void assign_ids_internal(int& next_id) override { + if (id_ == -1) { + id_ = next_id++; + } + parser_->assign_ids_internal(next_id); + } +}; + class not_parser : public parser_base { parser parser_; + friend class gbnf_visitor; + public: not_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -403,6 +481,8 @@ class not_parser : public parser_base { return "Not(" + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + const parser & child() const { return parser_; } void assign_ids_internal(int& next_id) override { @@ -414,6 +494,8 @@ class not_parser : public parser_base { }; class any_parser : public parser_base { + friend class gbnf_visitor; + public: any_parser(int id) : parser_base(id) {} @@ -438,6 +520,42 @@ class any_parser : public parser_base { std::string dump() const override { return "Any"; } + + std::string accept(gbnf_visitor & visitor) const override; +}; + +class space_parser : public parser_base { + friend class gbnf_visitor; + + public: + space_parser(int id) : parser_base(id) {} + + parser_type type() const override { return PARSER_SPACE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto pos = start; + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + if (c == ' ' || c == '\t' || c == '\n') { + ++pos; + } else { + break; + } + } + + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos)); + } + + std::string dump() const override { + return "Space"; + } + + std::string accept(gbnf_visitor & visitor) const override; }; class char_class_parser : public parser_base { @@ -450,9 +568,12 @@ class char_class_parser : public parser_base { std::string pattern_; std::vector ranges_; + bool negated_; + + friend class gbnf_visitor; public: - char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes) { + char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes), negated_(false) { std::string content = classes; if (content.front() == '[') { content = content.substr(1); @@ -462,6 +583,12 @@ class char_class_parser : public parser_base { content.pop_back(); } + // Check for negation + if (!content.empty() && content.front() == '^') { + negated_ = true; + content = content.substr(1); + } + auto parse_char = [&](size_t pos) -> std::pair { if (content[pos] == '\\' && pos + 1 < content.length()) { char next = content[pos + 1]; @@ -510,24 +637,39 @@ class char_class_parser : public parser_base { return parser_result(PARSER_RESULT_FAIL, start); } + bool matches = false; for (const auto & range : ranges_) { if (range.contains(ctx.input[start])) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + matches = true; + break; } } + // If negated, invert the match result + if (negated_) { + matches = !matches; + } + + if (matches) { + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + } + return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } std::string dump() const override { return "Char(" + pattern_ + ")"; } + + std::string accept(gbnf_visitor & visitor) const override; }; class group_parser : public parser_base { std::string name_; parser parser_; + friend class gbnf_visitor; + public: group_parser(const std::string & name, const parser & parser, int id) : parser_base(id), name_(name), parser_(parser) {} @@ -545,6 +687,8 @@ class group_parser : public parser_base { return "Group(" + name_ + ", " + parser_->dump() + ")"; } + std::string accept(gbnf_visitor & visitor) const override; + void assign_ids_internal(int& next_id) override { if (id_ == -1) { id_ = next_id++; @@ -553,10 +697,36 @@ class group_parser : public parser_base { } }; +class schema_parser : public parser_base { + parser parser_; + std::string name_; + nlohmann::ordered_json schema_; + + friend class gbnf_visitor; + + public: + schema_parser(const parser & parser, const std::string & name, const nlohmann::ordered_json & schema, int id) + : parser_base(id), parser_(parser), name_(name), schema_(schema) {} + + parser_type type() const override { return PARSER_SCHEMA; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + return parser_->parse(ctx, start); + } + + std::string dump() const override { + return "Schema(" + parser_->dump() + ", " + schema_.dump() + ")"; + } + + std::string accept(gbnf_visitor & visitor) const override; +}; + class rule_parser : public parser_base { std::string rule_name_; std::weak_ptr> rules_; + friend class gbnf_visitor; + public: rule_parser(const std::string & name, std::shared_ptr> rules, int id) : parser_base(id), rule_name_(name), rules_(rules) {} @@ -588,12 +758,16 @@ class rule_parser : public parser_base { std::string dump() const override { return "Rule(" + rule_name_ + ")"; } + + std::string accept(gbnf_visitor & visitor) const override; }; class root_parser : public parser_base { parser root_; std::shared_ptr> rules_; + friend class gbnf_visitor; + public: root_parser(const parser & root, std::shared_ptr> rules, int id) : parser_base(id), root_(root), rules_(std::move(rules)) {} @@ -608,6 +782,8 @@ class root_parser : public parser_base { return root_->dump(); } + std::string accept(gbnf_visitor & visitor) const override; + void assign_ids_internal(int& next_id) override { if (id_ == -1) { id_ = next_id++; @@ -616,6 +792,269 @@ class root_parser : public parser_base { } }; +class gbnf_visitor { + common_grammar_builder& builder_; + std::unordered_map rule_name_mapping_; + + public: + gbnf_visitor(common_grammar_builder& builder) : builder_(builder) {} + + private: + // Escape special characters for GBNF literals + static std::string escape_literal(const std::string & s) { + std::string escaped; + for (char c : s) { + switch (c) { + case '\n': escaped += "\\n"; break; + case '\t': escaped += "\\t"; break; + case '\r': escaped += "\\r"; break; + case '\\': escaped += "\\\\"; break; + case '"': escaped += "\\\""; break; + default: escaped += c; break; + } + } + return escaped; + } + + // Escape a single character for use in character classes + static std::string escape_char_class(char c) { + switch (c) { + case '\n': return "\\n"; + case '\t': return "\\t"; + case '\r': return "\\r"; + case '\\': return "\\\\"; + case ']': return "\\]"; + case '-': return "\\-"; + case '^': return "\\^"; + default: return std::string(1, c); + } + } + + // Generate pattern for until() that matches prefixes but prevents full delimiter match + // For "" generates: ( [^<] | "<" [^/] | " alternatives; + + // First alternative: match any character that's not the start of the delimiter + alternatives.push_back("[^" + escape_char_class(delimiter[0]) + "]"); + + // For each prefix, match the prefix followed by a char that's not the next delimiter char + for (size_t i = 1; i < delimiter.length(); ++i) { + std::string prefix = "\"" + escape_literal(delimiter.substr(0, i)) + "\""; + std::string next_char_negated = "[^" + escape_char_class(delimiter[i]) + "]"; + alternatives.push_back(prefix + " " + next_char_negated); + } + + // Combine alternatives with | + std::string result = "("; + for (size_t i = 0; i < alternatives.size(); ++i) { + if (i > 0) { + result += " | "; + } + result += alternatives[i]; + } + result += ")"; + + return result; + } + + // Check if expression needs parentheses + static bool needs_parens(parser_type type) { + return type == PARSER_CHOICE || type == PARSER_SEQUENCE; + } + + public: + std::string visit(const literal_parser & p) { + return "\"" + escape_literal(p.literal_) + "\""; + } + + std::string visit(const sequence_parser & p) { + std::string s; + for (size_t i = 0; i < p.parsers_.size(); ++i) { + if (i > 0) s += " "; + auto child_result = p.parsers_[i]->accept(*this); + s += child_result; + } + return s; + } + + std::string visit(const choice_parser & p) { + std::string s; + for (size_t i = 0; i < p.parsers_.size(); ++i) { + if (i > 0) { + s += " | "; + } + + auto child_type = p.parsers_[i]->type(); + auto child_result = p.parsers_[i]->accept(*this); + + // Parenthesize sequences in choices + if (child_type == PARSER_SEQUENCE) { + s += "(" + child_result + ")"; + } else { + s += child_result; + } + } + return s; + } + + std::string visit(const one_or_more_parser & p) { + auto child_type = p.parser_->type(); + auto child_result = p.parser_->accept(*this); + if (needs_parens(child_type)) { + return "(" + child_result + ")+"; + } + return child_result + "+"; + } + + std::string visit(const zero_or_more_parser & p) { + auto child_type = p.parser_->type(); + auto child_result = p.parser_->accept(*this); + if (needs_parens(child_type)) { + return "(" + child_result + ")*"; + } + return child_result + "*"; + } + + std::string visit(const optional_parser & p) { + auto child_type = p.parser_->type(); + auto child_result = p.parser_->accept(*this); + if (needs_parens(child_type)) { + return "(" + child_result + ")?"; + } + return child_result + "?"; + } + + std::string visit(const until_parser & p) { + // Generate pattern that matches prefixes but prevents full delimiter match + return generate_until_pattern(p.delimiter_) + "*"; + } + + std::string visit(const not_parser &) { + // NOT is tricky in GBNF - for now, emit error + LOG_ERR("NOT operator not directly supported in GBNF generation\n"); + return ""; // This will cause compilation errors, which is intended + } + + std::string visit(const any_parser &) { + // Match any single character + return "[\\x00-\\x{10FFFF}]"; + } + + std::string visit(const space_parser &) { + // Reference the built-in space rule + return "space"; + } + + std::string visit(const char_class_parser & p) { + // Return pattern as-is (already in GBNF format) + return p.pattern_; + } + + std::string visit(const group_parser & p) { + // Groups are transparent - just visit child + return p.parser_->accept(*this); + } + + std::string visit(const schema_parser & p) { + return builder_.add_schema(p.name_, p.schema_); + } + + std::string visit(const rule_parser & p) { + // Return canonical rule reference + auto it = rule_name_mapping_.find(p.rule_name_); + if (it != rule_name_mapping_.end()) { + return it->second; + } + // Fallback to original name if not in mapping (shouldn't happen in valid usage) + return p.rule_name_; + } + + std::string visit(const root_parser & p) { + // Generate named rules first + if (p.rules_) { + for (const auto & [name, rule] : *p.rules_) { + auto rule_body = rule->accept(*this); + auto canonical_name = builder_.add_rule(name, rule_body); + rule_name_mapping_[name] = canonical_name; + } + } + + // Return root body for composition + return p.root_->accept(*this); + } +}; + +// Implement accept() methods for all parser classes +std::string literal_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string sequence_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string choice_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string one_or_more_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string zero_or_more_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string optional_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string until_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string not_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string any_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string space_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string char_class_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string group_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string schema_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string rule_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + +std::string root_parser::accept(gbnf_visitor & visitor) const { + return visitor.visit(*this); +} + std::optional parser_result::group(const std::string & name, std::string_view input) const { auto it = groups.find(name); if (it == groups.end()) { @@ -666,6 +1105,11 @@ parser parser::operator|(const parser & other) const { return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } +parser parser::operator<<(const parser & other) const { + auto ws = parser(std::make_shared(-1)); + return parser(std::make_shared(std::initializer_list{*this, ws, other}, -1)); +} + parser_base & parser::operator*() const { return *ptr; } @@ -702,6 +1146,16 @@ std::string parser::dump() const { return ptr->dump(); } +void parser::build_grammar(common_grammar_builder& builder) const { + gbnf_visitor visitor(builder); + auto result = ptr->accept(visitor); + // The visitor returns the GBNF string for this parser + // root_parser registers its named rules and returns its root body + if (!result.empty()) { + builder.add_rule("root", result); + } +} + parser_builder::parser_builder() : rules_(std::make_shared>()) , next_id_(0) {} @@ -751,7 +1205,15 @@ parser parser_builder::rule(const std::string & name) { } parser parser_builder::space() { - return zero_or_more(char_class("[ \\t\\n\\r]")); + return parser(std::make_shared(next_id_++)); +} + +parser parser_builder::until(const std::string & delimiter, bool include_spaces) { + return parser(std::make_shared(delimiter, include_spaces, next_id_++, *this)); +} + +parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { + return parser(std::make_shared(p, name, schema, next_id_++)); } parser parser_builder::add_rule(const std::string & name, const parser & p) { @@ -765,89 +1227,99 @@ void parser_builder::assign_ids(parser & p) { } } -parser parser_builder::add_json_rule(const std::string & name) { +parser build_parser(const std::function & fn) { + parser_builder builder; + auto root = fn(builder); + builder.assign_ids(root); // Assign IDs to rules that were created with operators + + // Wrap the root parser in a root_parser to own the rules and break circular references + auto rules = builder.rules(); + if (rules && !rules->empty()) { + return parser(std::make_shared(root, rules, -1)); + } + return root; +} + +static parser json_parser() { + parser_builder builder; + // Whitespace: space, tab, newline, carriage return - auto ws = zero_or_more(char_class("[ \\t\\n\\r]")); + auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); // Number components - auto digit = char_class("[0-9]"); - auto digit1_9 = char_class("[1-9]"); - auto digits = one_or_more(digit); + auto digit = builder.char_class("[0-9]"); + auto digit1_9 = builder.char_class("[1-9]"); + auto digits = builder.one_or_more(digit); // Integer part: 0 or non-zero digit followed by more digits - auto int_part = literal("0") | (digit1_9 + zero_or_more(digit)); + auto int_part = builder.literal("0") | (digit1_9 + builder.zero_or_more(digit)); // Optional fractional part - auto frac = literal(".") + digits; + auto frac = builder.literal(".") + digits; // Optional exponent part - auto exp = (literal("e") | literal("E")) + optional(char_class("[+\\-]")) + digits; + auto exp = (builder.literal("e") | builder.literal("E")) + builder.optional(builder.char_class("[+\\-]")) + digits; // Complete number - auto number = optional(literal("-")) + int_part + optional(frac) + optional(exp); + auto number = builder.optional(builder.literal("-")) + int_part + builder.optional(frac) + builder.optional(exp); - add_rule("json_number", number); + builder.add_rule("json_number", number); // String components - auto hex = char_class("[0-9a-fA-F]"); - auto unicode_escape = literal("\\u") + hex + hex + hex + hex; - auto simple_escape = literal("\\") + char_class("[\"\\\\bfnrt/]"); + auto hex = builder.char_class("[0-9a-fA-F]"); + auto unicode_escape = builder.literal("\\u") + hex + hex + hex + hex; + auto simple_escape = builder.literal("\\") + builder.char_class("[\"\\\\bfnrt/]"); auto escape = simple_escape | unicode_escape; // String character: escape sequence or any char except quote and backslash - auto string_char = escape | (~char_class("[\"\\\\]") + any()); - auto string = literal("\"") + zero_or_more(string_char) + literal("\""); + auto string_char = escape | builder.char_class("[^\"\\\\]"); + auto string = builder.literal("\"") + builder.zero_or_more(string_char) + builder.literal("\""); - add_rule("json_string", string); + builder.add_rule("json_string", string); // Literals - auto true_lit = literal("true"); - auto false_lit = literal("false"); - auto null_lit = literal("null"); + auto true_lit = builder.literal("true"); + auto false_lit = builder.literal("false"); + auto null_lit = builder.literal("null"); // Value - uses forward references for recursive structures - add_rule("json_value", - rule("json_object") | - rule("json_array") | - rule("json_string") | - rule("json_number") | + builder.add_rule("json_value", + builder.rule("json_object") | + builder.rule("json_array") | + builder.rule("json_string") | + builder.rule("json_number") | true_lit | false_lit | null_lit ); // Object: { "key": value, ... } - auto member = rule("json_string") + ws + literal(":") + ws + rule("json_value"); - auto members = member + zero_or_more(ws + literal(",") + ws + member); + auto member = builder.rule("json_string") + ws + builder.literal(":") + ws + builder.rule("json_value"); + auto members = member + builder.zero_or_more(ws + builder.literal(",") + ws + member); // Empty object or object with members - auto object = (literal("{") + ws + literal("}")) | - (literal("{") + ws + members + ws + literal("}")); + auto object = (builder.literal("{") + ws + builder.literal("}")) | + (builder.literal("{") + ws + members + ws + builder.literal("}")); - add_rule("json_object", object); + builder.add_rule("json_object", object); // Array: [ value, ... ] - auto elements = rule("json_value") + zero_or_more(ws + literal(",") + ws + rule("json_value")); + auto elements = builder.rule("json_value") + builder.zero_or_more(ws + builder.literal(",") + ws + builder.rule("json_value")); // Empty array or array with elements - auto array = (literal("[") + ws + literal("]")) | - (literal("[") + ws + elements + ws + literal("]")); + auto array = (builder.literal("[") + ws + builder.literal("]")) | + (builder.literal("[") + ws + elements + ws + builder.literal("]")); - add_rule("json_array", array); + builder.add_rule("json_array", array); - // Register the main rule with the provided name - return add_rule(name, rule("json_value")); -} + // Get the json_value rule as the root + auto root = builder.rule("json_value"); + builder.assign_ids(root); -parser build_parser(const std::function & fn) { - parser_builder builder; - auto root = fn(builder); - builder.assign_ids(root); // Assign IDs to rules that were created with operators + // Wrap in root_parser to own the rules + return parser(std::make_shared(root, builder.rules(), -1)); +} - // Wrap the root parser in a root_parser to own the rules and break circular references - auto rules = builder.rules(); - if (rules && !rules->empty()) { - return parser(std::make_shared(root, rules, -1)); - } - return root; +parser parser_builder::json() { + return json_parser(); } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index edebd0bef75db..1ef3996dc862a 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -7,6 +9,8 @@ #include #include +struct common_grammar_builder; + enum parser_type { PARSER_LITERAL = 0, PARSER_SEQUENCE = 1, @@ -19,6 +23,9 @@ enum parser_type { PARSER_GROUP = 8, PARSER_RULE = 9, PARSER_OPTIONAL = 10, + PARSER_UNTIL = 11, + PARSER_SPACE = 12, + PARSER_SCHEMA = 13, }; enum parser_result_type { @@ -94,6 +101,7 @@ class parser_base; class sequence_parser; class choice_parser; class parser_builder; +class gbnf_visitor; class parser { std::shared_ptr ptr; @@ -114,6 +122,7 @@ class parser { parser operator~() const; parser operator+(const parser & other) const; parser operator|(const parser & other) const; + parser operator<<(const parser & other) const; parser_base & operator*() const; parser_base * operator->() const; @@ -127,6 +136,7 @@ class parser { parser_type type() const; parser_result parse(parser_context & ctx, size_t start = 0) const; std::string dump() const; + void build_grammar(common_grammar_builder& builder) const; }; class parser_builder { @@ -148,9 +158,11 @@ class parser_builder { parser group(const std::string & name, const parser & p); parser rule(const std::string & name); parser space(); + parser until(const std::string & delimiter, bool include_spaces = true); + parser json(); + parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); parser add_rule(const std::string & name, const parser & p); - parser add_json_rule(const std::string & name); void assign_ids(parser & p); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 55a443aed3e38..1fffbfbf040a0 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -2,6 +2,9 @@ #include #include "chat-parser-combinator.h" +#include "json-schema-to-grammar.h" +#include "nlohmann/json.hpp" +#include "nlohmann/json_fwd.hpp" template static void assert_equals(const std::string_view label, const T & expected, const T & actual) { @@ -365,53 +368,110 @@ static void test_optional() { static void test_json_parser() { auto json = build_parser([](parser_builder & p) { - return p.add_json_rule("json"); + return p.json(); }); - // Test parsing a simple JSON object - std::string input = R"({"name": "test", "value": 42, "flag": true})"; - parser_context ctx{input, parse_cache()}; + { + // Test parsing a simple JSON object + std::string input = R"({"name": "test", "value": 42, "flag": true})"; + parser_context ctx{input, parse_cache()}; - auto result = json.parse(ctx); + auto result = json.parse(ctx); - assert_equals(true, result.is_success()); - assert_equals(input.size(), result.end); + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test parsing a JSON array with mixed types + std::string input = R"([1, "hello", true, null, 3.14])"; + parser_context ctx{input, parse_cache()}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test parsing nested JSON with objects and arrays + std::string input = R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})"; + parser_context ctx{input, parse_cache()}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + } + { + // Test partial parsing - incomplete object + std::string input = R"({"name": "test", "value": )"; + parser_context ctx{input, parse_cache(), false}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } + { + // Test partial parsing - incomplete array + std::string input = R"([1, 2, 3, )"; + parser_context ctx{input, parse_cache(), false}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } + { + // Test partial parsing - incomplete nested structure + std::string input = R"({"data": {"nested": )"; + parser_context ctx{input, parse_cache(), false}; + + auto result = json.parse(ctx); + + assert_equals(true, result.is_success()); + } } static void test_complete_example() { + // Parser for a fictitious model that outputs: + // + // + // ... reasoning content ... + // + // ... content ... + // + // tool_name + // { ... json args ... } + // + // auto parser = build_parser([](parser_builder & p) { - auto space = p.add_rule("space", p.space()); - auto reasoning = p.add_rule("reasoning", - p.literal("") + space + - p.group("reasoning-content", - p.zero_or_more(~(space + p.literal("")) + p.any())) + - space + p.literal("")); + p.literal("") + << p.group("reasoning-content", p.until("")) + << p.literal("")); auto content = p.add_rule("content", - p.group("content", - p.zero_or_more(~(space + p.literal("")) + p.any()))); + p.group("content", p.until(""))); - auto ident_chars = p.add_rule("ident-chars", p.char_class("[a-zA-Z\\-_]")); - auto json = p.add_json_rule("json"); + auto json = p.json(); auto tool_call_name = p.add_rule("tool-call-name", - p.literal("") + space + - p.group("tool-name", p.one_or_more(~p.literal("") + ident_chars)) + - space + p.literal("")); + p.literal("") + << p.group("tool-name", p.one_or_more(p.char_class("[a-zA-Z\\-_]"))) + << p.literal("")); + + auto schema = nlohmann::ordered_json::parse(R"({"type": "object"})"); auto tool_call_args = p.add_rule("tool-call-args", - p.literal("") + space + - p.group("tool-args", json) + - space + p.literal("")); + p.literal("") + << p.group("tool-args", p.schema(json, "get_weather", schema)) + << p.literal("")); auto tool_call = p.add_rule("tool-call", - p.literal("") + space + - tool_call_name + space + - tool_call_args + space + - p.literal("")); + p.literal("") + << tool_call_name + << tool_call_args + << p.literal("")); - return p.add_rule("root", reasoning + p.optional(content) + p.optional(tool_call)); + return reasoning << p.optional(content) << p.optional(tool_call); }); // Test complete input @@ -457,6 +517,165 @@ static void test_complete_example() { assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); assert_equals(std::string(R"({"cit)"), *result.group("tool-args", ctx.input)); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + + std::cout << "Grammar:\n" << gbnf << "\n"; +} + +static void test_gbnf_generation() { + { + // Test literal + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); + assert_equals(true, gbnf.find("space ::=") != std::string::npos); + } + { + // Test char class + auto parser = build_parser([](parser_builder& p) { + return p.char_class("[a-z]"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= [a-z]") != std::string::npos); + } + { + // Test sequence + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.literal(" ") + p.literal("world"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\" \" \" \"world\"") != std::string::npos); + } + { + // Test choice + auto parser = build_parser([](parser_builder& p) { + return p.literal("cat") | p.literal("dog"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"cat\" | \"dog\"") != std::string::npos); + } + { + // Test one_or_more + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.char_class("[0-9]")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= [0-9]+") != std::string::npos); + } + { + // Test zero_or_more + auto parser = build_parser([](parser_builder& p) { + return p.zero_or_more(p.char_class("[a-z]")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= [a-z]*") != std::string::npos); + } + { + // Test optional + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") + p.optional(p.literal(" world")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\" \" world\"?") != std::string::npos); + } + { + // Test until + auto parser = build_parser([](parser_builder& p) { + return p.until(""); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + // Should generate pattern that prevents matching the full delimiter + assert_equals(true, gbnf.find("root ::= ([^<] | \"<\" [^/] | \"])*") != std::string::npos); + } + { + // Test groups are transparent + auto parser = build_parser([](parser_builder& p) { + return p.group("test", p.literal("hello")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); + } + { + // Test complex expression with parentheses + auto parser = build_parser([](parser_builder& p) { + return p.one_or_more(p.literal("a") | p.literal("b")); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= (\"a\" | \"b\")+") != std::string::npos); + } + { + // Test rule references + auto parser = build_parser([](parser_builder& p) { + auto digit = p.add_rule("digit", p.char_class("[0-9]")); + return p.one_or_more(digit); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + // Should have digit rule defined and referenced + assert_equals(true, gbnf.find("digit ::= [0-9]") != std::string::npos); + assert_equals(true, gbnf.find("root ::= digit+") != std::string::npos); + } + { + // Test escaping in literals + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello\nworld\t!"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + assert_equals(true, gbnf.find("root ::= \"hello\\nworld\\t!\"") != std::string::npos); + } + { + // Test operator<< (whitespace insertion) + auto parser = build_parser([](parser_builder& p) { + return p.literal("hello") << p.literal("world"); + }); + + auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { + parser.build_grammar(const_cast(builder)); + }); + // Should inline the whitespace pattern + assert_equals(true, gbnf.find("\"hello\"") != std::string::npos); + assert_equals(true, gbnf.find("\"world\"") != std::string::npos); + } } int main() { @@ -467,6 +686,7 @@ int main() { test_optional(); test_json_parser(); test_complete_example(); + test_gbnf_generation(); std::cout << "All tests passed!\n"; return 0; } From 228653248e1994bbe82f99c90f7b8a607a34d461 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 04:06:59 -0600 Subject: [PATCH 05/34] remove unused private variable --- common/chat-parser-combinator.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 897c4f6f75b4b..1a340f6162656 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -403,14 +403,13 @@ class optional_parser : public parser_base { class until_parser : public parser_base { std::string delimiter_; - bool include_spaces_; parser parser_; friend class gbnf_visitor; public: until_parser(const std::string & delimiter, bool include_spaces, int id, parser_builder & builder) - : parser_base(id), delimiter_(delimiter), include_spaces_(include_spaces) { + : parser_base(id), delimiter_(delimiter) { if (include_spaces) { auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); parser_ = builder.zero_or_more(builder.negate(ws + builder.literal(delimiter)) + builder.any()); From 3e6662f66c030d2804736e515f554b4e6ed4ac11 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 20:17:09 -0600 Subject: [PATCH 06/34] create a base visitor and implement id assignment as a visitor --- common/chat-parser-combinator.cpp | 461 +++++++++++++++--------------- common/chat-parser-combinator.h | 4 +- 2 files changed, 230 insertions(+), 235 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 1a340f6162656..598fd74a93657 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -8,7 +8,7 @@ #include #include -class gbnf_visitor; +class id_assignment_visitor; static parser json_parser(); @@ -16,7 +16,7 @@ class parser_base { protected: int id_; - void set_id(int id) { id_ = id; } + friend class id_assignment_visitor; public: parser_base(int id) : id_(id) {} @@ -25,19 +25,12 @@ class parser_base { virtual parser_type type() const = 0; virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; virtual std::string dump() const = 0; - virtual std::string accept(gbnf_visitor & visitor) const = 0; - virtual void assign_ids_internal(int& next_id) { - if (id_ == -1) { - id_ = next_id++; - } - } + virtual void accept(parser_visitor & visitor) = 0; }; class literal_parser : public parser_base { std::string literal_; - friend class gbnf_visitor; - public: literal_parser(const std::string & literal, int id) : parser_base(id), literal_(literal) {} @@ -73,14 +66,14 @@ class literal_parser : public parser_base { return "Literal(" + literal_ + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; + + const std::string & literal() const { return literal_; } }; class sequence_parser : public parser_base { std::vector parsers_; - friend class gbnf_visitor; - public: sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { @@ -139,25 +132,14 @@ class sequence_parser : public parser_base { return "Sequence(" + string_join(parts, ", ") + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const std::vector & parsers() const { return parsers_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - for (auto & p : parsers_) { - p->assign_ids_internal(next_id); - } - } }; class choice_parser : public parser_base { std::vector parsers_; - friend class gbnf_visitor; - public: choice_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { @@ -205,25 +187,14 @@ class choice_parser : public parser_base { return "Choice(" + string_join(parts, ", ") + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const std::vector & parsers() const { return parsers_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - for (auto & p : parsers_) { - p->assign_ids_internal(next_id); - } - } }; class one_or_more_parser : public parser_base { parser parser_; - friend class gbnf_visitor; - public: one_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -279,23 +250,14 @@ class one_or_more_parser : public parser_base { return "OneOrMore(" + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } }; class zero_or_more_parser : public parser_base { parser parser_; - friend class gbnf_visitor; - public: zero_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -341,23 +303,14 @@ class zero_or_more_parser : public parser_base { return "ZeroOrMore(" + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } }; class optional_parser : public parser_base { parser parser_; - friend class gbnf_visitor; - public: optional_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -389,24 +342,15 @@ class optional_parser : public parser_base { return "Optional(" + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } }; class until_parser : public parser_base { std::string delimiter_; parser parser_; - friend class gbnf_visitor; - public: until_parser(const std::string & delimiter, bool include_spaces, int id, parser_builder & builder) : parser_base(id), delimiter_(delimiter) { @@ -434,21 +378,16 @@ class until_parser : public parser_base { return "Until(" + delimiter_ + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } + const std::string & delimiter() const { return delimiter_; } + + const parser & child() const { return parser_; } }; class not_parser : public parser_base { parser parser_; - friend class gbnf_visitor; - public: not_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} @@ -480,21 +419,12 @@ class not_parser : public parser_base { return "Not(" + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } - - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } }; class any_parser : public parser_base { - friend class gbnf_visitor; - public: any_parser(int id) : parser_base(id) {} @@ -520,12 +450,10 @@ class any_parser : public parser_base { return "Any"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; }; class space_parser : public parser_base { - friend class gbnf_visitor; - public: space_parser(int id) : parser_base(id) {} @@ -554,7 +482,7 @@ class space_parser : public parser_base { return "Space"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; }; class char_class_parser : public parser_base { @@ -569,8 +497,6 @@ class char_class_parser : public parser_base { std::vector ranges_; bool negated_; - friend class gbnf_visitor; - public: char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes), negated_(false) { std::string content = classes; @@ -660,15 +586,15 @@ class char_class_parser : public parser_base { return "Char(" + pattern_ + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; + + const std::string & pattern() const { return pattern_; } }; class group_parser : public parser_base { std::string name_; parser parser_; - friend class gbnf_visitor; - public: group_parser(const std::string & name, const parser & parser, int id) : parser_base(id), name_(name), parser_(parser) {} @@ -686,14 +612,9 @@ class group_parser : public parser_base { return "Group(" + name_ + ", " + parser_->dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - parser_->assign_ids_internal(next_id); - } + const parser & child() const { return parser_; } }; class schema_parser : public parser_base { @@ -701,8 +622,6 @@ class schema_parser : public parser_base { std::string name_; nlohmann::ordered_json schema_; - friend class gbnf_visitor; - public: schema_parser(const parser & parser, const std::string & name, const nlohmann::ordered_json & schema, int id) : parser_base(id), parser_(parser), name_(name), schema_(schema) {} @@ -717,18 +636,22 @@ class schema_parser : public parser_base { return "Schema(" + parser_->dump() + ", " + schema_.dump() + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; + + const parser & child() const { return parser_; } + + const std::string & name() const { return name_; } + + const nlohmann::ordered_json & schema() const { return schema_; } }; class rule_parser : public parser_base { - std::string rule_name_; + std::string name_; std::weak_ptr> rules_; - friend class gbnf_visitor; - public: - rule_parser(const std::string & name, std::shared_ptr> rules, int id) - : parser_base(id), rule_name_(name), rules_(rules) {} + rule_parser(const std::string & name, const std::shared_ptr> & rules, int id) + : parser_base(id), name_(name), rules_(rules) {} parser_type type() const override { return PARSER_RULE; } @@ -744,9 +667,9 @@ class rule_parser : public parser_base { return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } - auto it = rules->find(rule_name_); + auto it = rules->find(name_); if (it == rules->end()) { - LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", rule_name_.c_str()); + LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); } @@ -755,17 +678,19 @@ class rule_parser : public parser_base { } std::string dump() const override { - return "Rule(" + rule_name_ + ")"; + return "Rule(" + name_ + ")"; } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; + + const std::string & name() const { return name_; } }; class root_parser : public parser_base { parser root_; std::shared_ptr> rules_; - friend class gbnf_visitor; + friend class parser_visitor; public: root_parser(const parser & root, std::shared_ptr> rules, int id) @@ -781,23 +706,45 @@ class root_parser : public parser_base { return root_->dump(); } - std::string accept(gbnf_visitor & visitor) const override; + void accept(parser_visitor & visitor) override; - void assign_ids_internal(int& next_id) override { - if (id_ == -1) { - id_ = next_id++; - } - root_->assign_ids_internal(next_id); - } + const parser & root() const { return root_; } + + std::shared_ptr> rules() const { return rules_; } +}; + +// Base visitor class for parser tree traversal +class parser_visitor { + public: + virtual ~parser_visitor() = default; + + virtual void visit(literal_parser & p) = 0; + virtual void visit(sequence_parser & p) = 0; + virtual void visit(choice_parser & p) = 0; + virtual void visit(one_or_more_parser & p) = 0; + virtual void visit(zero_or_more_parser & p) = 0; + virtual void visit(optional_parser & p) = 0; + virtual void visit(until_parser & p) = 0; + virtual void visit(not_parser & p) = 0; + virtual void visit(any_parser & p) = 0; + virtual void visit(space_parser & p) = 0; + virtual void visit(char_class_parser & p) = 0; + virtual void visit(group_parser & p) = 0; + virtual void visit(schema_parser & p) = 0; + virtual void visit(rule_parser & p) = 0; + virtual void visit(root_parser & p) = 0; }; -class gbnf_visitor { +class gbnf_visitor : public parser_visitor { common_grammar_builder& builder_; std::unordered_map rule_name_mapping_; + std::string current_result_; public: gbnf_visitor(common_grammar_builder& builder) : builder_(builder) {} + const std::string& result() const { return current_result_; } + private: // Escape special characters for GBNF literals static std::string escape_literal(const std::string & s) { @@ -872,187 +819,235 @@ class gbnf_visitor { } public: - std::string visit(const literal_parser & p) { - return "\"" + escape_literal(p.literal_) + "\""; + void visit(literal_parser & p) override { + current_result_ = "\"" + escape_literal(p.literal()) + "\""; } - std::string visit(const sequence_parser & p) { + void visit(sequence_parser & p) override { std::string s; - for (size_t i = 0; i < p.parsers_.size(); ++i) { - if (i > 0) s += " "; - auto child_result = p.parsers_[i]->accept(*this); - s += child_result; + for (const auto & child : p.parsers()) { + if (!s.empty()) { + s += " "; + } + child->accept(*this); + s += current_result_; } - return s; + current_result_ = s; } - std::string visit(const choice_parser & p) { + void visit(choice_parser & p) override { std::string s; - for (size_t i = 0; i < p.parsers_.size(); ++i) { - if (i > 0) { + for (const auto & child : p.parsers()) { + if (!s.empty()) { s += " | "; } - auto child_type = p.parsers_[i]->type(); - auto child_result = p.parsers_[i]->accept(*this); + child->accept(*this); // Parenthesize sequences in choices - if (child_type == PARSER_SEQUENCE) { - s += "(" + child_result + ")"; + if (child->type() == PARSER_SEQUENCE) { + s += "(" + current_result_ + ")"; } else { - s += child_result; + s += current_result_; } } - return s; + current_result_ = s; } - std::string visit(const one_or_more_parser & p) { - auto child_type = p.parser_->type(); - auto child_result = p.parser_->accept(*this); - if (needs_parens(child_type)) { - return "(" + child_result + ")+"; + void visit(one_or_more_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")+"; + } else { + current_result_ = current_result_ + "+"; } - return child_result + "+"; } - std::string visit(const zero_or_more_parser & p) { - auto child_type = p.parser_->type(); - auto child_result = p.parser_->accept(*this); - if (needs_parens(child_type)) { - return "(" + child_result + ")*"; + void visit(zero_or_more_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")*"; + } else { + current_result_ = current_result_ + "*"; } - return child_result + "*"; } - std::string visit(const optional_parser & p) { - auto child_type = p.parser_->type(); - auto child_result = p.parser_->accept(*this); - if (needs_parens(child_type)) { - return "(" + child_result + ")?"; + void visit(optional_parser & p) override { + p.child()->accept(*this); + if (needs_parens(p.child()->type())) { + current_result_ = "(" + current_result_ + ")?"; + } else { + current_result_ = current_result_ + "?"; } - return child_result + "?"; } - std::string visit(const until_parser & p) { + void visit(until_parser & p) override { // Generate pattern that matches prefixes but prevents full delimiter match - return generate_until_pattern(p.delimiter_) + "*"; + current_result_ = generate_until_pattern(p.delimiter()) + "*"; } - std::string visit(const not_parser &) { + void visit(not_parser &) override { // NOT is tricky in GBNF - for now, emit error LOG_ERR("NOT operator not directly supported in GBNF generation\n"); - return ""; // This will cause compilation errors, which is intended + current_result_ = ""; } - std::string visit(const any_parser &) { + void visit(any_parser &) override { // Match any single character - return "[\\x00-\\x{10FFFF}]"; + current_result_ = "[\\x00-\\x{10FFFF}]"; } - std::string visit(const space_parser &) { + void visit(space_parser &) override { // Reference the built-in space rule - return "space"; + current_result_ = "space"; } - std::string visit(const char_class_parser & p) { + void visit(char_class_parser & p) override { // Return pattern as-is (already in GBNF format) - return p.pattern_; + current_result_ = p.pattern(); } - std::string visit(const group_parser & p) { + void visit(group_parser & p) override { // Groups are transparent - just visit child - return p.parser_->accept(*this); + p.child()->accept(*this); } - std::string visit(const schema_parser & p) { - return builder_.add_schema(p.name_, p.schema_); + void visit(schema_parser & p) override { + current_result_ = builder_.add_schema(p.name(), p.schema()); } - std::string visit(const rule_parser & p) { + void visit(rule_parser & p) override { // Return canonical rule reference - auto it = rule_name_mapping_.find(p.rule_name_); + auto it = rule_name_mapping_.find(p.name()); if (it != rule_name_mapping_.end()) { - return it->second; + current_result_ = it->second; + } else { + // Fallback to original name if not in mapping (shouldn't happen in valid usage) + current_result_ = p.name(); } - // Fallback to original name if not in mapping (shouldn't happen in valid usage) - return p.rule_name_; } - std::string visit(const root_parser & p) { + void visit(root_parser & p) override { // Generate named rules first - if (p.rules_) { - for (const auto & [name, rule] : *p.rules_) { - auto rule_body = rule->accept(*this); + auto rules = p.rules(); + if (rules) { + for (const auto & [name, rule] : *rules) { + rule->accept(*this); + auto rule_body = current_result_; auto canonical_name = builder_.add_rule(name, rule_body); rule_name_mapping_[name] = canonical_name; } } // Return root body for composition - return p.root_->accept(*this); + p.root()->accept(*this); } }; -// Implement accept() methods for all parser classes -std::string literal_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} +// ID assignment visitor for assigning unique IDs to parsers +class id_assignment_visitor : public parser_visitor { + int & next_id_; -std::string sequence_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + public: + id_assignment_visitor(int & next_id) : next_id_(next_id) {} -std::string choice_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void assign_id(parser_base & p) { + if (p.id_ == -1) { + p.id_ = next_id_++; + } + } -std::string one_or_more_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(literal_parser & p) override { + assign_id(p); + } -std::string zero_or_more_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(any_parser & p) override { + assign_id(p); + } -std::string optional_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(space_parser & p) override { + assign_id(p); + } -std::string until_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(char_class_parser & p) override { + assign_id(p); + } -std::string not_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(schema_parser & p) override { + assign_id(p); + } -std::string any_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(rule_parser & p) override { + assign_id(p); + } -std::string space_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + // Composite parsers - assign ID and traverse children + void visit(sequence_parser & p) override { + assign_id(p); + for (const auto & child : p.parsers()) { + child->accept(*this); + } + } -std::string char_class_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(choice_parser & p) override { + assign_id(p); + for (const auto & child : p.parsers()) { + child->accept(*this); + } + } -std::string group_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(one_or_more_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } -std::string schema_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(zero_or_more_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } -std::string rule_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(optional_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } -std::string root_parser::accept(gbnf_visitor & visitor) const { - return visitor.visit(*this); -} + void visit(until_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } + + void visit(not_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } + + void visit(group_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } + + void visit(root_parser & p) override { + assign_id(p); + p.root()->accept(*this); + } +}; + +// Implement accept() methods for all parser classes +void literal_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void sequence_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void choice_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void one_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void zero_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void optional_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void until_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void not_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void any_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void space_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void char_class_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void group_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void schema_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void rule_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void root_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } std::optional parser_result::group(const std::string & name, std::string_view input) const { auto it = groups.find(name); @@ -1145,11 +1140,10 @@ std::string parser::dump() const { return ptr->dump(); } -void parser::build_grammar(common_grammar_builder& builder) const { +void parser::build_grammar(common_grammar_builder& builder) { gbnf_visitor visitor(builder); - auto result = ptr->accept(visitor); - // The visitor returns the GBNF string for this parser - // root_parser registers its named rules and returns its root body + ptr->accept(visitor); + auto result = visitor.result(); if (!result.empty()) { builder.add_rule("root", result); } @@ -1222,7 +1216,8 @@ parser parser_builder::add_rule(const std::string & name, const parser & p) { void parser_builder::assign_ids(parser & p) { if (p.ptr) { - p.ptr->assign_ids_internal(next_id_); + id_assignment_visitor visitor(next_id_); + p.ptr->accept(visitor); } } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 1ef3996dc862a..f0cb1d24ff7bb 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -101,7 +101,7 @@ class parser_base; class sequence_parser; class choice_parser; class parser_builder; -class gbnf_visitor; +class parser_visitor; class parser { std::shared_ptr ptr; @@ -136,7 +136,7 @@ class parser { parser_type type() const; parser_result parse(parser_context & ctx, size_t start = 0) const; std::string dump() const; - void build_grammar(common_grammar_builder& builder) const; + void build_grammar(common_grammar_builder& builder); }; class parser_builder { From 76cf0b5b6197d427e3c48aa4d24f549a3d3a4167 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 20:21:04 -0600 Subject: [PATCH 07/34] fix const ref for grammar builder --- common/chat-parser-combinator.cpp | 6 +-- common/chat-parser-combinator.h | 3 +- tests/test-chat-parser-combinator.cpp | 69 ++++++++++++++++----------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 598fd74a93657..aff72b67fd68d 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -736,12 +736,12 @@ class parser_visitor { }; class gbnf_visitor : public parser_visitor { - common_grammar_builder& builder_; + const common_grammar_builder & builder_; std::unordered_map rule_name_mapping_; std::string current_result_; public: - gbnf_visitor(common_grammar_builder& builder) : builder_(builder) {} + gbnf_visitor(const common_grammar_builder & builder) : builder_(builder) {} const std::string& result() const { return current_result_; } @@ -1140,7 +1140,7 @@ std::string parser::dump() const { return ptr->dump(); } -void parser::build_grammar(common_grammar_builder& builder) { +void parser::build_grammar(const common_grammar_builder & builder) { gbnf_visitor visitor(builder); ptr->accept(visitor); auto result = visitor.result(); diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index f0cb1d24ff7bb..e56a6adf24c17 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -136,7 +136,8 @@ class parser { parser_type type() const; parser_result parse(parser_context & ctx, size_t start = 0) const; std::string dump() const; - void build_grammar(common_grammar_builder& builder); + + void build_grammar(const common_grammar_builder & builder); }; class parser_builder { diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 1fffbfbf040a0..e4f637af9f797 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -518,8 +518,8 @@ static void test_complete_example() { assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); assert_equals(std::string(R"({"cit)"), *result.group("tool-args", ctx.input)); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); std::cout << "Grammar:\n" << gbnf << "\n"; @@ -532,9 +532,10 @@ static void test_gbnf_generation() { return p.literal("hello"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); assert_equals(true, gbnf.find("space ::=") != std::string::npos); } @@ -544,9 +545,10 @@ static void test_gbnf_generation() { return p.char_class("[a-z]"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= [a-z]") != std::string::npos); } { @@ -555,9 +557,10 @@ static void test_gbnf_generation() { return p.literal("hello") + p.literal(" ") + p.literal("world"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\" \" \" \"world\"") != std::string::npos); } { @@ -566,9 +569,10 @@ static void test_gbnf_generation() { return p.literal("cat") | p.literal("dog"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"cat\" | \"dog\"") != std::string::npos); } { @@ -577,9 +581,10 @@ static void test_gbnf_generation() { return p.one_or_more(p.char_class("[0-9]")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= [0-9]+") != std::string::npos); } { @@ -588,9 +593,10 @@ static void test_gbnf_generation() { return p.zero_or_more(p.char_class("[a-z]")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= [a-z]*") != std::string::npos); } { @@ -599,9 +605,10 @@ static void test_gbnf_generation() { return p.literal("hello") + p.optional(p.literal(" world")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\" \" world\"?") != std::string::npos); } { @@ -610,9 +617,10 @@ static void test_gbnf_generation() { return p.until(""); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + // Should generate pattern that prevents matching the full delimiter assert_equals(true, gbnf.find("root ::= ([^<] | \"<\" [^/] | \"])*") != std::string::npos); } @@ -622,9 +630,10 @@ static void test_gbnf_generation() { return p.group("test", p.literal("hello")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); } { @@ -633,9 +642,10 @@ static void test_gbnf_generation() { return p.one_or_more(p.literal("a") | p.literal("b")); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= (\"a\" | \"b\")+") != std::string::npos); } { @@ -645,9 +655,10 @@ static void test_gbnf_generation() { return p.one_or_more(digit); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + // Should have digit rule defined and referenced assert_equals(true, gbnf.find("digit ::= [0-9]") != std::string::npos); assert_equals(true, gbnf.find("root ::= digit+") != std::string::npos); @@ -658,9 +669,10 @@ static void test_gbnf_generation() { return p.literal("hello\nworld\t!"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + assert_equals(true, gbnf.find("root ::= \"hello\\nworld\\t!\"") != std::string::npos); } { @@ -669,9 +681,10 @@ static void test_gbnf_generation() { return p.literal("hello") << p.literal("world"); }); - auto gbnf = ::build_grammar([&](const common_grammar_builder& builder) { - parser.build_grammar(const_cast(builder)); + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); }); + // Should inline the whitespace pattern assert_equals(true, gbnf.find("\"hello\"") != std::string::npos); assert_equals(true, gbnf.find("\"world\"") != std::string::npos); From 9c7b3e8bcf57ea416d21a90d219c39d06e16b426 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 20:33:26 -0600 Subject: [PATCH 08/34] clean up types, friend classes, and class declarations --- common/chat-parser-combinator.cpp | 76 +++++++++++++++---------------- common/chat-parser-combinator.h | 39 ++-------------- 2 files changed, 43 insertions(+), 72 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index aff72b67fd68d..56215302df99e 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -8,7 +8,24 @@ #include #include -class id_assignment_visitor; +enum parser_type { + PARSER_LITERAL = 0, + PARSER_SEQUENCE = 1, + PARSER_CHOICE = 2, + PARSER_ZERO_OR_MORE = 3, + PARSER_ONE_OR_MORE = 4, + PARSER_NOT = 5, + PARSER_ANY = 6, + PARSER_CHAR_CLASS = 7, + PARSER_GROUP = 8, + PARSER_RULE = 9, + PARSER_OPTIONAL = 10, + PARSER_UNTIL = 11, + PARSER_SPACE = 12, + PARSER_SCHEMA = 13, +}; + +class parser_visitor; static parser json_parser(); @@ -16,12 +33,13 @@ class parser_base { protected: int id_; - friend class id_assignment_visitor; - public: parser_base(int id) : id_(id) {} virtual ~parser_base() = default; + int id() const { return id_; } + void set_id(int id) { id_ = id; } + virtual parser_type type() const = 0; virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; virtual std::string dump() const = 0; @@ -77,9 +95,10 @@ class sequence_parser : public parser_base { public: sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { - if (p.is_sequence()) { + if (p->type() == PARSER_SEQUENCE) { // Flatten sequences - for (const auto & embedded : p.to_sequence()->parsers()) { + auto seq = std::static_pointer_cast(p.ptr()); + for (const auto & embedded : seq->parsers()) { parsers_.push_back(embedded); } } else { @@ -143,9 +162,10 @@ class choice_parser : public parser_base { public: choice_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { - if (p.is_choice()) { + if (p->type() == PARSER_CHOICE) { // Flatten choices - for (const auto & embedded : p.to_choice()->parsers()) { + auto choice = std::static_pointer_cast(p.ptr()); + for (const auto & embedded : choice->parsers()) { parsers_.push_back(embedded); } } else { @@ -952,8 +972,8 @@ class id_assignment_visitor : public parser_visitor { id_assignment_visitor(int & next_id) : next_id_(next_id) {} void assign_id(parser_base & p) { - if (p.id_ == -1) { - p.id_ = next_id_++; + if (p.id() == -1) { + p.set_id(next_id_++); } } @@ -1085,7 +1105,7 @@ void parse_cache::clear() { parser::parser() {} -parser::parser(std::shared_ptr parser) : ptr(std::move(parser)) {} +parser::parser(std::shared_ptr parser) : ptr_(std::move(parser)) {} parser parser::operator~() const { return parser(std::make_shared(*this, -1)); @@ -1105,44 +1125,24 @@ parser parser::operator<<(const parser & other) const { } parser_base & parser::operator*() const { - return *ptr; + return *ptr_; } parser_base * parser::operator->() const { - return ptr.get(); -} - -bool parser::is_sequence() const { - return ptr->type() == PARSER_SEQUENCE; -} - -std::shared_ptr parser::to_sequence() const { - return std::dynamic_pointer_cast(ptr); -} - -bool parser::is_choice() const { - return ptr->type() == PARSER_CHOICE; -} - -std::shared_ptr parser::to_choice() const { - return std::dynamic_pointer_cast(ptr); -} - -parser_type parser::type() const { - return ptr->type(); + return ptr_.get(); } parser_result parser::parse(parser_context & ctx, size_t start) const { - return ptr->parse(ctx, start); + return ptr_->parse(ctx, start); } std::string parser::dump() const { - return ptr->dump(); + return ptr_->dump(); } -void parser::build_grammar(const common_grammar_builder & builder) { +void parser::build_grammar(const common_grammar_builder & builder) const { gbnf_visitor visitor(builder); - ptr->accept(visitor); + ptr_->accept(visitor); auto result = visitor.result(); if (!result.empty()) { builder.add_rule("root", result); @@ -1215,9 +1215,9 @@ parser parser_builder::add_rule(const std::string & name, const parser & p) { } void parser_builder::assign_ids(parser & p) { - if (p.ptr) { + if (p.ptr()) { id_assignment_visitor visitor(next_id_); - p.ptr->accept(visitor); + p.ptr()->accept(visitor); } } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index e56a6adf24c17..6c7b86d4e04e9 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -11,23 +11,6 @@ struct common_grammar_builder; -enum parser_type { - PARSER_LITERAL = 0, - PARSER_SEQUENCE = 1, - PARSER_CHOICE = 2, - PARSER_ZERO_OR_MORE = 3, - PARSER_ONE_OR_MORE = 4, - PARSER_NOT = 5, - PARSER_ANY = 6, - PARSER_CHAR_CLASS = 7, - PARSER_GROUP = 8, - PARSER_RULE = 9, - PARSER_OPTIONAL = 10, - PARSER_UNTIL = 11, - PARSER_SPACE = 12, - PARSER_SCHEMA = 13, -}; - enum parser_result_type { PARSER_RESULT_FAIL = 0, PARSER_RESULT_NEED_MORE_INPUT = 1, @@ -89,8 +72,6 @@ class parse_cache { void clear(); }; -class parser; - struct parser_context { std::string_view input; parse_cache memo; @@ -98,15 +79,9 @@ struct parser_context { }; class parser_base; -class sequence_parser; -class choice_parser; -class parser_builder; -class parser_visitor; class parser { - std::shared_ptr ptr; - - friend class parser_builder; + std::shared_ptr ptr_; public: parser(); @@ -114,7 +89,7 @@ class parser { parser(const parser & other) = default; parser & operator=(const parser & other) { if (this != &other) { - ptr = other.ptr; + ptr_ = other.ptr_; } return *this; } @@ -127,17 +102,13 @@ class parser { parser_base & operator*() const; parser_base * operator->() const; - bool is_sequence() const; - std::shared_ptr to_sequence() const; - - bool is_choice() const; - std::shared_ptr to_choice() const; + std::shared_ptr ptr() const { return ptr_; } - parser_type type() const; parser_result parse(parser_context & ctx, size_t start = 0) const; + std::string dump() const; - void build_grammar(const common_grammar_builder & builder); + void build_grammar(const common_grammar_builder & builder) const; }; class parser_builder { From f02e2b06fa0ef29aa647060ec74f7f0c6224606b Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 20:47:43 -0600 Subject: [PATCH 09/34] remove builder usage from until_parser --- common/chat-parser-combinator.cpp | 84 ++++++++++++++++--------------- common/chat-parser-combinator.h | 2 +- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 56215302df99e..0081606516d40 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -367,44 +367,6 @@ class optional_parser : public parser_base { const parser & child() const { return parser_; } }; -class until_parser : public parser_base { - std::string delimiter_; - parser parser_; - - public: - until_parser(const std::string & delimiter, bool include_spaces, int id, parser_builder & builder) - : parser_base(id), delimiter_(delimiter) { - if (include_spaces) { - auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); - parser_ = builder.zero_or_more(builder.negate(ws + builder.literal(delimiter)) + builder.any()); - } else { - parser_ = builder.zero_or_more(builder.negate(builder.literal(delimiter)) + builder.any()); - } - } - - parser_type type() const override { return PARSER_UNTIL; } - - parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto result = parser_->parse(ctx, start); - return ctx.memo.set(id_, start, result); - } - - std::string dump() const override { - return "Until(" + delimiter_ + ")"; - } - - void accept(parser_visitor & visitor) override; - - const std::string & delimiter() const { return delimiter_; } - - const parser & child() const { return parser_; } -}; - class not_parser : public parser_base { parser parser_; @@ -637,6 +599,48 @@ class group_parser : public parser_base { const parser & child() const { return parser_; } }; +class until_parser : public parser_base { + std::string delimiter_; + parser parser_; + + public: + until_parser(const std::string & delimiter, bool consume_spaces, int id) + : parser_base(id), delimiter_(delimiter) { + + auto delim = parser(std::make_shared(delimiter, -1)); + auto any = parser(std::make_shared(-1)); + + if (consume_spaces) { + auto ws = parser(std::make_shared(-1)); + parser_ = parser(std::make_shared(~(ws + delim) + any, -1)); + } else { + parser_ = parser(std::make_shared(~delim + any, -1)); + } + } + + parser_type type() const override { return PARSER_UNTIL; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + auto cached = ctx.memo.get(id_, start); + if (cached != std::nullopt) { + return *cached; + } + + auto result = parser_->parse(ctx, start); + return ctx.memo.set(id_, start, result); + } + + std::string dump() const override { + return "Until(" + delimiter_ + ")"; + } + + void accept(parser_visitor & visitor) override; + + const std::string & delimiter() const { return delimiter_; } + + const parser & child() const { return parser_; } +}; + class schema_parser : public parser_base { parser parser_; std::string name_; @@ -1201,8 +1205,8 @@ parser parser_builder::space() { return parser(std::make_shared(next_id_++)); } -parser parser_builder::until(const std::string & delimiter, bool include_spaces) { - return parser(std::make_shared(delimiter, include_spaces, next_id_++, *this)); +parser parser_builder::until(const std::string & delimiter, bool consume_spaces) { + return parser(std::make_shared(delimiter, consume_spaces, next_id_++)); } parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 6c7b86d4e04e9..e5f508d963011 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -130,7 +130,7 @@ class parser_builder { parser group(const std::string & name, const parser & p); parser rule(const std::string & name); parser space(); - parser until(const std::string & delimiter, bool include_spaces = true); + parser until(const std::string & delimiter, bool consume_spaces = true); parser json(); parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); From 66cf038a37596bee9771142d9900c7aea35c0b32 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 21:08:17 -0600 Subject: [PATCH 10/34] Use a counter class to help assign rule ids --- common/chat-parser-combinator.cpp | 74 +++++++++++++++---------------- common/chat-parser-combinator.h | 10 ++++- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 0081606516d40..998e5511d857a 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -27,8 +27,6 @@ enum parser_type { class parser_visitor; -static parser json_parser(); - class parser_base { protected: int id_; @@ -970,14 +968,14 @@ class gbnf_visitor : public parser_visitor { // ID assignment visitor for assigning unique IDs to parsers class id_assignment_visitor : public parser_visitor { - int & next_id_; + std::shared_ptr counter_; public: - id_assignment_visitor(int & next_id) : next_id_(next_id) {} + id_assignment_visitor(const std::shared_ptr & counter) : counter_(counter) {} void assign_id(parser_base & p) { if (p.id() == -1) { - p.set_id(next_id_++); + p.set_id(counter_->next()); } } @@ -1155,62 +1153,66 @@ void parser::build_grammar(const common_grammar_builder & builder) const { parser_builder::parser_builder() : rules_(std::make_shared>()) - , next_id_(0) {} + , counter_(std::make_shared(0)) {} + +parser_builder::parser_builder(std::shared_ptr counter) + : rules_(std::make_shared>()) + , counter_(std::move(counter)) {} parser parser_builder::literal(const std::string & literal) { - return parser(std::make_shared(literal, next_id_++)); + return parser(std::make_shared(literal, counter_->next())); } parser parser_builder::sequence(std::initializer_list parsers) { - return parser(std::make_shared(parsers, next_id_++)); + return parser(std::make_shared(parsers, counter_->next())); } parser parser_builder::choice(std::initializer_list parsers) { - return parser(std::make_shared(parsers, next_id_++)); + return parser(std::make_shared(parsers, counter_->next())); } parser parser_builder::one_or_more(const parser & p) { - return parser(std::make_shared(p, next_id_++)); + return parser(std::make_shared(p, counter_->next())); } parser parser_builder::zero_or_more(const parser & p) { - return parser(std::make_shared(p, next_id_++)); + return parser(std::make_shared(p, counter_->next())); } parser parser_builder::optional(const parser & p) { - return parser(std::make_shared(p, next_id_++)); + return parser(std::make_shared(p, counter_->next())); } parser parser_builder::negate(const parser & p) { - return parser(std::make_shared(p, next_id_++)); + return parser(std::make_shared(p, counter_->next())); } parser parser_builder::any() { - return parser(std::make_shared(next_id_++)); + return parser(std::make_shared(counter_->next())); } parser parser_builder::char_class(const std::string & classes) { - return parser(std::make_shared(classes, next_id_++)); + return parser(std::make_shared(classes, counter_->next())); } parser parser_builder::group(const std::string & name, const parser & p) { - return parser(std::make_shared(name, p, next_id_++)); + return parser(std::make_shared(name, p, counter_->next())); } parser parser_builder::rule(const std::string & name) { - return parser(std::make_shared(name, rules_, next_id_++)); + return parser(std::make_shared(name, rules_, counter_->next())); } parser parser_builder::space() { - return parser(std::make_shared(next_id_++)); + return parser(std::make_shared(counter_->next())); } parser parser_builder::until(const std::string & delimiter, bool consume_spaces) { - return parser(std::make_shared(delimiter, consume_spaces, next_id_++)); + return parser(std::make_shared(delimiter, consume_spaces, counter_->next())); } parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { - return parser(std::make_shared(p, name, schema, next_id_++)); + return parser(std::make_shared(p, name, schema, counter_->next())); } parser parser_builder::add_rule(const std::string & name, const parser & p) { @@ -1220,7 +1222,7 @@ parser parser_builder::add_rule(const std::string & name, const parser & p) { void parser_builder::assign_ids(parser & p) { if (p.ptr()) { - id_assignment_visitor visitor(next_id_); + id_assignment_visitor visitor(counter_); p.ptr()->accept(visitor); } } @@ -1238,8 +1240,8 @@ parser build_parser(const std::function & fn) { return root; } -static parser json_parser() { - parser_builder builder; +static parser json_parser(std::shared_ptr counter) { + parser_builder builder(std::move(counter)); // Whitespace: space, tab, newline, carriage return auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); @@ -1280,17 +1282,6 @@ static parser json_parser() { auto false_lit = builder.literal("false"); auto null_lit = builder.literal("null"); - // Value - uses forward references for recursive structures - builder.add_rule("json_value", - builder.rule("json_object") | - builder.rule("json_array") | - builder.rule("json_string") | - builder.rule("json_number") | - true_lit | - false_lit | - null_lit - ); - // Object: { "key": value, ... } auto member = builder.rule("json_string") + ws + builder.literal(":") + ws + builder.rule("json_value"); auto members = member + builder.zero_or_more(ws + builder.literal(",") + ws + member); @@ -1310,14 +1301,21 @@ static parser json_parser() { builder.add_rule("json_array", array); - // Get the json_value rule as the root - auto root = builder.rule("json_value"); - builder.assign_ids(root); + // Value - uses forward references for recursive structures + auto root = builder.add_rule("json_value", + builder.rule("json_object") | + builder.rule("json_array") | + builder.rule("json_string") | + builder.rule("json_number") | + true_lit | + false_lit | + null_lit + ); // Wrap in root_parser to own the rules return parser(std::make_shared(root, builder.rules(), -1)); } parser parser_builder::json() { - return json_parser(); + return json_parser(counter_); } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index e5f508d963011..56a522394bf75 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -111,12 +111,20 @@ class parser { void build_grammar(const common_grammar_builder & builder) const; }; +class parser_id_counter { + int next_id_; + public: + parser_id_counter(int start) : next_id_(start) {} + int next() { return next_id_++; } +}; + class parser_builder { std::shared_ptr> rules_; - int next_id_; + std::shared_ptr counter_; public: parser_builder(); + parser_builder(std::shared_ptr counter); parser literal(const std::string & literal); parser sequence(std::initializer_list parsers); From 2b3caefde82282c5933c9ad53014c336184cedfb Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 21:24:05 -0600 Subject: [PATCH 11/34] cache everything --- common/chat-parser-combinator.cpp | 419 ++++++++++++++---------------- common/chat-parser-combinator.h | 2 + 2 files changed, 194 insertions(+), 227 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 998e5511d857a..fb3b0fb083a0a 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -53,29 +53,26 @@ class literal_parser : public parser_base { parser_type type() const override { return PARSER_LITERAL; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto pos = start; - for (auto i = 0u; i < literal_.size(); ++i) { - if (pos >= ctx.input.size()) { - if (ctx.input_is_complete) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return ctx.memo.cached(id_, start, [&]() { + auto pos = start; + for (auto i = 0u; i < literal_.size(); ++i) { + if (pos >= ctx.input.size()) { + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + if (i > 0) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + return parser_result(PARSER_RESULT_FAIL, start); } - if (i > 0) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + if (ctx.input[pos] != literal_[i]) { + return parser_result(PARSER_RESULT_FAIL, start); } - return parser_result(PARSER_RESULT_FAIL, start); - } - if (ctx.input[pos] != literal_[i]) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + ++pos; } - ++pos; - } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + }); } std::string dump() const override { @@ -108,36 +105,33 @@ class sequence_parser : public parser_base { parser_type type() const override { return PARSER_SEQUENCE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - std::unordered_map groups; - - auto pos = start; - for (const auto & p : parsers_) { - auto result = p->parse(ctx, pos); - - // Copy groups - groups.insert(result.groups.begin(), result.groups.end()); + return ctx.memo.cached(id_, start, [&]() { + std::unordered_map groups; + + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); + + // Copy groups + groups.insert(result.groups.begin(), result.groups.end()); + + if (result.is_fail()) { + if (result.end >= ctx.input.size() && !ctx.input_is_complete) { + // If we fail because we don't have enough input, then return success + return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); + } + return parser_result(PARSER_RESULT_FAIL, start, result.end, groups); + } - if (result.is_fail()) { - if (result.end >= ctx.input.size() && !ctx.input_is_complete) { - // If we fail because we don't have enough input, then return success - return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start, result.end, groups)); - } - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); + pos = result.end; } - pos = result.end; - } - - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + }); } std::string dump() const override { @@ -175,25 +169,22 @@ class choice_parser : public parser_base { parser_type type() const override { return PARSER_CHOICE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } + return ctx.memo.cached(id_, start, [&]() { + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); - auto pos = start; - for (const auto & p : parsers_) { - auto result = p->parse(ctx, pos); - - if (result.is_success()) { - return ctx.memo.set(id_, start, result); - } + if (result.is_success()) { + return result; + } - if (result.is_need_more_input()) { - return result; + if (result.is_need_more_input()) { + return result; + } } - } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return parser_result(PARSER_RESULT_FAIL, start); + }); } std::string dump() const override { @@ -219,49 +210,41 @@ class one_or_more_parser : public parser_base { parser_type type() const override { return PARSER_ONE_OR_MORE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - std::unordered_map groups; - - // We can't return back the cached result, since there may be more - // repetitions since the last parsing attempt. Instead, resume parsing from - // the last successful repetition found. - auto pos = start; - if (cached != std::nullopt) { - pos = cached->end; - groups.insert(cached->groups.begin(), cached->groups.end()); - } + return ctx.memo.cached(id_, start, [&]() { + std::unordered_map groups; - if (pos == start) { - auto first_result = parser_->parse(ctx, pos); + // Parse at least once + auto first_result = parser_->parse(ctx, start); if (!first_result.is_success()) { return first_result; } - pos = first_result.end; + auto pos = first_result.end; groups.insert(first_result.groups.begin(), first_result.groups.end()); - } - for (;;) { - auto result = parser_->parse(ctx, pos); - groups.insert(result.groups.begin(), result.groups.end()); + // Parse zero or more additional times + for (;;) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); - } + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + } - if (result.is_fail()) { - // Done with repetitions - break; - } + if (result.is_fail()) { + // Done with repetitions + break; + } - if (result.end == pos) { - break; // Prevent an infinite loop - } + if (result.end == pos) { + break; // Prevent an infinite loop + } - pos = result.end; - } + pos = result.end; + } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + }); } std::string dump() const override { @@ -282,39 +265,32 @@ class zero_or_more_parser : public parser_base { parser_type type() const override { return PARSER_ZERO_OR_MORE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - std::unordered_map groups; - - // We can't return back the cached result, since there may be more - // repetitions since the last parsing attempt. Instead, resume parsing from - // the last successful repetition found. - auto pos = start; - if (cached != std::nullopt) { - pos = cached->end; - groups.insert(cached->groups.begin(), cached->groups.end()); - } + return ctx.memo.cached(id_, start, [&]() { + std::unordered_map groups; + auto pos = start; - for (;;) { - auto result = parser_->parse(ctx, pos); - groups.insert(result.groups.begin(), result.groups.end()); + for (;;) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); - } + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + } - if (result.is_fail()) { - // Done with repetitions (zero or more is always valid) - break; - } + if (result.is_fail()) { + // Done with repetitions (zero or more is always valid) + break; + } - if (result.end == pos) { - break; // Prevent an infinite loop - } + if (result.end == pos) { + break; // Prevent an infinite loop + } - pos = result.end; - } + pos = result.end; + } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos, groups)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + }); } std::string dump() const override { @@ -335,25 +311,22 @@ class optional_parser : public parser_base { parser_type type() const override { return PARSER_OPTIONAL; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto result = parser_->parse(ctx, start); + return ctx.memo.cached(id_, start, [&]() { + auto result = parser_->parse(ctx, start); - if (result.is_success()) { - // Matched successfully - return ctx.memo.set(id_, start, result); - } + if (result.is_success()) { + // Matched successfully + return result; + } - if (result.is_need_more_input()) { - // Propagate - need more input to determine if optional matches - return result; - } + if (result.is_need_more_input()) { + // Propagate - need more input to determine if optional matches + return result; + } - // No match, but optional always succeeds with zero matches - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start)); + // No match, but optional always succeeds with zero matches + return parser_result(PARSER_RESULT_SUCCESS, start, start); + }); } std::string dump() const override { @@ -374,25 +347,22 @@ class not_parser : public parser_base { parser_type type() const override { return PARSER_NOT; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto result = parser_->parse(ctx, start); + return ctx.memo.cached(id_, start, [&]() { + auto result = parser_->parse(ctx, start); - if (result.is_success()) { - // Fail if the underlying parser matches - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); - } + if (result.is_success()) { + // Fail if the underlying parser matches + return parser_result(PARSER_RESULT_FAIL, start); + } - if (result.is_need_more_input()) { - // Propagate - need to know what child would match before negating - return result; - } + if (result.is_need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } - // Child failed, so negation succeeds - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start)); + // Child failed, so negation succeeds + return parser_result(PARSER_RESULT_SUCCESS, start); + }); } std::string dump() const override { @@ -411,19 +381,16 @@ class any_parser : public parser_base { parser_type type() const override { return PARSER_ANY; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - if (start >= ctx.input.size()) { - if (ctx.input_is_complete) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return ctx.memo.cached(id_, start, [&]() { + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_FAIL, start); } - return parser_result(PARSER_RESULT_FAIL, start); - } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); + return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); + }); } std::string dump() const override { @@ -440,22 +407,19 @@ class space_parser : public parser_base { parser_type type() const override { return PARSER_SPACE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto pos = start; - while (pos < ctx.input.size()) { - char c = ctx.input[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - ++pos; - } else { - break; + return ctx.memo.cached(id_, start, [&]() { + auto pos = start; + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + if (c == ' ' || c == '\t' || c == '\n') { + ++pos; + } else { + break; + } } - } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, pos)); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + }); } std::string dump() const override { @@ -530,36 +494,33 @@ class char_class_parser : public parser_base { parser_type type() const override { return PARSER_CHAR_CLASS; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - if (start >= ctx.input.size()) { - if (ctx.input_is_complete) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return ctx.memo.cached(id_, start, [&]() { + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_FAIL, start); } - return parser_result(PARSER_RESULT_FAIL, start); - } - bool matches = false; - for (const auto & range : ranges_) { - if (range.contains(ctx.input[start])) { - matches = true; - break; + bool matches = false; + for (const auto & range : ranges_) { + if (range.contains(ctx.input[start])) { + matches = true; + break; + } } - } - // If negated, invert the match result - if (negated_) { - matches = !matches; - } + // If negated, invert the match result + if (negated_) { + matches = !matches; + } - if (matches) { - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_SUCCESS, start, start + 1)); - } + if (matches) { + return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); + } - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); + return parser_result(PARSER_RESULT_FAIL, start); + }); } std::string dump() const override { @@ -581,11 +542,13 @@ class group_parser : public parser_base { parser_type type() const override { return PARSER_GROUP; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto result = parser_->parse(ctx, start); + return ctx.memo.cached(id_, start, [&]() { + auto result = parser_->parse(ctx, start); - // Store result - result.groups[name_] = parser_match_location{result.start, result.end}; - return ctx.memo.set(id_, start, result); + // Store result + result.groups[name_] = parser_match_location{result.start, result.end}; + return result; + }); } std::string dump() const override { @@ -619,13 +582,9 @@ class until_parser : public parser_base { parser_type type() const override { return PARSER_UNTIL; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto result = parser_->parse(ctx, start); - return ctx.memo.set(id_, start, result); + return ctx.memo.cached(id_, start, [&]() { + return parser_->parse(ctx, start); + }); } std::string dump() const override { @@ -651,7 +610,9 @@ class schema_parser : public parser_base { parser_type type() const override { return PARSER_SCHEMA; } parser_result parse(parser_context & ctx, size_t start = 0) override { - return parser_->parse(ctx, start); + return ctx.memo.cached(id_, start, [&]() { + return parser_->parse(ctx, start); + }); } std::string dump() const override { @@ -678,25 +639,21 @@ class rule_parser : public parser_base { parser_type type() const override { return PARSER_RULE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - auto cached = ctx.memo.get(id_, start); - if (cached != std::nullopt) { - return *cached; - } - - auto rules = rules_.lock(); - if (!rules) { - LOG_ERR("rule_parser::parse called with expired rule registry\n"); - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); - } + return ctx.memo.cached(id_, start, [&]() { + auto rules = rules_.lock(); + if (!rules) { + LOG_ERR("rule_parser::parse called with expired rule registry\n"); + return parser_result(PARSER_RESULT_FAIL, start); + } - auto it = rules->find(name_); - if (it == rules->end()) { - LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); - return ctx.memo.set(id_, start, parser_result(PARSER_RESULT_FAIL, start)); - } + auto it = rules->find(name_); + if (it == rules->end()) { + LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); + return parser_result(PARSER_RESULT_FAIL, start); + } - auto result = it->second->parse(ctx, start); - return ctx.memo.set(id_, start, result); + return it->second->parse(ctx, start); + }); } std::string dump() const override { @@ -1105,6 +1062,14 @@ void parse_cache::clear() { results.clear(); } +parser_result parse_cache::cached(int id, size_t start, const std::function & fn) { + auto result = get(id, start); + if (result) { + return *result; + } + return set(id, start, fn()); +} + parser::parser() {} parser::parser(std::shared_ptr parser) : ptr_(std::move(parser)) {} diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 56a522394bf75..4dbbca3f6d1c5 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -70,6 +70,8 @@ class parse_cache { parser_result set(int id, size_t start, parser_result result); std::optional get(int id, size_t start); void clear(); + + parser_result cached(int id, size_t start, const std::function & fn); }; struct parser_context { From adac6bae7f8a53493a192f72cdb7302ef1cf7f62 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 21:32:29 -0600 Subject: [PATCH 12/34] add short description for each parser --- common/chat-parser-combinator.cpp | 30 +++++++++++++++++++++ common/chat-parser-combinator.h | 44 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index fb3b0fb083a0a..2b46a097d7a23 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -44,6 +44,8 @@ class parser_base { virtual void accept(parser_visitor & visitor) = 0; }; +// Matches an exact literal string. +// S -> "hello" class literal_parser : public parser_base { std::string literal_; @@ -84,6 +86,8 @@ class literal_parser : public parser_base { const std::string & literal() const { return literal_; } }; +// Matches a sequence of parsers in order, all must succeed. +// S -> A B C class sequence_parser : public parser_base { std::vector parsers_; @@ -148,6 +152,8 @@ class sequence_parser : public parser_base { const std::vector & parsers() const { return parsers_; } }; +// Matches the first parser that succeeds from a list of alternatives. +// S -> A | B | C class choice_parser : public parser_base { std::vector parsers_; @@ -201,6 +207,8 @@ class choice_parser : public parser_base { const std::vector & parsers() const { return parsers_; } }; +// Matches one or more repetitions of a parser. +// S -> A+ class one_or_more_parser : public parser_base { parser parser_; @@ -256,6 +264,8 @@ class one_or_more_parser : public parser_base { const parser & child() const { return parser_; } }; +// Matches zero or more repetitions of a parser, always succeeds. +// S -> A* class zero_or_more_parser : public parser_base { parser parser_; @@ -302,6 +312,8 @@ class zero_or_more_parser : public parser_base { const parser & child() const { return parser_; } }; +// Matches zero or one occurrence of a parser, always succeeds. +// S -> A? class optional_parser : public parser_base { parser parser_; @@ -338,6 +350,8 @@ class optional_parser : public parser_base { const parser & child() const { return parser_; } }; +// Negative lookahead: succeeds if child parser fails, consumes no input. +// S -> !A class not_parser : public parser_base { parser parser_; @@ -374,6 +388,8 @@ class not_parser : public parser_base { const parser & child() const { return parser_; } }; +// Matches any single character. +// S -> . class any_parser : public parser_base { public: any_parser(int id) : parser_base(id) {} @@ -400,6 +416,8 @@ class any_parser : public parser_base { void accept(parser_visitor & visitor) override; }; +// Matches zero or more whitespace characters (space, tab, newline). +// S -> [ \t\n]* class space_parser : public parser_base { public: space_parser(int id) : parser_base(id) {} @@ -429,6 +447,8 @@ class space_parser : public parser_base { void accept(parser_visitor & visitor) override; }; +// Matches a single character from a character class or range. +// S -> [a-z] or S -> [^0-9] class char_class_parser : public parser_base { struct char_range { int start; @@ -532,6 +552,8 @@ class char_class_parser : public parser_base { const std::string & pattern() const { return pattern_; } }; +// Captures the matched text from a parser and stores it with a name. +// S -> class group_parser : public parser_base { std::string name_; parser parser_; @@ -560,6 +582,8 @@ class group_parser : public parser_base { const parser & child() const { return parser_; } }; +// Matches all characters until a delimiter is found (delimiter not consumed). +// S -> (!delim .)* class until_parser : public parser_base { std::string delimiter_; parser parser_; @@ -598,6 +622,8 @@ class until_parser : public parser_base { const parser & child() const { return parser_; } }; +// Wraps a parser with JSON schema metadata for grammar generation. +// Used internally to convert JSON schemas to GBNF grammar rules. class schema_parser : public parser_base { parser parser_; std::string name_; @@ -628,6 +654,8 @@ class schema_parser : public parser_base { const nlohmann::ordered_json & schema() const { return schema_; } }; +// References a named rule for recursive or reusable grammar definitions. +// expr -> term | expr "+" term class rule_parser : public parser_base { std::string name_; std::weak_ptr> rules_; @@ -665,6 +693,8 @@ class rule_parser : public parser_base { const std::string & name() const { return name_; } }; +// Container for the root parser and all named rules in the grammar. +// Manages ownership of rule registry to enable recursive grammar definitions. class root_parser : public parser_base { parser root_; std::shared_ptr> rules_; diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 4dbbca3f6d1c5..25ce7f7c11cb0 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -128,20 +128,64 @@ class parser_builder { parser_builder(); parser_builder(std::shared_ptr counter); + // Matches an exact literal string. + // S -> "hello" parser literal(const std::string & literal); + + // Matches a sequence of parsers in order, all must succeed. + // S -> A B C parser sequence(std::initializer_list parsers); + + // Matches the first parser that succeeds from a list of alternatives. + // S -> A | B | C parser choice(std::initializer_list parsers); + + // Matches one or more repetitions of a parser. + // S -> A+ parser one_or_more(const parser & p); + + // Matches zero or more repetitions of a parser, always succeeds. + // S -> A* parser zero_or_more(const parser & p); + + // Matches zero or one occurrence of a parser, always succeeds. + // S -> A? parser optional(const parser & p); + + // Negative lookahead: succeeds if child parser fails, consumes no input. + // S -> !A parser negate(const parser & p); + + // Matches any single character. + // S -> . parser any(); + + // Matches a single character from a character class or range. + // S -> [a-z] or S -> [^0-9] parser char_class(const std::string & classes); + + // Captures the matched text from a parser and stores it with a name. + // S -> parser group(const std::string & name, const parser & p); + + // References a named rule for recursive or reusable grammar definitions. + // expr -> term | expr "+" term parser rule(const std::string & name); + + // Matches zero or more whitespace characters (space, tab, newline). + // S -> [ \t\n]* parser space(); + + // Matches all characters until a delimiter is found (delimiter not consumed). + // S -> (!delim .)* parser until(const std::string & delimiter, bool consume_spaces = true); + + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. + // value -> object | array | string | number | true | false | null parser json(); + + // Wraps a parser with JSON schema metadata for grammar generation. + // Used internally to convert JSON schemas to GBNF grammar rules. parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); parser add_rule(const std::string & name, const parser & p); From 0be2a93eb7ea86d3836d986007e8ba0e451dd834 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 21:49:24 -0600 Subject: [PATCH 13/34] create a type for the root parser --- common/chat-parser-combinator.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 2b46a097d7a23..d6aee652d2ea0 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -23,6 +23,7 @@ enum parser_type { PARSER_UNTIL = 11, PARSER_SPACE = 12, PARSER_SCHEMA = 13, + PARSER_ROOT = 14, }; class parser_visitor; @@ -705,7 +706,7 @@ class root_parser : public parser_base { root_parser(const parser & root, std::shared_ptr> rules, int id) : parser_base(id), root_(root), rules_(std::move(rules)) {} - parser_type type() const override { return root_->type(); } + parser_type type() const override { return PARSER_ROOT; } parser_result parse(parser_context & ctx, size_t start = 0) override { return root_->parse(ctx, start); From 31b386f6ef431220840e869da5b54c34804f058b Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 22:16:30 -0600 Subject: [PATCH 14/34] implement repetition parser --- common/chat-parser-combinator.cpp | 194 ++++++++++++++++++------------ common/chat-parser-combinator.h | 9 ++ 2 files changed, 129 insertions(+), 74 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index d6aee652d2ea0..4121bfcde1c2e 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -24,6 +24,7 @@ enum parser_type { PARSER_SPACE = 12, PARSER_SCHEMA = 13, PARSER_ROOT = 14, + PARSER_REPETITION = 15, }; class parser_visitor; @@ -208,48 +209,52 @@ class choice_parser : public parser_base { const std::vector & parsers() const { return parsers_; } }; -// Matches one or more repetitions of a parser. -// S -> A+ -class one_or_more_parser : public parser_base { +// Matches between min and max repetitions of a parser (inclusive). +// S -> A{m,n} +// Use -1 for max_count to represent unbounded repetition (equivalent to {m,}) +class repetition_parser : public parser_base { parser parser_; + int min_count_; + int max_count_; public: - one_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + repetition_parser(const parser & parser, int min_count, int max_count, int id) + : parser_base(id), parser_(parser), min_count_(min_count), max_count_(max_count) {} - parser_type type() const override { return PARSER_ONE_OR_MORE; } + parser_type type() const override { return PARSER_REPETITION; } parser_result parse(parser_context & ctx, size_t start = 0) override { return ctx.memo.cached(id_, start, [&]() { std::unordered_map groups; + auto pos = start; + int match_count = 0; - // Parse at least once - auto first_result = parser_->parse(ctx, start); - if (!first_result.is_success()) { - return first_result; - } - - auto pos = first_result.end; - groups.insert(first_result.groups.begin(), first_result.groups.end()); - - // Parse zero or more additional times - for (;;) { + // Try to match up to max_count times (or unlimited if max_count is -1) + while (max_count_ == -1 || match_count < max_count_) { auto result = parser_->parse(ctx, pos); groups.insert(result.groups.begin(), result.groups.end()); - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + if (result.is_success()) { + // Prevent infinite loop on empty matches + if (result.end == pos) { + break; + } + pos = result.end; + match_count++; + continue; } - if (result.is_fail()) { - // Done with repetitions - break; + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); } - if (result.end == pos) { - break; // Prevent an infinite loop - } + // Child failed - stop trying + break; + } - pos = result.end; + // Check if we got enough matches + if (match_count < min_count_) { + return parser_result(PARSER_RESULT_FAIL, start, pos, groups); } return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); @@ -257,98 +262,106 @@ class one_or_more_parser : public parser_base { } std::string dump() const override { - return "OneOrMore(" + parser_->dump() + ")"; + if (max_count_ == -1) { + return "Repetition(" + parser_->dump() + ", " + std::to_string(min_count_) + ", unbounded)"; + } + return "Repetition(" + parser_->dump() + ", " + std::to_string(min_count_) + ", " + std::to_string(max_count_) + ")"; } void accept(parser_visitor & visitor) override; const parser & child() const { return parser_; } + + int min_count() const { return min_count_; } + + int max_count() const { return max_count_; } }; -// Matches zero or more repetitions of a parser, always succeeds. -// S -> A* -class zero_or_more_parser : public parser_base { - parser parser_; +// Matches one or more repetitions of a parser. +// S -> A+ +class one_or_more_parser : public parser_base { + parser delegate_; public: - zero_or_more_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + one_or_more_parser(const parser & p, int id) : parser_base(id) { + delegate_ = parser(std::make_shared(p, 1, -1, id)); + } - parser_type type() const override { return PARSER_ZERO_OR_MORE; } + parser_type type() const override { return PARSER_ONE_OR_MORE; } parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - std::unordered_map groups; - auto pos = start; + return delegate_->parse(ctx, start); + } - for (;;) { - auto result = parser_->parse(ctx, pos); - groups.insert(result.groups.begin(), result.groups.end()); + std::string dump() const override { + auto rep = std::static_pointer_cast(delegate_.ptr()); + return "OneOrMore(" + rep->child()->dump() + ")"; + } - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); - } + void accept(parser_visitor & visitor) override; - if (result.is_fail()) { - // Done with repetitions (zero or more is always valid) - break; - } + const parser & child() const { + auto rep = std::static_pointer_cast(delegate_.ptr()); + return rep->child(); + } +}; - if (result.end == pos) { - break; // Prevent an infinite loop - } +// Matches zero or more repetitions of a parser, always succeeds. +// S -> A* +class zero_or_more_parser : public parser_base { + parser delegate_; - pos = result.end; - } + public: + zero_or_more_parser(const parser & p, int id) : parser_base(id) { + delegate_ = parser(std::make_shared(p, 0, -1, id)); + } - return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); - }); + parser_type type() const override { return PARSER_ZERO_OR_MORE; } + + parser_result parse(parser_context & ctx, size_t start = 0) override { + return delegate_->parse(ctx, start); } std::string dump() const override { - return "ZeroOrMore(" + parser_->dump() + ")"; + auto rep = std::static_pointer_cast(delegate_.ptr()); + return "ZeroOrMore(" + rep->child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - const parser & child() const { return parser_; } + const parser & child() const { + auto rep = std::static_pointer_cast(delegate_.ptr()); + return rep->child(); + } }; // Matches zero or one occurrence of a parser, always succeeds. // S -> A? class optional_parser : public parser_base { - parser parser_; + parser delegate_; public: - optional_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} + optional_parser(const parser & p, int id) : parser_base(id) { + delegate_ = parser(std::make_shared(p, 0, 1, id)); + } parser_type type() const override { return PARSER_OPTIONAL; } parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - auto result = parser_->parse(ctx, start); - - if (result.is_success()) { - // Matched successfully - return result; - } - - if (result.is_need_more_input()) { - // Propagate - need more input to determine if optional matches - return result; - } - - // No match, but optional always succeeds with zero matches - return parser_result(PARSER_RESULT_SUCCESS, start, start); - }); + return delegate_->parse(ctx, start); } std::string dump() const override { - return "Optional(" + parser_->dump() + ")"; + auto rep = std::static_pointer_cast(delegate_.ptr()); + return "Optional(" + rep->child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - const parser & child() const { return parser_; } + const parser & child() const { + auto rep = std::static_pointer_cast(delegate_.ptr()); + return rep->child(); + } }; // Negative lookahead: succeeds if child parser fails, consumes no input. @@ -734,6 +747,7 @@ class parser_visitor { virtual void visit(one_or_more_parser & p) = 0; virtual void visit(zero_or_more_parser & p) = 0; virtual void visit(optional_parser & p) = 0; + virtual void visit(repetition_parser & p) = 0; virtual void visit(until_parser & p) = 0; virtual void visit(not_parser & p) = 0; virtual void visit(any_parser & p) = 0; @@ -891,6 +905,24 @@ class gbnf_visitor : public parser_visitor { } } + void visit(repetition_parser & p) override { + p.child()->accept(*this); + std::string child_result = current_result_; + + if (needs_parens(p.child()->type())) { + child_result = "(" + child_result + ")"; + } + + if (p.max_count() == -1) { + // Unbounded: {n,} + current_result_ = child_result + "{" + std::to_string(p.min_count()) + ",}"; + } else { + // Bounded: {n,m} + current_result_ = child_result + "{" + std::to_string(p.min_count()) + "," + + std::to_string(p.max_count()) + "}"; + } + } + void visit(until_parser & p) override { // Generate pattern that matches prefixes but prevents full delimiter match current_result_ = generate_until_pattern(p.delimiter()) + "*"; @@ -1021,6 +1053,11 @@ class id_assignment_visitor : public parser_visitor { p.child()->accept(*this); } + void visit(repetition_parser & p) override { + assign_id(p); + p.child()->accept(*this); + } + void visit(until_parser & p) override { assign_id(p); p.child()->accept(*this); @@ -1049,6 +1086,7 @@ void choice_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void one_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void zero_or_more_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void optional_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void repetition_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void until_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void not_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void any_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } @@ -1207,6 +1245,14 @@ parser parser_builder::until(const std::string & delimiter, bool consume_spaces) return parser(std::make_shared(delimiter, consume_spaces, counter_->next())); } +parser parser_builder::repeat(const parser & p, int min, int max) { + return parser(std::make_shared(p, min, max, counter_->next())); +} + +parser parser_builder::repeat(const parser & p, int n) { + return repeat(p, n, n); +} + parser parser_builder::schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema) { return parser(std::make_shared(p, name, schema, counter_->next())); } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 25ce7f7c11cb0..b295b4b520498 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -180,6 +180,15 @@ class parser_builder { // S -> (!delim .)* parser until(const std::string & delimiter, bool consume_spaces = true); + // Matches between min and max repetitions of a parser (inclusive). + // S -> A{m,n} + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + parser repeat(const parser & p, int min, int max); + + // Matches exactly n repetitions of a parser. + // S -> A{n} + parser repeat(const parser & p, int n); + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. // value -> object | array | string | number | true | false | null parser json(); From ffb7a6f77db113c16ecf3cd2de4db9fc2f5344fb Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 22:27:54 -0600 Subject: [PATCH 15/34] Make optional, one_or_more, and zero_or_more subclasses of repetition --- common/chat-parser-combinator.cpp | 86 ++++++++----------------------- 1 file changed, 22 insertions(+), 64 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 4121bfcde1c2e..23bce3b3a50d6 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -12,19 +12,19 @@ enum parser_type { PARSER_LITERAL = 0, PARSER_SEQUENCE = 1, PARSER_CHOICE = 2, - PARSER_ZERO_OR_MORE = 3, - PARSER_ONE_OR_MORE = 4, - PARSER_NOT = 5, - PARSER_ANY = 6, - PARSER_CHAR_CLASS = 7, - PARSER_GROUP = 8, - PARSER_RULE = 9, - PARSER_OPTIONAL = 10, - PARSER_UNTIL = 11, - PARSER_SPACE = 12, - PARSER_SCHEMA = 13, - PARSER_ROOT = 14, - PARSER_REPETITION = 15, + PARSER_REPETITION = 3, + PARSER_OPTIONAL = 4, + PARSER_ZERO_OR_MORE = 5, + PARSER_ONE_OR_MORE = 6, + PARSER_NOT = 7, + PARSER_ANY = 8, + PARSER_CHAR_CLASS = 9, + PARSER_GROUP = 10, + PARSER_RULE = 11, + PARSER_UNTIL = 12, + PARSER_SPACE = 13, + PARSER_SCHEMA = 14, + PARSER_ROOT = 15, }; class parser_visitor; @@ -279,89 +279,47 @@ class repetition_parser : public parser_base { // Matches one or more repetitions of a parser. // S -> A+ -class one_or_more_parser : public parser_base { - parser delegate_; - +class one_or_more_parser : public repetition_parser { public: - one_or_more_parser(const parser & p, int id) : parser_base(id) { - delegate_ = parser(std::make_shared(p, 1, -1, id)); - } + one_or_more_parser(const parser & p, int id) : repetition_parser(p, 1, -1, id) {} parser_type type() const override { return PARSER_ONE_OR_MORE; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return delegate_->parse(ctx, start); - } - std::string dump() const override { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return "OneOrMore(" + rep->child()->dump() + ")"; + return "OneOrMore(" + child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - - const parser & child() const { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return rep->child(); - } }; // Matches zero or more repetitions of a parser, always succeeds. // S -> A* -class zero_or_more_parser : public parser_base { - parser delegate_; - +class zero_or_more_parser : public repetition_parser { public: - zero_or_more_parser(const parser & p, int id) : parser_base(id) { - delegate_ = parser(std::make_shared(p, 0, -1, id)); - } + zero_or_more_parser(const parser & p, int id) : repetition_parser(p, 0, -1, id) {} parser_type type() const override { return PARSER_ZERO_OR_MORE; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return delegate_->parse(ctx, start); - } - std::string dump() const override { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return "ZeroOrMore(" + rep->child()->dump() + ")"; + return "ZeroOrMore(" + child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - - const parser & child() const { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return rep->child(); - } }; // Matches zero or one occurrence of a parser, always succeeds. // S -> A? -class optional_parser : public parser_base { - parser delegate_; - +class optional_parser : public repetition_parser { public: - optional_parser(const parser & p, int id) : parser_base(id) { - delegate_ = parser(std::make_shared(p, 0, 1, id)); - } + optional_parser(const parser & p, int id) : repetition_parser(p, 0, 1, id) {} parser_type type() const override { return PARSER_OPTIONAL; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return delegate_->parse(ctx, start); - } - std::string dump() const override { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return "Optional(" + rep->child()->dump() + ")"; + return "Optional(" + child()->dump() + ")"; } void accept(parser_visitor & visitor) override; - - const parser & child() const { - auto rep = std::static_pointer_cast(delegate_.ptr()); - return rep->child(); - } }; // Negative lookahead: succeeds if child parser fails, consumes no input. From 085404a326000a195efb2ca550cec2dccf684273 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Mon, 10 Nov 2025 22:44:55 -0600 Subject: [PATCH 16/34] improve context constructor --- common/chat-parser-combinator.h | 26 ++++++- tests/test-chat-parser-combinator.cpp | 108 +++++++++++++------------- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index b295b4b520498..ab839971b725c 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -52,9 +52,15 @@ struct parser_result { std::unordered_map groups; parser_result() : type(PARSER_RESULT_FAIL) {} - parser_result(parser_result_type type, size_t start) : type(type), start(start), end(start) {} - parser_result(parser_result_type type, size_t start, size_t end) : type(type), start(start), end(end) {} - parser_result(parser_result_type type, size_t start, size_t end, const std::unordered_map & groups) : type(type), start(start), end(end), groups(groups) {} + + parser_result(parser_result_type type, size_t start) + : type(type), start(start), end(start) {} + + parser_result(parser_result_type type, size_t start, size_t end) + : type(type), start(start), end(end) {} + + parser_result(parser_result_type type, size_t start, size_t end, const std::unordered_map & groups) + : type(type), start(start), end(end), groups(groups) {} bool is_fail() const { return type == PARSER_RESULT_FAIL; } bool is_need_more_input() const { return type == PARSER_RESULT_NEED_MORE_INPUT; } @@ -77,7 +83,19 @@ class parse_cache { struct parser_context { std::string_view input; parse_cache memo; - bool input_is_complete = true; + bool input_is_complete; + + parser_context() + : memo(), input_is_complete(true) {} + + parser_context(std::string_view input) + : input(input), memo(), input_is_complete(true) {} + + parser_context(std::string_view input, bool complete) + : input(input), memo(), input_is_complete(complete) {} + + parser_context(std::string_view input, parse_cache memo, bool complete = true) + : input(input), memo(std::move(memo)), input_is_complete(complete) {} }; class parser_base; diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index e4f637af9f797..83ff2ba4a67fd 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -36,7 +36,7 @@ static void test_partial_parsing() { parser_context ctx; parser_result result; - ctx = parser_context{"hello", parse_cache()}; + ctx = parser_context("hello"); result = parser.parse(ctx); assert_equals(true, result.is_success()); } @@ -49,11 +49,11 @@ static void test_partial_parsing() { parser_context ctx; parser_result result; - ctx = parser_context{"a", parse_cache()}; + ctx = parser_context("a"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"A", parse_cache()}; + ctx = parser_context("A"); result = parser.parse(ctx); assert_equals(true, result.is_fail()); @@ -61,15 +61,15 @@ static void test_partial_parsing() { return p.char_class("a-z-"); }); - ctx = parser_context{"f", parse_cache()}; + ctx = parser_context("f"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"-", parse_cache()}; + ctx = parser_context("-"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"A", parse_cache()}; + ctx = parser_context("A"); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -80,25 +80,25 @@ static void test_partial_parsing() { }); // Partial matches - auto ctx = parser_context{"", parse_cache(), false}; + ctx = parser_context("", false); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"", parse_cache(), true}; + ctx = parser_context("", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); // No match, since it does not adhere to the grammar - ctx = parser_context{"I am parser", parse_cache(), false}; + ctx = parser_context("I am parser", false); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -109,25 +109,25 @@ static void test_partial_parsing() { }); // Partial matches - auto ctx = parser_context{"", parse_cache(), true}; + ctx = parser_context("", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"", parse_cache(), true}; + ctx = parser_context("", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); // No match - ctx = parser_context{"", parse_cache(), true}; + ctx = parser_context("", true); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -138,16 +138,16 @@ static void test_partial_parsing() { }); // Partial matches - auto ctx = parser_context{"a", parse_cache(), false}; + auto ctx = parser_context("a", false); auto result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); - ctx = parser_context{"aba", parse_cache(), false}; + ctx = parser_context("aba", false); result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); // Full match - ctx = parser_context{"ab", parse_cache(), true}; + ctx = parser_context("ab", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); } @@ -158,21 +158,21 @@ static void test_partial_parsing() { }); // Partial matches - auto ctx = parser_context{"a", parse_cache(), false}; + auto ctx = parser_context("a", false); auto result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); - ctx = parser_context{"aba", parse_cache(), false}; + ctx = parser_context("aba", false); result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); // Full match - ctx = parser_context{"ab", parse_cache(), true}; + ctx = parser_context("ab", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); // No match - ctx = parser_context{"cd", parse_cache(), true}; + ctx = parser_context("cd", true); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -189,7 +189,7 @@ static void test_capture_groups() { }); std::string input = "I have a thought"; - auto ctx = parser_context{input, parse_cache()}; + auto ctx = parser_context(input); auto result = parser.parse(ctx); assert_equals(true, result.is_success()); @@ -208,7 +208,7 @@ static void test_capture_groups() { }); std::string input = "I have a "; - auto ctx = parser_context{input, parse_cache(), false}; + auto ctx = parser_context(input, false); auto result = parser.parse(ctx); assert_equals(true, result.is_success()); @@ -228,7 +228,7 @@ static void test_capture_groups() { }); std::string input = "The user said hello.Hello!"; - auto ctx = parser_context{input, parse_cache(), true}; + auto ctx = parser_context(input, true); auto result = parser.parse(ctx); assert_equals(true, result.is_success()); @@ -253,19 +253,19 @@ static void test_char_class() { parser_context ctx; parser_result result; - ctx = parser_context{"\n", parse_cache()}; + ctx = parser_context("\n"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"\t", parse_cache()}; + ctx = parser_context("\t"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"\\", parse_cache()}; + ctx = parser_context("\\"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{" ", parse_cache()}; + ctx = parser_context(" "); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -278,20 +278,20 @@ static void test_char_class() { parser_context ctx; parser_result result; - ctx = parser_context{"a", parse_cache()}; + ctx = parser_context("a"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"-", parse_cache()}; + ctx = parser_context("-"); result = parser.parse(ctx); assert_equals(true, result.is_success()); - ctx = parser_context{"z", parse_cache()}; + ctx = parser_context("z"); result = parser.parse(ctx); assert_equals(true, result.is_success()); // Should NOT match 'b' since \- is a literal dash, not a range - ctx = parser_context{"b", parse_cache()}; + ctx = parser_context("b"); result = parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -312,32 +312,32 @@ static void test_recursive_references() { parser_result result; // Test simple number - ctx = parser_context{"1", parse_cache(), true}; + ctx = parser_context("1", true); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test simple list - ctx = parser_context{"[1]", parse_cache(), true}; + ctx = parser_context("[1]", true); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test nested list - ctx = parser_context{"[[2]]", parse_cache(), true}; + ctx = parser_context("[[2]]", true); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test deeply nested list - ctx = parser_context{"[[[3]]]", parse_cache(), true}; + ctx = parser_context("[[[3]]]", true); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test partial match - ctx = parser_context{"[[", parse_cache(), false}; + ctx = parser_context("[[", false); result = value_parser.parse(ctx); assert_equals(true, result.is_success()); // Test no match - ctx = parser_context{"[a]", parse_cache(), true}; + ctx = parser_context("[a]", true); result = value_parser.parse(ctx); assert_equals(true, result.is_fail()); } @@ -349,19 +349,19 @@ static void test_optional() { }); // Full match with optional part present - auto ctx = parser_context{"hello world", parse_cache()}; + auto ctx = parser_context("hello world"); auto result = parser.parse(ctx); assert_equals(true, result.is_success()); assert_equals((size_t)11, result.end); // Full match with optional part absent - ctx = parser_context{"hello", parse_cache(), true}; + ctx = parser_context("hello", true); result = parser.parse(ctx); assert_equals(true, result.is_success()); assert_equals((size_t)5, result.end); // Partial match - waiting for more input to determine if optional matches - ctx = parser_context{"hello ", parse_cache(), false}; + ctx = parser_context("hello ", false); result = parser.parse(ctx); assert_equals(true, result.is_need_more_input()); } @@ -374,7 +374,7 @@ static void test_json_parser() { { // Test parsing a simple JSON object std::string input = R"({"name": "test", "value": 42, "flag": true})"; - parser_context ctx{input, parse_cache()}; + parser_context ctx(input); auto result = json.parse(ctx); @@ -384,7 +384,7 @@ static void test_json_parser() { { // Test parsing a JSON array with mixed types std::string input = R"([1, "hello", true, null, 3.14])"; - parser_context ctx{input, parse_cache()}; + parser_context ctx(input); auto result = json.parse(ctx); @@ -394,7 +394,7 @@ static void test_json_parser() { { // Test parsing nested JSON with objects and arrays std::string input = R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})"; - parser_context ctx{input, parse_cache()}; + parser_context ctx(input); auto result = json.parse(ctx); @@ -404,7 +404,7 @@ static void test_json_parser() { { // Test partial parsing - incomplete object std::string input = R"({"name": "test", "value": )"; - parser_context ctx{input, parse_cache(), false}; + parser_context ctx(input, false); auto result = json.parse(ctx); @@ -413,7 +413,7 @@ static void test_json_parser() { { // Test partial parsing - incomplete array std::string input = R"([1, 2, 3, )"; - parser_context ctx{input, parse_cache(), false}; + parser_context ctx(input, false); auto result = json.parse(ctx); @@ -422,7 +422,7 @@ static void test_json_parser() { { // Test partial parsing - incomplete nested structure std::string input = R"({"data": {"nested": )"; - parser_context ctx{input, parse_cache(), false}; + parser_context ctx(input, false); auto result = json.parse(ctx); @@ -476,7 +476,7 @@ static void test_complete_example() { // Test complete input std::string input = R"(I need to call get_weather with city = New Yorkget_weather{"city": "New York"})"; - parser_context ctx{input, parse_cache()}; + parser_context ctx(input); auto result = parser.parse(ctx); @@ -488,21 +488,21 @@ static void test_complete_example() { // Test partial input input = R"(I need to call get_weather )"; - ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + ctx = parser_context(input, /* .is_input_complete = */ false); result = parser.parse(ctx); assert_equals(true, result.is_success()); assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); input = R"(I need to call get_weatherget_weather)"; - ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + ctx = parser_context(input, /* .is_input_complete = */ false); result = parser.parse(ctx); assert_equals(true, result.is_success()); assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); input = R"(I need to call get_weatherget_weatherI need to call get_weatherget_weather{"cit)"; - ctx = parser_context{input, parse_cache(), /* .is_input_complete = */ false}; + ctx = parser_context(input, /* .is_input_complete = */ false); result = parser.parse(ctx); assert_equals(true, result.is_success()); From 6bd9a9502679080641e31db2a1486d3ebd081d02 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Tue, 11 Nov 2025 22:22:53 -0600 Subject: [PATCH 17/34] improve until parsing and add benchmarks --- common/chat-parser-combinator.cpp | 71 ++++++-- common/chat-parser-combinator.h | 5 + tests/test-chat-parser-combinator.cpp | 237 +++++++++++++++++++++++++- 3 files changed, 293 insertions(+), 20 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 23bce3b3a50d6..0f5adfb662e27 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -46,6 +46,12 @@ class parser_base { virtual void accept(parser_visitor & visitor) = 0; }; +// We define our own space function because MSVC's std::isspace() +// crashes for non-printable characters in Debug builds. +static bool is_space(const char c) { + return (c == ' ' || c == '\t' || c == '\n'); +} + // Matches an exact literal string. // S -> "hello" class literal_parser : public parser_base { @@ -401,7 +407,7 @@ class space_parser : public parser_base { auto pos = start; while (pos < ctx.input.size()) { char c = ctx.input[pos]; - if (c == ' ' || c == '\t' || c == '\n') { + if (is_space(c)) { ++pos; } else { break; @@ -558,28 +564,46 @@ class group_parser : public parser_base { // S -> (!delim .)* class until_parser : public parser_base { std::string delimiter_; - parser parser_; + bool consume_spaces_; + + std::boyer_moore_searcher searcher_; public: until_parser(const std::string & delimiter, bool consume_spaces, int id) - : parser_base(id), delimiter_(delimiter) { - - auto delim = parser(std::make_shared(delimiter, -1)); - auto any = parser(std::make_shared(-1)); - - if (consume_spaces) { - auto ws = parser(std::make_shared(-1)); - parser_ = parser(std::make_shared(~(ws + delim) + any, -1)); - } else { - parser_ = parser(std::make_shared(~delim + any, -1)); - } + : parser_base(id), delimiter_(delimiter), consume_spaces_(consume_spaces), searcher_(delimiter_.begin(), delimiter_.end()) { } parser_type type() const override { return PARSER_UNTIL; } parser_result parse(parser_context & ctx, size_t start = 0) override { return ctx.memo.cached(id_, start, [&]() { - return parser_->parse(ctx, start); + parser_result result(PARSER_RESULT_SUCCESS, start, ctx.input.size()); + + // Search for the delimiter + const auto * it = std::search(ctx.input.begin(), ctx.input.end(), searcher_); + + if (it != ctx.input.end()) { + result.type = PARSER_RESULT_SUCCESS; + result.end = std::distance(ctx.input.begin(), it); + } else { + // If not found, check if the input ends with a prefix of the delimiter + size_t max_overlap = std::min(ctx.input.size(), delimiter_.size() - 1); + for (size_t overlap = max_overlap; overlap > 0; --overlap) { + if (std::equal(ctx.input.end() - overlap, ctx.input.end(), delimiter_.begin())) { + result.type = PARSER_RESULT_NEED_MORE_INPUT; + result.end = ctx.input.size() - overlap; + } + } + } + + if (consume_spaces_) { + // Remove trailing spaces + while (result.end > start && is_space(ctx.input[result.end - 1])) { + result.end--; + } + } + + return result; }); } @@ -590,8 +614,6 @@ class until_parser : public parser_base { void accept(parser_visitor & visitor) override; const std::string & delimiter() const { return delimiter_; } - - const parser & child() const { return parser_; } }; // Wraps a parser with JSON schema metadata for grammar generation. @@ -1018,7 +1040,6 @@ class id_assignment_visitor : public parser_visitor { void visit(until_parser & p) override { assign_id(p); - p.child()->accept(*this); } void visit(not_parser & p) override { @@ -1215,6 +1236,22 @@ parser parser_builder::schema(const parser & p, const std::string & name, const return parser(std::make_shared(p, name, schema, counter_->next())); } +parser parser_builder::json_key(const std::string & name, const parser & p) { + return literal("\"" + name + "\"") << literal(":") << p; +} + +parser parser_builder::json_string(const parser & p) { + auto quote = literal("\""); + return quote + p + quote; +} + +parser parser_builder::between(const std::string & left, const parser & p, const std::string & right, bool allow_spaces) { + if (allow_spaces) { + return literal(left) << p << literal(right); + } + return literal(left) + p + literal(right); +} + parser parser_builder::add_rule(const std::string & name, const parser & p) { (*rules_)[name] = p; return rule(name); diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index ab839971b725c..034aa773d6a37 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -211,6 +211,11 @@ class parser_builder { // value -> object | array | string | number | true | false | null parser json(); + parser json_key(const std::string & name, const parser & p); + parser json_string(const parser & p); + + parser between(const std::string & left, const parser & p, const std::string & right, bool allow_spaces = true); + // Wraps a parser with JSON schema metadata for grammar generation. // Used internally to convert JSON schemas to GBNF grammar rules. parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 83ff2ba4a67fd..a0e007c79c879 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -1,10 +1,13 @@ #include #include +#include +#include "nlohmann/json.hpp" + +#include "chat-parser.h" #include "chat-parser-combinator.h" +#include "common.h" #include "json-schema-to-grammar.h" -#include "nlohmann/json.hpp" -#include "nlohmann/json_fwd.hpp" template static void assert_equals(const std::string_view label, const T & expected, const T & actual) { @@ -455,7 +458,7 @@ static void test_complete_example() { auto tool_call_name = p.add_rule("tool-call-name", p.literal("") - << p.group("tool-name", p.one_or_more(p.char_class("[a-zA-Z\\-_]"))) + << p.group("tool-name", p.until("")) << p.literal("")); auto schema = nlohmann::ordered_json::parse(R"({"type": "object"})"); @@ -494,6 +497,20 @@ static void test_complete_example() { assert_equals(true, result.is_success()); assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + input = R"(I need to call I need to call get_weatherI need to call get_weatherget_weather)"; ctx = parser_context(input, /* .is_input_complete = */ false); result = parser.parse(ctx); @@ -691,6 +708,184 @@ static void test_gbnf_generation() { } } +static parser create_command_r7b_parser() { + return build_parser([](parser_builder & p) { + auto thinking = p.literal("<|START_THINKING|>") + << p.until("<|END_THINKING|>") + << p.literal("<|END_THINKING|>"); + + auto response = p.literal("<|START_RESPONSE|>") + << p.until("<|END_RESPONSE|>") + << p.literal("<|END_RESPONSE|>"); + + auto json = p.json(); + auto tool_call_id = p.json_key("tool_call_id", p.json_string(p.until("\""))); + auto tool_call_name = p.json_key("tool_name", p.json_string(p.until("\""))); + auto tool_call_args = p.json_key("parameters", json); + auto tool_call_fields = tool_call_id | tool_call_name | tool_call_args; + + auto tool_call = p.between("{", + tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields), + "}"); + + auto tool_calls = p.literal("<|START_ACTION|>") + << p.literal("[") + << tool_call + << p.zero_or_more(p.literal(",") << tool_call) + << p.literal("]") + << p.literal("<|END_ACTION|>"); + + return p.optional(thinking) << (tool_calls | response); + }); +} + +static void test_command_r7b_parser(const parser & p, const std::string & input, bool partial) { + parser_context ctx(input, !partial); + p.parse(ctx); +} + +static void test_command_r7b_legacy_parser(const std::string & input, bool partial) { + // Original parser taken from chat.cpp + common_chat_msg_parser builder(input, + /* is_partial= */ partial, { + /* .format = */ COMMON_CHAT_FORMAT_GENERIC, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + }); + + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +struct bench_tool_call { + std::string id; + std::string name; + nlohmann::ordered_json args; +}; + +// Simple tokenize function that splits by space +static std::vector simple_tokenize(const std::string & input) { + std::vector result; + std::string current; + + for (size_t i = 0; i < input.size(); i++) { + if (input[i] == ' ') { + if (!current.empty()) { + result.push_back(current); + current.clear(); + } + current += ' '; + } else { + current += input[i]; + } + } + + if (!current.empty()) { + result.push_back(current); + } + + return result; +} + +static void benchmark_compare( + const std::string & reasoning, + const std::string & content, + const std::vector & tool_calls, + int iterations) { + + // Build response + std::vector tokens; // Since we don't have a command r7b tokenizer, we're going to "simulate" them. + + if (!reasoning.empty()) { + auto tokenized = simple_tokenize(reasoning); + tokens.emplace_back("<|START_THINKING|>"); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + tokens.emplace_back("<|END_THINKING|>"); + } + + if (!content.empty()) { + auto tokenized = simple_tokenize(content); + tokens.emplace_back("<|START_RESPONSE|>"); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + tokens.emplace_back("<|END_RESPONSE|>"); + } + + if (!tool_calls.empty()) { + tokens.emplace_back("<|START_ACTION|>"); + + auto json = nlohmann::json::array(); + for (const auto & tc : tool_calls) { + auto tc_json = nlohmann::json::object(); + tc_json["tool_call_id"] = tc.id; + tc_json["tool_name"] = tc.name; + tc_json["parameters"] = tc.args; + json.push_back(tc_json); + } + + auto tokenized = simple_tokenize(json.dump(-1, ' ', true)); + tokens.insert(tokens.end(), tokenized.begin(), tokenized.end()); + + tokens.emplace_back("<|END_ACTION|>"); + } + + auto run = [&](const std::function & fn) { + std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string()); + + std::chrono::microseconds duration(0); + for (int i = 0; i < iterations; i++) { + auto start = std::chrono::high_resolution_clock::now(); + fn(input, false); + auto end = std::chrono::high_resolution_clock::now(); + duration += std::chrono::duration_cast(end - start); + } + return duration.count() / iterations; + }; + + auto parser = create_command_r7b_parser(); + + auto duration_new = run([&](const std::string & input, bool partial) { + test_command_r7b_parser(parser, input, partial); + }); + + auto duration_legacy = run([&](const std::string & input, bool partial) { + try { + test_command_r7b_legacy_parser(input, partial); + } catch (const common_chat_msg_partial_exception &) { } + }); + + std::cout << " New parser avg: " << duration_new << " us\n"; + std::cout << "Legacy parser avg: " << duration_legacy << " us\n"; +} + int main() { test_partial_parsing(); test_char_class(); @@ -701,5 +896,41 @@ int main() { test_complete_example(); test_gbnf_generation(); std::cout << "All tests passed!\n"; + + std::cout << "\n== Benchmarks ==\n"; + std::string example_reasoning = + "To plan an effective trip to Japan that includes both historical sites and modern attractions within a budget of $4000 for a two-week stay, we need to:\n\n" + "1. Identify key historical sites and modern attractions in Japan.\n" + "2. Find affordable accommodation options that provide a balance between comfort and cost.\n" + "3. Determine the best modes of transportation for getting around Japan.\n" + "4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without overspending.\n" + "5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees to attractions."; + + std::string example_content = + "For a two-week trip to Japan with a $4,000 budget, I recommend planning an itinerary that balances historical sites with modern attractions. The destination will be Japan, with a duration of 14 days.\n\n" + "Given your interests in both historical sites and modern attractions, you'll want to focus on cities like Kyoto for its temples and traditional culture, Tokyo for its cutting-edge technology and entertainment districts, and possibly Hiroshima or Nara for additional historical significance.\n\n" + "For accommodation, I suggest looking for affordable options such as budget hotels, hostels, or guesthouses that offer good value without sacrificing too much comfort. Japan has excellent mid-range accommodation options that can keep your lodging costs manageable.\n\n" + "Transportation should prioritize efficiency—consider getting a JR Rail Pass for intercity travel, which allows unlimited rides on most JR trains including the Shinkansen (bullet train). Within cities, use local trains and subways, which are both affordable and highly reliable.\n\n" + "For meals, embrace local cuisine by eating at neighborhood restaurants, ramen shops, and izakayas rather than touristy establishments. This will give you an authentic experience while keeping costs reasonable—you can enjoy excellent meals for $10-20 per person at local spots.\n\n"; + + std::vector example_tool_calls = {{ + "call_0", + "plan_trip", + nlohmann::json::parse(R"({ + "destination": "Japan", + "duration": 14, + "budget": 4000, + "interests": ["historical sites", "modern attractions"], + "accommodation_preferences": "affordable", + "transportation_preferences": "efficient", + "meal_preferences": "local cuisine" + })") + }}; + + std::cout << "\nReasoning + Content:\n"; + benchmark_compare(example_reasoning, example_content, std::vector(), 100); + + std::cout << "\nReasoning + Tool Call:\n"; + benchmark_compare(example_reasoning, "", example_tool_calls, 100); return 0; } From 62656db20a67340f81de99eaf5120e97229fd66e Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Tue, 11 Nov 2025 22:36:07 -0600 Subject: [PATCH 18/34] remove cached() pattern, cache in parser_base with specialized parsing functions for each parser --- common/chat-parser-combinator.cpp | 375 +++++++++++++++--------------- common/chat-parser-combinator.h | 2 - 2 files changed, 182 insertions(+), 195 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 0f5adfb662e27..871a12cde4be9 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -41,7 +41,28 @@ class parser_base { void set_id(int id) { id_ = id; } virtual parser_type type() const = 0; - virtual parser_result parse(parser_context & ctx, size_t start = 0) = 0; + + // Template Method: handles caching, delegates to parse_uncached() + virtual parser_result parse(parser_context & ctx, size_t start = 0) { + if (id_ == -1) { + // Don't cache parsers with ID -1 (from operators) + return parse_uncached(ctx, start); + } + + // Check cache + auto cached = ctx.memo.get(id_, start); + if (cached) { + return *cached; + } + + // Execute and cache + auto result = parse_uncached(ctx, start); + return ctx.memo.set(id_, start, result); + } + + // Actual parsing implementation (to be overridden by subclasses) + virtual parser_result parse_uncached(parser_context & ctx, size_t start = 0) = 0; + virtual std::string dump() const = 0; virtual void accept(parser_visitor & visitor) = 0; }; @@ -62,27 +83,25 @@ class literal_parser : public parser_base { parser_type type() const override { return PARSER_LITERAL; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - auto pos = start; - for (auto i = 0u; i < literal_.size(); ++i) { - if (pos >= ctx.input.size()) { - if (ctx.input_is_complete) { - return parser_result(PARSER_RESULT_FAIL, start); - } - if (i > 0) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); - } + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + for (auto i = 0u; i < literal_.size(); ++i) { + if (pos >= ctx.input.size()) { + if (ctx.input_is_complete) { return parser_result(PARSER_RESULT_FAIL, start); } - if (ctx.input[pos] != literal_[i]) { - return parser_result(PARSER_RESULT_FAIL, start); + if (i > 0) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); } - ++pos; + return parser_result(PARSER_RESULT_FAIL, start); + } + if (ctx.input[pos] != literal_[i]) { + return parser_result(PARSER_RESULT_FAIL, start); } + ++pos; + } - return parser_result(PARSER_RESULT_SUCCESS, start, pos); - }); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); } std::string dump() const override { @@ -116,34 +135,32 @@ class sequence_parser : public parser_base { parser_type type() const override { return PARSER_SEQUENCE; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - std::unordered_map groups; + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + std::unordered_map groups; - auto pos = start; - for (const auto & p : parsers_) { - auto result = p->parse(ctx, pos); + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); - // Copy groups - groups.insert(result.groups.begin(), result.groups.end()); + // Copy groups + groups.insert(result.groups.begin(), result.groups.end()); - if (result.is_fail()) { - if (result.end >= ctx.input.size() && !ctx.input_is_complete) { - // If we fail because we don't have enough input, then return success - return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); - } - return parser_result(PARSER_RESULT_FAIL, start, result.end, groups); - } - - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); + if (result.is_fail()) { + if (result.end >= ctx.input.size() && !ctx.input_is_complete) { + // If we fail because we don't have enough input, then return success + return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); } + return parser_result(PARSER_RESULT_FAIL, start, result.end, groups); + } - pos = result.end; + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); } - return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); - }); + pos = result.end; + } + + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); } std::string dump() const override { @@ -182,23 +199,21 @@ class choice_parser : public parser_base { parser_type type() const override { return PARSER_CHOICE; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - auto pos = start; - for (const auto & p : parsers_) { - auto result = p->parse(ctx, pos); + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + for (const auto & p : parsers_) { + auto result = p->parse(ctx, pos); - if (result.is_success()) { - return result; - } + if (result.is_success()) { + return result; + } - if (result.is_need_more_input()) { - return result; - } + if (result.is_need_more_input()) { + return result; } + } - return parser_result(PARSER_RESULT_FAIL, start); - }); + return parser_result(PARSER_RESULT_FAIL, start); } std::string dump() const override { @@ -229,42 +244,40 @@ class repetition_parser : public parser_base { parser_type type() const override { return PARSER_REPETITION; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - std::unordered_map groups; - auto pos = start; - int match_count = 0; - - // Try to match up to max_count times (or unlimited if max_count is -1) - while (max_count_ == -1 || match_count < max_count_) { - auto result = parser_->parse(ctx, pos); - groups.insert(result.groups.begin(), result.groups.end()); - - if (result.is_success()) { - // Prevent infinite loop on empty matches - if (result.end == pos) { - break; - } - pos = result.end; - match_count++; - continue; - } + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + std::unordered_map groups; + auto pos = start; + int match_count = 0; - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); - } + // Try to match up to max_count times (or unlimited if max_count is -1) + while (max_count_ == -1 || match_count < max_count_) { + auto result = parser_->parse(ctx, pos); + groups.insert(result.groups.begin(), result.groups.end()); - // Child failed - stop trying - break; + if (result.is_success()) { + // Prevent infinite loop on empty matches + if (result.end == pos) { + break; + } + pos = result.end; + match_count++; + continue; } - // Check if we got enough matches - if (match_count < min_count_) { - return parser_result(PARSER_RESULT_FAIL, start, pos, groups); + if (result.is_need_more_input()) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); } - return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); - }); + // Child failed - stop trying + break; + } + + // Check if we got enough matches + if (match_count < min_count_) { + return parser_result(PARSER_RESULT_FAIL, start, pos, groups); + } + + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); } std::string dump() const override { @@ -338,23 +351,21 @@ class not_parser : public parser_base { parser_type type() const override { return PARSER_NOT; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - auto result = parser_->parse(ctx, start); + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto result = parser_->parse(ctx, start); - if (result.is_success()) { - // Fail if the underlying parser matches - return parser_result(PARSER_RESULT_FAIL, start); - } + if (result.is_success()) { + // Fail if the underlying parser matches + return parser_result(PARSER_RESULT_FAIL, start); + } - if (result.is_need_more_input()) { - // Propagate - need to know what child would match before negating - return result; - } + if (result.is_need_more_input()) { + // Propagate - need to know what child would match before negating + return result; + } - // Child failed, so negation succeeds - return parser_result(PARSER_RESULT_SUCCESS, start); - }); + // Child failed, so negation succeeds + return parser_result(PARSER_RESULT_SUCCESS, start); } std::string dump() const override { @@ -374,17 +385,15 @@ class any_parser : public parser_base { parser_type type() const override { return PARSER_ANY; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - if (start >= ctx.input.size()) { - if (ctx.input_is_complete) { - return parser_result(PARSER_RESULT_FAIL, start); - } + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { return parser_result(PARSER_RESULT_FAIL, start); } + return parser_result(PARSER_RESULT_FAIL, start); + } - return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); - }); + return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); } std::string dump() const override { @@ -402,20 +411,18 @@ class space_parser : public parser_base { parser_type type() const override { return PARSER_SPACE; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - auto pos = start; - while (pos < ctx.input.size()) { - char c = ctx.input[pos]; - if (is_space(c)) { - ++pos; - } else { - break; - } + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto pos = start; + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + if (is_space(c)) { + ++pos; + } else { + break; } + } - return parser_result(PARSER_RESULT_SUCCESS, start, pos); - }); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); } std::string dump() const override { @@ -491,34 +498,32 @@ class char_class_parser : public parser_base { parser_type type() const override { return PARSER_CHAR_CLASS; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - if (start >= ctx.input.size()) { - if (ctx.input_is_complete) { - return parser_result(PARSER_RESULT_FAIL, start); - } + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + if (start >= ctx.input.size()) { + if (ctx.input_is_complete) { return parser_result(PARSER_RESULT_FAIL, start); } + return parser_result(PARSER_RESULT_FAIL, start); + } - bool matches = false; - for (const auto & range : ranges_) { - if (range.contains(ctx.input[start])) { - matches = true; - break; - } + bool matches = false; + for (const auto & range : ranges_) { + if (range.contains(ctx.input[start])) { + matches = true; + break; } + } - // If negated, invert the match result - if (negated_) { - matches = !matches; - } + // If negated, invert the match result + if (negated_) { + matches = !matches; + } - if (matches) { - return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); - } + if (matches) { + return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); + } - return parser_result(PARSER_RESULT_FAIL, start); - }); + return parser_result(PARSER_RESULT_FAIL, start); } std::string dump() const override { @@ -541,14 +546,12 @@ class group_parser : public parser_base { parser_type type() const override { return PARSER_GROUP; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - auto result = parser_->parse(ctx, start); + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto result = parser_->parse(ctx, start); - // Store result - result.groups[name_] = parser_match_location{result.start, result.end}; - return result; - }); + // Store result + result.groups[name_] = parser_match_location{result.start, result.end}; + return result; } std::string dump() const override { @@ -575,36 +578,34 @@ class until_parser : public parser_base { parser_type type() const override { return PARSER_UNTIL; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - parser_result result(PARSER_RESULT_SUCCESS, start, ctx.input.size()); + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + parser_result result(PARSER_RESULT_SUCCESS, start, ctx.input.size()); - // Search for the delimiter - const auto * it = std::search(ctx.input.begin(), ctx.input.end(), searcher_); + // Search for the delimiter + const auto * it = std::search(ctx.input.begin(), ctx.input.end(), searcher_); - if (it != ctx.input.end()) { - result.type = PARSER_RESULT_SUCCESS; - result.end = std::distance(ctx.input.begin(), it); - } else { - // If not found, check if the input ends with a prefix of the delimiter - size_t max_overlap = std::min(ctx.input.size(), delimiter_.size() - 1); - for (size_t overlap = max_overlap; overlap > 0; --overlap) { - if (std::equal(ctx.input.end() - overlap, ctx.input.end(), delimiter_.begin())) { - result.type = PARSER_RESULT_NEED_MORE_INPUT; - result.end = ctx.input.size() - overlap; - } + if (it != ctx.input.end()) { + result.type = PARSER_RESULT_SUCCESS; + result.end = std::distance(ctx.input.begin(), it); + } else { + // If not found, check if the input ends with a prefix of the delimiter + size_t max_overlap = std::min(ctx.input.size(), delimiter_.size() - 1); + for (size_t overlap = max_overlap; overlap > 0; --overlap) { + if (std::equal(ctx.input.end() - overlap, ctx.input.end(), delimiter_.begin())) { + result.type = (ctx.input_is_complete) ? PARSER_RESULT_FAIL : PARSER_RESULT_NEED_MORE_INPUT; + result.end = ctx.input.size() - overlap; } } + } - if (consume_spaces_) { - // Remove trailing spaces - while (result.end > start && is_space(ctx.input[result.end - 1])) { - result.end--; - } + if (consume_spaces_) { + // Remove trailing spaces + while (result.end > start && is_space(ctx.input[result.end - 1])) { + result.end--; } + } - return result; - }); + return result; } std::string dump() const override { @@ -629,10 +630,8 @@ class schema_parser : public parser_base { parser_type type() const override { return PARSER_SCHEMA; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - return parser_->parse(ctx, start); - }); + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + return parser_->parse(ctx, start); } std::string dump() const override { @@ -660,22 +659,20 @@ class rule_parser : public parser_base { parser_type type() const override { return PARSER_RULE; } - parser_result parse(parser_context & ctx, size_t start = 0) override { - return ctx.memo.cached(id_, start, [&]() { - auto rules = rules_.lock(); - if (!rules) { - LOG_ERR("rule_parser::parse called with expired rule registry\n"); - return parser_result(PARSER_RESULT_FAIL, start); - } + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto rules = rules_.lock(); + if (!rules) { + LOG_ERR("rule_parser::parse called with expired rule registry\n"); + return parser_result(PARSER_RESULT_FAIL, start); + } - auto it = rules->find(name_); - if (it == rules->end()) { - LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); - return parser_result(PARSER_RESULT_FAIL, start); - } + auto it = rules->find(name_); + if (it == rules->end()) { + LOG_ERR("rule_parser::parse rule '%s' not found in registry\n", name_.c_str()); + return parser_result(PARSER_RESULT_FAIL, start); + } - return it->second->parse(ctx, start); - }); + return it->second->parse(ctx, start); } std::string dump() const override { @@ -701,7 +698,7 @@ class root_parser : public parser_base { parser_type type() const override { return PARSER_ROOT; } - parser_result parse(parser_context & ctx, size_t start = 0) override { + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { return root_->parse(ctx, start); } @@ -1110,14 +1107,6 @@ void parse_cache::clear() { results.clear(); } -parser_result parse_cache::cached(int id, size_t start, const std::function & fn) { - auto result = get(id, start); - if (result) { - return *result; - } - return set(id, start, fn()); -} - parser::parser() {} parser::parser(std::shared_ptr parser) : ptr_(std::move(parser)) {} diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 034aa773d6a37..4a1242a718362 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -76,8 +76,6 @@ class parse_cache { parser_result set(int id, size_t start, parser_result result); std::optional get(int id, size_t start); void clear(); - - parser_result cached(int id, size_t start, const std::function & fn); }; struct parser_context { From 18557f3a5ebe37bbe1502480c84ba38e40501854 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Tue, 11 Nov 2025 23:54:43 -0600 Subject: [PATCH 19/34] improve json parsing performance to better match legacy parsing --- common/chat-parser-combinator.cpp | 257 +++++++++++++++++++++----- common/chat-parser-combinator.h | 14 +- tests/test-chat-parser-combinator.cpp | 30 +-- 3 files changed, 241 insertions(+), 60 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 871a12cde4be9..de825a8e6a2ee 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -18,13 +18,14 @@ enum parser_type { PARSER_ONE_OR_MORE = 6, PARSER_NOT = 7, PARSER_ANY = 8, - PARSER_CHAR_CLASS = 9, + PARSER_CHARS = 9, PARSER_GROUP = 10, PARSER_RULE = 11, PARSER_UNTIL = 12, PARSER_SPACE = 13, PARSER_SCHEMA = 14, PARSER_ROOT = 15, + PARSER_JSON_STRING = 16, }; class parser_visitor; @@ -93,7 +94,7 @@ class literal_parser : public parser_base { if (i > 0) { return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); } - return parser_result(PARSER_RESULT_FAIL, start); + return parser_result(PARSER_RESULT_FAIL, start, pos); } if (ctx.input[pos] != literal_[i]) { return parser_result(PARSER_RESULT_FAIL, start); @@ -432,9 +433,9 @@ class space_parser : public parser_base { void accept(parser_visitor & visitor) override; }; -// Matches a single character from a character class or range. -// S -> [a-z] or S -> [^0-9] -class char_class_parser : public parser_base { +// Matches between min and max repetitions of characters from a character class. +// S -> [a-z]{m,n} +class chars_parser : public parser_base { struct char_range { int start; int end; @@ -445,9 +446,13 @@ class char_class_parser : public parser_base { std::string pattern_; std::vector ranges_; bool negated_; + int min_count_; + int max_count_; public: - char_class_parser(const std::string & classes, int id) : parser_base(id), pattern_(classes), negated_(false) { + chars_parser(const std::string & classes, int min_count, int max_count, int id) + : parser_base(id), pattern_(classes), negated_(false), min_count_(min_count), max_count_(max_count) { + std::string content = classes; if (content.front() == '[') { content = content.substr(1); @@ -496,43 +501,164 @@ class char_class_parser : public parser_base { } } - parser_type type() const override { return PARSER_CHAR_CLASS; } + parser_type type() const override { return PARSER_CHARS; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { - if (start >= ctx.input.size()) { - if (ctx.input_is_complete) { - return parser_result(PARSER_RESULT_FAIL, start); + auto pos = start; + int match_count = 0; + + // Try to match up to max_count times (or unlimited if max_count is -1) + while (max_count_ == -1 || match_count < max_count_) { + if (pos >= ctx.input.size()) { + break; } - return parser_result(PARSER_RESULT_FAIL, start); - } - bool matches = false; - for (const auto & range : ranges_) { - if (range.contains(ctx.input[start])) { - matches = true; + bool matches = false; + for (const auto & range : ranges_) { + if (range.contains(ctx.input[pos])) { + matches = true; + break; + } + } + + // If negated, invert the match result + if (negated_) { + matches = !matches; + } + + if (matches) { + ++pos; + ++match_count; + } else { break; } } - // If negated, invert the match result - if (negated_) { - matches = !matches; + // Check if we got enough matches + if (match_count < min_count_) { + return parser_result(PARSER_RESULT_FAIL, start); } - if (matches) { - return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); + } + + std::string dump() const override { + if (max_count_ == -1) { + return "CharRepeat(" + pattern_ + ", " + std::to_string(min_count_) + ", unbounded)"; } + return "CharRepeat(" + pattern_ + ", " + std::to_string(min_count_) + ", " + std::to_string(max_count_) + ")"; + } - return parser_result(PARSER_RESULT_FAIL, start); + void accept(parser_visitor & visitor) override; + + const std::string & pattern() const { return pattern_; } + + int min_count() const { return min_count_; } + + int max_count() const { return max_count_; } +}; + +// Specialized parser for JSON string content (without quotes). +// Parses the content between quotes with single-pass streaming support. +// Stops before the closing quote (doesn't consume it). +// Handles escape sequences and emits NEED_MORE_INPUT for incomplete input. +// S -> (regular chars and escape sequences)* until closing " +class json_string_parser : public parser_base { + std::optional capture_name_; + + static bool is_hex_digit(char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); + } + + public: + json_string_parser(std::optional capture_name, int id) + : parser_base(id), capture_name_(std::move(capture_name)) {} + + parser_type type() const override { return PARSER_JSON_STRING; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + std::unordered_map groups; + auto pos = start; + + // Parse string content (without quotes) + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + + if (c == '"') { + // Found closing quote - success (don't consume it) + if (capture_name_) { + groups[*capture_name_] = parser_match_location{start, pos}; + } + return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + } + + if (c == '\\') { + // Handle escape sequence + ++pos; + if (pos >= ctx.input.size()) { + // Mid-escape sequence + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + + char escape = ctx.input[pos]; + switch (escape) { + case '"': + case '\\': + case '/': + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + // Valid escape + ++pos; + break; + + case 'u': + // Unicode escape: must be followed by 4 hex digits + ++pos; + for (int i = 0; i < 4; ++i) { + if (pos >= ctx.input.size()) { + // Incomplete unicode escape + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start); + } + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + if (!is_hex_digit(ctx.input[pos])) { + return parser_result(PARSER_RESULT_FAIL, start); + } + ++pos; + } + break; + + default: + // Invalid escape sequence + return parser_result(PARSER_RESULT_FAIL, start); + } + } else { + // Regular character + ++pos; + } + } + + // Reached end without finding closing quote + return parser_result(PARSER_RESULT_FAIL, start, pos); } std::string dump() const override { - return "Char(" + pattern_ + ")"; + if (capture_name_) { + return "JsonString(" + *capture_name_ + ")"; + } + return "JsonString()"; } void accept(parser_visitor & visitor) override; - const std::string & pattern() const { return pattern_; } + const std::optional & capture_name() const { return capture_name_; } }; // Captures the matched text from a parser and stores it with a name. @@ -729,7 +855,8 @@ class parser_visitor { virtual void visit(not_parser & p) = 0; virtual void visit(any_parser & p) = 0; virtual void visit(space_parser & p) = 0; - virtual void visit(char_class_parser & p) = 0; + virtual void visit(chars_parser & p) = 0; + virtual void visit(json_string_parser & p) = 0; virtual void visit(group_parser & p) = 0; virtual void visit(schema_parser & p) = 0; virtual void visit(rule_parser & p) = 0; @@ -921,9 +1048,36 @@ class gbnf_visitor : public parser_visitor { current_result_ = "space"; } - void visit(char_class_parser & p) override { - // Return pattern as-is (already in GBNF format) - current_result_ = p.pattern(); + void visit(chars_parser & p) override { + const std::string & pattern = p.pattern(); + + if (p.min_count() == 0 && p.max_count() == -1) { + // Zero or more: * + current_result_ = pattern + "*"; + } else if (p.min_count() == 1 && p.max_count() == -1) { + // One or more: + + current_result_ = pattern + "+"; + } else if (p.max_count() == -1) { + // Unbounded: {n,} + current_result_ = pattern + "{" + std::to_string(p.min_count()) + ",}"; + } else if (p.min_count() == p.max_count()) { + // Exact count: {n} or just pattern for n=1 + if (p.min_count() == 1) { + current_result_ = pattern; + } else { + current_result_ = pattern + "{" + std::to_string(p.min_count()) + "}"; + } + } else { + // Bounded: {n,m} + current_result_ = pattern + "{" + std::to_string(p.min_count()) + "," + + std::to_string(p.max_count()) + "}"; + } + } + + void visit(json_string_parser &) override { + // JSON string content (without quotes) + // Pattern: (any non-quote/backslash OR escape sequences)* until closing quote + current_result_ = R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; } void visit(group_parser & p) override { @@ -988,7 +1142,11 @@ class id_assignment_visitor : public parser_visitor { assign_id(p); } - void visit(char_class_parser & p) override { + void visit(chars_parser & p) override { + assign_id(p); + } + + void visit(json_string_parser & p) override { assign_id(p); } @@ -1067,7 +1225,8 @@ void until_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void not_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void any_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void space_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } -void char_class_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void chars_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void json_string_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void group_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void schema_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void rule_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } @@ -1193,8 +1352,20 @@ parser parser_builder::any() { return parser(std::make_shared(counter_->next())); } -parser parser_builder::char_class(const std::string & classes) { - return parser(std::make_shared(classes, counter_->next())); +parser parser_builder::chars(const std::string & classes, int min, int max) { + return parser(std::make_shared(classes, min, max, counter_->next())); +} + +parser parser_builder::one(const std::string & classes) { + return chars(classes, 1, 1); +} + +parser parser_builder::json_string() { + return parser(std::make_shared(std::nullopt, counter_->next())); +} + +parser parser_builder::json_string(const std::string & name) { + return parser(std::make_shared(name, counter_->next())); } parser parser_builder::group(const std::string & name, const parser & p) { @@ -1270,36 +1441,28 @@ static parser json_parser(std::shared_ptr counter) { parser_builder builder(std::move(counter)); // Whitespace: space, tab, newline, carriage return - auto ws = builder.zero_or_more(builder.char_class("[ \\t\\n\\r]")); + auto ws = builder.chars("[ \\t\\n\\r]", 0, -1); // Number components - auto digit = builder.char_class("[0-9]"); - auto digit1_9 = builder.char_class("[1-9]"); - auto digits = builder.one_or_more(digit); + auto digit1_9 = builder.chars("[1-9]", 1, 1); + auto digits = builder.chars("[0-9]"); // Integer part: 0 or non-zero digit followed by more digits - auto int_part = builder.literal("0") | (digit1_9 + builder.zero_or_more(digit)); + auto int_part = builder.literal("0") | (digit1_9 + builder.chars("[0-9]", 0, -1)); // Optional fractional part auto frac = builder.literal(".") + digits; // Optional exponent part - auto exp = (builder.literal("e") | builder.literal("E")) + builder.optional(builder.char_class("[+\\-]")) + digits; + auto exp = (builder.literal("e") | builder.literal("E")) + builder.optional(builder.chars("[+\\-]", 1, 1)) + digits; // Complete number auto number = builder.optional(builder.literal("-")) + int_part + builder.optional(frac) + builder.optional(exp); builder.add_rule("json_number", number); - // String components - auto hex = builder.char_class("[0-9a-fA-F]"); - auto unicode_escape = builder.literal("\\u") + hex + hex + hex + hex; - auto simple_escape = builder.literal("\\") + builder.char_class("[\"\\\\bfnrt/]"); - auto escape = simple_escape | unicode_escape; - - // String character: escape sequence or any char except quote and backslash - auto string_char = escape | builder.char_class("[^\"\\\\]"); - auto string = builder.literal("\"") + builder.zero_or_more(string_char) + builder.literal("\""); + // String: specialized single-pass parser (content only, wrapped with quotes) + auto string = builder.literal("\"") + builder.json_string() + builder.literal("\""); builder.add_rule("json_string", string); diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 4a1242a718362..6275de1fcde62 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -176,9 +176,17 @@ class parser_builder { // S -> . parser any(); + // Matches between min and max repetitions of characters from a character class. + // S -> [a-z]{m,n} + // + // Use -1 for max to represent unbounded repetition (equivalent to {m,}) + parser chars(const std::string & classes, int min = 1, int max = -1); + // Matches a single character from a character class or range. // S -> [a-z] or S -> [^0-9] - parser char_class(const std::string & classes); + // + // Equivalent to chars(classes, 1, 1) + parser one(const std::string & classes); // Captures the matched text from a parser and stores it with a name. // S -> @@ -209,6 +217,10 @@ class parser_builder { // value -> object | array | string | number | true | false | null parser json(); + // Specialized single-pass JSON string parser with escape sequence handling + parser json_string(); + parser json_string(const std::string & name); + parser json_key(const std::string & name, const parser & p); parser json_string(const parser & p); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index a0e007c79c879..6264d72f76fbb 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -46,7 +46,7 @@ static void test_partial_parsing() { { // Test char class auto parser = build_parser([](parser_builder& p) { - return p.char_class("a-z"); + return p.one("a-z"); }); parser_context ctx; @@ -61,7 +61,7 @@ static void test_partial_parsing() { assert_equals(true, result.is_fail()); parser = build_parser([](parser_builder& p) { - return p.char_class("a-z-"); + return p.one("a-z-"); }); ctx = parser_context("f"); @@ -246,11 +246,11 @@ static void test_capture_groups() { } } -static void test_char_class() { +static void test_one() { { // Test common escape sequences auto parser = build_parser([](parser_builder& p) { - return p.char_class("[\\n\\t\\\\]"); + return p.one("[\\n\\t\\\\]"); }); parser_context ctx; @@ -275,7 +275,7 @@ static void test_char_class() { { // Test escaped dash (literal dash, not a range) auto parser = build_parser([](parser_builder& p) { - return p.char_class("[a\\-z]"); + return p.one("[a\\-z]"); }); parser_context ctx; @@ -302,7 +302,7 @@ static void test_char_class() { static void test_recursive_references() { auto value_parser = build_parser([](parser_builder& p) { - p.add_rule("number", p.one_or_more(p.char_class("0-9"))); + p.add_rule("number", p.one_or_more(p.one("0-9"))); p.add_rule("list", p.sequence({ p.literal("["), p.rule("value"), @@ -559,7 +559,7 @@ static void test_gbnf_generation() { { // Test char class auto parser = build_parser([](parser_builder& p) { - return p.char_class("[a-z]"); + return p.one("[a-z]"); }); auto gbnf = build_grammar([&](const common_grammar_builder & builder) { @@ -595,7 +595,7 @@ static void test_gbnf_generation() { { // Test one_or_more auto parser = build_parser([](parser_builder& p) { - return p.one_or_more(p.char_class("[0-9]")); + return p.one_or_more(p.one("[0-9]")); }); auto gbnf = build_grammar([&](const common_grammar_builder & builder) { @@ -607,7 +607,7 @@ static void test_gbnf_generation() { { // Test zero_or_more auto parser = build_parser([](parser_builder& p) { - return p.zero_or_more(p.char_class("[a-z]")); + return p.zero_or_more(p.one("[a-z]")); }); auto gbnf = build_grammar([&](const common_grammar_builder & builder) { @@ -668,7 +668,7 @@ static void test_gbnf_generation() { { // Test rule references auto parser = build_parser([](parser_builder& p) { - auto digit = p.add_rule("digit", p.char_class("[0-9]")); + auto digit = p.add_rule("digit", p.one("[0-9]")); return p.one_or_more(digit); }); @@ -709,7 +709,7 @@ static void test_gbnf_generation() { } static parser create_command_r7b_parser() { - return build_parser([](parser_builder & p) { + auto parser = build_parser([](parser_builder & p) { auto thinking = p.literal("<|START_THINKING|>") << p.until("<|END_THINKING|>") << p.literal("<|END_THINKING|>"); @@ -737,6 +737,12 @@ static parser create_command_r7b_parser() { return p.optional(thinking) << (tool_calls | response); }); + + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + return parser; } static void test_command_r7b_parser(const parser & p, const std::string & input, bool partial) { @@ -888,7 +894,7 @@ static void benchmark_compare( int main() { test_partial_parsing(); - test_char_class(); + test_one(); test_capture_groups(); test_recursive_references(); test_optional(); From f6aa60857a72120aecf40775d68cb34c94124e03 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 00:20:21 -0600 Subject: [PATCH 20/34] fix const auto * it for windows --- common/chat-parser-combinator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index de825a8e6a2ee..2481443be9c44 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -708,7 +708,7 @@ class until_parser : public parser_base { parser_result result(PARSER_RESULT_SUCCESS, start, ctx.input.size()); // Search for the delimiter - const auto * it = std::search(ctx.input.begin(), ctx.input.end(), searcher_); + const auto it = std::search(ctx.input.begin(), ctx.input.end(), searcher_); if (it != ctx.input.end()) { result.type = PARSER_RESULT_SUCCESS; From d58dacea18127adfc37272c018d8f50a0c5d9b25 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 00:35:33 -0600 Subject: [PATCH 21/34] move id assignment to classes instead of using a visitor --- common/chat-parser-combinator.cpp | 139 +++++++++--------------------- 1 file changed, 41 insertions(+), 98 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 2481443be9c44..064afd80f1686 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -64,6 +64,12 @@ class parser_base { // Actual parsing implementation (to be overridden by subclasses) virtual parser_result parse_uncached(parser_context & ctx, size_t start = 0) = 0; + virtual void assign_id(std::shared_ptr counter) { + if (id_ == -1) { + id_ = counter->next(); + } + } + virtual std::string dump() const = 0; virtual void accept(parser_visitor & visitor) = 0; }; @@ -164,6 +170,13 @@ class sequence_parser : public parser_base { return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); } + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + for (auto & p : parsers_) { + p->assign_id(counter); + } + } + std::string dump() const override { std::vector parts; parts.reserve(parsers_.size()); @@ -217,6 +230,13 @@ class choice_parser : public parser_base { return parser_result(PARSER_RESULT_FAIL, start); } + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + for (auto & p : parsers_) { + p->assign_id(counter); + } + } + std::string dump() const override { std::vector parts; parts.reserve(parsers_.size()); @@ -281,6 +301,11 @@ class repetition_parser : public parser_base { return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); } + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + parser_->assign_id(counter); + } + std::string dump() const override { if (max_count_ == -1) { return "Repetition(" + parser_->dump() + ", " + std::to_string(min_count_) + ", unbounded)"; @@ -369,6 +394,11 @@ class not_parser : public parser_base { return parser_result(PARSER_RESULT_SUCCESS, start); } + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + parser_->assign_id(counter); + } + std::string dump() const override { return "Not(" + parser_->dump() + ")"; } @@ -680,6 +710,11 @@ class group_parser : public parser_base { return result; } + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + parser_->assign_id(counter); + } + std::string dump() const override { return "Group(" + name_ + ", " + parser_->dump() + ")"; } @@ -828,6 +863,11 @@ class root_parser : public parser_base { return root_->parse(ctx, start); } + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + root_->assign_id(counter); + } + std::string dump() const override { return root_->dump(); } @@ -1117,102 +1157,6 @@ class gbnf_visitor : public parser_visitor { } }; -// ID assignment visitor for assigning unique IDs to parsers -class id_assignment_visitor : public parser_visitor { - std::shared_ptr counter_; - - public: - id_assignment_visitor(const std::shared_ptr & counter) : counter_(counter) {} - - void assign_id(parser_base & p) { - if (p.id() == -1) { - p.set_id(counter_->next()); - } - } - - void visit(literal_parser & p) override { - assign_id(p); - } - - void visit(any_parser & p) override { - assign_id(p); - } - - void visit(space_parser & p) override { - assign_id(p); - } - - void visit(chars_parser & p) override { - assign_id(p); - } - - void visit(json_string_parser & p) override { - assign_id(p); - } - - void visit(schema_parser & p) override { - assign_id(p); - } - - void visit(rule_parser & p) override { - assign_id(p); - } - - // Composite parsers - assign ID and traverse children - void visit(sequence_parser & p) override { - assign_id(p); - for (const auto & child : p.parsers()) { - child->accept(*this); - } - } - - void visit(choice_parser & p) override { - assign_id(p); - for (const auto & child : p.parsers()) { - child->accept(*this); - } - } - - void visit(one_or_more_parser & p) override { - assign_id(p); - p.child()->accept(*this); - } - - void visit(zero_or_more_parser & p) override { - assign_id(p); - p.child()->accept(*this); - } - - void visit(optional_parser & p) override { - assign_id(p); - p.child()->accept(*this); - } - - void visit(repetition_parser & p) override { - assign_id(p); - p.child()->accept(*this); - } - - void visit(until_parser & p) override { - assign_id(p); - } - - void visit(not_parser & p) override { - assign_id(p); - p.child()->accept(*this); - } - - void visit(group_parser & p) override { - assign_id(p); - p.child()->accept(*this); - } - - void visit(root_parser & p) override { - assign_id(p); - p.root()->accept(*this); - } -}; - // Implement accept() methods for all parser classes void literal_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void sequence_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } @@ -1419,8 +1363,7 @@ parser parser_builder::add_rule(const std::string & name, const parser & p) { void parser_builder::assign_ids(parser & p) { if (p.ptr()) { - id_assignment_visitor visitor(counter_); - p.ptr()->accept(visitor); + p.ptr()->assign_id(counter_); } } From 20f9a1b83b3d7516a44e1ba726930f57395a63c3 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 00:48:05 -0600 Subject: [PATCH 22/34] create named rules in the command r7b example --- common/chat-parser-combinator.cpp | 2 +- tests/test-chat-parser-combinator.cpp | 30 ++++++++++++++------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 064afd80f1686..e1a1b8f082a52 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -1384,7 +1384,7 @@ static parser json_parser(std::shared_ptr counter) { parser_builder builder(std::move(counter)); // Whitespace: space, tab, newline, carriage return - auto ws = builder.chars("[ \\t\\n\\r]", 0, -1); + auto ws = builder.space(); // Number components auto digit1_9 = builder.chars("[1-9]", 1, 1); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 6264d72f76fbb..032acba9b8d3b 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -710,38 +710,40 @@ static void test_gbnf_generation() { static parser create_command_r7b_parser() { auto parser = build_parser([](parser_builder & p) { - auto thinking = p.literal("<|START_THINKING|>") + auto thinking = p.add_rule("thinking", p.literal("<|START_THINKING|>") << p.until("<|END_THINKING|>") - << p.literal("<|END_THINKING|>"); + << p.literal("<|END_THINKING|>")); - auto response = p.literal("<|START_RESPONSE|>") + auto response = p.add_rule("response", p.literal("<|START_RESPONSE|>") << p.until("<|END_RESPONSE|>") - << p.literal("<|END_RESPONSE|>"); + << p.literal("<|END_RESPONSE|>")); - auto json = p.json(); - auto tool_call_id = p.json_key("tool_call_id", p.json_string(p.until("\""))); - auto tool_call_name = p.json_key("tool_name", p.json_string(p.until("\""))); - auto tool_call_args = p.json_key("parameters", json); - auto tool_call_fields = tool_call_id | tool_call_name | tool_call_args; + auto json = p.add_rule("json", p.json()); + auto tool_call_id = p.add_rule("tool-call-id", p.json_key("tool_call_id", p.json_string(p.until("\"")))); + auto tool_call_name = p.add_rule("tool-name", p.json_key("tool_name", p.json_string(p.until("\"")))); + auto tool_call_args = p.add_rule("tool-args", p.json_key("parameters", json)); + auto tool_call_fields = p.add_rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); - auto tool_call = p.between("{", + auto tool_call = p.add_rule("tool-call", p.between("{", tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields), - "}"); + "}")); - auto tool_calls = p.literal("<|START_ACTION|>") + auto tool_calls = p.add_rule("tool-calls", p.literal("<|START_ACTION|>") << p.literal("[") << tool_call << p.zero_or_more(p.literal(",") << tool_call) << p.literal("]") - << p.literal("<|END_ACTION|>"); + << p.literal("<|END_ACTION|>")); - return p.optional(thinking) << (tool_calls | response); + return p.optional(thinking) << p.add_rule("content", (tool_calls | response)); }); auto grammar = build_grammar([&](const common_grammar_builder & builder) { parser.build_grammar(builder); }); + std::cout << "=== Grammar ===\n\n" << grammar << "\n\n"; + return parser; } From 35b164037e19be2d971514be636a62de69ae92a2 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 01:09:39 -0600 Subject: [PATCH 23/34] use '.' for any in GBNF --- common/chat-parser-combinator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index e1a1b8f082a52..f4633de06f962 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -1080,7 +1080,7 @@ class gbnf_visitor : public parser_visitor { void visit(any_parser &) override { // Match any single character - current_result_ = "[\\x00-\\x{10FFFF}]"; + current_result_ = "."; } void visit(space_parser &) override { From bcb1c03c02cbcf03ea0c9e59dc2c6423d10c3d79 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 01:16:24 -0600 Subject: [PATCH 24/34] fix parens around choices in gbnf grammar --- common/chat-parser-combinator.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index f4633de06f962..4ab0835abcc2e 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -998,7 +998,13 @@ class gbnf_visitor : public parser_visitor { s += " "; } child->accept(*this); - s += current_result_; + + // Parenthesize choices + if (needs_parens(child->type())) { + s += "(" + current_result_ + ")"; + } else { + s += current_result_; + } } current_result_ = s; } @@ -1012,8 +1018,8 @@ class gbnf_visitor : public parser_visitor { child->accept(*this); - // Parenthesize sequences in choices - if (child->type() == PARSER_SEQUENCE) { + // Parenthesize choices + if (child->type() == PARSER_CHOICE) { s += "(" + current_result_ + ")"; } else { s += current_result_; From 4bed84dfe0341e46305d3a7f5f575b7853b59267 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 01:29:44 -0600 Subject: [PATCH 25/34] add convenience operators to turn strings to literals --- common/chat-parser-combinator.cpp | 18 ++++++++++++++++++ common/chat-parser-combinator.h | 7 +++++++ tests/test-chat-parser-combinator.cpp | 25 ++++++++++--------------- 3 files changed, 35 insertions(+), 15 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 4ab0835abcc2e..a096745130b57 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -1220,6 +1220,8 @@ parser::parser() {} parser::parser(std::shared_ptr parser) : ptr_(std::move(parser)) {} +parser::parser(const std::string & literal) : ptr_(std::make_shared(literal, -1)) {} + parser parser::operator~() const { return parser(std::make_shared(*this, -1)); } @@ -1228,15 +1230,31 @@ parser parser::operator+(const parser & other) const { return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } +parser parser::operator+(const std::string & literal) const { + auto lit = parser(std::make_shared(literal, -1)); + return parser(std::make_shared(std::initializer_list{*this, lit}, -1)); +} + parser parser::operator|(const parser & other) const { return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } +parser parser::operator|(const std::string & literal) const { + auto lit = parser(std::make_shared(literal, -1)); + return parser(std::make_shared(std::initializer_list{*this, lit}, -1)); +} + parser parser::operator<<(const parser & other) const { auto ws = parser(std::make_shared(-1)); return parser(std::make_shared(std::initializer_list{*this, ws, other}, -1)); } +parser parser::operator<<(const std::string & literal) const { + auto ws = parser(std::make_shared(-1)); + auto lit = parser(std::make_shared(literal, -1)); + return parser(std::make_shared(std::initializer_list{*this, ws, lit}, -1)); +} + parser_base & parser::operator*() const { return *ptr_; } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 6275de1fcde62..9e5270150bb48 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -105,6 +105,7 @@ class parser { parser(); parser(std::shared_ptr parser); parser(const parser & other) = default; + parser(const std::string & literal); parser & operator=(const parser & other) { if (this != &other) { ptr_ = other.ptr_; @@ -113,9 +114,15 @@ class parser { } parser operator~() const; + parser operator+(const parser & other) const; + parser operator+(const std::string & literal) const; + parser operator|(const parser & other) const; + parser operator|(const std::string & literal) const; + parser operator<<(const parser & other) const; + parser operator<<(const std::string & literal) const; parser_base & operator*() const; parser_base * operator->() const; diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 032acba9b8d3b..752afee4349d1 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -710,13 +710,11 @@ static void test_gbnf_generation() { static parser create_command_r7b_parser() { auto parser = build_parser([](parser_builder & p) { - auto thinking = p.add_rule("thinking", p.literal("<|START_THINKING|>") - << p.until("<|END_THINKING|>") - << p.literal("<|END_THINKING|>")); + auto thinking = p.add_rule("thinking", + p.literal("<|START_THINKING|>") << p.until("<|END_THINKING|>") << "<|END_THINKING|>"); - auto response = p.add_rule("response", p.literal("<|START_RESPONSE|>") - << p.until("<|END_RESPONSE|>") - << p.literal("<|END_RESPONSE|>")); + auto response = p.add_rule("response", + p.literal("<|START_RESPONSE|>") << p.until("<|END_RESPONSE|>") << "<|END_RESPONSE|>"); auto json = p.add_rule("json", p.json()); auto tool_call_id = p.add_rule("tool-call-id", p.json_key("tool_call_id", p.json_string(p.until("\"")))); @@ -724,16 +722,13 @@ static parser create_command_r7b_parser() { auto tool_call_args = p.add_rule("tool-args", p.json_key("parameters", json)); auto tool_call_fields = p.add_rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); - auto tool_call = p.add_rule("tool-call", p.between("{", - tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields), - "}")); + auto tool_call = p.add_rule("tool-call", + p.between("{", tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields), "}")); - auto tool_calls = p.add_rule("tool-calls", p.literal("<|START_ACTION|>") - << p.literal("[") - << tool_call - << p.zero_or_more(p.literal(",") << tool_call) - << p.literal("]") - << p.literal("<|END_ACTION|>")); + auto tool_calls = p.add_rule("tool-calls", + p.literal("<|START_ACTION|>") + << "[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]" + << "<|END_ACTION|>"); return p.optional(thinking) << p.add_rule("content", (tool_calls | response)); }); From c02aaa61b53fc1f99dd7ca236ffabeaf9540a1cd Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 01:44:20 -0600 Subject: [PATCH 26/34] add free-form operators for const char * to simplify defining literals --- common/chat-parser-combinator.cpp | 27 +++++---------------------- common/chat-parser-combinator.h | 14 ++++++-------- tests/test-chat-parser-combinator.cpp | 12 ++++++------ 3 files changed, 17 insertions(+), 36 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index a096745130b57..81b55240a842f 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -1222,6 +1222,8 @@ parser::parser(std::shared_ptr parser) : ptr_(std::move(parser)) {} parser::parser(const std::string & literal) : ptr_(std::make_shared(literal, -1)) {} +parser::parser(const char * literal) : ptr_(std::make_shared(literal, -1)) {} + parser parser::operator~() const { return parser(std::make_shared(*this, -1)); } @@ -1230,30 +1232,18 @@ parser parser::operator+(const parser & other) const { return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } -parser parser::operator+(const std::string & literal) const { - auto lit = parser(std::make_shared(literal, -1)); - return parser(std::make_shared(std::initializer_list{*this, lit}, -1)); -} - parser parser::operator|(const parser & other) const { return parser(std::make_shared(std::initializer_list{*this, other}, -1)); } -parser parser::operator|(const std::string & literal) const { - auto lit = parser(std::make_shared(literal, -1)); - return parser(std::make_shared(std::initializer_list{*this, lit}, -1)); -} - parser parser::operator<<(const parser & other) const { auto ws = parser(std::make_shared(-1)); return parser(std::make_shared(std::initializer_list{*this, ws, other}, -1)); } -parser parser::operator<<(const std::string & literal) const { - auto ws = parser(std::make_shared(-1)); - auto lit = parser(std::make_shared(literal, -1)); - return parser(std::make_shared(std::initializer_list{*this, ws, lit}, -1)); -} +parser operator+(const char * lhs, const parser & rhs) { return parser(lhs) + rhs; } +parser operator|(const char * lhs, const parser & rhs) { return parser(lhs) | rhs; } +parser operator<<(const char * lhs, const parser & rhs) { return parser(lhs) << rhs; } parser_base & parser::operator*() const { return *ptr_; @@ -1373,13 +1363,6 @@ parser parser_builder::json_string(const parser & p) { return quote + p + quote; } -parser parser_builder::between(const std::string & left, const parser & p, const std::string & right, bool allow_spaces) { - if (allow_spaces) { - return literal(left) << p << literal(right); - } - return literal(left) + p + literal(right); -} - parser parser_builder::add_rule(const std::string & name, const parser & p) { (*rules_)[name] = p; return rule(name); diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 9e5270150bb48..5b8b08aeb528f 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -106,6 +106,8 @@ class parser { parser(std::shared_ptr parser); parser(const parser & other) = default; parser(const std::string & literal); + parser(const char * literal); + parser & operator=(const parser & other) { if (this != &other) { ptr_ = other.ptr_; @@ -114,15 +116,9 @@ class parser { } parser operator~() const; - parser operator+(const parser & other) const; - parser operator+(const std::string & literal) const; - parser operator|(const parser & other) const; - parser operator|(const std::string & literal) const; - parser operator<<(const parser & other) const; - parser operator<<(const std::string & literal) const; parser_base & operator*() const; parser_base * operator->() const; @@ -136,6 +132,10 @@ class parser { void build_grammar(const common_grammar_builder & builder) const; }; +parser operator+(const char * lhs, const parser & rhs); +parser operator|(const char * lhs, const parser & rhs); +parser operator<<(const char * lhs, const parser & rhs); + class parser_id_counter { int next_id_; public: @@ -231,8 +231,6 @@ class parser_builder { parser json_key(const std::string & name, const parser & p); parser json_string(const parser & p); - parser between(const std::string & left, const parser & p, const std::string & right, bool allow_spaces = true); - // Wraps a parser with JSON schema metadata for grammar generation. // Used internally to convert JSON schemas to GBNF grammar rules. parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 752afee4349d1..c20cf5417af2c 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -711,10 +711,10 @@ static void test_gbnf_generation() { static parser create_command_r7b_parser() { auto parser = build_parser([](parser_builder & p) { auto thinking = p.add_rule("thinking", - p.literal("<|START_THINKING|>") << p.until("<|END_THINKING|>") << "<|END_THINKING|>"); + "<|START_THINKING|>" << p.until("<|END_THINKING|>") << "<|END_THINKING|>"); auto response = p.add_rule("response", - p.literal("<|START_RESPONSE|>") << p.until("<|END_RESPONSE|>") << "<|END_RESPONSE|>"); + "<|START_RESPONSE|>" << p.until("<|END_RESPONSE|>") << "<|END_RESPONSE|>"); auto json = p.add_rule("json", p.json()); auto tool_call_id = p.add_rule("tool-call-id", p.json_key("tool_call_id", p.json_string(p.until("\"")))); @@ -723,14 +723,14 @@ static parser create_command_r7b_parser() { auto tool_call_fields = p.add_rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); auto tool_call = p.add_rule("tool-call", - p.between("{", tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields), "}")); + "{" << tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields) << "}"); auto tool_calls = p.add_rule("tool-calls", - p.literal("<|START_ACTION|>") - << "[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]" + "<|START_ACTION|>" + << ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]") << "<|END_ACTION|>"); - return p.optional(thinking) << p.add_rule("content", (tool_calls | response)); + return p.optional(thinking) << p.add_rule("content", tool_calls | response); }); auto grammar = build_grammar([&](const common_grammar_builder & builder) { From 8e821275f07b7a031ee95e9246030e43a7035222 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 01:51:43 -0600 Subject: [PATCH 27/34] simplify test case parser --- tests/test-chat-parser-combinator.cpp | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index c20cf5417af2c..93d730ee8ffe2 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -447,9 +447,7 @@ static void test_complete_example() { // auto parser = build_parser([](parser_builder & p) { auto reasoning = p.add_rule("reasoning", - p.literal("") - << p.group("reasoning-content", p.until("")) - << p.literal("")); + "" << p.group("reasoning-content", p.until("")) << ""); auto content = p.add_rule("content", p.group("content", p.until(""))); @@ -457,22 +455,15 @@ static void test_complete_example() { auto json = p.json(); auto tool_call_name = p.add_rule("tool-call-name", - p.literal("") - << p.group("tool-name", p.until("")) - << p.literal("")); + "" << p.group("tool-name", p.until("")) << ""); auto schema = nlohmann::ordered_json::parse(R"({"type": "object"})"); auto tool_call_args = p.add_rule("tool-call-args", - p.literal("") - << p.group("tool-args", p.schema(json, "get_weather", schema)) - << p.literal("")); + "" << p.group("tool-args", p.schema(json, "get_weather", schema)) << ""); auto tool_call = p.add_rule("tool-call", - p.literal("") - << tool_call_name - << tool_call_args - << p.literal("")); + "" << tool_call_name << tool_call_args << ""); return reasoning << p.optional(content) << p.optional(tool_call); }); From 9685b696584b960d04225e67a18ce4be02fe0f9a Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 02:32:01 -0600 Subject: [PATCH 28/34] implement semantic actions --- common/chat-parser-combinator.cpp | 49 +++++++++++ common/chat-parser-combinator.h | 33 +++++++- tests/test-chat-parser-combinator.cpp | 112 ++++++++++++++++++++++++++ 3 files changed, 190 insertions(+), 4 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 81b55240a842f..cf68c646ac1b8 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -26,6 +26,7 @@ enum parser_type { PARSER_SCHEMA = 14, PARSER_ROOT = 15, PARSER_JSON_STRING = 16, + PARSER_ACTION = 17, }; class parser_visitor; @@ -879,6 +880,43 @@ class root_parser : public parser_base { std::shared_ptr> rules() const { return rules_; } }; +// Wraps a parser with a semantic action callback. +class action_parser : public parser_base { + parser parser_; + std::function action_; + + public: + action_parser(const parser & parser, std::function action, int id) + : parser_base(id), parser_(parser), action_(std::move(action)) {} + + parser_type type() const override { return PARSER_ACTION; } + + parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { + auto result = parser_->parse(ctx, start); + + // Invoke action callback on success if environment is available + if (result.is_success() && ctx.env && action_) { + std::string_view matched = ctx.input.substr(result.start, result.end - result.start); + action_(result, matched, *ctx.env); + } + + return result; + } + + void assign_id(std::shared_ptr counter) override { + parser_base::assign_id(counter); + parser_->assign_id(counter); + } + + std::string dump() const override { + return "Action(" + parser_->dump() + ")"; + } + + void accept(parser_visitor & visitor) override; + + const parser & child() const { return parser_; } +}; + // Base visitor class for parser tree traversal class parser_visitor { public: @@ -901,6 +939,7 @@ class parser_visitor { virtual void visit(schema_parser & p) = 0; virtual void visit(rule_parser & p) = 0; virtual void visit(root_parser & p) = 0; + virtual void visit(action_parser & p) = 0; }; class gbnf_visitor : public parser_visitor { @@ -1161,6 +1200,11 @@ class gbnf_visitor : public parser_visitor { // Return root body for composition p.root()->accept(*this); } + + void visit(action_parser & p) override { + // Actions are transparent for grammar generation - just visit child + p.child()->accept(*this); + } }; // Implement accept() methods for all parser classes @@ -1181,6 +1225,7 @@ void group_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void schema_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void rule_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void root_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } +void action_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } std::optional parser_result::group(const std::string & name, std::string_view input) const { auto it = groups.find(name); @@ -1354,6 +1399,10 @@ parser parser_builder::schema(const parser & p, const std::string & name, const return parser(std::make_shared(p, name, schema, counter_->next())); } +parser parser_builder::action(const parser & p, std::function fn) { + return parser(std::make_shared(p, std::move(fn), counter_->next())); +} + parser parser_builder::json_key(const std::string & name, const parser & p) { return literal("\"" + name + "\"") << literal(":") << p; } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 5b8b08aeb528f..585ca44d04b69 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -8,9 +8,19 @@ #include #include #include +#include +#include +struct common_chat_tool_call; struct common_grammar_builder; +struct parser_environment { + std::string content; + std::string reasoning; + std::vector tool_calls; + std::unordered_map> scratchpad; +}; + enum parser_result_type { PARSER_RESULT_FAIL = 0, PARSER_RESULT_NEED_MORE_INPUT = 1, @@ -82,18 +92,28 @@ struct parser_context { std::string_view input; parse_cache memo; bool input_is_complete; + parser_environment * env; parser_context() - : memo(), input_is_complete(true) {} + : memo(), input_is_complete(true), env(nullptr) {} parser_context(std::string_view input) - : input(input), memo(), input_is_complete(true) {} + : input(input), memo(), input_is_complete(true), env(nullptr) {} parser_context(std::string_view input, bool complete) - : input(input), memo(), input_is_complete(complete) {} + : input(input), memo(), input_is_complete(complete), env(nullptr) {} parser_context(std::string_view input, parse_cache memo, bool complete = true) - : input(input), memo(std::move(memo)), input_is_complete(complete) {} + : input(input), memo(std::move(memo)), input_is_complete(complete), env(nullptr) {} + + parser_context(std::string_view input, parser_environment * environment) + : input(input), memo(), input_is_complete(true), env(environment) {} + + parser_context(std::string_view input, parser_environment * environment, bool complete) + : input(input), memo(), input_is_complete(complete), env(environment) {} + + parser_context(std::string_view input, parse_cache memo, parser_environment * environment, bool complete = true) + : input(input), memo(std::move(memo)), input_is_complete(complete), env(environment) {} }; class parser_base; @@ -235,6 +255,11 @@ class parser_builder { // Used internally to convert JSON schemas to GBNF grammar rules. parser schema(const parser & p, const std::string & name, const nlohmann::ordered_json & schema); + // Wraps a parser with a semantic action callback. + // The callback is invoked on successful parse with the result, matched text, and environment. + // S -> A [action] + parser action(const parser & p, std::function fn); + parser add_rule(const std::string & name, const parser & p); void assign_ids(parser & p); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 93d730ee8ffe2..57795f5403bf1 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -533,6 +533,117 @@ static void test_complete_example() { std::cout << "Grammar:\n" << gbnf << "\n"; } +static void test_actions() { + { + // Test simple action - append matched text to content + auto parser = build_parser([](parser_builder& p) { + auto word = p.chars("[a-z]+"); + return p.action(word, [](const parser_result &, std::string_view matched, parser_environment & env) { + env.content += std::string(matched); + }); + }); + + parser_environment env; + parser_context ctx("hello", &env); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello", env.content); + } + { + // Test multiple sequential actions - build a sentence + auto parser = build_parser([](parser_builder& p) { + auto greeting = p.action(p.literal("hello"), [](const parser_result &, std::string_view matched, parser_environment & env) { + env.content += std::string(matched) + " "; + }); + + auto name = p.action(p.chars("[A-Z][a-z]+"), [](const parser_result &, std::string_view matched, parser_environment & env) { + env.content += std::string(matched); + env.scratchpad["name"] = std::string(matched); + }); + + return greeting + p.literal(" ") + name; + }); + + parser_environment env; + parser_context ctx("hello Alice", &env); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello Alice", env.content); + assert_equals("Alice", std::get(env.scratchpad["name"])); + } + { + // Test using scratchpad for intermediate calculations + auto parser = build_parser([](parser_builder& p) { + auto digit = p.action(p.one("[0-9]"), [](const parser_result &, std::string_view matched, parser_environment & env) { + auto it = env.scratchpad.find("sum"); + int current_sum = it != env.scratchpad.end() ? std::get(it->second) : 0; + current_sum += (matched[0] - '0'); + env.scratchpad["sum"] = current_sum; + }); + + return p.one_or_more(digit + p.optional(p.literal("+"))); + }); + + parser_environment env; + parser_context ctx("1+2+3+4", &env); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals(10, std::get(env.scratchpad["sum"])); // 1+2+3+4 = 10 + } + { + // Test actions don't run when parse fails + auto parser = build_parser([](parser_builder& p) { + return p.action(p.literal("success"), [](const parser_result &, std::string_view, parser_environment & env) { + env.content = "action_ran"; + }); + }); + + parser_environment env; + parser_context ctx("failure", &env); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_fail()); + assert_equals("", env.content); // Action should not have run + } + { + // Test Actions work with partial parsing + auto parser = build_parser([](parser_builder& p) { + auto content = p.action(p.until(""), [](const parser_result &, std::string_view matched, parser_environment & env) { + env.content += std::string(matched); + }); + return "" << content << ""; + }); + + { + parser_environment env; + parser_context ctx("hello ", &env, false); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello", env.content); + } + { + parser_environment env; + parser_context ctx("hello world", &env, false); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello world", env.content); + } + { + parser_environment env; + parser_context ctx("hello world", &env, true); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("hello world", env.content); + } + } +} + static void test_gbnf_generation() { { // Test literal @@ -888,6 +999,7 @@ int main() { test_optional(); test_json_parser(); test_complete_example(); + test_actions(); test_gbnf_generation(); std::cout << "All tests passed!\n"; From d9a62295b8fa9e39c58fda1e286c80fb352e3714 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 03:06:47 -0600 Subject: [PATCH 29/34] remove groups in favor of actions and a scratchpad --- common/chat-parser-combinator.cpp | 118 +++----------- common/chat-parser-combinator.h | 25 +-- tests/test-chat-parser-combinator.cpp | 222 ++++++++++++-------------- 3 files changed, 120 insertions(+), 245 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index cf68c646ac1b8..04f354232f548 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -19,14 +19,13 @@ enum parser_type { PARSER_NOT = 7, PARSER_ANY = 8, PARSER_CHARS = 9, - PARSER_GROUP = 10, - PARSER_RULE = 11, - PARSER_UNTIL = 12, - PARSER_SPACE = 13, - PARSER_SCHEMA = 14, - PARSER_ROOT = 15, - PARSER_JSON_STRING = 16, - PARSER_ACTION = 17, + PARSER_RULE = 10, + PARSER_UNTIL = 11, + PARSER_SPACE = 12, + PARSER_SCHEMA = 13, + PARSER_ROOT = 14, + PARSER_JSON_STRING = 15, + PARSER_ACTION = 16, }; class parser_visitor; @@ -81,6 +80,10 @@ static bool is_space(const char c) { return (c == ' ' || c == '\t' || c == '\n'); } +static bool is_hex_digit(const char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); +} + // Matches an exact literal string. // S -> "hello" class literal_parser : public parser_base { @@ -144,31 +147,26 @@ class sequence_parser : public parser_base { parser_type type() const override { return PARSER_SEQUENCE; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { - std::unordered_map groups; - auto pos = start; for (const auto & p : parsers_) { auto result = p->parse(ctx, pos); - // Copy groups - groups.insert(result.groups.begin(), result.groups.end()); - if (result.is_fail()) { if (result.end >= ctx.input.size() && !ctx.input_is_complete) { // If we fail because we don't have enough input, then return success - return parser_result(PARSER_RESULT_SUCCESS, start, result.end, groups); + return parser_result(PARSER_RESULT_SUCCESS, start, result.end); } - return parser_result(PARSER_RESULT_FAIL, start, result.end, groups); + return parser_result(PARSER_RESULT_FAIL, start, result.end); } if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end, groups); + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end); } pos = result.end; } - return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); } void assign_id(std::shared_ptr counter) override { @@ -267,14 +265,12 @@ class repetition_parser : public parser_base { parser_type type() const override { return PARSER_REPETITION; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { - std::unordered_map groups; auto pos = start; int match_count = 0; // Try to match up to max_count times (or unlimited if max_count is -1) while (max_count_ == -1 || match_count < max_count_) { auto result = parser_->parse(ctx, pos); - groups.insert(result.groups.begin(), result.groups.end()); if (result.is_success()) { // Prevent infinite loop on empty matches @@ -287,7 +283,7 @@ class repetition_parser : public parser_base { } if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos, groups); + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); } // Child failed - stop trying @@ -296,10 +292,10 @@ class repetition_parser : public parser_base { // Check if we got enough matches if (match_count < min_count_) { - return parser_result(PARSER_RESULT_FAIL, start, pos, groups); + return parser_result(PARSER_RESULT_FAIL, start, pos); } - return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); } void assign_id(std::shared_ptr counter) override { @@ -595,20 +591,13 @@ class chars_parser : public parser_base { // Handles escape sequences and emits NEED_MORE_INPUT for incomplete input. // S -> (regular chars and escape sequences)* until closing " class json_string_parser : public parser_base { - std::optional capture_name_; - - static bool is_hex_digit(char c) { - return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); - } public: - json_string_parser(std::optional capture_name, int id) - : parser_base(id), capture_name_(std::move(capture_name)) {} + json_string_parser(int id) : parser_base(id) {} parser_type type() const override { return PARSER_JSON_STRING; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { - std::unordered_map groups; auto pos = start; // Parse string content (without quotes) @@ -617,10 +606,7 @@ class json_string_parser : public parser_base { if (c == '"') { // Found closing quote - success (don't consume it) - if (capture_name_) { - groups[*capture_name_] = parser_match_location{start, pos}; - } - return parser_result(PARSER_RESULT_SUCCESS, start, pos, groups); + return parser_result(PARSER_RESULT_SUCCESS, start, pos); } if (c == '\\') { @@ -681,48 +667,10 @@ class json_string_parser : public parser_base { } std::string dump() const override { - if (capture_name_) { - return "JsonString(" + *capture_name_ + ")"; - } return "JsonString()"; } void accept(parser_visitor & visitor) override; - - const std::optional & capture_name() const { return capture_name_; } -}; - -// Captures the matched text from a parser and stores it with a name. -// S -> -class group_parser : public parser_base { - std::string name_; - parser parser_; - - public: - group_parser(const std::string & name, const parser & parser, int id) : parser_base(id), name_(name), parser_(parser) {} - - parser_type type() const override { return PARSER_GROUP; } - - parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { - auto result = parser_->parse(ctx, start); - - // Store result - result.groups[name_] = parser_match_location{result.start, result.end}; - return result; - } - - void assign_id(std::shared_ptr counter) override { - parser_base::assign_id(counter); - parser_->assign_id(counter); - } - - std::string dump() const override { - return "Group(" + name_ + ", " + parser_->dump() + ")"; - } - - void accept(parser_visitor & visitor) override; - - const parser & child() const { return parser_; } }; // Matches all characters until a delimiter is found (delimiter not consumed). @@ -935,7 +883,6 @@ class parser_visitor { virtual void visit(space_parser & p) = 0; virtual void visit(chars_parser & p) = 0; virtual void visit(json_string_parser & p) = 0; - virtual void visit(group_parser & p) = 0; virtual void visit(schema_parser & p) = 0; virtual void visit(rule_parser & p) = 0; virtual void visit(root_parser & p) = 0; @@ -1165,11 +1112,6 @@ class gbnf_visitor : public parser_visitor { current_result_ = R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; } - void visit(group_parser & p) override { - // Groups are transparent - just visit child - p.child()->accept(*this); - } - void visit(schema_parser & p) override { current_result_ = builder_.add_schema(p.name(), p.schema()); } @@ -1221,21 +1163,11 @@ void any_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void space_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void chars_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void json_string_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } -void group_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void schema_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void rule_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void root_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } void action_parser::accept(parser_visitor & visitor) { visitor.visit(*this); } -std::optional parser_result::group(const std::string & name, std::string_view input) const { - auto it = groups.find(name); - if (it == groups.end()) { - return std::nullopt; - } - - return std::string(it->second.view(input)); -} - parser_result parse_cache::set(int id, size_t start, parser_result result) { if (id == -1) { // Don't cache parsers with ID -1 (from operators and global factory functions) @@ -1364,15 +1296,7 @@ parser parser_builder::one(const std::string & classes) { } parser parser_builder::json_string() { - return parser(std::make_shared(std::nullopt, counter_->next())); -} - -parser parser_builder::json_string(const std::string & name) { - return parser(std::make_shared(name, counter_->next())); -} - -parser parser_builder::group(const std::string & name, const parser & p) { - return parser(std::make_shared(name, p, counter_->next())); + return parser(std::make_shared(counter_->next())); } parser parser_builder::rule(const std::string & name) { diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 585ca44d04b69..cb7da41740a75 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -16,7 +16,7 @@ struct common_grammar_builder; struct parser_environment { std::string content; - std::string reasoning; + std::string reasoning_content; std::vector tool_calls; std::unordered_map> scratchpad; }; @@ -43,24 +43,11 @@ struct std::hash { } }; -struct parser_match_location { - size_t start; - size_t end; - - size_t length() const { return end - start; } - - std::string_view view(std::string_view sv) const { - return sv.substr(start, length()); - } -}; - struct parser_result { parser_result_type type = PARSER_RESULT_FAIL; size_t start = 0; size_t end = 0; - std::unordered_map groups; - parser_result() : type(PARSER_RESULT_FAIL) {} parser_result(parser_result_type type, size_t start) @@ -69,14 +56,9 @@ struct parser_result { parser_result(parser_result_type type, size_t start, size_t end) : type(type), start(start), end(end) {} - parser_result(parser_result_type type, size_t start, size_t end, const std::unordered_map & groups) - : type(type), start(start), end(end), groups(groups) {} - bool is_fail() const { return type == PARSER_RESULT_FAIL; } bool is_need_more_input() const { return type == PARSER_RESULT_NEED_MORE_INPUT; } bool is_success() const { return type == PARSER_RESULT_SUCCESS; } - - std::optional group(const std::string & name, std::string_view input) const; }; class parse_cache { @@ -215,10 +197,6 @@ class parser_builder { // Equivalent to chars(classes, 1, 1) parser one(const std::string & classes); - // Captures the matched text from a parser and stores it with a name. - // S -> - parser group(const std::string & name, const parser & p); - // References a named rule for recursive or reusable grammar definitions. // expr -> term | expr "+" term parser rule(const std::string & name); @@ -246,7 +224,6 @@ class parser_builder { // Specialized single-pass JSON string parser with escape sequence handling parser json_string(); - parser json_string(const std::string & name); parser json_key(const std::string & name, const parser & p); parser json_string(const parser & p); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 57795f5403bf1..a309f7fce5a1c 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -4,6 +4,7 @@ #include "nlohmann/json.hpp" +#include "chat.h" #include "chat-parser.h" #include "chat-parser-combinator.h" #include "common.h" @@ -181,71 +182,6 @@ static void test_partial_parsing() { } } -static void test_capture_groups() { - { - auto parser = build_parser([](parser_builder& p) { - return p.literal("") + - p.group("reasoning_content", - p.zero_or_more(~p.literal("") + p.any()) - ) + - p.literal(""); - }); - - std::string input = "I have a thought"; - auto ctx = parser_context(input); - auto result = parser.parse(ctx); - - assert_equals(true, result.is_success()); - - auto it = result.groups.find("reasoning_content"); - assert_equals(true, it != result.groups.end()); - assert_equals("I have a thought", std::string(it->second.view(input))); - } - { - auto parser = build_parser([](parser_builder& p) { - return p.literal("") + - p.group("reasoning_content", - p.zero_or_more(~p.literal("") + p.any()) - ) + - p.literal(""); - }); - - std::string input = "I have a "; - auto ctx = parser_context(input, false); - auto result = parser.parse(ctx); - - assert_equals(true, result.is_success()); - - auto it = result.groups.find("reasoning_content"); - assert_equals(true, it != result.groups.end()); - assert_equals("I have a ", std::string(it->second.view(input))); - } - { - auto parser = build_parser([](parser_builder& p) { - return p.literal("") + - p.group("reasoning_content", - p.zero_or_more(~p.literal("") + p.any()) - ) + - p.literal("") + - p.group("content", p.zero_or_more(p.any())); - }); - - std::string input = "The user said hello.Hello!"; - auto ctx = parser_context(input, true); - auto result = parser.parse(ctx); - - assert_equals(true, result.is_success()); - - auto it = result.groups.find("reasoning_content"); - assert_equals(true, it != result.groups.end()); - assert_equals("The user said hello.", std::string(it->second.view(input))); - - it = result.groups.find("content"); - assert_equals(true, it != result.groups.end()); - assert_equals("Hello!", std::string(it->second.view(input))); - } -} - static void test_one() { { // Test common escape sequences @@ -446,85 +382,136 @@ static void test_complete_example() { // // auto parser = build_parser([](parser_builder & p) { + auto handle_reasoning = [](const parser_result &, std::string_view match, parser_environment & env) { + env.reasoning_content += match; + }; + + auto handle_content = [](const parser_result &, std::string_view match, parser_environment & env) { + env.content += match; + }; + + auto handle_tool_call_name = [](const parser_result &, std::string_view match, parser_environment & env) { + env.scratchpad["tool_name"] = std::string(match); + }; + + auto handle_tool_call_args = [](const parser_result &, std::string_view match, parser_environment & env) { + env.scratchpad["tool_args"] = std::string(match); + }; + + auto handle_tool_call = [](const parser_result &, std::string_view, parser_environment & env) { + auto name = env.scratchpad.find("tool_name"); + auto args = env.scratchpad.find("tool_args"); + if (name != env.scratchpad.end() && args != env.scratchpad.end()) { + auto tool_call = common_chat_tool_call{ + std::get(name->second), + std::get(args->second), + std::string() + }; + + env.tool_calls.push_back(tool_call); + } + }; + auto reasoning = p.add_rule("reasoning", - "" << p.group("reasoning-content", p.until("")) << ""); + "" << p.action(p.until(""), handle_reasoning) << ""); auto content = p.add_rule("content", - p.group("content", p.until(""))); + p.action(p.until(""), handle_content)); auto json = p.json(); auto tool_call_name = p.add_rule("tool-call-name", - "" << p.group("tool-name", p.until("")) << ""); + "" << p.action(p.until(""), handle_tool_call_name) << ""); auto schema = nlohmann::ordered_json::parse(R"({"type": "object"})"); auto tool_call_args = p.add_rule("tool-call-args", - "" << p.group("tool-args", p.schema(json, "get_weather", schema)) << ""); + "" << p.action(p.schema(json, "get_weather", schema), handle_tool_call_args) << ""); auto tool_call = p.add_rule("tool-call", - "" << tool_call_name << tool_call_args << ""); + "" << p.action(tool_call_name << tool_call_args, handle_tool_call) << ""); return reasoning << p.optional(content) << p.optional(tool_call); }); // Test complete input - std::string input = R"(I need to call get_weather with city = New Yorkget_weather{"city": "New York"})"; - parser_context ctx(input); + { + std::string input = R"(I need to call get_weather with city = New Yorkget_weather{"city": "New York"})"; + parser_environment env; + parser_context ctx(input, &env); - auto result = parser.parse(ctx); + auto result = parser.parse(ctx); - assert_equals(true, result.is_success()); - assert_equals(input.size(), result.end); - assert_equals(std::string("I need to call get_weather with city = New York"), *result.group("reasoning-content", ctx.input)); - assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); - assert_equals(std::string(R"({"city": "New York"})"), *result.group("tool-args", ctx.input)); + assert_equals(true, result.is_success()); + assert_equals(input.size(), result.end); + assert_equals("I need to call get_weather with city = New York", env.reasoning_content); + assert_equals((size_t)1, env.tool_calls.size()); + assert_equals("", env.tool_calls[0].id); + assert_equals("get_weather", env.tool_calls[0].name); + assert_equals(R"({"city": "New York"})", env.tool_calls[0].arguments); + } // Test partial input - input = R"(I need to call get_weather )"; - ctx = parser_context(input, /* .is_input_complete = */ false); - result = parser.parse(ctx); + { + std::string input = R"(I need to call get_weather )"; + parser_environment env = parser_environment(); + parser_context ctx = parser_context(input, &env, /* .is_input_complete = */ false); - assert_equals(true, result.is_success()); - assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + auto result = parser.parse(ctx); - input = R"(I need to call I need to call I need to call get_weatherI need to call get_weatherI need to call get_weatherget_weather)"; - ctx = parser_context(input, /* .is_input_complete = */ false); - result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + } + { + std::string input = R"(I need to call get_weatherget_weather)"; + parser_environment env = parser_environment(); + parser_context ctx = parser_context(input, &env, /* .is_input_complete = */ false); - assert_equals(true, result.is_success()); - assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); + auto result = parser.parse(ctx); - input = R"(I need to call get_weatherget_weatherI need to call get_weatherget_weatherI need to call get_weatherget_weather{"cit)"; - ctx = parser_context(input, /* .is_input_complete = */ false); - result = parser.parse(ctx); + assert_equals(true, result.is_need_more_input()); + assert_equals("I need to call get_weather", env.reasoning_content); + } + { + std::string input = R"(I need to call get_weatherget_weather{"cit)"; + parser_environment env = parser_environment(); + parser_context ctx = parser_context(input, &env, /* .is_input_complete = */ false); - assert_equals(true, result.is_success()); - assert_equals(std::string("I need to call get_weather"), *result.group("reasoning-content", ctx.input)); - assert_equals(std::string("get_weather"), *result.group("tool-name", ctx.input)); - assert_equals(std::string(R"({"cit)"), *result.group("tool-args", ctx.input)); + auto result = parser.parse(ctx); + + assert_equals(true, result.is_success()); + assert_equals("I need to call get_weather", env.reasoning_content); + assert_equals("get_weather", std::get(env.scratchpad["tool_name"])); + assert_equals(R"({"cit)", std::get(env.scratchpad["tool_args"])); + } auto gbnf = build_grammar([&](const common_grammar_builder & builder) { parser.build_grammar(builder); @@ -743,18 +730,6 @@ static void test_gbnf_generation() { // Should generate pattern that prevents matching the full delimiter assert_equals(true, gbnf.find("root ::= ([^<] | \"<\" [^/] | \"])*") != std::string::npos); } - { - // Test groups are transparent - auto parser = build_parser([](parser_builder& p) { - return p.group("test", p.literal("hello")); - }); - - auto gbnf = build_grammar([&](const common_grammar_builder & builder) { - parser.build_grammar(builder); - }); - - assert_equals(true, gbnf.find("root ::= \"hello\"") != std::string::npos); - } { // Test complex expression with parentheses auto parser = build_parser([](parser_builder& p) { @@ -994,7 +969,6 @@ static void benchmark_compare( int main() { test_partial_parsing(); test_one(); - test_capture_groups(); test_recursive_references(); test_optional(); test_json_parser(); From 117d908c6ef4c9038cc9611307b809a55d02a206 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 03:31:01 -0600 Subject: [PATCH 30/34] add built in actions for common operations --- common/chat-parser-combinator.cpp | 76 +++++++++++++++++++++++++++ common/chat-parser-combinator.h | 36 +++++++++++++ tests/test-chat-parser-combinator.cpp | 44 +++------------- 3 files changed, 119 insertions(+), 37 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 04f354232f548..60cb54f84cc1f 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -1,6 +1,7 @@ #include "chat-parser-combinator.h" #include "json-schema-to-grammar.h" #include "common.h" +#include "chat.h" #include "log.h" #include @@ -84,6 +85,22 @@ static bool is_hex_digit(const char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); } +// Unescapes a JSON string (without the surrounding quotes) +// Uses nlohmann::json::parse to handle all JSON escape sequences +static std::string unescape_json_string(std::string_view str) { + try { + // Wrap in quotes and parse as JSON string + std::string quoted = "\"" + std::string(str) + "\""; + auto parsed = nlohmann::json::parse(quoted); + if (parsed.is_string()) { + return parsed.get(); + } + } catch (...) { + // If parsing fails, return original string + } + return std::string(str); +} + // Matches an exact literal string. // S -> "hello" class literal_parser : public parser_base { @@ -1327,6 +1344,65 @@ parser parser_builder::action(const parser & p, std::function(p, std::move(fn), counter_->next())); } +parser parser_builder::append_reasoning(const parser & p) { + return action(p, [](const parser_result &, std::string_view matched, parser_environment & env) { + if (!env.reasoning_content.empty()) { + env.reasoning_content += "\n"; + } + env.reasoning_content += matched; + }); +} + +parser parser_builder::append_content(const parser & p) { + return action(p, [](const parser_result &, std::string_view matched, parser_environment & env) { + if (!env.content.empty()) { + env.content += "\n"; + } + env.content += matched; + }); +} + +parser parser_builder::capture(const parser & p, const std::string & key, bool unescape_json) { + return action(p, [key, unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + std::string value = unescape_json ? unescape_json_string(matched) : std::string(matched); + env.scratchpad[key] = std::move(value); + }); +} + +parser parser_builder::capture_tool_call_id(const parser & p, bool unescape_json) { + return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + env.tool_call_id = unescape_json ? unescape_json_string(matched) : std::string(matched); + }); +} + +parser parser_builder::capture_tool_call_name(const parser & p, bool unescape_json) { + return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + env.tool_call_name = unescape_json ? unescape_json_string(matched) : std::string(matched); + }); +} + +parser parser_builder::capture_tool_call_args(const parser & p, bool unescape_json) { + return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + env.tool_call_args = unescape_json ? unescape_json_string(matched) : std::string(matched); + }); +} + +parser parser_builder::add_tool_call(const parser & p) { + return action(p, [](const parser_result &, std::string_view, parser_environment & env) { + auto tool_call = common_chat_tool_call{ + env.tool_call_name, + env.tool_call_args, + env.tool_call_id + }; + env.tool_calls.push_back(tool_call); + + // Clear the fields to prevent bleeding to next tool call + env.tool_call_id.clear(); + env.tool_call_name.clear(); + env.tool_call_args.clear(); + }); +} + parser parser_builder::json_key(const std::string & name, const parser & p) { return literal("\"" + name + "\"") << literal(":") << p; } diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index cb7da41740a75..98ca83f7bfb02 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -18,6 +18,13 @@ struct parser_environment { std::string content; std::string reasoning_content; std::vector tool_calls; + + // Tool call fields for building tool calls + std::string tool_call_id; + std::string tool_call_name; + std::string tool_call_args; + + // Scratch pad for any custom logic std::unordered_map> scratchpad; }; @@ -225,6 +232,7 @@ class parser_builder { // Specialized single-pass JSON string parser with escape sequence handling parser json_string(); + // TODO: improve convenience functions to allow users to build specific JSON fields parser json_key(const std::string & name, const parser & p); parser json_string(const parser & p); @@ -237,6 +245,34 @@ class parser_builder { // S -> A [action] parser action(const parser & p, std::function fn); + // Convenience action wrappers for common patterns + + // Appends matched text to env.reasoning_content + parser append_reasoning(const parser & p); + + // Appends matched text to env.content + parser append_content(const parser & p); + + // Captures matched text to env.scratchpad[key] + // If unescape_json is true, the matched text is unescaped as a JSON string + parser capture(const parser & p, const std::string & key, bool unescape_json = false); + + // Captures matched text to env.tool_call_id + // If unescape_json is true, the matched text is unescaped as a JSON string + parser capture_tool_call_id(const parser & p, bool unescape_json = false); + + // Captures matched text to env.tool_call_name + // If unescape_json is true, the matched text is unescaped as a JSON string + parser capture_tool_call_name(const parser & p, bool unescape_json = false); + + // Captures matched text to env.tool_call_args + // If unescape_json is true, the matched text is unescaped as a JSON string + parser capture_tool_call_args(const parser & p, bool unescape_json = false); + + // Adds a tool call to env.tool_calls using env.tool_call_{id,name,args} + // Clears the tool call fields after adding + parser add_tool_call(const parser & p); + parser add_rule(const std::string & name, const parser & p); void assign_ids(parser & p); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index a309f7fce5a1c..69506ba0cc314 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -382,54 +382,24 @@ static void test_complete_example() { // // auto parser = build_parser([](parser_builder & p) { - auto handle_reasoning = [](const parser_result &, std::string_view match, parser_environment & env) { - env.reasoning_content += match; - }; - - auto handle_content = [](const parser_result &, std::string_view match, parser_environment & env) { - env.content += match; - }; - - auto handle_tool_call_name = [](const parser_result &, std::string_view match, parser_environment & env) { - env.scratchpad["tool_name"] = std::string(match); - }; - - auto handle_tool_call_args = [](const parser_result &, std::string_view match, parser_environment & env) { - env.scratchpad["tool_args"] = std::string(match); - }; - - auto handle_tool_call = [](const parser_result &, std::string_view, parser_environment & env) { - auto name = env.scratchpad.find("tool_name"); - auto args = env.scratchpad.find("tool_args"); - if (name != env.scratchpad.end() && args != env.scratchpad.end()) { - auto tool_call = common_chat_tool_call{ - std::get(name->second), - std::get(args->second), - std::string() - }; - - env.tool_calls.push_back(tool_call); - } - }; - auto reasoning = p.add_rule("reasoning", - "" << p.action(p.until(""), handle_reasoning) << ""); + "" << p.append_reasoning(p.until("")) << ""); auto content = p.add_rule("content", - p.action(p.until(""), handle_content)); + p.append_content(p.until(""))); auto json = p.json(); auto tool_call_name = p.add_rule("tool-call-name", - "" << p.action(p.until(""), handle_tool_call_name) << ""); + "" << p.capture_tool_call_name(p.until("")) << ""); auto schema = nlohmann::ordered_json::parse(R"({"type": "object"})"); auto tool_call_args = p.add_rule("tool-call-args", - "" << p.action(p.schema(json, "get_weather", schema), handle_tool_call_args) << ""); + "" << p.capture_tool_call_args(p.schema(json, "get_weather", schema)) << ""); auto tool_call = p.add_rule("tool-call", - "" << p.action(tool_call_name << tool_call_args, handle_tool_call) << ""); + "" << p.add_tool_call(tool_call_name << tool_call_args) << ""); return reasoning << p.optional(content) << p.optional(tool_call); }); @@ -509,8 +479,8 @@ static void test_complete_example() { assert_equals(true, result.is_success()); assert_equals("I need to call get_weather", env.reasoning_content); - assert_equals("get_weather", std::get(env.scratchpad["tool_name"])); - assert_equals(R"({"cit)", std::get(env.scratchpad["tool_args"])); + assert_equals("get_weather", env.tool_calls[0].name); + assert_equals(R"({"cit)", env.tool_calls[0].arguments); } auto gbnf = build_grammar([&](const common_grammar_builder & builder) { From f97abde571c5f978f704fec598ef246acdba9894 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 03:42:27 -0600 Subject: [PATCH 31/34] add actions to command r7b example --- tests/test-chat-parser-combinator.cpp | 65 ++++++++++++++++++++------- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 69506ba0cc314..0bfa0b45b6364 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -758,19 +758,25 @@ static void test_gbnf_generation() { static parser create_command_r7b_parser() { auto parser = build_parser([](parser_builder & p) { auto thinking = p.add_rule("thinking", - "<|START_THINKING|>" << p.until("<|END_THINKING|>") << "<|END_THINKING|>"); + "<|START_THINKING|>" << p.append_reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>"); auto response = p.add_rule("response", - "<|START_RESPONSE|>" << p.until("<|END_RESPONSE|>") << "<|END_RESPONSE|>"); + "<|START_RESPONSE|>" << p.append_content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>"); auto json = p.add_rule("json", p.json()); - auto tool_call_id = p.add_rule("tool-call-id", p.json_key("tool_call_id", p.json_string(p.until("\"")))); - auto tool_call_name = p.add_rule("tool-name", p.json_key("tool_name", p.json_string(p.until("\"")))); - auto tool_call_args = p.add_rule("tool-args", p.json_key("parameters", json)); + + auto tool_call_id = p.add_rule("tool-call-id", + p.json_key("tool_call_id", "\"" + p.capture_tool_call_id(p.json_string(), /* unescape_json = */ true) + "\"")); + + auto tool_call_name = p.add_rule("tool-name", + p.json_key("tool_name", "\"" + p.capture_tool_call_name(p.json_string(), /* unescape_json = */ true) + "\"")); + + auto tool_call_args = p.add_rule("tool-args", p.json_key("parameters", p.capture_tool_call_args(json))); + auto tool_call_fields = p.add_rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); auto tool_call = p.add_rule("tool-call", - "{" << tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields) << "}"); + "{" << p.add_tool_call(tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields)) << "}"); auto tool_calls = p.add_rule("tool-calls", "<|START_ACTION|>" @@ -789,12 +795,27 @@ static parser create_command_r7b_parser() { return parser; } -static void test_command_r7b_parser(const parser & p, const std::string & input, bool partial) { - parser_context ctx(input, !partial); +static void test_command_r7b_parser(const parser & p, const std::string & input, bool partial, bool print_results = false) { + parser_environment env; + parser_context ctx(input, &env, !partial); p.parse(ctx); + + if (print_results) { + std::cout << "== Parsed (new) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << env.reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << env.content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : env.tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } } -static void test_command_r7b_legacy_parser(const std::string & input, bool partial) { +static void test_command_r7b_legacy_parser(const std::string & input, bool partial, bool print_results = false) { // Original parser taken from chat.cpp common_chat_msg_parser builder(input, /* is_partial= */ partial, { @@ -834,6 +855,20 @@ static void test_command_r7b_legacy_parser(const std::string & input, bool parti } else { builder.add_content(builder.consume_rest()); } + + if (print_results) { + std::cout << "== Parsed (legacy) ==\n"; + std::cout << "=== Reasoning ===\n"; + std::cout << builder.result().reasoning_content << "\n"; + std::cout << "\n\n=== Content ===\n"; + std::cout << builder.result().content << "\n"; + std::cout << "\n\n=== Tool Calls ===\n"; + for (const auto & tc : builder.result().tool_calls) { + std::cout << "id: " << tc.id << "\n"; + std::cout << "name: " << tc.name << "\n"; + std::cout << "args: " << tc.arguments << "\n"; + } + } } struct bench_tool_call { @@ -907,13 +942,13 @@ static void benchmark_compare( tokens.emplace_back("<|END_ACTION|>"); } - auto run = [&](const std::function & fn) { + auto run = [&](const std::function & fn) { std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string()); std::chrono::microseconds duration(0); for (int i = 0; i < iterations; i++) { auto start = std::chrono::high_resolution_clock::now(); - fn(input, false); + fn(input, false, i == 0); auto end = std::chrono::high_resolution_clock::now(); duration += std::chrono::duration_cast(end - start); } @@ -922,13 +957,13 @@ static void benchmark_compare( auto parser = create_command_r7b_parser(); - auto duration_new = run([&](const std::string & input, bool partial) { - test_command_r7b_parser(parser, input, partial); + auto duration_new = run([&](const std::string & input, bool partial, bool print_content) { + test_command_r7b_parser(parser, input, partial, print_content); }); - auto duration_legacy = run([&](const std::string & input, bool partial) { + auto duration_legacy = run([&](const std::string & input, bool partial, bool print_content) { try { - test_command_r7b_legacy_parser(input, partial); + test_command_r7b_legacy_parser(input, partial, print_content); } catch (const common_chat_msg_partial_exception &) { } }); From 3114a0e679d3447a672fb1a20c81e844afa1301d Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 03:53:41 -0600 Subject: [PATCH 32/34] use std::default_searcher for platforms that don't have bm --- common/chat-parser-combinator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 60cb54f84cc1f..331f00198ea56 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -696,7 +696,7 @@ class until_parser : public parser_base { std::string delimiter_; bool consume_spaces_; - std::boyer_moore_searcher searcher_; + std::default_searcher searcher_; public: until_parser(const std::string & delimiter, bool consume_spaces, int id) From cc4d52c05983e7146280d589e4c9540a43923bd3 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Wed, 12 Nov 2025 22:56:16 -0600 Subject: [PATCH 33/34] improve parser_type handling and add cast helper --- common/chat-parser-combinator.cpp | 124 ++++++++++++++++++++---------- 1 file changed, 84 insertions(+), 40 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index 331f00198ea56..f8b5dbaa32109 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -10,23 +10,23 @@ #include enum parser_type { - PARSER_LITERAL = 0, - PARSER_SEQUENCE = 1, - PARSER_CHOICE = 2, - PARSER_REPETITION = 3, - PARSER_OPTIONAL = 4, - PARSER_ZERO_OR_MORE = 5, - PARSER_ONE_OR_MORE = 6, - PARSER_NOT = 7, - PARSER_ANY = 8, - PARSER_CHARS = 9, - PARSER_RULE = 10, - PARSER_UNTIL = 11, - PARSER_SPACE = 12, - PARSER_SCHEMA = 13, - PARSER_ROOT = 14, - PARSER_JSON_STRING = 15, - PARSER_ACTION = 16, + PARSER_LITERAL, + PARSER_SEQUENCE, + PARSER_CHOICE, + PARSER_REPETITION, + PARSER_OPTIONAL, + PARSER_ZERO_OR_MORE, + PARSER_ONE_OR_MORE, + PARSER_NOT, + PARSER_ANY, + PARSER_CHARS, + PARSER_RULE, + PARSER_UNTIL, + PARSER_SPACE, + PARSER_SCHEMA, + PARSER_ROOT, + PARSER_JSON_STRING, + PARSER_ACTION, }; class parser_visitor; @@ -75,6 +75,20 @@ class parser_base { virtual void accept(parser_visitor & visitor) = 0; }; +// Convenience cast functions +template +static std::shared_ptr cast(const std::shared_ptr & p) { + if (p->type() != T::type_value) { + return nullptr; + } + return std::static_pointer_cast(p); +} + +template +static std::shared_ptr cast(const parser & p) { + return cast(p.ptr()); +} + // We define our own space function because MSVC's std::isspace() // crashes for non-printable characters in Debug builds. static bool is_space(const char c) { @@ -107,9 +121,11 @@ class literal_parser : public parser_base { std::string literal_; public: + static constexpr parser_type type_value = PARSER_LITERAL; + literal_parser(const std::string & literal, int id) : parser_base(id), literal_(literal) {} - parser_type type() const override { return PARSER_LITERAL; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto pos = start; @@ -147,11 +163,11 @@ class sequence_parser : public parser_base { std::vector parsers_; public: + static constexpr parser_type type_value = PARSER_SEQUENCE; + sequence_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { - if (p->type() == PARSER_SEQUENCE) { - // Flatten sequences - auto seq = std::static_pointer_cast(p.ptr()); + if (auto seq = cast(p)) { for (const auto & embedded : seq->parsers()) { parsers_.push_back(embedded); } @@ -161,7 +177,7 @@ class sequence_parser : public parser_base { } } - parser_type type() const override { return PARSER_SEQUENCE; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto pos = start; @@ -213,11 +229,11 @@ class choice_parser : public parser_base { std::vector parsers_; public: + static constexpr parser_type type_value = PARSER_CHOICE; + choice_parser(std::initializer_list parsers, int id) : parser_base(id) { for (const auto & p : parsers) { - if (p->type() == PARSER_CHOICE) { - // Flatten choices - auto choice = std::static_pointer_cast(p.ptr()); + if (auto choice = cast(p)) { for (const auto & embedded : choice->parsers()) { parsers_.push_back(embedded); } @@ -227,7 +243,7 @@ class choice_parser : public parser_base { } } - parser_type type() const override { return PARSER_CHOICE; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto pos = start; @@ -276,10 +292,12 @@ class repetition_parser : public parser_base { int max_count_; public: + static constexpr parser_type type_value = PARSER_REPETITION; + repetition_parser(const parser & parser, int min_count, int max_count, int id) : parser_base(id), parser_(parser), min_count_(min_count), max_count_(max_count) {} - parser_type type() const override { return PARSER_REPETITION; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto pos = start; @@ -340,9 +358,11 @@ class repetition_parser : public parser_base { // S -> A+ class one_or_more_parser : public repetition_parser { public: + static constexpr parser_type type_value = PARSER_ONE_OR_MORE; + one_or_more_parser(const parser & p, int id) : repetition_parser(p, 1, -1, id) {} - parser_type type() const override { return PARSER_ONE_OR_MORE; } + parser_type type() const override { return type_value; } std::string dump() const override { return "OneOrMore(" + child()->dump() + ")"; @@ -355,9 +375,11 @@ class one_or_more_parser : public repetition_parser { // S -> A* class zero_or_more_parser : public repetition_parser { public: + static constexpr parser_type type_value = PARSER_ZERO_OR_MORE; + zero_or_more_parser(const parser & p, int id) : repetition_parser(p, 0, -1, id) {} - parser_type type() const override { return PARSER_ZERO_OR_MORE; } + parser_type type() const override { return type_value; } std::string dump() const override { return "ZeroOrMore(" + child()->dump() + ")"; @@ -370,9 +392,11 @@ class zero_or_more_parser : public repetition_parser { // S -> A? class optional_parser : public repetition_parser { public: + static constexpr parser_type type_value = PARSER_OPTIONAL; + optional_parser(const parser & p, int id) : repetition_parser(p, 0, 1, id) {} - parser_type type() const override { return PARSER_OPTIONAL; } + parser_type type() const override { return type_value; } std::string dump() const override { return "Optional(" + child()->dump() + ")"; @@ -387,9 +411,11 @@ class not_parser : public parser_base { parser parser_; public: + static constexpr parser_type type_value = PARSER_NOT; + not_parser(const parser & parser, int id) : parser_base(id), parser_(parser) {} - parser_type type() const override { return PARSER_NOT; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto result = parser_->parse(ctx, start); @@ -426,9 +452,11 @@ class not_parser : public parser_base { // S -> . class any_parser : public parser_base { public: + static constexpr parser_type type_value = PARSER_ANY; + any_parser(int id) : parser_base(id) {} - parser_type type() const override { return PARSER_ANY; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { if (start >= ctx.input.size()) { @@ -452,9 +480,11 @@ class any_parser : public parser_base { // S -> [ \t\n]* class space_parser : public parser_base { public: + static constexpr parser_type type_value = PARSER_SPACE; + space_parser(int id) : parser_base(id) {} - parser_type type() const override { return PARSER_SPACE; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto pos = start; @@ -545,7 +575,9 @@ class chars_parser : public parser_base { } } - parser_type type() const override { return PARSER_CHARS; } + static constexpr parser_type type_value = PARSER_CHARS; + + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto pos = start; @@ -610,9 +642,11 @@ class chars_parser : public parser_base { class json_string_parser : public parser_base { public: + static constexpr parser_type type_value = PARSER_JSON_STRING; + json_string_parser(int id) : parser_base(id) {} - parser_type type() const override { return PARSER_JSON_STRING; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto pos = start; @@ -699,11 +733,13 @@ class until_parser : public parser_base { std::default_searcher searcher_; public: + static constexpr parser_type type_value = PARSER_UNTIL; + until_parser(const std::string & delimiter, bool consume_spaces, int id) : parser_base(id), delimiter_(delimiter), consume_spaces_(consume_spaces), searcher_(delimiter_.begin(), delimiter_.end()) { } - parser_type type() const override { return PARSER_UNTIL; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { parser_result result(PARSER_RESULT_SUCCESS, start, ctx.input.size()); @@ -752,10 +788,12 @@ class schema_parser : public parser_base { nlohmann::ordered_json schema_; public: + static constexpr parser_type type_value = PARSER_SCHEMA; + schema_parser(const parser & parser, const std::string & name, const nlohmann::ordered_json & schema, int id) : parser_base(id), parser_(parser), name_(name), schema_(schema) {} - parser_type type() const override { return PARSER_SCHEMA; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { return parser_->parse(ctx, start); @@ -781,10 +819,12 @@ class rule_parser : public parser_base { std::weak_ptr> rules_; public: + static constexpr parser_type type_value = PARSER_RULE; + rule_parser(const std::string & name, const std::shared_ptr> & rules, int id) : parser_base(id), name_(name), rules_(rules) {} - parser_type type() const override { return PARSER_RULE; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto rules = rules_.lock(); @@ -820,10 +860,12 @@ class root_parser : public parser_base { friend class parser_visitor; public: + static constexpr parser_type type_value = PARSER_ROOT; + root_parser(const parser & root, std::shared_ptr> rules, int id) : parser_base(id), root_(root), rules_(std::move(rules)) {} - parser_type type() const override { return PARSER_ROOT; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { return root_->parse(ctx, start); @@ -851,10 +893,12 @@ class action_parser : public parser_base { std::function action_; public: + static constexpr parser_type type_value = PARSER_ACTION; + action_parser(const parser & parser, std::function action, int id) : parser_base(id), parser_(parser), action_(std::move(action)) {} - parser_type type() const override { return PARSER_ACTION; } + parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto result = parser_->parse(ctx, start); From c119c1290a087aaa67cd7df8ea3efbda02213869 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Thu, 13 Nov 2025 02:14:21 -0600 Subject: [PATCH 34/34] add partial result type to better control when to run actions --- common/chat-parser-combinator.cpp | 109 ++++++++++++++------------ common/chat-parser-combinator.h | 13 ++- tests/test-chat-parser-combinator.cpp | 24 +++--- 3 files changed, 79 insertions(+), 67 deletions(-) diff --git a/common/chat-parser-combinator.cpp b/common/chat-parser-combinator.cpp index f8b5dbaa32109..0a754bc012c58 100644 --- a/common/chat-parser-combinator.cpp +++ b/common/chat-parser-combinator.cpp @@ -134,10 +134,7 @@ class literal_parser : public parser_base { if (ctx.input_is_complete) { return parser_result(PARSER_RESULT_FAIL, start); } - if (i > 0) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); - } - return parser_result(PARSER_RESULT_FAIL, start, pos); + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); } if (ctx.input[pos] != literal_[i]) { return parser_result(PARSER_RESULT_FAIL, start); @@ -183,17 +180,8 @@ class sequence_parser : public parser_base { auto pos = start; for (const auto & p : parsers_) { auto result = p->parse(ctx, pos); - - if (result.is_fail()) { - if (result.end >= ctx.input.size() && !ctx.input_is_complete) { - // If we fail because we don't have enough input, then return success - return parser_result(PARSER_RESULT_SUCCESS, start, result.end); - } - return parser_result(PARSER_RESULT_FAIL, start, result.end); - } - - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, result.end); + if (!result.is_success()) { + return parser_result(result.type, start, result.end); } pos = result.end; @@ -249,12 +237,7 @@ class choice_parser : public parser_base { auto pos = start; for (const auto & p : parsers_) { auto result = p->parse(ctx, pos); - - if (result.is_success()) { - return result; - } - - if (result.is_need_more_input()) { + if (!result.is_fail()) { return result; } } @@ -305,6 +288,10 @@ class repetition_parser : public parser_base { // Try to match up to max_count times (or unlimited if max_count is -1) while (max_count_ == -1 || match_count < max_count_) { + if (pos >= ctx.input.size()) { + break; + } + auto result = parser_->parse(ctx, pos); if (result.is_success()) { @@ -317,8 +304,8 @@ class repetition_parser : public parser_base { continue; } - if (result.is_need_more_input()) { - return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + if (result.is_need_more_input() || result.is_partial()) { + return parser_result(result.type, start, result.end); } // Child failed - stop trying @@ -420,7 +407,7 @@ class not_parser : public parser_base { parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto result = parser_->parse(ctx, start); - if (result.is_success()) { + if (result.is_success() || result.is_partial()) { // Fail if the underlying parser matches return parser_result(PARSER_RESULT_FAIL, start); } @@ -463,9 +450,8 @@ class any_parser : public parser_base { if (ctx.input_is_complete) { return parser_result(PARSER_RESULT_FAIL, start); } - return parser_result(PARSER_RESULT_FAIL, start); + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start); } - return parser_result(PARSER_RESULT_SUCCESS, start, start + 1); } @@ -612,7 +598,10 @@ class chars_parser : public parser_base { // Check if we got enough matches if (match_count < min_count_) { - return parser_result(PARSER_RESULT_FAIL, start); + if (pos >= ctx.input.size() && !ctx.input_is_complete) { + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); + } + return parser_result(PARSER_RESULT_FAIL, start, pos); } return parser_result(PARSER_RESULT_SUCCESS, start, pos); @@ -714,7 +703,10 @@ class json_string_parser : public parser_base { } // Reached end without finding closing quote - return parser_result(PARSER_RESULT_FAIL, start, pos); + if (ctx.input_is_complete) { + return parser_result(PARSER_RESULT_FAIL, start, pos); + } + return parser_result(PARSER_RESULT_NEED_MORE_INPUT, start, pos); } std::string dump() const override { @@ -748,15 +740,19 @@ class until_parser : public parser_base { const auto it = std::search(ctx.input.begin(), ctx.input.end(), searcher_); if (it != ctx.input.end()) { - result.type = PARSER_RESULT_SUCCESS; result.end = std::distance(ctx.input.begin(), it); + result.type = PARSER_RESULT_SUCCESS; } else { // If not found, check if the input ends with a prefix of the delimiter size_t max_overlap = std::min(ctx.input.size(), delimiter_.size() - 1); for (size_t overlap = max_overlap; overlap > 0; --overlap) { if (std::equal(ctx.input.end() - overlap, ctx.input.end(), delimiter_.begin())) { - result.type = (ctx.input_is_complete) ? PARSER_RESULT_FAIL : PARSER_RESULT_NEED_MORE_INPUT; result.end = ctx.input.size() - overlap; + if (ctx.input_is_complete) { + result.type = PARSER_RESULT_FAIL; + } else { + result.type = PARSER_RESULT_NEED_MORE_INPUT; + } } } } @@ -890,21 +886,25 @@ class root_parser : public parser_base { // Wraps a parser with a semantic action callback. class action_parser : public parser_base { parser parser_; - std::function action_; + std::function action_; + int when_; public: static constexpr parser_type type_value = PARSER_ACTION; - action_parser(const parser & parser, std::function action, int id) - : parser_base(id), parser_(parser), action_(std::move(action)) {} + action_parser( + const parser & parser, + std::function action, + int when, + int id + ) : parser_base(id), parser_(parser), action_(std::move(action)), when_(when) {} parser_type type() const override { return type_value; } parser_result parse_uncached(parser_context & ctx, size_t start = 0) override { auto result = parser_->parse(ctx, start); - // Invoke action callback on success if environment is available - if (result.is_success() && ctx.env && action_) { + if ((result.type & when_) && ctx.env && action_) { std::string_view matched = ctx.input.substr(result.start, result.end - result.start); action_(result, matched, *ctx.env); } @@ -918,7 +918,7 @@ class action_parser : public parser_base { } std::string dump() const override { - return "Action(" + parser_->dump() + ")"; + return "Action(" + parser_->dump() + ", when=" + std::to_string(when_) +")"; } void accept(parser_visitor & visitor) override; @@ -926,6 +926,7 @@ class action_parser : public parser_base { const parser & child() const { return parser_; } }; + // Base visitor class for parser tree traversal class parser_visitor { public: @@ -1384,51 +1385,57 @@ parser parser_builder::schema(const parser & p, const std::string & name, const return parser(std::make_shared(p, name, schema, counter_->next())); } -parser parser_builder::action(const parser & p, std::function fn) { - return parser(std::make_shared(p, std::move(fn), counter_->next())); +parser parser_builder::action(const parser & p, std::function fn, int when) { + return parser(std::make_shared(p, std::move(fn), when, counter_->next())); +} + +parser parser_builder::partial(const parser & p) { + return action(p, [](parser_result &result, std::string_view, parser_environment &) { + result.type = PARSER_RESULT_PARTIAL; + }, PARSER_RESULT_NEED_MORE_INPUT); } parser parser_builder::append_reasoning(const parser & p) { - return action(p, [](const parser_result &, std::string_view matched, parser_environment & env) { + return action(p, [](parser_result &, std::string_view matched, parser_environment & env) { if (!env.reasoning_content.empty()) { env.reasoning_content += "\n"; } env.reasoning_content += matched; - }); + }, PARSER_RESULT_SUCCESS | PARSER_RESULT_PARTIAL); } parser parser_builder::append_content(const parser & p) { - return action(p, [](const parser_result &, std::string_view matched, parser_environment & env) { + return action(p, [](parser_result &, std::string_view matched, parser_environment & env) { if (!env.content.empty()) { env.content += "\n"; } env.content += matched; - }); + }, PARSER_RESULT_SUCCESS | PARSER_RESULT_PARTIAL); } parser parser_builder::capture(const parser & p, const std::string & key, bool unescape_json) { - return action(p, [key, unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + return action(p, [key, unescape_json](parser_result &, std::string_view matched, parser_environment & env) { std::string value = unescape_json ? unescape_json_string(matched) : std::string(matched); env.scratchpad[key] = std::move(value); - }); + }, PARSER_RESULT_SUCCESS | PARSER_RESULT_PARTIAL); } parser parser_builder::capture_tool_call_id(const parser & p, bool unescape_json) { - return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + return action(p, [unescape_json](parser_result &, std::string_view matched, parser_environment & env) { env.tool_call_id = unescape_json ? unescape_json_string(matched) : std::string(matched); - }); + }, PARSER_RESULT_SUCCESS | PARSER_RESULT_PARTIAL); } parser parser_builder::capture_tool_call_name(const parser & p, bool unescape_json) { - return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + return action(p, [unescape_json](parser_result &, std::string_view matched, parser_environment & env) { env.tool_call_name = unescape_json ? unescape_json_string(matched) : std::string(matched); - }); + }, PARSER_RESULT_SUCCESS | PARSER_RESULT_PARTIAL); } parser parser_builder::capture_tool_call_args(const parser & p, bool unescape_json) { - return action(p, [unescape_json](const parser_result &, std::string_view matched, parser_environment & env) { + return action(p, [unescape_json](parser_result &, std::string_view matched, parser_environment & env) { env.tool_call_args = unescape_json ? unescape_json_string(matched) : std::string(matched); - }); + }, PARSER_RESULT_SUCCESS | PARSER_RESULT_PARTIAL); } parser parser_builder::add_tool_call(const parser & p) { @@ -1444,7 +1451,7 @@ parser parser_builder::add_tool_call(const parser & p) { env.tool_call_id.clear(); env.tool_call_name.clear(); env.tool_call_args.clear(); - }); + }, PARSER_RESULT_SUCCESS | PARSER_RESULT_PARTIAL); } parser parser_builder::json_key(const std::string & name, const parser & p) { diff --git a/common/chat-parser-combinator.h b/common/chat-parser-combinator.h index 98ca83f7bfb02..1a38fde861dcc 100644 --- a/common/chat-parser-combinator.h +++ b/common/chat-parser-combinator.h @@ -29,9 +29,10 @@ struct parser_environment { }; enum parser_result_type { - PARSER_RESULT_FAIL = 0, - PARSER_RESULT_NEED_MORE_INPUT = 1, - PARSER_RESULT_SUCCESS = 2, + PARSER_RESULT_FAIL = 1 << 0, + PARSER_RESULT_SUCCESS = 1 << 1, + PARSER_RESULT_NEED_MORE_INPUT = 1 << 2, + PARSER_RESULT_PARTIAL = 1 << 3, }; struct parse_cache_key { @@ -65,6 +66,7 @@ struct parser_result { bool is_fail() const { return type == PARSER_RESULT_FAIL; } bool is_need_more_input() const { return type == PARSER_RESULT_NEED_MORE_INPUT; } + bool is_partial() const { return type == PARSER_RESULT_PARTIAL; } bool is_success() const { return type == PARSER_RESULT_SUCCESS; } }; @@ -243,10 +245,13 @@ class parser_builder { // Wraps a parser with a semantic action callback. // The callback is invoked on successful parse with the result, matched text, and environment. // S -> A [action] - parser action(const parser & p, std::function fn); + parser action(const parser & p, std::function fn, int when = PARSER_RESULT_SUCCESS); // Convenience action wrappers for common patterns + // Converts PARSER_RESULT_NEED_MORE_INPUT to PARSER_RESULT_PARTIAL + parser partial(const parser & p); + // Appends matched text to env.reasoning_content parser append_reasoning(const parser & p); diff --git a/tests/test-chat-parser-combinator.cpp b/tests/test-chat-parser-combinator.cpp index 0bfa0b45b6364..2affdd7718a0c 100644 --- a/tests/test-chat-parser-combinator.cpp +++ b/tests/test-chat-parser-combinator.cpp @@ -90,7 +90,7 @@ static void test_partial_parsing() { ctx = parser_context("", false); result = parser.parse(ctx); - assert_equals(true, result.is_success()); + assert_equals(true, result.is_need_more_input()); ctx = parser_context("" << p.capture_tool_call_args(p.schema(json, "get_weather", schema)) << ""); + "" << p.capture_tool_call_args(p.schema(p.partial(json), "get_weather", schema)) << ""); auto tool_call = p.add_rule("tool-call", - "" << p.add_tool_call(tool_call_name << tool_call_args) << ""); + "" << p.add_tool_call(tool_call_name << p.partial(tool_call_args)) << ""); return reasoning << p.optional(content) << p.optional(tool_call); }); @@ -429,7 +429,7 @@ static void test_complete_example() { auto result = parser.parse(ctx); - assert_equals(true, result.is_success()); + assert_equals(true, result.is_need_more_input()); assert_equals("I need to call get_weather", env.reasoning_content); } { @@ -457,7 +457,7 @@ static void test_complete_example() { auto result = parser.parse(ctx); - assert_equals(true, result.is_success()); + assert_equals(true, result.is_need_more_input()); assert_equals("I need to call get_weather", env.reasoning_content); } { @@ -477,7 +477,7 @@ static void test_complete_example() { auto result = parser.parse(ctx); - assert_equals(true, result.is_success()); + assert_equals(true, result.is_partial()); assert_equals("I need to call get_weather", env.reasoning_content); assert_equals("get_weather", env.tool_calls[0].name); assert_equals(R"({"cit)", env.tool_calls[0].arguments); @@ -579,7 +579,7 @@ static void test_actions() { parser_context ctx("hello ", &env, false); auto result = parser.parse(ctx); - assert_equals(true, result.is_success()); + assert_equals(true, result.is_need_more_input()); assert_equals("hello", env.content); } { @@ -587,7 +587,7 @@ static void test_actions() { parser_context ctx("hello world", &env, false); auto result = parser.parse(ctx); - assert_equals(true, result.is_success()); + assert_equals(true, result.is_need_more_input()); assert_equals("hello world", env.content); } {