From 91523d3c56fd574cac5828809dfaf28da593a806 Mon Sep 17 00:00:00 2001 From: Leif Hedstrom Date: Sat, 28 Feb 2026 17:03:23 -0700 Subject: [PATCH 1/2] hrw4u: Adds procedures (macros) and libraries Extends the hrw4u grammar and compiler with a procedure system for defining reusable, parameterized blocks of rules. Procedures use namespaced names (e.g. `local::set-cache`) with `$param` substitution, and can be defined inline or loaded from external `.hrw4u` library files via `use` directives. A `--output=hrw4u` flatten mode expands all procedure calls into a self-contained source file. This commit also adds LSP hover/completion support for procedures, comprehensive test coverage, validation for circular imports and arity mismatches, and documentation in the admin guide. --- ci/rat-regex.txt | 1 + doc/admin-guide/configuration/hrw4u.en.rst | 95 +++ tools/hrw4u/Makefile | 3 +- tools/hrw4u/grammar/hrw4u.g4 | 41 +- tools/hrw4u/pyproject.toml | 3 +- tools/hrw4u/scripts/hrw4u | 30 +- tools/hrw4u/scripts/hrw4u-lsp | 221 +++++-- tools/hrw4u/scripts/testcase.py | 18 +- tools/hrw4u/src/common.py | 87 ++- tools/hrw4u/src/errors.py | 72 ++- tools/hrw4u/src/hrw_symbols.py | 5 +- tools/hrw4u/src/hrw_visitor.py | 4 +- tools/hrw4u/src/lsp/hover.py | 214 +++---- tools/hrw4u/src/lsp/strings.py | 70 +-- tools/hrw4u/src/procedures.py | 29 + tools/hrw4u/src/symbols.py | 5 +- tools/hrw4u/src/symbols_base.py | 4 +- tools/hrw4u/src/visitor.py | 562 ++++++++++++++++-- tools/hrw4u/src/visitor_base.py | 234 +------- .../tests/data/procedures/basic-call.ast.txt | 1 + .../data/procedures/basic-call.input.txt | 5 + .../data/procedures/basic-call.output.txt | 2 + .../procedures/circular-use.fail.error.txt | 1 + .../procedures/circular-use.fail.input.txt | 5 + .../data/procedures/default-param.ast.txt | 1 + .../data/procedures/default-param.input.txt | 5 + .../data/procedures/default-param.output.txt | 2 + .../procedures/duplicate-proc.fail.error.txt | 1 + .../procedures/duplicate-proc.fail.input.txt | 11 + .../data/procedures/elif-in-proc.ast.txt | 1 + .../data/procedures/elif-in-proc.flatten.txt | 20 + .../data/procedures/elif-in-proc.input.txt | 8 + .../data/procedures/elif-in-proc.output.txt | 24 + .../data/procedures/in-conditional.ast.txt | 1 + .../procedures/in-conditional.flatten.txt | 5 + .../data/procedures/in-conditional.input.txt | 7 + .../data/procedures/in-conditional.output.txt | 3 + .../data/procedures/local-and-use.ast.txt | 1 + .../data/procedures/local-and-use.input.txt | 10 + .../data/procedures/local-and-use.output.txt | 3 + .../data/procedures/local-mixed-body.ast.txt | 1 + .../procedures/local-mixed-body.flatten.txt | 6 + .../procedures/local-mixed-body.input.txt | 11 + .../procedures/local-mixed-body.output.txt | 6 + .../procedures/local-multi-section.ast.txt | 1 + .../procedures/local-multi-section.input.txt | 11 + .../procedures/local-multi-section.output.txt | 5 + .../tests/data/procedures/local-proc.ast.txt | 1 + .../data/procedures/local-proc.input.txt | 7 + .../data/procedures/local-proc.output.txt | 2 + .../data/procedures/local-with-params.ast.txt | 1 + .../procedures/local-with-params.input.txt | 7 + .../procedures/local-with-params.output.txt | 2 + .../tests/data/procedures/mixed-body.ast.txt | 1 + .../data/procedures/mixed-body.flatten.txt | 6 + .../data/procedures/mixed-body.input.txt | 5 + .../data/procedures/mixed-body.output.txt | 6 + .../tests/data/procedures/multi-proc.ast.txt | 1 + .../data/procedures/multi-proc.input.txt | 6 + .../data/procedures/multi-proc.output.txt | 3 + .../procedures/multi-section-mixed.ast.txt | 1 + .../multi-section-mixed.flatten.txt | 13 + .../procedures/multi-section-mixed.input.txt | 15 + .../procedures/multi-section-mixed.output.txt | 13 + .../tests/data/procedures/multi-use.ast.txt | 1 + .../tests/data/procedures/multi-use.input.txt | 7 + .../data/procedures/multi-use.output.txt | 3 + .../data/procedures/override-param.ast.txt | 1 + .../data/procedures/override-param.input.txt | 5 + .../data/procedures/override-param.output.txt | 2 + .../proc-after-section.fail.error.txt | 1 + .../proc-after-section.fail.input.txt | 7 + .../data/procedures/procs/base/Stamp.hrw4u | 3 + .../data/procedures/procs/caller/Wrap.hrw4u | 6 + .../data/procedures/procs/circular/A.hrw4u | 5 + .../data/procedures/procs/circular/B.hrw4u | 5 + .../procedures/procs/reexport/debug.hrw4u | 1 + .../procedures/procs/test/TagAndOrigin.hrw4u | 7 + .../procs/test/add-debug-header.hrw4u | 3 + .../procs/test/classify-request.hrw4u | 10 + .../procedures/procs/test/mixed-body.hrw4u | 7 + .../procedures/procs/test/set-cache.hrw4u | 3 + .../procedures/procs/test/set-origin.hrw4u | 3 + .../procs/test/wrong-namespace.hrw4u | 3 + .../tests/data/procedures/reexport.ast.txt | 1 + .../tests/data/procedures/reexport.input.txt | 5 + .../tests/data/procedures/reexport.output.txt | 2 + .../data/procedures/string-param.ast.txt | 1 + .../data/procedures/string-param.input.txt | 5 + .../data/procedures/string-param.output.txt | 2 + .../procedures/top-level-only.fail.error.txt | 1 + .../procedures/top-level-only.fail.input.txt | 5 + .../tests/data/procedures/transitive.ast.txt | 1 + .../data/procedures/transitive.input.txt | 5 + .../data/procedures/transitive.output.txt | 3 + .../procedures/unknown-proc.fail.error.txt | 1 + .../procedures/unknown-proc.fail.input.txt | 3 + .../procedures/wrong-arity.fail.error.txt | 1 + .../procedures/wrong-arity.fail.input.txt | 5 + .../procedures/wrong-namespace.fail.error.txt | 1 + .../procedures/wrong-namespace.fail.input.txt | 5 + tools/hrw4u/tests/test_procedures.py | 51 ++ tools/hrw4u/tests/test_units.py | 226 ++++--- tools/hrw4u/tests/utils.py | 148 +++-- 104 files changed, 1761 insertions(+), 740 deletions(-) create mode 100644 tools/hrw4u/src/procedures.py create mode 100644 tools/hrw4u/tests/data/procedures/basic-call.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/basic-call.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/basic-call.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/circular-use.fail.error.txt create mode 100644 tools/hrw4u/tests/data/procedures/circular-use.fail.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/default-param.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/default-param.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/default-param.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/duplicate-proc.fail.error.txt create mode 100644 tools/hrw4u/tests/data/procedures/duplicate-proc.fail.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/elif-in-proc.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/elif-in-proc.flatten.txt create mode 100644 tools/hrw4u/tests/data/procedures/elif-in-proc.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/elif-in-proc.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/in-conditional.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/in-conditional.flatten.txt create mode 100644 tools/hrw4u/tests/data/procedures/in-conditional.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/in-conditional.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-and-use.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-and-use.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-and-use.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-mixed-body.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-mixed-body.flatten.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-mixed-body.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-mixed-body.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-multi-section.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-multi-section.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-multi-section.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-proc.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-proc.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-proc.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-with-params.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-with-params.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/local-with-params.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/mixed-body.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/mixed-body.flatten.txt create mode 100644 tools/hrw4u/tests/data/procedures/mixed-body.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/mixed-body.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-proc.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-proc.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-proc.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-section-mixed.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-section-mixed.flatten.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-section-mixed.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-section-mixed.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-use.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-use.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/multi-use.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/override-param.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/override-param.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/override-param.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/proc-after-section.fail.error.txt create mode 100644 tools/hrw4u/tests/data/procedures/proc-after-section.fail.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/procs/base/Stamp.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/caller/Wrap.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/circular/A.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/circular/B.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/reexport/debug.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/test/TagAndOrigin.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/test/add-debug-header.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/test/classify-request.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/test/mixed-body.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/test/set-cache.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/test/set-origin.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/procs/test/wrong-namespace.hrw4u create mode 100644 tools/hrw4u/tests/data/procedures/reexport.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/reexport.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/reexport.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/string-param.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/string-param.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/string-param.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/top-level-only.fail.error.txt create mode 100644 tools/hrw4u/tests/data/procedures/top-level-only.fail.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/transitive.ast.txt create mode 100644 tools/hrw4u/tests/data/procedures/transitive.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/transitive.output.txt create mode 100644 tools/hrw4u/tests/data/procedures/unknown-proc.fail.error.txt create mode 100644 tools/hrw4u/tests/data/procedures/unknown-proc.fail.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/wrong-arity.fail.error.txt create mode 100644 tools/hrw4u/tests/data/procedures/wrong-arity.fail.input.txt create mode 100644 tools/hrw4u/tests/data/procedures/wrong-namespace.fail.error.txt create mode 100644 tools/hrw4u/tests/data/procedures/wrong-namespace.fail.input.txt create mode 100644 tools/hrw4u/tests/test_procedures.py diff --git a/ci/rat-regex.txt b/ci/rat-regex.txt index 10cff3979c7..2ddc42fa952 100644 --- a/ci/rat-regex.txt +++ b/ci/rat-regex.txt @@ -24,6 +24,7 @@ .*\.config$ .*\.yaml$ .*\.gold$ +.*\.hrw4u$ ^\.gitignore$ ^\.gitmodules$ ^\.perltidyrc$ diff --git a/doc/admin-guide/configuration/hrw4u.en.rst b/doc/admin-guide/configuration/hrw4u.en.rst index c401dafc4e6..1b7d2d2fcd6 100644 --- a/doc/admin-guide/configuration/hrw4u.en.rst +++ b/doc/admin-guide/configuration/hrw4u.en.rst @@ -416,6 +416,101 @@ or when integrating with existing header_rewrite rules that reference specific s addition, a remap configuration can use ``@PPARAM`` to set one of these slot variables explicitly as part of the configuration. +Procedures +---------- + +Procedures allow you to define reusable blocks of rules that can be called from +multiple sections or files. A procedure is a named, parameterized block of +conditions and operators that expands inline at the call site. + +Defining Procedures +^^^^^^^^^^^^^^^^^^^ + +Procedures are declared with the ``procedure`` keyword and must use a qualified +name with the ``::`` namespace separator:: + + procedure local::add-debug-header($tag) { + inbound.req.X-Debug = "$tag"; + } + +The namespace prefix (``local::`` in this example) groups related procedures. +Parameters are prefixed with ``$`` and substituted at the call site. + +Procedures may be defined in the same file as the sections that use them, or in +separate ``.hrw4u`` files loaded with the ``use`` directive. Procedure declarations +must appear before any section blocks. + +Using Procedures +^^^^^^^^^^^^^^^^ + +Call a procedure from any section by its qualified name:: + + procedure local::set-cache-headers($ttl) { + outbound.resp.Cache-Control = "max-age=$ttl"; + outbound.resp.X-Cache-TTL = "$ttl"; + } + + READ_RESPONSE { + local::set-cache-headers("3600"); + } + + SEND_RESPONSE { + local::set-cache-headers("0"); + } + +The procedure body is expanded inline — each section gets its own copy with +the correct hook context. + +Procedure Files and ``use`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For larger projects, procedures can be organized into separate files and loaded +with the ``use`` directive. The ``use`` spec maps to a file path: ``use Acme::Common`` +loads ``Acme/Common.hrw4u`` from the procedures search path. + +The ``--procedures-path`` flag specifies where to search:: + + hrw4u --procedures-path /etc/trafficserver/procedures rules.hrw4u + +Given this file structure:: + + /etc/trafficserver/procedures/ + └── Acme/ + └── Common.hrw4u + +Where ``Acme/Common.hrw4u`` contains:: + + procedure Acme::add-security-headers() { + outbound.resp.X-Frame-Options = "DENY"; + outbound.resp.X-Content-Type-Options = "nosniff"; + } + +Then in ``rules.hrw4u``:: + + use Acme::Common + + READ_RESPONSE { + Acme::add-security-headers(); + } + +The ``use`` directive enforces namespace consistency: all procedures in a file +loaded via ``use Acme::Common`` must use the ``Acme::`` namespace prefix. + +Parameters and Defaults +^^^^^^^^^^^^^^^^^^^^^^^ + +Procedures support positional parameters with optional defaults:: + + procedure local::tag-request($env, $version = "v1") { + inbound.req.X-Env = "$env"; + inbound.req.X-Version = "$version"; + } + + REMAP { + local::tag-request("prod"); + # $version defaults to "v1" + } + Groups ------ diff --git a/tools/hrw4u/Makefile b/tools/hrw4u/Makefile index 33ab62873c2..826025451cb 100644 --- a/tools/hrw4u/Makefile +++ b/tools/hrw4u/Makefile @@ -51,7 +51,8 @@ UTILS_FILES=src/symbols_base.py \ SRC_FILES_HRW4U=src/visitor.py \ src/symbols.py \ src/suggestions.py \ - src/kg_visitor.py + src/kg_visitor.py \ + src/procedures.py ALL_HRW4U_FILES=$(SHARED_FILES) $(UTILS_FILES) $(SRC_FILES_HRW4U) diff --git a/tools/hrw4u/grammar/hrw4u.g4 b/tools/hrw4u/grammar/hrw4u.g4 index 889bbddc569..48fe316c5ca 100644 --- a/tools/hrw4u/grammar/hrw4u.g4 +++ b/tools/hrw4u/grammar/hrw4u.g4 @@ -29,9 +29,14 @@ TRUE : [tT][rR][uU][eE]; FALSE : [fF][aA][lL][sS][eE]; WITH : 'with'; BREAK : 'break'; +USE : 'use'; +PROCEDURE : 'procedure'; REGEX : '/' ( '\\/' | ~[/\r\n] )* '/' ; -STRING : '"' ( '\\' . | ~["\\\r\n] )* '"' ; +STRING : '"' ( ESCAPED_BLOCK | '\\' . | ~["\\\r\n] )* '"' ; + +// {{ ... }} is an escape hatch — contents are passed through verbatim, inner quotes allowed +fragment ESCAPED_BLOCK : '{{' ( ~'}' | '}' ~'}' )* '}}'; IPV4_LITERAL : (OCTET '.' OCTET '.' OCTET '.' OCTET ('/' IPV4_CIDR)?) @@ -59,8 +64,13 @@ fragment IPV6_CIDR : '3'[3-9] | '12'[0-8] ; +// Qualified identifier: Namespace::Name (one or more :: segments). +QUALIFIED_IDENT : [a-zA-Z_][a-zA-Z0-9_-]* ('::' [a-zA-Z_][a-zA-Z0-9_-]*)+ + ; + IDENT : [a-zA-Z_][a-zA-Z0-9_@.-]* ; NUMBER : [0-9]+ ; +DOLLAR : '$'; LPAREN : '('; RPAREN : ')'; LBRACE : '{'; @@ -89,14 +99,36 @@ WS : [ \t\r\n]+ -> skip ; // Parser Rules // ----------------------------- program - : programItem+ EOF + : programItem* EOF ; programItem - : section + : useDirective + | procedureDecl + | section | commentLine ; +useDirective + : USE QUALIFIED_IDENT + ; + +procedureDecl + : PROCEDURE QUALIFIED_IDENT LPAREN paramList? RPAREN block + ; + +paramList + : param (COMMA param)* + ; + +param + : DOLLAR IDENT (EQUAL value)? + ; + +paramRef + : DOLLAR IDENT + ; + section : varSection | name=IDENT LBRACE sectionBody+ RBRACE @@ -211,7 +243,7 @@ comparable ; functionCall - : funcName=IDENT LPAREN argumentList? RPAREN + : funcName=(IDENT | QUALIFIED_IDENT) LPAREN argumentList? RPAREN ; argumentList @@ -251,6 +283,7 @@ value | ident=IDENT | ip | iprange + | paramRef ; commentLine diff --git a/tools/hrw4u/pyproject.toml b/tools/hrw4u/pyproject.toml index ce607fd5b69..4398e35c7df 100644 --- a/tools/hrw4u/pyproject.toml +++ b/tools/hrw4u/pyproject.toml @@ -20,7 +20,7 @@ build-backend = "setuptools.build_meta" [project] name = "hrw4u" -version = "1.4.1" +version = "1.5.0" description = "HRW4U CLI tool for Apache Traffic Server header rewrite rules" authors = [ {name = "Leif Hedstrom", email = "leif@apache.org"} @@ -76,6 +76,7 @@ markers = [ "examples: marks tests for all header_rewrite docs examples", "reverse: marks tests for reverse conversion (header_rewrite -> hrw4u)", "ast: marks tests for AST validation", + "procedures: marks tests for procedure expansion", ] [dependency-groups] diff --git a/tools/hrw4u/scripts/hrw4u b/tools/hrw4u/scripts/hrw4u index 72dfc82e571..2940a4c970a 100755 --- a/tools/hrw4u/scripts/hrw4u +++ b/tools/hrw4u/scripts/hrw4u @@ -19,12 +19,38 @@ from __future__ import annotations +import argparse +import os +from pathlib import Path +from typing import Any + from hrw4u.hrw4uLexer import hrw4uLexer from hrw4u.hrw4uParser import hrw4uParser from hrw4u.visitor import HRW4UVisitor from hrw4u.common import run_main +def _add_args(parser: argparse.ArgumentParser, output_group: argparse._MutuallyExclusiveGroup) -> None: + output_group.add_argument( + "--output", + choices=["hrw", "hrw4u"], + default="hrw", + help="Output format: hrw (header_rewrite, default) or hrw4u (expand procedures inline)") + parser.add_argument( + "--procedures-path", + metavar="DIR[:DIR...]", + dest="procedures_path", + default="", + help="Colon-separated list of directories to search for procedure files") + + +def _visitor_kwargs(args: argparse.Namespace) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + if args.procedures_path: + kwargs['proc_search_paths'] = [Path(p) for p in args.procedures_path.split(os.pathsep) if p] + return kwargs + + def main() -> None: """Main entry point for the hrw4u script.""" run_main( @@ -34,7 +60,9 @@ def main() -> None: visitor_class=HRW4UVisitor, error_prefix="hrw4u", output_flag_name="hrw", - output_flag_help="Produce the HRW output (default)") + output_flag_help="Produce the HRW output (default)", + add_args=_add_args, + visitor_kwargs=_visitor_kwargs) if __name__ == "__main__": diff --git a/tools/hrw4u/scripts/hrw4u-lsp b/tools/hrw4u/scripts/hrw4u-lsp index 9b16f75eeef..ce78da8e4c0 100755 --- a/tools/hrw4u/scripts/hrw4u-lsp +++ b/tools/hrw4u/scripts/hrw4u-lsp @@ -23,54 +23,22 @@ import json import os import sys import urllib.parse -from functools import lru_cache +from pathlib import Path from typing import Any -from antlr4.error.ErrorListener import ErrorListener - from hrw4u.hrw4uLexer import hrw4uLexer from hrw4u.hrw4uParser import hrw4uParser -from hrw4u.visitor import HRW4UVisitor +from hrw4u.visitor import HRW4UVisitor, ProcSig from hrw4u.common import create_parse_tree -from hrw4u.tables import FUNCTION_MAP, STATEMENT_FUNCTION_MAP -from hrw4u.states import SectionType from hrw4u.types import VarType, LanguageKeyword +from hrw4u.procedures import resolve_use_path -from hrw4u_lsp.lsp.strings import ( - StringLiteralHandler, ContextAnalyzer, ExpressionParser, DocumentAnalyzer, CompletionContext, LSPDiagnostic, - VariableDeclaration) +from hrw4u_lsp.lsp.strings import (StringLiteralHandler, ContextAnalyzer, ExpressionParser, DocumentAnalyzer) +from hrw4u_lsp.lsp.types import LSPDiagnostic, CompletionContext, VariableDeclaration from hrw4u_lsp.lsp.hover import ( HoverInfoProvider, FunctionHoverProvider, VariableHoverProvider, SectionHoverProvider, RegexHoverProvider, ModifierHoverProvider, OperatorHoverProvider) from hrw4u_lsp.lsp.completions import CompletionProvider -from hrw4u_lsp.lsp.documentation import LSP_FUNCTION_DOCUMENTATION - - -class LSPErrorListener(ErrorListener): - - def __init__(self) -> None: - super().__init__() - self.errors: list[dict[str, Any]] = [] - - def syntaxError(self, _, offendingSymbol, line, column, msg, e) -> None: - self.errors.append( - { - "range": - { - "start": { - "line": line - 1, - "character": column - }, - "end": - { - "line": line - 1, - "character": column + (len(str(offendingSymbol.text)) if offendingSymbol.text else 1) - } - }, - "severity": 1, - "message": msg, - "source": "hrw4u-parser" - }) class DocumentManager: @@ -80,21 +48,20 @@ class DocumentManager: self.documents: dict[str, str] = {} self.diagnostics: dict[str, list[LSPDiagnostic]] = {} self.variable_declarations: dict[str, dict[str, VariableDeclaration]] = {} + self.proc_registries: dict[str, dict[str, ProcSig]] = {} self._completion_provider = CompletionProvider() self._uri_path_cache: dict[str, str] = {} + self.proc_search_paths: list[Path] = [] def _add_operator_completions(self, completions: list, base_prefix: str, current_section, context: CompletionContext) -> None: - """Add operator and condition completions.""" operator_completions = self._completion_provider.get_operator_completions( base_prefix, current_section, context["replacement_range"]) completions.extend(operator_completions) - def _add_function_completions(self, completions: list, function_map: dict, function_type: str) -> None: - """Add function completions using centralized provider.""" + def _add_function_completions(self, completions: list) -> None: completions.extend(self._completion_provider.get_function_completions()) def _add_keyword_completions(self, completions: list) -> None: - """Add keyword completions using centralized provider.""" completions.extend(self._completion_provider.get_keyword_completions()) def open_document(self, uri: str, text: str) -> None: @@ -115,6 +82,7 @@ class DocumentManager: del self.diagnostics[uri] if uri in self.variable_declarations: del self.variable_declarations[uri] + self.proc_registries.pop(uri, None) def _uri_to_path(self, uri: str) -> str: """Convert a URI to a file path with caching.""" @@ -129,13 +97,16 @@ class DocumentManager: self._uri_path_cache[uri] = path return path + @staticmethod + def _path_to_uri(file_path: str | Path) -> str: + return "file://" + urllib.parse.quote(str(file_path), safe="/:@") + def _analyze_document(self, uri: str) -> None: """Analyze document and collect diagnostics.""" text = self.documents.get(uri, "") diagnostics = [] self.variable_declarations[uri] = DocumentAnalyzer.parse_variable_declarations(text) - diagnostics.extend(DocumentAnalyzer.validate_section_names(text)) try: filename = self._uri_to_path(uri) @@ -147,9 +118,11 @@ class DocumentManager: parser_errors = list(parser._syntax_errors) if tree is not None: - visitor = HRW4UVisitor(filename=filename, error_collector=error_collector) + visitor = HRW4UVisitor( + filename=filename, error_collector=error_collector, proc_search_paths=self.proc_search_paths or None) try: visitor.visit(tree) + self.proc_registries[uri] = dict(visitor._proc_registry) except Exception as e: diagnostics.append( { @@ -271,15 +244,7 @@ class DocumentManager: "source": "hrw4u-parser" }) - # Convert generic diagnostics to typed diagnostics where possible - typed_diagnostics: list[LSPDiagnostic] = [] - for diag in diagnostics: - if isinstance(diag, dict): - # For backward compatibility with generic dict diagnostics - typed_diagnostics.append(diag) # type: ignore - else: - typed_diagnostics.append(diag) - self.diagnostics[uri] = typed_diagnostics + self.diagnostics[uri] = diagnostics def get_diagnostics(self, uri: str) -> list[LSPDiagnostic]: return self.diagnostics.get(uri, []) @@ -304,8 +269,7 @@ class DocumentManager: self._add_operator_completions(completions, base_prefix, current_section, context) elif context["is_function_context"]: - self._add_function_completions(completions, FUNCTION_MAP, "Function") - self._add_function_completions(completions, STATEMENT_FUNCTION_MAP, "Statement") + self._add_function_completions(completions) # Variable type completions (in VARS context) if context["current_section"] and context["current_section"].value == "VARS": @@ -316,6 +280,53 @@ class DocumentManager: return completions + def _format_proc_signature(self, sig: ProcSig) -> str: + parts = [] + for p in sig.params: + if p.default_ctx is not None: + default_text = p.default_ctx.getText() + parts.append(f"${p.name} = {default_text}") + else: + parts.append(f"${p.name}") + params = ", ".join(parts) + return f"**procedure {sig.qualified_name}**({params})\n\nSource: `{sig.source_file}`" + + def _get_procedure_hover(self, uri: str, current_line: str, line: int, character: int) -> dict[str, Any] | None: + stripped = current_line.strip() + registry = self.proc_registries.get(uri, {}) + + # `use Namespace::Name` directive + if stripped.startswith("use "): + spec = stripped[4:].strip() + if '::' not in spec: + return None + if not self.proc_search_paths: + return HoverInfoProvider.create_hover_info(f"**use {spec}**\n\nNo procedures path configured") + resolved = resolve_use_path(spec, self.proc_search_paths) + if resolved: + return HoverInfoProvider.create_hover_info(f"**use {spec}** → `{resolved}`") + return HoverInfoProvider.create_hover_info(f"**use {spec}**\n\nFile not found in procedures path") + + # `procedure Namespace::Name(...)` declaration + if stripped.startswith("procedure "): + rest = stripped[10:].strip() + paren = rest.find('(') + name = rest[:paren].strip() if paren != -1 else rest.split()[0] if rest else "" + if name and '::' in name: + sig = registry.get(name) + if sig: + return HoverInfoProvider.create_hover_info(self._format_proc_signature(sig)) + return HoverInfoProvider.create_hover_info(f"**procedure {name}**") + + # Qualified name at cursor (Namespace::Name call or reference) + word = self._extract_qualified_name(current_line, character) + if word: + sig = registry.get(word) + if sig: + return HoverInfoProvider.create_hover_info(self._format_proc_signature(sig)) + + return None + def get_hover_info(self, uri: str, line: int, character: int) -> dict[str, Any] | None: """Get hover information for the symbol at the given position.""" text = self.documents.get(uri, "") @@ -332,6 +343,10 @@ class DocumentManager: if comment_start != -1 and character >= comment_start: return None + proc_hover = self._get_procedure_hover(uri, current_line, line, character) + if proc_hover: + return proc_hover + string_info = StringLiteralHandler.check_string_literal(current_line, character) if string_info: return string_info @@ -396,6 +411,84 @@ class DocumentManager: return None + def _extract_qualified_name(self, current_line: str, character: int) -> str | None: + start = character + end = character + + while start > 0 and (current_line[start - 1].isalnum() or current_line[start - 1] in ':_-'): + start -= 1 + while end < len(current_line) and (current_line[end].isalnum() or current_line[end] in ':_-'): + end += 1 + + word = current_line[start:end] + return word if '::' in word else None + + def get_definition(self, uri: str, line: int, character: int) -> dict[str, Any] | None: + text = self.documents.get(uri, "") + lines = text.split('\n') + + if line >= len(lines): + return None + + current_line = lines[line] + if character >= len(current_line): + return None + + stripped = current_line.strip() + registry = self.proc_registries.get(uri, {}) + + # `use Namespace::Name` → open the resolved file + if stripped.startswith("use "): + spec = stripped[4:].strip() + if '::' not in spec or not self.proc_search_paths: + return None + resolved = resolve_use_path(spec, self.proc_search_paths) + if resolved: + return { + "uri": self._path_to_uri(resolved), + "range": { + "start": { + "line": 0, + "character": 0 + }, + "end": { + "line": 0, + "character": 0 + } + } + } + return None + + # `procedure` declaration or qualified name call → jump to definition + name = None + if stripped.startswith("procedure "): + rest = stripped[10:].strip() + paren = rest.find('(') + name = rest[:paren].strip() if paren != -1 else rest.split()[0] if rest else "" + else: + name = self._extract_qualified_name(current_line, character) + + if name: + sig = registry.get(name) + if sig: + def_line = getattr(sig.body_ctx, 'start', None) + target_line = (def_line.line - 1) if def_line else 0 + return { + "uri": self._path_to_uri(sig.source_file), + "range": { + "start": { + "line": target_line, + "character": 0 + }, + "end": { + "line": target_line, + "character": 0 + } + } + } + + return None + class HRW4ULanguageServer: """Main LSP server implementation.""" @@ -457,6 +550,8 @@ class HRW4ULanguageServer: self._handle_completion(message) elif method == "textDocument/hover": self._handle_hover(message) + elif method == "textDocument/definition": + self._handle_definition(message) elif method == "textDocument/codeAction": self._handle_code_action(message) elif method == "shutdown": @@ -465,6 +560,13 @@ class HRW4ULanguageServer: self.running = False def _handle_initialize(self, message: dict[str, Any]) -> None: + params = message.get("params", {}) + init_options = params.get("initializationOptions", {}) + if init_options: + procedures_path = init_options.get("proceduresPath", "") + if procedures_path: + self.document_manager.proc_search_paths = [Path(p) for p in procedures_path.split(os.pathsep) if p] + response = { "jsonrpc": "2.0", "id": message["id"], @@ -477,6 +579,7 @@ class HRW4ULanguageServer: "triggerCharacters": ["."] }, "hoverProvider": True, + "definitionProvider": True, "codeActionProvider": True } } @@ -536,6 +639,16 @@ class HRW4ULanguageServer: response = {"jsonrpc": "2.0", "id": message["id"], "result": hover_info} self._send_message(response) + def _handle_definition(self, message: dict[str, Any]) -> None: + params = message["params"] + uri = params["textDocument"]["uri"] + position = params["position"] + + location = self.document_manager.get_definition(uri, position["line"], position["character"]) + + response = {"jsonrpc": "2.0", "id": message["id"], "result": location} + self._send_message(response) + def _handle_code_action(self, message: dict[str, Any]) -> None: """Handle code action requests for quick fixes.""" params = message["params"] diff --git a/tools/hrw4u/scripts/testcase.py b/tools/hrw4u/scripts/testcase.py index 802e32f2166..19b75fed6bd 100755 --- a/tools/hrw4u/scripts/testcase.py +++ b/tools/hrw4u/scripts/testcase.py @@ -26,12 +26,10 @@ from hrw4u.hrw4uParser import hrw4uParser from hrw4u.visitor import HRW4UVisitor -KNOWN_MARKS = {"hooks", "conds", "ops", "vars", "examples", "invalid"} +KNOWN_MARKS = {"hooks", "conds", "ops", "vars", "examples", "invalid", "procedures"} def load_exceptions(test_dir: Path) -> dict[str, str]: - """Load exceptions from exceptions.txt in the test directory. - Returns a dict mapping test filename to direction (hrw4u or u4wrh).""" exceptions_file = test_dir / "exceptions.txt" exceptions = {} @@ -52,7 +50,14 @@ def load_exceptions(test_dir: Path) -> dict[str, str]: return exceptions -def parse_tree(input_text: str) -> tuple[hrw4uParser, any]: +def _proc_search_paths(input_path: Path) -> list[Path] | None: + procs_dir = input_path.parent / 'procs' + if procs_dir.is_dir(): + return [procs_dir] + return None + + +def parse_tree(input_text: str) -> tuple[hrw4uParser, hrw4uParser.ProgramContext]: stream = InputStream(input_text) lexer = hrw4uLexer(stream) tokens = CommonTokenStream(lexer) @@ -90,7 +95,7 @@ def process_file( if input_path.name.endswith(".fail.input.txt"): try: parser, tree = parse_tree(input_text) - visitor = HRW4UVisitor(filename=str(input_path)) + visitor = HRW4UVisitor(filename=str(input_path), proc_search_paths=_proc_search_paths(input_path)) visitor.visit(tree) print(f"Unexpected success: {input_path}") return False @@ -107,7 +112,8 @@ def process_file( try: parser, tree = parse_tree(input_text) ast_text = tree.toStringTree(recog=parser).strip() - output_text = "\n".join(HRW4UVisitor(filename=str(input_path)).visit(tree)).strip() + output_text = "\n".join( + HRW4UVisitor(filename=str(input_path), proc_search_paths=_proc_search_paths(input_path)).visit(tree)).strip() if update_ast: ast_path.write_text(ast_text + "\n") diff --git a/tools/hrw4u/src/common.py b/tools/hrw4u/src/common.py index d694f0d58de..7ca9c92ed46 100644 --- a/tools/hrw4u/src/common.py +++ b/tools/hrw4u/src/common.py @@ -20,7 +20,7 @@ import argparse import re import sys -from typing import Final, NoReturn, Protocol, TextIO, Any +from typing import Final, NoReturn, Protocol, TextIO, Any, Callable from antlr4.error.ErrorStrategy import BailErrorStrategy, DefaultErrorStrategy from antlr4 import InputStream, CommonTokenStream @@ -39,11 +39,13 @@ class RegexPatterns: PERCENT_INLINE: Final = re.compile(r"%\{([A-Z0-9_-]+)(?::(.*?))?\}") PERCENT_PATTERN: Final = re.compile(r'%\{([^}]+)\}') SUBSTITUTE_PATTERN: Final = re.compile( - r"""(?[a-zA-Z_][a-zA-Z0-9_-]*)\s*\((?P[^)]*)\)\s*\} + r"""(?P\{\{.*?\}\}) + | + (?[a-zA-Z_][a-zA-Z0-9_-]*)\s*\((?P[^)]*)\)\s*\} | (?[^{}()]+)\} """, - re.VERBOSE, + re.VERBOSE | re.DOTALL, ) # Additional performance patterns @@ -148,13 +150,14 @@ def create_parse_tree( lexer_class: type[LexerProtocol], parser_class: type[ParserProtocol], error_prefix: str, - collect_errors: bool = True) -> tuple[Any, ParserProtocol, ErrorCollector | None]: + collect_errors: bool = True, + max_errors: int = 5) -> tuple[Any, ParserProtocol, ErrorCollector | None]: """Create ANTLR parse tree from input content with optional error collection.""" input_stream = InputStream(content) error_collector = None if collect_errors: - error_collector = ErrorCollector() + error_collector = ErrorCollector(max_errors=max_errors) error_listener = CollectingErrorListener(filename=filename, error_collector=error_collector) else: error_listener = ThrowingErrorListener(filename=filename) @@ -199,7 +202,8 @@ def generate_output( visitor_class: type[VisitorProtocol], filename: str, args: Any, - error_collector: ErrorCollector | None = None) -> None: + error_collector: ErrorCollector | None = None, + visitor_kwargs: Callable[[Any], dict[str, Any]] | None = None) -> None: """Generate and print output based on mode with optional error collection.""" if args.ast: if tree is not None: @@ -209,10 +213,18 @@ def generate_output( else: if tree is not None: preserve_comments = not getattr(args, 'no_comments', False) + extra_kwargs = visitor_kwargs(args) if visitor_kwargs else {} visitor = visitor_class( - filename=filename, debug=args.debug, error_collector=error_collector, preserve_comments=preserve_comments) + filename=filename, + debug=args.debug, + error_collector=error_collector, + preserve_comments=preserve_comments, + **extra_kwargs) try: - result = visitor.visit(tree) + if getattr(args, 'output', None) == 'hrw4u': + result = visitor.flatten(tree) + else: + result = visitor.visit(tree) if result: print("\n".join(result)) except Exception as e: @@ -232,8 +244,16 @@ def generate_output( def run_main( - description: str, lexer_class: type[LexerProtocol], parser_class: type[ParserProtocol], - visitor_class: type[VisitorProtocol], error_prefix: str, output_flag_name: str, output_flag_help: str) -> None: + description: str, + lexer_class: type[LexerProtocol], + parser_class: type[ParserProtocol], + visitor_class: type[VisitorProtocol], + error_prefix: str, + output_flag_name: str, + output_flag_help: str, + add_args: Callable[[argparse.ArgumentParser, argparse._MutuallyExclusiveGroup], None] | None = None, + pre_process: Callable[[str, str, Any], str] | None = None, + visitor_kwargs: Callable[[Any], dict[str, Any]] | None = None) -> None: """ Generic main function for hrw4u and u4wrh scripts with bulk compilation support. @@ -245,6 +265,8 @@ def run_main( error_prefix: Error prefix for error messages output_flag_name: Name of output flag (e.g., "hrw", "hrw4u") output_flag_help: Help text for output flag + add_args: Optional callback to add extra arguments to the parser and output group + pre_process: Optional callback(content, filename, args) -> content run before parsing """ parser = argparse.ArgumentParser( description=description, @@ -262,20 +284,29 @@ def run_main( parser.add_argument("--debug", action="store_true", help="Enable debug output") parser.add_argument( "--stop-on-error", action="store_true", help="Stop processing on first error (default: collect and report multiple errors)") + parser.add_argument( + "--max-errors", + type=int, + default=5, + dest="max_errors", + help="Maximum number of errors to report before stopping (default: 5; ignored with --stop-on-error)") - args = parser.parse_args() - - if not hasattr(args, output_flag_name): - setattr(args, output_flag_name, False) + if add_args is not None: + add_args(parser, output_group) - if not (args.ast or getattr(args, output_flag_name)): - setattr(args, output_flag_name, True) + args = parser.parse_args() if not args.files: content, filename = process_input(sys.stdin) + if pre_process is not None: + try: + content = pre_process(content, filename, args) + except Hrw4uSyntaxError as e: + print(str(e), file=sys.stderr) + sys.exit(1) tree, parser_obj, error_collector = create_parse_tree( - content, filename, lexer_class, parser_class, error_prefix, not args.stop_on_error) - generate_output(tree, parser_obj, visitor_class, filename, args, error_collector) + content, filename, lexer_class, parser_class, error_prefix, not args.stop_on_error, args.max_errors) + generate_output(tree, parser_obj, visitor_class, filename, args, error_collector, visitor_kwargs) return if any(':' in f for f in args.files): @@ -299,15 +330,21 @@ def run_main( print(f"Error reading '{input_path}': {e}", file=sys.stderr) sys.exit(1) + if pre_process is not None: + try: + content = pre_process(content, filename, args) + except Hrw4uSyntaxError as e: + print(str(e), file=sys.stderr) + sys.exit(1) tree, parser_obj, error_collector = create_parse_tree( - content, filename, lexer_class, parser_class, error_prefix, not args.stop_on_error) + content, filename, lexer_class, parser_class, error_prefix, not args.stop_on_error, args.max_errors) try: with open(output_path, 'w', encoding='utf-8') as output_file: original_stdout = sys.stdout try: sys.stdout = output_file - generate_output(tree, parser_obj, visitor_class, filename, args, error_collector) + generate_output(tree, parser_obj, visitor_class, filename, args, error_collector, visitor_kwargs) finally: sys.stdout = original_stdout except Exception as e: @@ -329,7 +366,13 @@ def run_main( print(f"Error reading '{input_path}': {e}", file=sys.stderr) sys.exit(1) + if pre_process is not None: + try: + content = pre_process(content, filename, args) + except Hrw4uSyntaxError as e: + print(str(e), file=sys.stderr) + sys.exit(1) tree, parser_obj, error_collector = create_parse_tree( - content, filename, lexer_class, parser_class, error_prefix, not args.stop_on_error) + content, filename, lexer_class, parser_class, error_prefix, not args.stop_on_error, args.max_errors) - generate_output(tree, parser_obj, visitor_class, filename, args, error_collector) + generate_output(tree, parser_obj, visitor_class, filename, args, error_collector, visitor_kwargs) diff --git a/tools/hrw4u/src/errors.py b/tools/hrw4u/src/errors.py index a2710fb840f..7b37a939bc2 100644 --- a/tools/hrw4u/src/errors.py +++ b/tools/hrw4u/src/errors.py @@ -17,11 +17,51 @@ from __future__ import annotations +import re +from typing import Final + from antlr4.error.ErrorListener import ErrorListener +_TOKEN_NAMES: Final[dict[str, str]] = { + 'QUALIFIED_IDENT': "qualified name (e.g. 'Namespace::Name')", + 'IDENT': 'identifier', + 'LPAREN': "'('", + 'RPAREN': "')'", + 'LBRACE': "'{'", + 'RBRACE': "'}'", + 'LBRACKET': "'['", + 'RBRACKET': "']'", + 'SEMICOLON': "';'", + 'COLON': "':'", + 'COMMA': "','", + 'EQUAL': "'='", + 'EQUALS': "'=='", + 'PLUSEQUAL': "'+='", + 'NEQ': "'!='", + 'DOLLAR': "'$'", + 'STRING': 'string literal', + 'NUMBER': 'number', + 'REGEX': 'regex pattern', + 'IPV4_LITERAL': 'IPv4 address', + 'IPV6_LITERAL': 'IPv6 address', + 'COMMENT': 'comment', + 'AND': "'&&'", + 'OR': "'||'", + 'TILDE': "'~'", + 'NOT_TILDE': "'!~'", + 'GT': "'>'", + 'LT': "'<'", + 'AT': "'@'", +} + +_TOKEN_PATTERN: Final = re.compile(r'\b(' + '|'.join(re.escape(k) for k in sorted(_TOKEN_NAMES, key=len, reverse=True)) + r')\b') + + +def humanize_error_message(msg: str) -> str: + return _TOKEN_PATTERN.sub(lambda m: _TOKEN_NAMES[m.group(1)], msg) + class ThrowingErrorListener(ErrorListener): - """ANTLR error listener that throws exceptions on syntax errors.""" def __init__(self, filename: str = "") -> None: super().__init__() @@ -42,11 +82,10 @@ def syntaxError(self, recognizer: object, _: object, line: int, column: int, msg except Exception: pass - raise Hrw4uSyntaxError(self.filename, line, column, msg, code_line) + raise Hrw4uSyntaxError(self.filename, line, column, humanize_error_message(msg), code_line) class Hrw4uSyntaxError(Exception): - """Formatted syntax error with source context and Python 3.11+ exception notes.""" def __init__(self, filename: str, line: int, column: int, message: str, source_line: str) -> None: super().__init__(self._format_error(filename, line, column, message, source_line)) @@ -56,11 +95,9 @@ def __init__(self, filename: str, line: int, column: int, message: str, source_l self.source_line = source_line def add_context_note(self, context: str) -> None: - """Add contextual information using Python 3.11+ exception notes.""" self.add_note(f"Context: {context}") def add_resolution_hint(self, hint: str) -> None: - """Add resolution hint using Python 3.11+ exception notes.""" self.add_note(f"Hint: {hint}") def _format_error(self, filename: str, line: int, col: int, message: str, source_line: str) -> str: @@ -84,7 +121,6 @@ def add_symbol_suggestion(self, suggestions: list[str]) -> None: def hrw4u_error(filename: str, ctx: object, exc: Exception) -> Hrw4uSyntaxError: - """Convert exceptions to formatted syntax errors with source context.""" if isinstance(exc, Hrw4uSyntaxError): return exc @@ -107,24 +143,21 @@ def hrw4u_error(filename: str, ctx: object, exc: Exception) -> Hrw4uSyntaxError: class ErrorCollector: - """Collects multiple syntax errors for comprehensive error reporting.""" - def __init__(self) -> None: - """Initialize an empty error collector.""" + def __init__(self, max_errors: int = 5) -> None: self.errors: list[Hrw4uSyntaxError] = [] + self.max_errors = max_errors def add_error(self, error: Hrw4uSyntaxError) -> None: - """ - Add a syntax error to the collection. - """ self.errors.append(error) def has_errors(self) -> bool: - """ - Check if any errors have been collected. - """ return bool(self.errors) + @property + def at_limit(self) -> bool: + return len(self.errors) >= self.max_errors + def get_error_summary(self) -> str: if not self.errors: return "No errors found." @@ -137,11 +170,13 @@ def get_error_summary(self) -> str: if hasattr(error, '__notes__') and error.__notes__: lines.extend(error.__notes__) + if self.at_limit: + lines.append(f"(stopped after {self.max_errors} errors)") + return "\n".join(lines) class CollectingErrorListener(ErrorListener): - """ANTLR error listener that collects syntax errors for tolerant parsing.""" def __init__(self, filename: str = "", error_collector: ErrorCollector | None = None) -> None: super().__init__() @@ -162,5 +197,8 @@ def syntaxError(self, recognizer: object, _: object, line: int, column: int, msg except Exception: pass - error = Hrw4uSyntaxError(self.filename, line, column, msg, code_line) + error = Hrw4uSyntaxError(self.filename, line, column, humanize_error_message(msg), code_line) self.error_collector.add_error(error) + + if self.error_collector.at_limit: + raise error diff --git a/tools/hrw4u/src/hrw_symbols.py b/tools/hrw4u/src/hrw_symbols.py index 558827d6db3..001358ddb1a 100644 --- a/tools/hrw4u/src/hrw_symbols.py +++ b/tools/hrw4u/src/hrw_symbols.py @@ -21,6 +21,7 @@ import re from hrw4u.errors import SymbolResolutionError +from hrw4u.debugging import Dbg from hrw4u.validation import Validator import hrw4u.types as types import hrw4u.tables as tables @@ -31,8 +32,8 @@ class InverseSymbolResolver(SymbolResolverBase): """Reverse mapping utilities for hrw4u output generation.""" - def __init__(self) -> None: - super().__init__(debug=False) # Default to no debug for inverse resolver + def __init__(self, dbg: Dbg | None = None) -> None: + super().__init__(debug=False, dbg=dbg) self._state_vars: dict[tuple[types.VarType, int], str] = {} @cached_property diff --git a/tools/hrw4u/src/hrw_visitor.py b/tools/hrw4u/src/hrw_visitor.py index 453fe062b66..3dbdfd6208e 100644 --- a/tools/hrw4u/src/hrw_visitor.py +++ b/tools/hrw4u/src/hrw_visitor.py @@ -51,7 +51,7 @@ def __init__( self._in_group: bool = False self._group_terms: list[tuple[str, CondState]] = [] - self.symbol_resolver = InverseSymbolResolver() + self.symbol_resolver = InverseSymbolResolver(dbg=self._dbg) self._section_opened = False self._if_depth = 0 # Track nesting depth of if blocks @@ -132,7 +132,7 @@ def _build_expression_parts(self, terms: list[tuple[str, CondState]]) -> str: if state.not_: processed_term = self.symbol_resolver.negate_expression(term) else: - processed_term = self._normalize_empty_string_condition(term, state) + processed_term = self._normalize_empty_string_condition(term) processed_term = self._apply_with_modifiers(processed_term, state) self.debug(f"processed term {idx}: {processed_term}") diff --git a/tools/hrw4u/src/lsp/hover.py b/tools/hrw4u/src/lsp/hover.py index 6a6f6c92c51..0beffd0d2a3 100644 --- a/tools/hrw4u/src/lsp/hover.py +++ b/tools/hrw4u/src/lsp/hover.py @@ -19,8 +19,7 @@ from __future__ import annotations -import re -from typing import Any, Dict +from typing import Any from . import documentation as doc from hrw4u.tables import OPERATOR_MAP, CONDITION_MAP, FUNCTION_MAP, STATEMENT_FUNCTION_MAP, LSPPatternMatcher @@ -32,13 +31,13 @@ class HoverInfoProvider: """Centralized provider for hover information generation.""" @staticmethod - def create_hover_info(markdown_content: str) -> Dict[str, Any]: + def create_hover_info(markdown_content: str) -> dict[str, Any]: """Create a standardized hover info dictionary.""" return {"contents": {"kind": "markdown", "value": markdown_content}} @staticmethod def create_field_interpolation_hover( - expression: str, field_display: str, field_desc: str, context: str, maps_to: str, usage: str = None) -> Dict[str, Any]: + expression: str, field_display: str, field_desc: str, context: str, maps_to: str, usage: str = None) -> dict[str, Any]: """Create hover info for field interpolations with a standard format.""" usage_text = usage or "Used in string value interpolation." @@ -57,7 +56,7 @@ def create_field_hover(expression: str, field_desc: str, context: str, maps_to: str, - usage: str = None) -> Dict[str, Any]: + usage: str = None) -> dict[str, Any]: """Create hover info for field expressions with a standard format.""" usage_text = usage or "Used in expression evaluation." @@ -75,7 +74,7 @@ class CertificateHoverProvider: """Specialized hover provider for certificate expressions.""" @staticmethod - def parse_certificate_expression(expression: str, is_interpolation: bool = False) -> Dict[str, Any] | None: + def parse_certificate_expression(expression: str, is_interpolation: bool = False) -> dict[str, Any] | None: """Parse certificate expressions using table-driven approach.""" parsed_data = doc.CertificatePattern.parse_certificate_expression(expression, is_interpolation) @@ -90,7 +89,7 @@ class InterpolationHoverProvider: """Specialized hover provider for string interpolation expressions.""" @staticmethod - def get_interpolated_expression_info(expression: str) -> Dict[str, Any] | None: + def get_interpolated_expression_info(expression: str) -> dict[str, Any] | None: """Get hover info for interpolated expressions.""" # Try table-driven pattern matching first if match := LSPPatternMatcher.match_any_pattern(expression): @@ -118,7 +117,7 @@ def get_interpolated_expression_info(expression: str) -> Dict[str, Any] | None: return None @staticmethod - def _handle_pattern_match(match, expression: str, is_interpolation: bool = False) -> Dict[str, Any] | None: + def _handle_pattern_match(match, expression: str, is_interpolation: bool = False) -> dict[str, Any] | None: """Handle a matched pattern and generate appropriate hover info.""" if match.context_type == 'Certificate': return CertificateHoverProvider.parse_certificate_expression(expression, is_interpolation=is_interpolation) @@ -141,7 +140,7 @@ class DottedExpressionHoverProvider: """Specialized hover provider for dotted expressions like outbound.req.X-Field.""" @staticmethod - def parse_dotted_expression(full_expression: str, character_pos: int, expr_start: int) -> Dict[str, Any] | None: + def parse_dotted_expression(full_expression: str, character_pos: int, expr_start: int) -> dict[str, Any] | None: """Parse dotted expressions and provide appropriate hover info.""" cursor_pos = character_pos - expr_start @@ -210,7 +209,7 @@ def parse_dotted_expression(full_expression: str, character_pos: int, expr_start return None @staticmethod - def _handle_pattern_match(match, full_expression: str, cursor_pos: int, expr_start: int) -> Dict[str, Any] | None: + def _handle_pattern_match(match, full_expression: str, cursor_pos: int, expr_start: int) -> dict[str, Any] | None: """Handle a matched pattern for dotted expressions.""" pattern_len = len(match.pattern) @@ -257,7 +256,7 @@ def _handle_pattern_match(match, full_expression: str, cursor_pos: int, expr_sta return None @staticmethod - def _handle_header_suffix(match, full_expression: str) -> Dict[str, Any] | None: + def _handle_header_suffix(match, full_expression: str) -> dict[str, Any] | None: """Handle header field suffix.""" header_name = match.suffix if header_name: @@ -279,7 +278,7 @@ def _handle_header_suffix(match, full_expression: str) -> Dict[str, Any] | None: return None @staticmethod - def _handle_cookie_suffix(match, full_expression: str) -> Dict[str, Any] | None: + def _handle_cookie_suffix(match, full_expression: str) -> dict[str, Any] | None: """Handle cookie field suffix.""" cookie_name = match.suffix if cookie_name: @@ -296,7 +295,7 @@ def _handle_cookie_suffix(match, full_expression: str) -> Dict[str, Any] | None: return None @staticmethod - def _handle_connection_suffix(match, full_expression: str) -> Dict[str, Any] | None: + def _handle_connection_suffix(match, full_expression: str) -> dict[str, Any] | None: """Handle connection field suffix.""" parts = full_expression.split('.') if len(parts) == 3: @@ -316,7 +315,7 @@ def _handle_connection_suffix(match, full_expression: str) -> Dict[str, Any] | N return None @staticmethod - def _handle_field_suffix(match, full_expression: str) -> Dict[str, Any] | None: + def _handle_field_suffix(match, full_expression: str) -> dict[str, Any] | None: """Handle field suffix for now., id., geo. patterns.""" field_dict = getattr(doc, match.field_dict_key) suffix_key = match.suffix.upper() @@ -361,7 +360,7 @@ def _get_known_prefixes(cls) -> set[str]: return prefixes @staticmethod - def get_operator_hover_info(operator: str) -> Dict[str, Any]: + def get_operator_hover_info(operator: str) -> dict[str, Any]: """Get hover info for operators.""" # Handle method field with comprehensive documentation if operator == "inbound.method": @@ -389,89 +388,12 @@ def get_operator_hover_info(operator: str) -> Dict[str, Any]: # Special handling for inbound/outbound header contexts if operator in doc.LSP_SUB_NAMESPACE_DOCUMENTATION: - sub_namespace_doc = doc.LSP_SUB_NAMESPACE_DOCUMENTATION[operator] - sections = [ - f"**{operator}** - {sub_namespace_doc.name}", "", f"**Context:** {sub_namespace_doc.context}", "", - f"**Description:** {sub_namespace_doc.description}", "", - f"**Available items:** {', '.join(sub_namespace_doc.available_items)}", "", f"**Usage:** {sub_namespace_doc.usage}" - ] - if sub_namespace_doc.examples: - sections.extend(["", "**Examples:**"]) - for example in sub_namespace_doc.examples: - sections.append(f"```hrw4u\n{example}\n```") - return HoverInfoProvider.create_hover_info("\n".join(sections)) - - # Check exact matches first - if operator in OPERATOR_MAP: - params = OPERATOR_MAP[operator] - commands = params.target if params else None - if isinstance(commands, str): - cmd_str = commands - elif commands: - cmd_str = ' / '.join(commands) - else: - cmd_str = "unknown" - sections = params.sections if params else None - - section_info = "" - if sections: - section_names = [s.value for s in sections] - section_info = f"\n\n**Restricted in sections:** {', '.join(section_names)}" - - return HoverInfoProvider.create_hover_info( - f"**{operator}** - HRW4U Operator\n\n" + f"**Maps to:** `{cmd_str}`{section_info}") - - # Check prefix matches - for key, params in OPERATOR_MAP.items(): - if key.endswith('.') and operator.startswith(key): - commands = params.target if params else None - if isinstance(commands, str): - cmd_str = commands - elif commands: - cmd_str = ' / '.join(commands) - else: - cmd_str = "unknown" - suffix = operator[len(key):] - sections = params.sections if params else None - - section_info = "" - if sections: - section_names = [s.value for s in sections] - section_info = f"\n\n**Restricted in sections:** {', '.join(section_names)}" + return OperatorHoverProvider._format_namespace_doc(operator, doc.LSP_SUB_NAMESPACE_DOCUMENTATION[operator]) - return HoverInfoProvider.create_hover_info( - f"**{operator}** - HRW4U Operator\n\n" + f"**Base:** `{key}`\n" + f"**Suffix:** `{suffix}`\n" + - f"**Maps to:** `{cmd_str}`{section_info}") - - # Check condition map - if operator in CONDITION_MAP: - params = CONDITION_MAP[operator] - tag = params.target if params else None - sections = params.sections if params else None - - section_info = "" - if sections: - section_names = [s.value for s in sections] - section_info = f"\n\n**Restricted in sections:** {', '.join(section_names)}" - - return HoverInfoProvider.create_hover_info( - f"**{operator}** - HRW4U Condition\n\n" + f"**Maps to:** `{tag}`{section_info}") - - # Check condition prefix matches - for key, params in CONDITION_MAP.items(): - if key.endswith('.') and operator.startswith(key): - tag = params.target if params else None - suffix = operator[len(key):] - sections = params.sections if params else None - - section_info = "" - if sections: - section_names = [s.value for s in sections] - section_info = f"\n\n**Restricted in sections:** {', '.join(section_names)}" - - return HoverInfoProvider.create_hover_info( - f"**{operator}** - HRW4U Condition\n\n" + f"**Base:** `{key}`\n" + f"**Suffix:** `{suffix}`\n" + - f"**Maps to:** `{tag}`{section_info}") + for table, kind in ((OPERATOR_MAP, "Operator"), (CONDITION_MAP, "Condition")): + result = OperatorHoverProvider._lookup_map(operator, table, kind) + if result: + return result # Handle namespace prefixes with comprehensive documentation as fallback namespace_info = OperatorHoverProvider._get_namespace_hover_info(operator) @@ -492,60 +414,76 @@ def get_operator_hover_info(operator: str) -> Dict[str, Any]: return None @staticmethod - def _get_namespace_hover_info(operator: str) -> Dict[str, Any] | None: - """Get comprehensive hover info for namespace prefixes using centralized documentation.""" - # Strip trailing dot for namespace lookup (handles cases like "inbound." -> "inbound") - namespace_key = operator.rstrip('.') - - # First check for sub-namespace patterns (e.g., "inbound.conn", "outbound.req") - if namespace_key in doc.LSP_SUB_NAMESPACE_DOCUMENTATION: - sub_namespace_doc = doc.LSP_SUB_NAMESPACE_DOCUMENTATION[namespace_key] - - # Build the hover content from the sub-namespace documentation - sections = [ - f"**{namespace_key}** - {sub_namespace_doc.name}", "", f"**Context:** {sub_namespace_doc.context}", "", - f"**Description:** {sub_namespace_doc.description}", "", - f"**Available items:** {', '.join(sub_namespace_doc.available_items)}", "", f"**Usage:** {sub_namespace_doc.usage}" - ] + def _format_target(params) -> str: + target = params.target if params else None - if sub_namespace_doc.examples: - sections.extend(["", "**Examples:**"]) - for example in sub_namespace_doc.examples: - sections.append(f"```hrw4u\n{example}\n```") + if isinstance(target, str): + return target + if target: + return ' / '.join(target) + return "unknown" - return HoverInfoProvider.create_hover_info("\n".join(sections)) + @staticmethod + def _format_section_info(params) -> str: + sections = params.sections if params else None - # For single-part namespace documentation, show it unless it's a known condition that should - # take precedence (like standalone "now" in conditional contexts) - # This allows namespace documentation for "geo", "id", etc. while preserving condition behavior - pass # No additional filtering - let the fallback logic in get_operator_hover_info handle it + if not sections: + return "" + return f"\n\n**Restricted in sections:** {', '.join(s.value for s in sections)}" - # Fall back to single-part namespace documentation - if namespace_key not in doc.LSP_NAMESPACE_DOCUMENTATION: - return None + @classmethod + def _lookup_map(cls, operator: str, table: dict, kind: str) -> dict[str, Any] | None: + """Look up operator in a map by exact match, then prefix match.""" + if operator in table: + params = table[operator] + cmd_str = cls._format_target(params) + section_info = cls._format_section_info(params) + return HoverInfoProvider.create_hover_info(f"**{operator}** - HRW4U {kind}\n\n**Maps to:** `{cmd_str}`{section_info}") + + for key, params in table.items(): + if key.endswith('.') and operator.startswith(key): + cmd_str = cls._format_target(params) + suffix = operator[len(key):] + section_info = cls._format_section_info(params) + return HoverInfoProvider.create_hover_info( + f"**{operator}** - HRW4U {kind}\n\n" + f"**Base:** `{key}`\n**Suffix:** `{suffix}`\n" + f"**Maps to:** `{cmd_str}`{section_info}") - namespace_doc = doc.LSP_NAMESPACE_DOCUMENTATION[namespace_key] + return None - # Build the hover content from the centralized documentation + @staticmethod + def _format_namespace_doc(key: str, ns_doc) -> dict[str, Any]: sections = [ - f"**{namespace_key}** - {namespace_doc.name}", "", f"**Context:** {namespace_doc.context}", "", - f"**Description:** {namespace_doc.description}", "", f"**Available items:** {', '.join(namespace_doc.available_items)}", - "", f"**Usage:** {namespace_doc.usage}" + f"**{key}** - {ns_doc.name}", "", f"**Context:** {ns_doc.context}", "", f"**Description:** {ns_doc.description}", "", + f"**Available items:** {', '.join(ns_doc.available_items)}", "", f"**Usage:** {ns_doc.usage}" ] - if namespace_doc.examples: + if ns_doc.examples: sections.extend(["", "**Examples:**"]) - for example in namespace_doc.examples: + for example in ns_doc.examples: sections.append(f"```hrw4u\n{example}\n```") return HoverInfoProvider.create_hover_info("\n".join(sections)) + @staticmethod + def _get_namespace_hover_info(operator: str) -> dict[str, Any] | None: + namespace_key = operator.rstrip('.') + + if namespace_key in doc.LSP_SUB_NAMESPACE_DOCUMENTATION: + return OperatorHoverProvider._format_namespace_doc(namespace_key, doc.LSP_SUB_NAMESPACE_DOCUMENTATION[namespace_key]) + + if namespace_key in doc.LSP_NAMESPACE_DOCUMENTATION: + return OperatorHoverProvider._format_namespace_doc(namespace_key, doc.LSP_NAMESPACE_DOCUMENTATION[namespace_key]) + + return None + class RegexHoverProvider: """Specialized hover provider for regular expression patterns.""" @staticmethod - def get_regex_hover_info(line: str, character: int) -> Dict[str, Any] | None: + def get_regex_hover_info(line: str, character: int) -> dict[str, Any] | None: """Get hover info for regex patterns with brief LSP-appropriate documentation.""" regex_data = doc.RegexPattern.detect_regex_pattern(line, character) if regex_data: @@ -558,7 +496,7 @@ class FunctionHoverProvider: """Specialized hover provider for functions.""" @staticmethod - def get_function_hover_info(function_name: str) -> Dict[str, Any]: + def get_function_hover_info(function_name: str) -> dict[str, Any]: """Get hover info for functions with comprehensive documentation.""" # Check comprehensive documentation first if function_name in doc.LSP_FUNCTION_DOCUMENTATION: @@ -617,7 +555,7 @@ class VariableHoverProvider: """Specialized hover provider for variables and variable types.""" @staticmethod - def get_variable_type_hover_info(var_type: str) -> Dict[str, Any]: + def get_variable_type_hover_info(var_type: str) -> dict[str, Any]: """Get hover info for variable types.""" try: vt = VarType.from_str(var_type.lower()) @@ -628,8 +566,8 @@ def get_variable_type_hover_info(var_type: str) -> Dict[str, Any]: return HoverInfoProvider.create_hover_info(f"**{var_type}** - Unknown Variable Type") @staticmethod - def get_variable_hover_info(variable_declarations: Dict[str, Dict[str, Any]], uri: str, - variable_name: str) -> Dict[str, Any] | None: + def get_variable_hover_info(variable_declarations: Dict[str, dict[str, Any]], uri: str, + variable_name: str) -> dict[str, Any] | None: """Get hover info for declared variables.""" variables = variable_declarations.get(uri, {}) if variable_name in variables: @@ -648,7 +586,7 @@ class SectionHoverProvider: """Specialized hover provider for sections.""" @staticmethod - def get_section_hover_info(section_name: str) -> Dict[str, Any]: + def get_section_hover_info(section_name: str) -> dict[str, Any]: """Get hover info for section names.""" # Don't treat regex patterns as section names! if section_name.startswith('/') or section_name.startswith('~') or '(' in section_name or ')' in section_name: @@ -674,7 +612,7 @@ class ModifierHoverProvider: """Specialized hover provider for condition modifiers used with 'with' keyword.""" @staticmethod - def get_modifier_hover_info(line: str, character: int) -> Dict[str, Any] | None: + def get_modifier_hover_info(line: str, character: int) -> dict[str, Any] | None: """Get hover info for condition modifiers in 'with' clauses.""" modifier_data = doc.ModifierPattern.detect_modifier_list(line, character) if modifier_data: diff --git a/tools/hrw4u/src/lsp/strings.py b/tools/hrw4u/src/lsp/strings.py index 9c89a9df0d4..a16595da8c6 100644 --- a/tools/hrw4u/src/lsp/strings.py +++ b/tools/hrw4u/src/lsp/strings.py @@ -20,24 +20,24 @@ from __future__ import annotations import re -from typing import Any, Dict +from typing import Any from .documentation import LSP_STRING_LITERAL_INFO from .hover import HoverInfoProvider, InterpolationHoverProvider -from .types import (CompletionContext, LSPPosition, VariableDeclaration, DiagnosticRange, LSPDiagnostic) +from .types import (CompletionContext, VariableDeclaration) class StringLiteralHandler: """Handles string literal processing and hover information.""" @staticmethod - def _create_string_literal_hover() -> Dict[str, Any]: + def _create_string_literal_hover() -> dict[str, Any]: """Create standardized string literal hover info.""" return HoverInfoProvider.create_hover_info( f"**{LSP_STRING_LITERAL_INFO['name']}** - HRW4U String Literal\n\n{LSP_STRING_LITERAL_INFO['description']}") @staticmethod - def check_string_literal(line: str, character: int) -> Dict[str, Any] | None: + def check_string_literal(line: str, character: int) -> dict[str, Any] | None: """Check if the cursor is inside a string literal and parse interpolated expressions.""" in_single_quote = False in_double_quote = False @@ -83,7 +83,7 @@ class InterpolationHandler: """Handles string interpolation expression processing.""" @staticmethod - def check_interpolated_expression(string_content: str, cursor_pos: int) -> Dict[str, Any] | None: + def check_interpolated_expression(string_content: str, cursor_pos: int) -> dict[str, Any] | None: """Check if the cursor is over an interpolated expression like {geo.country}.""" # Find all interpolated expressions {expression} for match in re.finditer(r'\{([^}]+)\}', string_content): @@ -171,7 +171,7 @@ class ExpressionParser: """Parses various types of expressions for hover information.""" @staticmethod - def parse_dotted_expression(line: str, character: int) -> Dict[str, Any] | None: + def parse_dotted_expression(line: str, character: int) -> dict[str, Any] | None: """Parse dotted expressions like outbound.req.X-Fie or inbound.req.@X-foo.""" from .hover import DottedExpressionHoverProvider @@ -245,61 +245,3 @@ def parse_variable_declarations(text: str) -> Dict[str, VariableDeclaration]: continue return variable_declarations - - @staticmethod - def validate_section_names(text: str) -> list[LSPDiagnostic]: - """Pre-validate section names in the document with optimization.""" - from hrw4u.states import SectionType - from functools import lru_cache - - @lru_cache(maxsize=128) - def cached_section_validation(section_name: str) -> bool: - """Cache section name validation.""" - valid_sections = {section.value for section in SectionType} - return section_name in valid_sections - - diagnostics = [] - lines = text.split('\n') - - # Pre-compute valid sections once - valid_sections_list = sorted([section.value for section in SectionType]) - valid_sections_str = ', '.join(valid_sections_list) - - for line_num, line in enumerate(lines): - stripped = line.strip() - if not stripped or stripped.startswith(('//', '#')): - continue - - # Look for section declarations (IDENT followed by {) - if '{' in stripped and not stripped.startswith('{'): - parts = stripped.split('{', 1) - potential_section = parts[0].strip() - - # Skip if contains spaces or '=' (not a valid section declaration) - if ' ' in potential_section or '=' in potential_section: - continue - if potential_section.startswith('}'): - continue - - # Skip if it's a known language keyword - from hrw4u.types import LanguageKeyword - - language_keywords = {kw.keyword for kw in LanguageKeyword} - if potential_section.lower() in language_keywords: - continue - - # Use cached validation - if potential_section and not cached_section_validation(potential_section): - start_pos = line.find(potential_section) - end_pos = start_pos + len(potential_section) - - diagnostics.append( - LSPDiagnostic( - range=DiagnosticRange( - start=LSPPosition(line=line_num, character=start_pos), - end=LSPPosition(line=line_num, character=end_pos)), - severity=1, - message=f"Invalid section name: '{potential_section}'. Valid sections are: {valid_sections_str}", - source="hrw4u-section-validator")) - - return diagnostics diff --git a/tools/hrw4u/src/procedures.py b/tools/hrw4u/src/procedures.py new file mode 100644 index 00000000000..29e5ddc122f --- /dev/null +++ b/tools/hrw4u/src/procedures.py @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Procedure path resolution for hrw4u. +""" + +from __future__ import annotations + +from pathlib import Path + + +def resolve_use_path(spec: str, search_paths: list[Path]) -> Path | None: + """Resolve 'foo::bar::Baz' to /foo/bar/Baz.hrw4u.""" + rel = Path(*spec.split('::')).with_suffix('.hrw4u') + return next((base / rel for base in search_paths if (base / rel).is_file()), None) diff --git a/tools/hrw4u/src/symbols.py b/tools/hrw4u/src/symbols.py index bb5210923f1..979c9141ac0 100644 --- a/tools/hrw4u/src/symbols.py +++ b/tools/hrw4u/src/symbols.py @@ -22,14 +22,15 @@ import hrw4u.types as types from hrw4u.states import SectionType from hrw4u.common import SystemDefaults +from hrw4u.debugging import Dbg from hrw4u.symbols_base import SymbolResolverBase from hrw4u.suggestions import SuggestionEngine class SymbolResolver(SymbolResolverBase): - def __init__(self, debug: bool = SystemDefaults.DEFAULT_DEBUG) -> None: - super().__init__(debug) + def __init__(self, debug: bool = SystemDefaults.DEFAULT_DEBUG, dbg: Dbg | None = None) -> None: + super().__init__(debug, dbg=dbg) self._symbols: dict[str, types.Symbol] = {} self._var_counter = {vt: 0 for vt in types.VarType} self._suggestion_engine = SuggestionEngine() diff --git a/tools/hrw4u/src/symbols_base.py b/tools/hrw4u/src/symbols_base.py index 1fb95cd60c0..bce15006fba 100644 --- a/tools/hrw4u/src/symbols_base.py +++ b/tools/hrw4u/src/symbols_base.py @@ -28,8 +28,8 @@ class SymbolResolverBase: - def __init__(self, debug: bool = SystemDefaults.DEFAULT_DEBUG) -> None: - self._dbg = Dbg(debug) + def __init__(self, debug: bool = SystemDefaults.DEFAULT_DEBUG, dbg: Dbg | None = None) -> None: + self._dbg = dbg if dbg is not None else Dbg(debug) # Clear caches when debug status changes to ensure consistency if hasattr(self, '_condition_cache'): self._condition_cache.cache_clear() diff --git a/tools/hrw4u/src/visitor.py b/tools/hrw4u/src/visitor.py index e62f2220784..bd8104d2111 100644 --- a/tools/hrw4u/src/visitor.py +++ b/tools/hrw4u/src/visitor.py @@ -18,22 +18,44 @@ from __future__ import annotations import re +from contextlib import contextmanager from dataclasses import dataclass from functools import lru_cache +from pathlib import Path +from typing import Any + +from antlr4 import InputStream, CommonTokenStream +from antlr4.error.ErrorStrategy import BailErrorStrategy from hrw4u.hrw4uVisitor import hrw4uVisitor from hrw4u.hrw4uParser import hrw4uParser +from hrw4u.hrw4uLexer import hrw4uLexer from hrw4u.symbols import SymbolResolver, SymbolResolutionError -from hrw4u.errors import hrw4u_error +from hrw4u.errors import hrw4u_error, Hrw4uSyntaxError, ThrowingErrorListener from hrw4u.states import CondState, SectionType from hrw4u.common import RegexPatterns, SystemDefaults from hrw4u.visitor_base import BaseHRWVisitor from hrw4u.validation import Validator +from hrw4u.procedures import resolve_use_path -# Cache regex validator at module level for efficiency _regex_validator = Validator.regex_pattern() +@dataclass(slots=True) +class ProcParam: + name: str + default_ctx: Any # value parse tree node, or None + + +@dataclass(slots=True) +class ProcSig: + qualified_name: str + params: list[ProcParam] + body_ctx: Any # block parse tree node + source_file: str + source_text: str # full text of source file (for flatten) + + @dataclass(slots=True) class QueuedItem: text: str @@ -43,24 +65,32 @@ class QueuedItem: class HRW4UVisitor(hrw4uVisitor, BaseHRWVisitor): _SUBSTITUTE_PATTERN = RegexPatterns.SUBSTITUTE_PATTERN + _PARAM_REF_PATTERN = re.compile(r'\$([a-zA-Z_][a-zA-Z0-9_-]*)') def __init__( self, filename: str = SystemDefaults.DEFAULT_FILENAME, debug: bool = SystemDefaults.DEFAULT_DEBUG, error_collector=None, - preserve_comments: bool = True) -> None: + preserve_comments: bool = True, + proc_search_paths: list[Path] | None = None) -> None: super().__init__(filename, debug, error_collector) self._cond_state = CondState() self._queued: QueuedItem | None = None self.preserve_comments = preserve_comments - self.symbol_resolver = SymbolResolver(debug) + self.symbol_resolver = SymbolResolver(debug, dbg=self._dbg) + + self._proc_registry: dict[str, ProcSig] = {} + self._proc_loaded: set[str] = set() + self._proc_bindings: dict[str, str] = {} + self._proc_call_stack: list[str] = [] + self._proc_search_paths: list[Path] = list(proc_search_paths) if proc_search_paths else [] + self._source_text: str = "" @lru_cache(maxsize=256) def _cached_symbol_resolution(self, symbol_text: str, section_name: str) -> tuple[str, bool]: - """Cache expensive symbol resolution operations.""" try: section = SectionType(section_name) return self.symbol_resolver.resolve_condition(symbol_text, section) @@ -69,7 +99,6 @@ def _cached_symbol_resolution(self, symbol_text: str, section_name: str) -> tupl @lru_cache(maxsize=128) def _cached_hook_mapping(self, section_name: str) -> str: - """Cache hook mapping lookups.""" return self.symbol_resolver.map_hook(section_name) def _make_condition(self, cond_text: str, last: bool = False, negate: bool = False) -> str: @@ -79,17 +108,14 @@ def _make_condition(self, cond_text: str, last: bool = False, negate: bool = Fal return f"cond {cond_text}" def _queue_condition(self, text: str) -> None: - self.debug_log(f"queue cond: {text} state={self._cond_state.to_list()}") + self.debug(f"queue cond: {text} state={self._cond_state.to_list()}") self._queued = QueuedItem(text=text, state=self._cond_state.copy(), indent=self.cond_indent) self._cond_state.reset() def _flush_condition(self) -> None: - """ - Flush any queued condition to output. - """ if self._queued: mods = self._queued.state.to_list() - self.debug_log(f"flush cond: {self._queued.text} state={mods} indent={self._queued.indent}") + self.debug(f"flush cond: {self._queued.text} state={mods} indent={self._queued.indent}") mod_suffix = self._queued.state.render_suffix() self.output.append(self.format_with_indent(f"{self._queued.text}{mod_suffix}", self._queued.indent)) self._queued = None @@ -102,9 +128,6 @@ def _parse_function_call(self, ctx) -> tuple[str, list[str]]: return func, args def _parse_function_args(self, arg_str: str) -> list[str]: - """ - Parse function arguments correctly handling quotes and nested parentheses. - """ if not arg_str.strip(): return [] @@ -148,23 +171,26 @@ def _parse_function_args(self, arg_str: str) -> list[str]: return args def _substitute_strings(self, s: str, ctx) -> str: - """Optimized string substitution using string builder.""" inner = s[1:-1] + if self._proc_bindings: + inner = self._PARAM_REF_PATTERN.sub(lambda m: self._proc_bindings.get(m.group(1), m.group(0)), inner) + def repl(m: re.Match) -> str: try: + if m.group("escaped"): + return m.group("escaped") if m.group("func"): func_name = m.group("func").strip() arg_str = m.group("args").strip() args = self._parse_function_args(arg_str) if arg_str else [] replacement = self.symbol_resolver.resolve_function(func_name, args, strip_quotes=False) - self.debug_log(f"substitute: {{{func_name}({arg_str})}} -> {replacement}") + self.debug(f"substitute: {{{func_name}({arg_str})}} -> {replacement}") return replacement if m.group("var"): var_name = m.group("var").strip() - # Use resolve_condition directly to properly validate section restrictions replacement, _ = self.symbol_resolver.resolve_condition(var_name, self.current_section) - self.debug_log(f"substitute: {{{var_name}}} -> {replacement}") + self.debug(f"substitute: {{{var_name}}} -> {replacement}") return replacement raise SymbolResolutionError(m.group(0), "Unrecognized substitution format") except Exception as e: @@ -183,9 +209,6 @@ def repl(m: re.Match) -> str: return f'"{substituted}"' def _resolve_identifier_with_validation(self, name: str) -> tuple[str, bool]: - """ - Resolve an identifier with proper validation for declared variables vs system fields. - """ if not name: raise SymbolResolutionError("identifier", "Missing or empty identifier text") @@ -194,7 +217,6 @@ def _resolve_identifier_with_validation(self, name: str) -> tuple[str, bool]: symbol, default_expr = self._cached_symbol_resolution(name, self.current_section.value) - # If resolution failed (symbol == name), we need to validate if symbol == name: if '.' not in name and ':' not in name: error = SymbolResolutionError( @@ -211,17 +233,445 @@ def _resolve_identifier_with_validation(self, name: str) -> tuple[str, bool]: return symbol, default_expr + def _get_value_text(self, val_ctx) -> str: + if val_ctx.paramRef(): + name = val_ctx.paramRef().IDENT().getText() + if name not in self._proc_bindings: + raise Hrw4uSyntaxError( + self.filename, val_ctx.start.line, val_ctx.start.column, f"'${name}' used outside procedure context", "") + return self._proc_bindings[name] + return val_ctx.getText() + + def _collect_proc_params(self, param_list_ctx) -> list[ProcParam]: + return [ProcParam(name=p.IDENT().getText(), default_ctx=p.value() if p.value() else None) for p in param_list_ctx.param()] + + def _load_proc_file(self, path: Path, load_stack: list[str], use_spec: str | None = None) -> None: + """Parse a procedure file and register its declarations.""" + abs_path = str(path.resolve()) + if abs_path in self._proc_loaded: + return + if abs_path in load_stack: + cycle = ' -> '.join([*load_stack, abs_path]) + raise Hrw4uSyntaxError(str(path), 1, 0, f"circular use dependency: {cycle}", "") + + # Derive expected namespace prefix from the use spec. + # 'Apple::Common' → 'Apple::', 'Apple::Simple::All' → 'Apple::Simple::' + expected_ns = None + if use_spec and '::' in use_spec: + expected_ns = use_spec[:use_spec.rindex('::') + 2] + + text = path.read_text(encoding='utf-8') + listener = ThrowingErrorListener(filename=str(path)) + + lexer = hrw4uLexer(InputStream(text)) + lexer.removeErrorListeners() + lexer.addErrorListener(listener) + + stream = CommonTokenStream(lexer) + parser = hrw4uParser(stream) + parser.removeErrorListeners() + parser.addErrorListener(listener) + parser.errorHandler = BailErrorStrategy() + tree = parser.program() + + new_stack = [*load_stack, abs_path] + found_proc = False + + for item in tree.programItem(): + if item.useDirective(): + spec = item.useDirective().QUALIFIED_IDENT().getText() + sub_path = resolve_use_path(spec, self._proc_search_paths) + if sub_path is None: + raise Hrw4uSyntaxError( + str(path), + item.useDirective().start.line, 0, f"use '{spec}': file not found in procedures path", "") + self._load_proc_file(sub_path, new_stack, use_spec=spec) + found_proc = True + elif item.procedureDecl(): + ctx = item.procedureDecl() + name = ctx.QUALIFIED_IDENT().getText() + if '::' not in name: + raise Hrw4uSyntaxError( + str(path), ctx.start.line, ctx.start.column, f"procedure name '{name}' must be qualified (e.g. 'ns::name')", + "") + if expected_ns and not name.startswith(expected_ns): + raise Hrw4uSyntaxError( + str(path), ctx.start.line, ctx.start.column, + f"procedure '{name}' does not match namespace '{expected_ns[:-2]}' " + f"(expected from 'use {use_spec}')", "") + if name in self._proc_registry: + existing = self._proc_registry[name] + raise Hrw4uSyntaxError( + str(path), ctx.start.line, 0, f"procedure '{name}' already declared in {existing.source_file}", "") + params = self._collect_proc_params(ctx.paramList()) if ctx.paramList() else [] + self._proc_registry[name] = ProcSig(name, params, ctx.block(), str(path), text) + found_proc = True + + if not found_proc: + raise Hrw4uSyntaxError(str(path), 1, 0, f"no 'procedure' declarations found in {path.name}", "") + + self._proc_loaded.add(abs_path) + + def _visit_block_items(self, block_ctx) -> None: + """Visit a block's items at the current indent level (no extra indent added).""" + for item in block_ctx.blockItem(): + if item.statement(): + self.visit(item.statement()) + elif item.conditional(): + self.emit_statement("if") + saved = self.stmt_indent, self.cond_indent + self.stmt_indent += 1 + self.cond_indent = self.stmt_indent + self.visit(item.conditional()) + self.stmt_indent, self.cond_indent = saved + self.emit_statement("endif") + elif item.commentLine() and self.preserve_comments: + self.visit(item.commentLine()) + + def _bind_proc_args(self, sig: ProcSig, call_ctx) -> dict[str, str]: + """Resolve call arguments against a procedure signature, returning bindings.""" + call_args: list[str] = [] + if call_ctx.argumentList(): + for val_ctx in call_ctx.argumentList().value(): + text = self._get_value_text(val_ctx) + if text.startswith('"') and text.endswith('"'): + text = self._substitute_strings(text, call_ctx)[1:-1] + call_args.append(text) + + required = sum(1 for p in sig.params if p.default_ctx is None) + if not (required <= len(call_args) <= len(sig.params)): + expected = f"{required}-{len(sig.params)}" if required < len(sig.params) else str(len(sig.params)) + raise Hrw4uSyntaxError( + self.filename, call_ctx.start.line, call_ctx.start.column, + f"procedure '{sig.qualified_name}': expected {expected} arg(s), got {len(call_args)}", "") + + bindings: dict[str, str] = {} + for i, param in enumerate(sig.params): + if i < len(call_args): + bindings[param.name] = call_args[i] + else: + default = self._get_value_text(param.default_ctx) + if default.startswith('"') and default.endswith('"'): + default = default[1:-1] + bindings[param.name] = default + + return bindings + + @contextmanager + def _proc_context(self, sig: ProcSig, bindings: dict[str, str]): + """Context manager that saves/restores procedure expansion state.""" + saved_bindings = self._proc_bindings + saved_stack = self._proc_call_stack + saved_filename = self.filename + + self._proc_bindings = bindings + self._proc_call_stack = [*saved_stack, sig.qualified_name] + self.filename = sig.source_file + try: + yield + finally: + self._proc_bindings = saved_bindings + self._proc_call_stack = saved_stack + self.filename = saved_filename -# -# Visitor Methods -# + def _expand_proc_as_section_body(self, block_ctx, hook: str, in_statement_block: bool) -> bool: + """Expand a procedure body using section-body semantics (hook re-emission). + + Returns the final in_statement_block state. + """ + items = block_ctx.blockItem() + is_first = not in_statement_block + + for idx, item in enumerate(items): + is_conditional = item.conditional() is not None + is_comment = item.commentLine() is not None + proc_info = self._get_proc_call_info(item) + + if is_comment: + if self.preserve_comments: + self.visit(item.commentLine()) + elif proc_info: + _, call_ctx = proc_info + in_statement_block = self._section_expand_proc_call(call_ctx, hook, in_statement_block, is_first and idx == 0) + elif is_conditional or not in_statement_block: + if not (is_first and idx == 0): + self._flush_condition() + self.output.append("") + + self._emit_section_header(hook, []) + + if is_conditional: + self.visit(item) + in_statement_block = False + else: + in_statement_block = True + with self.stmt_indented(): + self.visit(item.statement()) + else: + with self.stmt_indented(): + self.visit(item.statement()) + + return in_statement_block + + def _section_expand_proc_call(self, call_ctx, hook: str, in_statement_block: bool, is_first_item: bool) -> bool: + """Expand a procedure call within a section body context. Returns in_statement_block.""" + name = call_ctx.funcName.text + sig = self._proc_registry[name] + bindings = self._bind_proc_args(sig, call_ctx) + + with self._proc_context(sig, bindings): + return self._expand_proc_as_section_body(sig.body_ctx, hook, in_statement_block) + + def _get_proc_call_info(self, item) -> tuple[ProcSig, Any] | None: + """If item (sectionBody or blockItem) is a procedure call, return (sig, call_ctx).""" + stmt = item.statement() + if stmt and stmt.functionCall(): + func_name = stmt.functionCall().funcName.text + sig = self._proc_registry.get(func_name) + if sig: + return sig, stmt.functionCall() + return None + + def _expand_procedure_call(self, call_ctx) -> None: + """Expand a procedure call inline at the current indent level.""" + name = call_ctx.funcName.text + sig = self._proc_registry.get(name) + + if sig is None: + raise Hrw4uSyntaxError( + self.filename, call_ctx.start.line, call_ctx.start.column, f"unknown procedure '{name}': not loaded via 'use'", "") + + if name in self._proc_call_stack: + cycle = ' -> '.join([*self._proc_call_stack, name]) + raise Hrw4uSyntaxError( + self.filename, call_ctx.start.line, call_ctx.start.column, f"circular procedure call: {cycle}", "") + + bindings = self._bind_proc_args(sig, call_ctx) + + with self._proc_context(sig, bindings): + self._visit_block_items(sig.body_ctx) + + @staticmethod + def _get_source_text(ctx, source_text: str) -> str: + """Extract original source text for a parse tree node.""" + return source_text[ctx.start.start:ctx.stop.stop + 1] + + def _flatten_substitute_params(self, text: str, bindings: dict[str, str]) -> str: + """Replace $param references in source text with bound values.""" + if not bindings: + return text + return self._PARAM_REF_PATTERN.sub(lambda m: bindings.get(m.group(1), m.group(0)), text) + + def _flatten_reindent(self, text: str, indent: str, source_indent: str | None = None) -> list[str]: + """Re-indent text: replace source indentation with target indent, preserving relative nesting. + + If source_indent is None, it is auto-detected from the first non-empty line. + """ + lines: list[str] = [] + + for line in text.splitlines(): + stripped = line.strip() + if not stripped: + lines.append("") + continue + + if source_indent is None: + source_indent = line[:len(line) - len(line.lstrip())] + + if line.startswith(source_indent): + lines.append(f"{indent}{line[len(source_indent):]}") + else: + lines.append(f"{indent}{stripped}") + + return lines + + @staticmethod + def _source_indent_at(ctx, source_text: str) -> str: + """Detect the source indentation of a parse tree node from its position in source text.""" + start = ctx.start.start + line_start = source_text.rfind('\n', 0, start) + line_start = 0 if line_start == -1 else line_start + 1 + prefix = source_text[line_start:start] + return prefix if prefix.isspace() or not prefix else "" + + def _has_proc_calls(self, ctx) -> bool: + """Check if a block or conditional contains any procedure calls (recursively).""" + if hasattr(ctx, 'blockItem'): + for item in ctx.blockItem(): + if self._get_proc_call_info(item): + return True + if item.conditional() and self._has_proc_calls(item.conditional()): + return True + return False + + if self._has_proc_calls(ctx.ifStatement().block()): + return True + for elif_ctx in ctx.elifClause(): + if self._has_proc_calls(elif_ctx.block()): + return True + if ctx.elseClause() and self._has_proc_calls(ctx.elseClause().block()): + return True + return False + + def _flatten_items(self, items, indent: str, source_text: str, bindings: dict[str, str] | None = None) -> list[str]: + """Flatten a list of sectionBody or blockItem nodes, expanding procedure calls.""" + if bindings is None: + bindings = {} + lines: list[str] = [] + + for item in items: + if item.commentLine() is not None: + if self.preserve_comments: + comment_text = self._get_source_text(item.commentLine(), source_text) + lines.extend(self._flatten_reindent(comment_text, indent)) + continue + + proc_info = self._get_proc_call_info(item) + if proc_info: + sig, call_ctx = proc_info + nested_bindings = self._bind_proc_args(sig, call_ctx) + lines.extend(self._flatten_items(sig.body_ctx.blockItem(), indent, sig.source_text, nested_bindings)) + continue + + if item.conditional() and self._has_proc_calls(item.conditional()): + lines.extend(self._flatten_conditional(item.conditional(), indent, source_text, bindings)) + continue + + item_text = self._get_source_text(item, source_text) + item_text = self._flatten_substitute_params(item_text, bindings) + source_indent = self._source_indent_at(item, source_text) + lines.extend(self._flatten_reindent(item_text, indent, source_indent)) + + return lines + + def _flatten_conditional(self, cond_ctx, indent: str, source_text: str, bindings: dict[str, str]) -> list[str]: + """Flatten a conditional block, expanding proc calls within its branches.""" + lines: list[str] = [] + inner_indent = indent + " " + + if_ctx = cond_ctx.ifStatement() + cond_text = self._get_source_text(if_ctx.condition(), source_text) + cond_text = self._flatten_substitute_params(cond_text, bindings) + lines.append(f"{indent}if {cond_text.strip()} {{") + lines.extend(self._flatten_items(if_ctx.block().blockItem(), inner_indent, source_text, bindings)) + + for elif_ctx in cond_ctx.elifClause(): + elif_cond = self._get_source_text(elif_ctx.condition(), source_text) + elif_cond = self._flatten_substitute_params(elif_cond, bindings) + lines.append(f"{indent}}} elif {elif_cond.strip()} {{") + lines.extend(self._flatten_items(elif_ctx.block().blockItem(), inner_indent, source_text, bindings)) + + if cond_ctx.elseClause(): + lines.append(f"{indent}}} else {{") + lines.extend(self._flatten_items(cond_ctx.elseClause().block().blockItem(), inner_indent, source_text, bindings)) + + lines.append(f"{indent}}}") + return lines + + def flatten(self, ctx, source_text: str = "") -> list[str]: + """Flatten procedures: expand all procedure calls inline and output self-contained HRW4U.""" + if not source_text: + source_text = ctx.start.source[1].getText(0, ctx.start.source[1].size - 1) + self._source_text = source_text + indent = " " * 4 + + # Phase 1: Load all procedures (use directives + local procedure declarations) + for item in ctx.programItem(): + if item.useDirective(): + with self.trap(item.useDirective()): + self.visitUseDirective(item.useDirective()) + elif item.procedureDecl(): + with self.trap(item.procedureDecl()): + self.visitProcedureDecl(item.procedureDecl()) + + # Phase 2: Emit flattened output + output: list[str] = [] + program_items = ctx.programItem() + + for idx, item in enumerate(program_items): + if item.useDirective() or item.procedureDecl(): + continue + + if item.commentLine() and self.preserve_comments: + comment_text = item.commentLine().COMMENT().getText() + output.append(comment_text) + continue + + if item.section(): + section_ctx = item.section() + + if section_ctx.varSection(): + var_text = self._get_source_text(section_ctx.varSection(), self._source_text) + output.append(var_text) + continue + + section_name = section_ctx.name.text + output.append(f"{section_name} {{") + + body_lines = self._flatten_items(section_ctx.sectionBody(), indent, self._source_text) + output.extend(body_lines) + output.append("}") + + remaining = program_items[idx + 1:] + if any(r.section() for r in remaining): + output.append("") + + return output + + def visitUseDirective(self, ctx) -> None: + spec = ctx.QUALIFIED_IDENT().getText() + if not self._proc_search_paths: + raise Hrw4uSyntaxError( + self.filename, ctx.start.line, ctx.start.column, "use directive requires --procedures-path to be set", "") + path = resolve_use_path(spec, self._proc_search_paths) + if path is None: + raise Hrw4uSyntaxError( + self.filename, ctx.start.line, ctx.start.column, f"use '{spec}': file not found in procedures path", "") + self._load_proc_file(path, [], use_spec=spec) + + def visitProcedureDecl(self, ctx) -> None: + name = ctx.QUALIFIED_IDENT().getText() + if '::' not in name: + raise Hrw4uSyntaxError( + self.filename, ctx.start.line, ctx.start.column, f"procedure name '{name}' must be qualified (e.g. 'ns::name')", "") + if name in self._proc_registry: + existing = self._proc_registry[name] + raise Hrw4uSyntaxError( + self.filename, ctx.start.line, ctx.start.column, f"procedure '{name}' already declared in {existing.source_file}", + "") + params = self._collect_proc_params(ctx.paramList()) if ctx.paramList() else [] + self._proc_registry[name] = ProcSig(name, params, ctx.block(), self.filename, self._source_text) def visitProgram(self, ctx) -> list[str]: with self.debug_context("visitProgram"): + seen_sections = False program_items = ctx.programItem() for idx, item in enumerate(program_items): start_length = len(self.output) - if item.section(): + if item.useDirective(): + if seen_sections: + error = hrw4u_error( + self.filename, item.useDirective(), + ValueError("'use' directives must appear before any section blocks")) + if self.error_collector: + self.error_collector.add_error(error) + continue + raise error + with self.trap(item.useDirective()): + self.visitUseDirective(item.useDirective()) + elif item.procedureDecl(): + if seen_sections: + error = hrw4u_error( + self.filename, item.procedureDecl(), + ValueError("'procedure' declarations must appear before any section blocks")) + if self.error_collector: + self.error_collector.add_error(error) + continue + raise error + with self.trap(item.procedureDecl()): + self.visitProcedureDecl(item.procedureDecl()) + elif item.section(): + seen_sections = True self.visit(item.section()) if idx < len(program_items) - 1 and len(self.output) > start_length: next_items = program_items[idx + 1:] @@ -257,7 +707,7 @@ def _prepare_section(self, ctx): raise ValueError(f"Invalid section name: '{section_name}'. Valid sections: {', '.join(valid_sections)}") hook = self._cached_hook_mapping(section_name) - self.debug_log(f"`{section_name}' -> `{hook}'") + self.debug(f"`{section_name}' -> `{hook}'") return hook def _emit_section_header(self, hook, pending_comments): @@ -275,6 +725,7 @@ def _emit_section_body(self, section_bodies, hook): for idx, body in enumerate(section_bodies): is_conditional = body.conditional() is not None is_comment = body.commentLine() is not None + proc_info = self._get_proc_call_info(body) if is_comment: if self.preserve_comments: @@ -282,6 +733,14 @@ def _emit_section_body(self, section_bodies, hook): pending_leading_comments.append(body) else: self.visit(body) + elif proc_info: + if not first_hook_emitted: + first_hook_emitted = True + for comment in pending_leading_comments: + self.visit(comment) + pending_leading_comments = [] + _, call_ctx = proc_info + in_statement_block = self._section_expand_proc_call(call_ctx, hook, in_statement_block, idx == 0) elif is_conditional or not in_statement_block: if first_hook_emitted: self._flush_condition() @@ -305,7 +764,6 @@ def _emit_section_body(self, section_bodies, hook): with self.stmt_indented(): self.visit(body) - # Handle case where section has only comments if not first_hook_emitted and pending_leading_comments: self._emit_section_header(hook, pending_leading_comments) @@ -325,7 +783,7 @@ def visitCommentLine(self, ctx) -> None: return with self.debug_context("visitCommentLine"): comment_text = ctx.COMMENT().getText() - self.debug_log(f"preserving comment: {comment_text}") + self.debug(f"preserving comment: {comment_text}") self.output.append(comment_text) def visitStatement(self, ctx) -> None: @@ -338,6 +796,14 @@ def visitStatement(self, ctx) -> None: case _ if ctx.functionCall(): func, args = self._parse_function_call(ctx.functionCall()) + if func in self._proc_registry: + self._expand_procedure_call(ctx.functionCall()) + return + if '::' in func: + raise Hrw4uSyntaxError( + self.filename, + ctx.functionCall().start.line, + ctx.functionCall().start.column, f"unknown procedure '{func}': not loaded via 'use'", "") subst_args = [ self._substitute_strings(arg, ctx) if arg.startswith('"') and arg.endswith('"') else arg for arg in args ] @@ -349,7 +815,7 @@ def visitStatement(self, ctx) -> None: if ctx.lhs is None: raise SymbolResolutionError("assignment", "Missing left-hand side in assignment") lhs = ctx.lhs.text - rhs = ctx.value().getText() + rhs = self._get_value_text(ctx.value()) if rhs.startswith('"') and rhs.endswith('"'): rhs = self._substitute_strings(rhs, ctx) self._dbg(f"assignment: {lhs} = {rhs}") @@ -361,7 +827,7 @@ def visitStatement(self, ctx) -> None: if ctx.lhs is None: raise SymbolResolutionError("assignment", "Missing left-hand side in += assignment") lhs = ctx.lhs.text - rhs = ctx.value().getText() + rhs = self._get_value_text(ctx.value()) if rhs.startswith('"') and rhs.endswith('"'): rhs = self._substitute_strings(rhs, ctx) self._dbg(f"add assignment: {lhs} += {rhs}") @@ -447,20 +913,7 @@ def visitElifClause(self, ctx) -> None: def visitBlock(self, ctx) -> None: with self.debug_context("visitBlock"): with self.stmt_indented(): - for item in ctx.blockItem(): - if item.statement(): - self.visit(item.statement()) - elif item.conditional(): - # Nested conditional - emit if/endif operators with saved state - self.emit_statement("if") - saved_indents = self.stmt_indent, self.cond_indent - self.stmt_indent += 1 - self.cond_indent = self.stmt_indent - self.visit(item.conditional()) - self.stmt_indent, self.cond_indent = saved_indents - self.emit_statement("endif") - elif item.commentLine() and self.preserve_comments: - self.visit(item.commentLine()) + self._visit_block_items(ctx) def visitCondition(self, ctx) -> None: with self.debug_context("visitCondition"): @@ -490,7 +943,7 @@ def visitComparison(self, ctx, *, last: bool = False) -> None: match ctx: case _ if ctx.value(): - rhs = ctx.value().getText() + rhs = self._get_value_text(ctx.value()) if rhs.startswith('"') and rhs.endswith('"'): rhs = self._substitute_strings(rhs, ctx) match operator.symbol.type: @@ -506,16 +959,14 @@ def visitComparison(self, ctx, *, last: bool = False) -> None: except Exception as e: with self.trap(ctx.regex()): raise e - regex_expr = "/.*/'" # return "ERROR" is for error_collector case only + regex_expr = "/.*/'" cond_txt = f"{lhs} {regex_expr}" - # IP Ranges are a bit special, we keep the {} verbatim and no quotes allowed case _ if ctx.iprange(): cond_txt = f"{lhs} {ctx.iprange().getText()}" case _ if ctx.set_(): inner = ctx.set_().getText()[1:-1] - # We no longer strip the quotes here for sets, fixed in #12256 cond_txt = f"{lhs} ({inner})" case _: @@ -553,18 +1004,13 @@ def emit_condition(self, text: str, *, final: bool = False) -> None: self._queue_condition(text) def emit_separator(self) -> None: - """Emit a blank line separator.""" self.output.append("") def emit_statement(self, line: str) -> None: - """Override base class method to handle condition flushing.""" self._flush_condition() super().emit_statement(line) def _end_lhs_then_emit_rhs(self, set_and_or: bool, rhs_emitter) -> None: - """ - Helper for expression emission: update queued state, flush, then emit RHS. - """ if self._queued: self._queued.state.and_or = set_and_or if not set_and_or: @@ -575,9 +1021,9 @@ def _end_lhs_then_emit_rhs(self, set_and_or: bool, rhs_emitter) -> None: def emit_expression(self, ctx, *, nested: bool = False, last: bool = False, grouped: bool = False) -> None: with self.debug_context("emit_expression"): if ctx.OR(): - self.debug_log("`OR' detected") + self.debug("`OR' detected") if grouped: - self.debug_log("GROUP-START") + self.debug("GROUP-START") self.emit_condition("cond %{GROUP}", final=True) with self.cond_indented(): self.emit_expression(ctx.expression(), nested=False, last=False) @@ -592,7 +1038,7 @@ def emit_expression(self, ctx, *, nested: bool = False, last: bool = False, grou def emit_term(self, ctx, *, last: bool = False) -> None: with self.debug_context("emit_term"): if ctx.AND(): - self.debug_log("`AND' detected") + self.debug("`AND' detected") self.emit_term(ctx.term(), last=False) self._end_lhs_then_emit_rhs(False, lambda: self.emit_factor(ctx.factor(), last=last)) else: @@ -653,7 +1099,7 @@ def emit_factor(self, ctx, *, last: bool = False) -> None: cond_txt = symbol negate = self._cond_state.not_ - cond_txt = self._normalize_empty_string_condition(cond_txt, self._cond_state) + cond_txt = self._normalize_empty_string_condition(cond_txt) cond_txt = self._apply_with_modifiers(cond_txt, self._cond_state) self._cond_state.not_ = False diff --git a/tools/hrw4u/src/visitor_base.py b/tools/hrw4u/src/visitor_base.py index e7d1e88d2e0..531f98ab5f8 100644 --- a/tools/hrw4u/src/visitor_base.py +++ b/tools/hrw4u/src/visitor_base.py @@ -19,53 +19,15 @@ from contextlib import contextmanager from dataclasses import dataclass from typing import Any + from hrw4u.debugging import Dbg from hrw4u.states import SectionType from hrw4u.common import SystemDefaults from hrw4u.errors import hrw4u_error -class VisitorMixin: - - def format_with_indent(self, text: str, indent_level: int) -> str: - """Format text with proper indentation""" - return " " * (indent_level * SystemDefaults.INDENT_SPACES) + text - - -class ErrorHandler: - - @staticmethod - def handle_visitor_error(filename: str, ctx: object, exc: Exception, error_collector=None, return_value: str = "") -> str: - """Standard error handling for visitor methods.""" - from hrw4u.errors import hrw4u_error - - error = hrw4u_error(filename, ctx, exc) - if error_collector: - error_collector.add_error(error) - return return_value - else: - raise error - - @staticmethod - def handle_symbol_error(filename: str, ctx: object, symbol_name: str, exc: Exception, error_collector=None) -> str | None: - """Handle symbol resolution errors with context.""" - from hrw4u.errors import hrw4u_error, SymbolResolutionError - - if isinstance(exc, SymbolResolutionError): - error = hrw4u_error(filename, ctx, exc) - else: - error = hrw4u_error(filename, ctx, f"symbol error in '{symbol_name}': {exc}") - - if error_collector: - error_collector.add_error(error) - return f"ERROR({symbol_name})" - else: - raise error - - @dataclass(slots=True) class VisitorState: - """Encapsulates visitor state that is commonly tracked across visitor implementations.""" stmt_indent: int = 0 cond_indent: int = 0 in_if_block: bool = False @@ -74,7 +36,7 @@ class VisitorState: current_section: SectionType | None = None -class BaseHRWVisitor(VisitorMixin): +class BaseHRWVisitor: def __init__( self, @@ -90,121 +52,72 @@ def __init__( self._initialize_visitor() def _initialize_visitor(self) -> None: - """Hook for subclass-specific initialization. Override as needed.""" pass - # Error handling patterns - centralized and consistent - def handle_visitor_error(self, ctx: Any, exc: Exception, return_value: str = "") -> str: - """ - Standard error handling for visitor methods. - """ - return ErrorHandler.handle_visitor_error(self.filename, ctx, exc, self.error_collector, return_value) - - def handle_symbol_error(self, ctx: Any, symbol_name: str, exc: Exception) -> str | None: - """ - Handle symbol resolution errors with context. - """ - return ErrorHandler.handle_symbol_error(self.filename, ctx, symbol_name, exc, self.error_collector) - - def safe_visit_with_error_handling(self, method_name: str, ctx: Any, visit_func, *args, **kwargs): - """ - Generic wrapper for visitor methods with consistent error handling. - """ - try: - self._dbg.enter(method_name) - return visit_func(*args, **kwargs) - except Exception as exc: - return self.handle_visitor_error(ctx, exc) - finally: - self._dbg.exit(method_name) + def format_with_indent(self, text: str, indent_level: int) -> str: + return " " * (indent_level * SystemDefaults.INDENT_SPACES) + text - # State management - common patterns @property def current_section(self) -> SectionType | None: - """Get current section being processed.""" return self._state.current_section @current_section.setter def current_section(self, section: SectionType | None) -> None: - """Set current section being processed.""" self._state.current_section = section @property def stmt_indent(self) -> int: - """Get current statement indentation level.""" return self._state.stmt_indent @stmt_indent.setter def stmt_indent(self, level: int) -> None: - """Set current statement indentation level.""" self._state.stmt_indent = level @property def cond_indent(self) -> int: - """Get current condition indentation level.""" return self._state.cond_indent @cond_indent.setter def cond_indent(self, level: int) -> None: - """Set current condition indentation level.""" self._state.cond_indent = level @property - def _stmt_indent(self) -> int: - """Backward compatibility property - use stmt_indent instead.""" - return self._state.stmt_indent - - @_stmt_indent.setter - def _stmt_indent(self, level: int) -> None: - """Backward compatibility property - use stmt_indent instead.""" - self._state.stmt_indent = level - - @property - def _cond_indent(self) -> int: - """Backward compatibility property - use cond_indent instead.""" - return self._state.cond_indent - - @_cond_indent.setter - def _cond_indent(self, level: int) -> None: - """Backward compatibility property - use cond_indent instead.""" - self._state.cond_indent = level + def current_indent(self) -> int: + return self.stmt_indent def increment_stmt_indent(self) -> None: - """Increment statement indentation level.""" self._state.stmt_indent += 1 def decrement_stmt_indent(self) -> None: - """Decrement statement indentation level (with bounds checking).""" self._state.stmt_indent = max(0, self._state.stmt_indent - 1) def increment_cond_indent(self) -> None: - """Increment condition indentation level.""" self._state.cond_indent += 1 def decrement_cond_indent(self) -> None: - """Decrement condition indentation level (with bounds checking).""" self._state.cond_indent = max(0, self._state.cond_indent - 1) - # Output management patterns def emit_line(self, text: str, indent_level: int | None = None) -> None: - """ - Emit a line of output with proper indentation. - """ if indent_level is None: indent_level = self._state.stmt_indent - formatted_line = self.format_with_indent(text, indent_level) - self.output.append(formatted_line) + self.output.append(self.format_with_indent(text, indent_level)) def emit_statement(self, statement: str) -> None: - """Emit a statement with current statement indentation.""" self.emit_line(statement, self._state.stmt_indent) def emit_condition(self, condition: str) -> None: - """Emit a condition with current condition indentation.""" self.emit_line(condition, self._state.cond_indent) + def emit(self, text: str) -> None: + self.output.append(self.format_with_indent(text, self.current_indent)) + + def increase_indent(self) -> None: + self.increment_stmt_indent() + + def decrease_indent(self) -> None: + self.decrement_stmt_indent() + def debug_enter(self, method_name: str, *args: Any) -> None: - """Standard debug entry pattern with argument logging.""" if args: arg_strs = [str(arg) for arg in args] self._dbg.enter(f"{method_name}: {', '.join(arg_strs)}") @@ -212,44 +125,19 @@ def debug_enter(self, method_name: str, *args: Any) -> None: self._dbg.enter(method_name) def debug_exit(self, method_name: str, result: Any = None) -> None: - """Standard debug exit pattern with result logging.""" if result is not None: self._dbg.exit(f"{method_name} => {result}") else: self._dbg.exit(method_name) - def debug_log(self, message: str) -> None: - """Standard debug message logging.""" + def debug(self, message: str) -> None: self._dbg(message) @property def is_debug(self) -> bool: - """Check if debug mode is enabled.""" return self._dbg.enabled - # Context managers for common patterns - def indent_context(self, increment: int = 1): - """Context manager for temporary indentation changes.""" - - class IndentContext: - - def __init__(self, visitor: BaseHRWVisitor, inc: int): - self.visitor = visitor - self.increment = inc - - def __enter__(self): - for _ in range(self.increment): - self.visitor.increment_stmt_indent() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - for _ in range(self.increment): - self.visitor.decrement_stmt_indent() - - return IndentContext(self, increment) - def debug_context(self, method_name: str, *args: Any): - """Context manager for debug entry/exit around operations.""" class DebugContext: @@ -271,9 +159,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): return DebugContext(self, method_name, args) def trap(self, ctx, *, note: str | None = None): - """ - Context manager for consistent error handling across visitor methods. - """ class _Trap: @@ -298,121 +183,58 @@ def __exit__(_, exc_type, exc, tb): @contextmanager def stmt_indented(self): - """Context manager for statement indentation - ensures proper cleanup on exceptions.""" - self._stmt_indent += 1 + self.stmt_indent += 1 try: yield finally: - self._stmt_indent -= 1 + self.stmt_indent -= 1 @contextmanager def cond_indented(self): - """Context manager for condition indentation - ensures proper cleanup on exceptions.""" - self._cond_indent += 1 + self.cond_indent += 1 try: yield finally: - self._cond_indent -= 1 + self.cond_indent -= 1 def get_final_output(self) -> list[str]: - """ - Get the final output after processing. - """ self._finalize_output() return self.output def _finalize_output(self) -> None: - """ - Finalize output processing. Override in subclasses for specific behavior. - """ pass - # Utility methods for common visitor operations - def should_continue_on_error(self) -> bool: - """Determine if processing should continue when errors occur.""" - return self.error_collector is not None - - def get_error_summary(self) -> str | None: - """Get error summary if errors were collected.""" - if self.error_collector and self.error_collector.has_errors(): - return self.error_collector.get_error_summary() - return None - - def reset_state(self) -> None: - """Reset visitor state for reuse.""" - self.output.clear() - self._state = VisitorState() - - def debug(self, message: str) -> None: - """Alias for debug_log for backward compatibility.""" - self.debug_log(message) - - def emit(self, text: str) -> None: - """Emit a line of output.""" - self.output.append(self.format_with_indent(text, self.current_indent)) - - def increase_indent(self) -> None: - """Increase indentation level.""" - self.increment_stmt_indent() - - def decrease_indent(self) -> None: - """Decrease indentation level.""" - self.decrement_stmt_indent() - - @property - def current_indent(self) -> int: - """Get current indentation level.""" - return self.stmt_indent - def handle_error(self, exc: Exception) -> None: - """Handle error without context.""" if self.error_collector: self.error_collector.add_error(exc) else: raise exc def _apply_with_modifiers(self, expr: str, state) -> str: - """Apply with modifiers to expression - shared logic for both visitors.""" if hasattr(state, 'to_with_modifiers'): with_mods = state.to_with_modifiers() return f"{expr} with {','.join(with_mods)}" if with_mods else expr return expr - def _normalize_empty_string_condition(self, term: str, state) -> str: - """Normalize empty string conditions - shared logic for both visitors.""" - if hasattr(state, 'not_') and state.not_: - # Apply negation first, then check for empty string patterns - if term.endswith(' != ""'): - return term.replace(' != ""', '') - elif term.endswith(' == ""'): - return f"!{term.replace(' == \"\"', '')}" - else: - return term - else: - # No negation, handle empty string patterns normally - if term.endswith(' != ""'): - return term.replace(' != ""', '') - elif term.endswith(' == ""'): - return f"!{term.replace(' == \"\"', '')}" - else: - return term + def _normalize_empty_string_condition(self, term: str) -> str: + if term.endswith(' != ""'): + return term.replace(' != ""', '') + elif term.endswith(' == ""'): + return f"!{term.replace(' == \"\"', '')}" + return term def _build_condition_connector(self, state, is_last_term: bool = False) -> str: - """Build connector for condition terms - shared logic.""" if hasattr(state, 'and_or') and state.and_or and not is_last_term: return "||" - else: - return "&&" + return "&&" def _reconstruct_redirect_args(self, args: list[str]) -> list[str]: - """Reconstruct set-redirect arguments by merging URL parts - shared logic for both visitors.""" if len(args) <= 1: return args url = "".join(a[1:-1] if a.startswith('"') and a.endswith('"') else a for a in args[1:]) return [args[0], url] def _parse_op_tails(self, node, ctx=None) -> tuple[list[str], object, object]: - """Parse operation tails - shared logic for modifier and argument capture.""" from hrw4u.states import CondState, OperatorState, ModifierType args: list[str] = [] diff --git a/tools/hrw4u/tests/data/procedures/basic-call.ast.txt b/tools/hrw4u/tests/data/procedures/basic-call.ast.txt new file mode 100644 index 00000000000..c692933e977 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/basic-call.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::add-debug-header)) (programItem (section REMAP { (sectionBody (statement (functionCall test::add-debug-header ( (argumentList (value "my-tag")) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/basic-call.input.txt b/tools/hrw4u/tests/data/procedures/basic-call.input.txt new file mode 100644 index 00000000000..902a2b6aaca --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/basic-call.input.txt @@ -0,0 +1,5 @@ +use test::add-debug-header + +REMAP { + test::add-debug-header("my-tag"); +} diff --git a/tools/hrw4u/tests/data/procedures/basic-call.output.txt b/tools/hrw4u/tests/data/procedures/basic-call.output.txt new file mode 100644 index 00000000000..513702e9f96 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/basic-call.output.txt @@ -0,0 +1,2 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Debug "my-tag" diff --git a/tools/hrw4u/tests/data/procedures/circular-use.fail.error.txt b/tools/hrw4u/tests/data/procedures/circular-use.fail.error.txt new file mode 100644 index 00000000000..5e5ab489ae0 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/circular-use.fail.error.txt @@ -0,0 +1 @@ +circular use dependency \ No newline at end of file diff --git a/tools/hrw4u/tests/data/procedures/circular-use.fail.input.txt b/tools/hrw4u/tests/data/procedures/circular-use.fail.input.txt new file mode 100644 index 00000000000..1c9be4a5660 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/circular-use.fail.input.txt @@ -0,0 +1,5 @@ +use circular::A + +REMAP { + circular::A(); +} diff --git a/tools/hrw4u/tests/data/procedures/default-param.ast.txt b/tools/hrw4u/tests/data/procedures/default-param.ast.txt new file mode 100644 index 00000000000..3e1b9fab9c6 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/default-param.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::set-cache)) (programItem (section REMAP { (sectionBody (statement (functionCall test::set-cache ( )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/default-param.input.txt b/tools/hrw4u/tests/data/procedures/default-param.input.txt new file mode 100644 index 00000000000..36c64aaa6b7 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/default-param.input.txt @@ -0,0 +1,5 @@ +use test::set-cache + +REMAP { + test::set-cache(); +} diff --git a/tools/hrw4u/tests/data/procedures/default-param.output.txt b/tools/hrw4u/tests/data/procedures/default-param.output.txt new file mode 100644 index 00000000000..6b49d21fa3a --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/default-param.output.txt @@ -0,0 +1,2 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Cache-TTL 300 diff --git a/tools/hrw4u/tests/data/procedures/duplicate-proc.fail.error.txt b/tools/hrw4u/tests/data/procedures/duplicate-proc.fail.error.txt new file mode 100644 index 00000000000..2ca96b61e0a --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/duplicate-proc.fail.error.txt @@ -0,0 +1 @@ +procedure 'local::dup' already declared diff --git a/tools/hrw4u/tests/data/procedures/duplicate-proc.fail.input.txt b/tools/hrw4u/tests/data/procedures/duplicate-proc.fail.input.txt new file mode 100644 index 00000000000..67150cdadfc --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/duplicate-proc.fail.input.txt @@ -0,0 +1,11 @@ +procedure local::dup() { + inbound.req.X-A = "1"; +} + +procedure local::dup() { + inbound.req.X-B = "2"; +} + +REMAP { + local::dup(); +} diff --git a/tools/hrw4u/tests/data/procedures/elif-in-proc.ast.txt b/tools/hrw4u/tests/data/procedures/elif-in-proc.ast.txt new file mode 100644 index 00000000000..12ddaf77a85 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/elif-in-proc.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::classify-request)) (programItem (section REMAP { (sectionBody (statement (functionCall test::classify-request ( (argumentList (value "remap-classifier")) )) ;)) (sectionBody (conditional (ifStatement if (condition (expression (term (factor (comparison (comparable inbound.req.X-Special) == (value "yes")))))) (block { (blockItem (statement (functionCall test::classify-request ( (argumentList (value "special")) )) ;)) })))) })) ) diff --git a/tools/hrw4u/tests/data/procedures/elif-in-proc.flatten.txt b/tools/hrw4u/tests/data/procedures/elif-in-proc.flatten.txt new file mode 100644 index 00000000000..99b9b3e19c4 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/elif-in-proc.flatten.txt @@ -0,0 +1,20 @@ +REMAP { + if inbound.req.X-Priority == "high" { + inbound.req.X-Class = "urgent"; + } elif inbound.req.X-Priority == "medium" { + inbound.req.X-Class = "normal"; + } else { + inbound.req.X-Class = "low"; + } + inbound.req.X-Classified-By = "remap-classifier"; + if inbound.req.X-Special == "yes" { + if inbound.req.X-Priority == "high" { + inbound.req.X-Class = "urgent"; + } elif inbound.req.X-Priority == "medium" { + inbound.req.X-Class = "normal"; + } else { + inbound.req.X-Class = "low"; + } + inbound.req.X-Classified-By = "special"; + } +} diff --git a/tools/hrw4u/tests/data/procedures/elif-in-proc.input.txt b/tools/hrw4u/tests/data/procedures/elif-in-proc.input.txt new file mode 100644 index 00000000000..5ff8c71a4c7 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/elif-in-proc.input.txt @@ -0,0 +1,8 @@ +use test::classify-request + +REMAP { + test::classify-request("remap-classifier"); + if inbound.req.X-Special == "yes" { + test::classify-request("special"); + } +} diff --git a/tools/hrw4u/tests/data/procedures/elif-in-proc.output.txt b/tools/hrw4u/tests/data/procedures/elif-in-proc.output.txt new file mode 100644 index 00000000000..43b2f4c0b45 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/elif-in-proc.output.txt @@ -0,0 +1,24 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] +cond %{CLIENT-HEADER:X-Priority} ="high" + set-header X-Class "urgent" +elif + cond %{CLIENT-HEADER:X-Priority} ="medium" + set-header X-Class "normal" +else + set-header X-Class "low" + +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Classified-By "remap-classifier" + +cond %{REMAP_PSEUDO_HOOK} [AND] +cond %{CLIENT-HEADER:X-Special} ="yes" + if + cond %{CLIENT-HEADER:X-Priority} ="high" + set-header X-Class "urgent" + elif + cond %{CLIENT-HEADER:X-Priority} ="medium" + set-header X-Class "normal" + else + set-header X-Class "low" + endif + set-header X-Classified-By "special" diff --git a/tools/hrw4u/tests/data/procedures/in-conditional.ast.txt b/tools/hrw4u/tests/data/procedures/in-conditional.ast.txt new file mode 100644 index 00000000000..bcbc68f08e7 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/in-conditional.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::add-debug-header)) (programItem (section REMAP { (sectionBody (conditional (ifStatement if (condition (expression (term (factor (comparison (comparable inbound.req.X-Test) == (value "yes")))))) (block { (blockItem (statement (functionCall test::add-debug-header ( (argumentList (value "matched")) )) ;)) })))) })) ) diff --git a/tools/hrw4u/tests/data/procedures/in-conditional.flatten.txt b/tools/hrw4u/tests/data/procedures/in-conditional.flatten.txt new file mode 100644 index 00000000000..67501794a88 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/in-conditional.flatten.txt @@ -0,0 +1,5 @@ +REMAP { + if inbound.req.X-Test == "yes" { + inbound.req.X-Debug = "matched"; + } +} diff --git a/tools/hrw4u/tests/data/procedures/in-conditional.input.txt b/tools/hrw4u/tests/data/procedures/in-conditional.input.txt new file mode 100644 index 00000000000..fca99880884 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/in-conditional.input.txt @@ -0,0 +1,7 @@ +use test::add-debug-header + +REMAP { + if inbound.req.X-Test == "yes" { + test::add-debug-header("matched"); + } +} diff --git a/tools/hrw4u/tests/data/procedures/in-conditional.output.txt b/tools/hrw4u/tests/data/procedures/in-conditional.output.txt new file mode 100644 index 00000000000..ea5cf1085bd --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/in-conditional.output.txt @@ -0,0 +1,3 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] +cond %{CLIENT-HEADER:X-Test} ="yes" + set-header X-Debug "matched" diff --git a/tools/hrw4u/tests/data/procedures/local-and-use.ast.txt b/tools/hrw4u/tests/data/procedures/local-and-use.ast.txt new file mode 100644 index 00000000000..3ffe926fdad --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-and-use.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::add-debug-header)) (programItem (procedureDecl procedure local::mine ( ) (block { (blockItem (statement inbound.req.X-Mine = (value "local") ;)) }))) (programItem (section REMAP { (sectionBody (statement (functionCall test::add-debug-header ( (argumentList (value "ext")) )) ;)) (sectionBody (statement (functionCall local::mine ( )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/local-and-use.input.txt b/tools/hrw4u/tests/data/procedures/local-and-use.input.txt new file mode 100644 index 00000000000..909d77c5c28 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-and-use.input.txt @@ -0,0 +1,10 @@ +use test::add-debug-header + +procedure local::mine() { + inbound.req.X-Mine = "local"; +} + +REMAP { + test::add-debug-header("ext"); + local::mine(); +} diff --git a/tools/hrw4u/tests/data/procedures/local-and-use.output.txt b/tools/hrw4u/tests/data/procedures/local-and-use.output.txt new file mode 100644 index 00000000000..a92a616e568 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-and-use.output.txt @@ -0,0 +1,3 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Debug "ext" + set-header X-Mine "local" diff --git a/tools/hrw4u/tests/data/procedures/local-mixed-body.ast.txt b/tools/hrw4u/tests/data/procedures/local-mixed-body.ast.txt new file mode 100644 index 00000000000..8014e6906df --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-mixed-body.ast.txt @@ -0,0 +1 @@ +(program (programItem (procedureDecl procedure local::mixed ( ) (block { (blockItem (conditional (ifStatement if (condition (expression (term (factor (comparison (comparable inbound.req.X-Env) == (value "prod")))))) (block { (blockItem (statement inbound.req.X-Prod = (value "yes") ;)) })))) (blockItem (statement inbound.req.X-Always = (value "set") ;)) }))) (programItem (section REMAP { (sectionBody (statement (functionCall local::mixed ( )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/local-mixed-body.flatten.txt b/tools/hrw4u/tests/data/procedures/local-mixed-body.flatten.txt new file mode 100644 index 00000000000..71340b5e654 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-mixed-body.flatten.txt @@ -0,0 +1,6 @@ +REMAP { + if inbound.req.X-Env == "prod" { + inbound.req.X-Prod = "yes"; + } + inbound.req.X-Always = "set"; +} diff --git a/tools/hrw4u/tests/data/procedures/local-mixed-body.input.txt b/tools/hrw4u/tests/data/procedures/local-mixed-body.input.txt new file mode 100644 index 00000000000..67631621e53 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-mixed-body.input.txt @@ -0,0 +1,11 @@ +procedure local::mixed() { + if inbound.req.X-Env == "prod" { + inbound.req.X-Prod = "yes"; + } + + inbound.req.X-Always = "set"; +} + +REMAP { + local::mixed(); +} diff --git a/tools/hrw4u/tests/data/procedures/local-mixed-body.output.txt b/tools/hrw4u/tests/data/procedures/local-mixed-body.output.txt new file mode 100644 index 00000000000..0937c004f56 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-mixed-body.output.txt @@ -0,0 +1,6 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] +cond %{CLIENT-HEADER:X-Env} ="prod" + set-header X-Prod "yes" + +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Always "set" diff --git a/tools/hrw4u/tests/data/procedures/local-multi-section.ast.txt b/tools/hrw4u/tests/data/procedures/local-multi-section.ast.txt new file mode 100644 index 00000000000..98d502bca36 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-multi-section.ast.txt @@ -0,0 +1 @@ +(program (programItem (procedureDecl procedure local::stamp ( ) (block { (blockItem (statement inbound.req.X-Stamp = (value "yes") ;)) }))) (programItem (section REMAP { (sectionBody (statement (functionCall local::stamp ( )) ;)) })) (programItem (section SEND_REQUEST { (sectionBody (statement (functionCall local::stamp ( )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/local-multi-section.input.txt b/tools/hrw4u/tests/data/procedures/local-multi-section.input.txt new file mode 100644 index 00000000000..897bf86c728 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-multi-section.input.txt @@ -0,0 +1,11 @@ +procedure local::stamp() { + inbound.req.X-Stamp = "yes"; +} + +REMAP { + local::stamp(); +} + +SEND_REQUEST { + local::stamp(); +} diff --git a/tools/hrw4u/tests/data/procedures/local-multi-section.output.txt b/tools/hrw4u/tests/data/procedures/local-multi-section.output.txt new file mode 100644 index 00000000000..7ee1107f9a5 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-multi-section.output.txt @@ -0,0 +1,5 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Stamp "yes" + +cond %{SEND_REQUEST_HDR_HOOK} [AND] + set-header X-Stamp "yes" diff --git a/tools/hrw4u/tests/data/procedures/local-proc.ast.txt b/tools/hrw4u/tests/data/procedures/local-proc.ast.txt new file mode 100644 index 00000000000..044fbc75658 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-proc.ast.txt @@ -0,0 +1 @@ +(program (programItem (procedureDecl procedure local::stamp ( ) (block { (blockItem (statement inbound.req.X-Stamp = (value "tagged") ;)) }))) (programItem (section REMAP { (sectionBody (statement (functionCall local::stamp ( )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/local-proc.input.txt b/tools/hrw4u/tests/data/procedures/local-proc.input.txt new file mode 100644 index 00000000000..9465bcbecd0 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-proc.input.txt @@ -0,0 +1,7 @@ +procedure local::stamp() { + inbound.req.X-Stamp = "tagged"; +} + +REMAP { + local::stamp(); +} diff --git a/tools/hrw4u/tests/data/procedures/local-proc.output.txt b/tools/hrw4u/tests/data/procedures/local-proc.output.txt new file mode 100644 index 00000000000..b23cdab0b7f --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-proc.output.txt @@ -0,0 +1,2 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Stamp "tagged" diff --git a/tools/hrw4u/tests/data/procedures/local-with-params.ast.txt b/tools/hrw4u/tests/data/procedures/local-with-params.ast.txt new file mode 100644 index 00000000000..7da878775f2 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-with-params.ast.txt @@ -0,0 +1 @@ +(program (programItem (procedureDecl procedure local::tag ( (paramList (param $ name)) ) (block { (blockItem (statement inbound.req.X-Tag = (value "$name") ;)) }))) (programItem (section REMAP { (sectionBody (statement (functionCall local::tag ( (argumentList (value "hello")) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/local-with-params.input.txt b/tools/hrw4u/tests/data/procedures/local-with-params.input.txt new file mode 100644 index 00000000000..909e004415e --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-with-params.input.txt @@ -0,0 +1,7 @@ +procedure local::tag($name) { + inbound.req.X-Tag = "$name"; +} + +REMAP { + local::tag("hello"); +} diff --git a/tools/hrw4u/tests/data/procedures/local-with-params.output.txt b/tools/hrw4u/tests/data/procedures/local-with-params.output.txt new file mode 100644 index 00000000000..7c52b2212a3 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/local-with-params.output.txt @@ -0,0 +1,2 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Tag "hello" diff --git a/tools/hrw4u/tests/data/procedures/mixed-body.ast.txt b/tools/hrw4u/tests/data/procedures/mixed-body.ast.txt new file mode 100644 index 00000000000..b53b86894f3 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/mixed-body.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::mixed-body)) (programItem (section REMAP { (sectionBody (statement (functionCall test::mixed-body ( )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/mixed-body.flatten.txt b/tools/hrw4u/tests/data/procedures/mixed-body.flatten.txt new file mode 100644 index 00000000000..4101dc98f7b --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/mixed-body.flatten.txt @@ -0,0 +1,6 @@ +REMAP { + if inbound.req.X-Foo == "bar" { + inbound.req.X-Bar = "yes"; + } + inbound.req.X-Baloo = "hello"; +} diff --git a/tools/hrw4u/tests/data/procedures/mixed-body.input.txt b/tools/hrw4u/tests/data/procedures/mixed-body.input.txt new file mode 100644 index 00000000000..ec727e856e8 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/mixed-body.input.txt @@ -0,0 +1,5 @@ +use test::mixed-body + +REMAP { + test::mixed-body(); +} diff --git a/tools/hrw4u/tests/data/procedures/mixed-body.output.txt b/tools/hrw4u/tests/data/procedures/mixed-body.output.txt new file mode 100644 index 00000000000..07b26899a90 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/mixed-body.output.txt @@ -0,0 +1,6 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] +cond %{CLIENT-HEADER:X-Foo} ="bar" + set-header X-Bar "yes" + +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Baloo "hello" diff --git a/tools/hrw4u/tests/data/procedures/multi-proc.ast.txt b/tools/hrw4u/tests/data/procedures/multi-proc.ast.txt new file mode 100644 index 00000000000..fcfa46597cd --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-proc.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::TagAndOrigin)) (programItem (section REMAP { (sectionBody (statement (functionCall test::add-tag ( (argumentList (value "v2")) )) ;)) (sectionBody (statement (functionCall test::add-origin ( (argumentList (value "origin.example.com")) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/multi-proc.input.txt b/tools/hrw4u/tests/data/procedures/multi-proc.input.txt new file mode 100644 index 00000000000..3b479e8e1b6 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-proc.input.txt @@ -0,0 +1,6 @@ +use test::TagAndOrigin + +REMAP { + test::add-tag("v2"); + test::add-origin("origin.example.com"); +} diff --git a/tools/hrw4u/tests/data/procedures/multi-proc.output.txt b/tools/hrw4u/tests/data/procedures/multi-proc.output.txt new file mode 100644 index 00000000000..92da98f4692 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-proc.output.txt @@ -0,0 +1,3 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Tag "v2" + set-header X-Origin "origin.example.com" diff --git a/tools/hrw4u/tests/data/procedures/multi-section-mixed.ast.txt b/tools/hrw4u/tests/data/procedures/multi-section-mixed.ast.txt new file mode 100644 index 00000000000..435b7e42234 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-section-mixed.ast.txt @@ -0,0 +1 @@ +(program (programItem (procedureDecl procedure local::guarded-tag ( (paramList (param $ tag)) ) (block { (blockItem (conditional (ifStatement if (condition (expression (term (factor (comparison (comparable inbound.req.X-Env) == (value "prod")))))) (block { (blockItem (statement inbound.req.X-Prod = (value "$tag") ;)) })))) (blockItem (statement inbound.req.X-Tag = (value "$tag") ;)) }))) (programItem (section REMAP { (sectionBody (statement (functionCall local::guarded-tag ( (argumentList (value "remap-val")) )) ;)) })) (programItem (section SEND_RESPONSE { (sectionBody (statement (functionCall local::guarded-tag ( (argumentList (value "resp-val")) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/multi-section-mixed.flatten.txt b/tools/hrw4u/tests/data/procedures/multi-section-mixed.flatten.txt new file mode 100644 index 00000000000..dc5fdfdb30d --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-section-mixed.flatten.txt @@ -0,0 +1,13 @@ +REMAP { + if inbound.req.X-Env == "prod" { + inbound.req.X-Prod = "remap-val"; + } + inbound.req.X-Tag = "remap-val"; +} + +SEND_RESPONSE { + if inbound.req.X-Env == "prod" { + inbound.req.X-Prod = "resp-val"; + } + inbound.req.X-Tag = "resp-val"; +} diff --git a/tools/hrw4u/tests/data/procedures/multi-section-mixed.input.txt b/tools/hrw4u/tests/data/procedures/multi-section-mixed.input.txt new file mode 100644 index 00000000000..5aa9926964a --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-section-mixed.input.txt @@ -0,0 +1,15 @@ +procedure local::guarded-tag($tag) { + if inbound.req.X-Env == "prod" { + inbound.req.X-Prod = "$tag"; + } + + inbound.req.X-Tag = "$tag"; +} + +REMAP { + local::guarded-tag("remap-val"); +} + +SEND_RESPONSE { + local::guarded-tag("resp-val"); +} diff --git a/tools/hrw4u/tests/data/procedures/multi-section-mixed.output.txt b/tools/hrw4u/tests/data/procedures/multi-section-mixed.output.txt new file mode 100644 index 00000000000..84bc6c8f219 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-section-mixed.output.txt @@ -0,0 +1,13 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] +cond %{CLIENT-HEADER:X-Env} ="prod" + set-header X-Prod "remap-val" + +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Tag "remap-val" + +cond %{SEND_RESPONSE_HDR_HOOK} [AND] +cond %{CLIENT-HEADER:X-Env} ="prod" + set-header X-Prod "resp-val" + +cond %{SEND_RESPONSE_HDR_HOOK} [AND] + set-header X-Tag "resp-val" diff --git a/tools/hrw4u/tests/data/procedures/multi-use.ast.txt b/tools/hrw4u/tests/data/procedures/multi-use.ast.txt new file mode 100644 index 00000000000..d69eadd5fe4 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-use.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::add-debug-header)) (programItem (useDirective use test::set-origin)) (programItem (section REMAP { (sectionBody (statement (functionCall test::add-debug-header ( (argumentList (value "foo")) )) ;)) (sectionBody (statement (functionCall test::set-origin ( (argumentList (value "bar.com")) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/multi-use.input.txt b/tools/hrw4u/tests/data/procedures/multi-use.input.txt new file mode 100644 index 00000000000..57613308485 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-use.input.txt @@ -0,0 +1,7 @@ +use test::add-debug-header +use test::set-origin + +REMAP { + test::add-debug-header("foo"); + test::set-origin("bar.com"); +} diff --git a/tools/hrw4u/tests/data/procedures/multi-use.output.txt b/tools/hrw4u/tests/data/procedures/multi-use.output.txt new file mode 100644 index 00000000000..fe94333fbfe --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/multi-use.output.txt @@ -0,0 +1,3 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Debug "foo" + set-header X-Origin "bar.com" diff --git a/tools/hrw4u/tests/data/procedures/override-param.ast.txt b/tools/hrw4u/tests/data/procedures/override-param.ast.txt new file mode 100644 index 00000000000..0a45cd4aa40 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/override-param.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::set-cache)) (programItem (section REMAP { (sectionBody (statement (functionCall test::set-cache ( (argumentList (value 600)) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/override-param.input.txt b/tools/hrw4u/tests/data/procedures/override-param.input.txt new file mode 100644 index 00000000000..ce03df3f238 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/override-param.input.txt @@ -0,0 +1,5 @@ +use test::set-cache + +REMAP { + test::set-cache(600); +} diff --git a/tools/hrw4u/tests/data/procedures/override-param.output.txt b/tools/hrw4u/tests/data/procedures/override-param.output.txt new file mode 100644 index 00000000000..32281b60b92 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/override-param.output.txt @@ -0,0 +1,2 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Cache-TTL 600 diff --git a/tools/hrw4u/tests/data/procedures/proc-after-section.fail.error.txt b/tools/hrw4u/tests/data/procedures/proc-after-section.fail.error.txt new file mode 100644 index 00000000000..0f0f2dcf7d4 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/proc-after-section.fail.error.txt @@ -0,0 +1 @@ +'procedure' declarations must appear before any section blocks diff --git a/tools/hrw4u/tests/data/procedures/proc-after-section.fail.input.txt b/tools/hrw4u/tests/data/procedures/proc-after-section.fail.input.txt new file mode 100644 index 00000000000..d488ffad66e --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/proc-after-section.fail.input.txt @@ -0,0 +1,7 @@ +REMAP { + inbound.req.X-Foo = "bar"; +} + +procedure local::late() { + inbound.req.X-Late = "yes"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/base/Stamp.hrw4u b/tools/hrw4u/tests/data/procedures/procs/base/Stamp.hrw4u new file mode 100644 index 00000000000..e635ab65c6b --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/base/Stamp.hrw4u @@ -0,0 +1,3 @@ +procedure base::stamp() { + inbound.req.X-Stamp = "base"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/caller/Wrap.hrw4u b/tools/hrw4u/tests/data/procedures/procs/caller/Wrap.hrw4u new file mode 100644 index 00000000000..cbd794dfe60 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/caller/Wrap.hrw4u @@ -0,0 +1,6 @@ +use base::Stamp + +procedure caller::wrap($tag) { + base::stamp(); + inbound.req.X-Tag = "$tag"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/circular/A.hrw4u b/tools/hrw4u/tests/data/procedures/procs/circular/A.hrw4u new file mode 100644 index 00000000000..31b6fdeae0a --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/circular/A.hrw4u @@ -0,0 +1,5 @@ +use circular::B + +procedure circular::A() { + inbound.req.X-A = "a"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/circular/B.hrw4u b/tools/hrw4u/tests/data/procedures/procs/circular/B.hrw4u new file mode 100644 index 00000000000..f18d26120ea --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/circular/B.hrw4u @@ -0,0 +1,5 @@ +use circular::A + +procedure circular::B() { + inbound.req.X-B = "b"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/reexport/debug.hrw4u b/tools/hrw4u/tests/data/procedures/procs/reexport/debug.hrw4u new file mode 100644 index 00000000000..4918d7cb125 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/reexport/debug.hrw4u @@ -0,0 +1 @@ +use test::add-debug-header diff --git a/tools/hrw4u/tests/data/procedures/procs/test/TagAndOrigin.hrw4u b/tools/hrw4u/tests/data/procedures/procs/test/TagAndOrigin.hrw4u new file mode 100644 index 00000000000..c85897dcbf9 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/test/TagAndOrigin.hrw4u @@ -0,0 +1,7 @@ +procedure test::add-tag($tag) { + inbound.req.X-Tag = "$tag"; +} + +procedure test::add-origin($host) { + inbound.req.X-Origin = "$host"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/test/add-debug-header.hrw4u b/tools/hrw4u/tests/data/procedures/procs/test/add-debug-header.hrw4u new file mode 100644 index 00000000000..fbe2e8c3aba --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/test/add-debug-header.hrw4u @@ -0,0 +1,3 @@ +procedure test::add-debug-header($tag) { + inbound.req.X-Debug = "$tag"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/test/classify-request.hrw4u b/tools/hrw4u/tests/data/procedures/procs/test/classify-request.hrw4u new file mode 100644 index 00000000000..59411ab2f76 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/test/classify-request.hrw4u @@ -0,0 +1,10 @@ +procedure test::classify-request($label) { + if inbound.req.X-Priority == "high" { + inbound.req.X-Class = "urgent"; + } elif inbound.req.X-Priority == "medium" { + inbound.req.X-Class = "normal"; + } else { + inbound.req.X-Class = "low"; + } + inbound.req.X-Classified-By = "$label"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/test/mixed-body.hrw4u b/tools/hrw4u/tests/data/procedures/procs/test/mixed-body.hrw4u new file mode 100644 index 00000000000..89652cb8bf6 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/test/mixed-body.hrw4u @@ -0,0 +1,7 @@ +procedure test::mixed-body() { + if inbound.req.X-Foo == "bar" { + inbound.req.X-Bar = "yes"; + } + + inbound.req.X-Baloo = "hello"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/test/set-cache.hrw4u b/tools/hrw4u/tests/data/procedures/procs/test/set-cache.hrw4u new file mode 100644 index 00000000000..b27bb560f0a --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/test/set-cache.hrw4u @@ -0,0 +1,3 @@ +procedure test::set-cache($ttl=300) { + inbound.req.X-Cache-TTL = $ttl; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/test/set-origin.hrw4u b/tools/hrw4u/tests/data/procedures/procs/test/set-origin.hrw4u new file mode 100644 index 00000000000..7e477bfd055 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/test/set-origin.hrw4u @@ -0,0 +1,3 @@ +procedure test::set-origin($host) { + inbound.req.X-Origin = "$host"; +} diff --git a/tools/hrw4u/tests/data/procedures/procs/test/wrong-namespace.hrw4u b/tools/hrw4u/tests/data/procedures/procs/test/wrong-namespace.hrw4u new file mode 100644 index 00000000000..f79d0d4dc38 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/procs/test/wrong-namespace.hrw4u @@ -0,0 +1,3 @@ +procedure wrong::namespace() { + inbound.req.X-Bad = "yes"; +} diff --git a/tools/hrw4u/tests/data/procedures/reexport.ast.txt b/tools/hrw4u/tests/data/procedures/reexport.ast.txt new file mode 100644 index 00000000000..9c0a9ff4dc9 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/reexport.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use reexport::debug)) (programItem (section REMAP { (sectionBody (statement (functionCall test::add-debug-header ( (argumentList (value "via-reexport")) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/reexport.input.txt b/tools/hrw4u/tests/data/procedures/reexport.input.txt new file mode 100644 index 00000000000..b87b8259a6d --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/reexport.input.txt @@ -0,0 +1,5 @@ +use reexport::debug + +REMAP { + test::add-debug-header("via-reexport"); +} diff --git a/tools/hrw4u/tests/data/procedures/reexport.output.txt b/tools/hrw4u/tests/data/procedures/reexport.output.txt new file mode 100644 index 00000000000..8f191a6b815 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/reexport.output.txt @@ -0,0 +1,2 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Debug "via-reexport" diff --git a/tools/hrw4u/tests/data/procedures/string-param.ast.txt b/tools/hrw4u/tests/data/procedures/string-param.ast.txt new file mode 100644 index 00000000000..81475a72be8 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/string-param.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use test::set-origin)) (programItem (section REMAP { (sectionBody (statement (functionCall test::set-origin ( (argumentList (value "example.com")) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/string-param.input.txt b/tools/hrw4u/tests/data/procedures/string-param.input.txt new file mode 100644 index 00000000000..4ebc0826dd2 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/string-param.input.txt @@ -0,0 +1,5 @@ +use test::set-origin + +REMAP { + test::set-origin("example.com"); +} diff --git a/tools/hrw4u/tests/data/procedures/string-param.output.txt b/tools/hrw4u/tests/data/procedures/string-param.output.txt new file mode 100644 index 00000000000..ef889e018f0 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/string-param.output.txt @@ -0,0 +1,2 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Origin "example.com" diff --git a/tools/hrw4u/tests/data/procedures/top-level-only.fail.error.txt b/tools/hrw4u/tests/data/procedures/top-level-only.fail.error.txt new file mode 100644 index 00000000000..fae8d02cde7 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/top-level-only.fail.error.txt @@ -0,0 +1 @@ +'use' directives must appear before \ No newline at end of file diff --git a/tools/hrw4u/tests/data/procedures/top-level-only.fail.input.txt b/tools/hrw4u/tests/data/procedures/top-level-only.fail.input.txt new file mode 100644 index 00000000000..81376f016d8 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/top-level-only.fail.input.txt @@ -0,0 +1,5 @@ +REMAP { + inbound.req.X-Foo = "bar"; +} + +use test::add-debug-header diff --git a/tools/hrw4u/tests/data/procedures/transitive.ast.txt b/tools/hrw4u/tests/data/procedures/transitive.ast.txt new file mode 100644 index 00000000000..51314addabf --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/transitive.ast.txt @@ -0,0 +1 @@ +(program (programItem (useDirective use caller::Wrap)) (programItem (section REMAP { (sectionBody (statement (functionCall caller::wrap ( (argumentList (value "hello")) )) ;)) })) ) diff --git a/tools/hrw4u/tests/data/procedures/transitive.input.txt b/tools/hrw4u/tests/data/procedures/transitive.input.txt new file mode 100644 index 00000000000..3d6ef25ecfd --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/transitive.input.txt @@ -0,0 +1,5 @@ +use caller::Wrap + +REMAP { + caller::wrap("hello"); +} diff --git a/tools/hrw4u/tests/data/procedures/transitive.output.txt b/tools/hrw4u/tests/data/procedures/transitive.output.txt new file mode 100644 index 00000000000..8a9c820dd0f --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/transitive.output.txt @@ -0,0 +1,3 @@ +cond %{REMAP_PSEUDO_HOOK} [AND] + set-header X-Stamp "base" + set-header X-Tag "hello" diff --git a/tools/hrw4u/tests/data/procedures/unknown-proc.fail.error.txt b/tools/hrw4u/tests/data/procedures/unknown-proc.fail.error.txt new file mode 100644 index 00000000000..a70895d4979 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/unknown-proc.fail.error.txt @@ -0,0 +1 @@ +unknown procedure 'test::add-debug-header': not loaded via 'use' diff --git a/tools/hrw4u/tests/data/procedures/unknown-proc.fail.input.txt b/tools/hrw4u/tests/data/procedures/unknown-proc.fail.input.txt new file mode 100644 index 00000000000..5af771b27df --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/unknown-proc.fail.input.txt @@ -0,0 +1,3 @@ +REMAP { + test::add-debug-header("foo"); +} diff --git a/tools/hrw4u/tests/data/procedures/wrong-arity.fail.error.txt b/tools/hrw4u/tests/data/procedures/wrong-arity.fail.error.txt new file mode 100644 index 00000000000..30c6c58f295 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/wrong-arity.fail.error.txt @@ -0,0 +1 @@ +procedure 'test::add-debug-header': expected 1 arg(s), got 2 diff --git a/tools/hrw4u/tests/data/procedures/wrong-arity.fail.input.txt b/tools/hrw4u/tests/data/procedures/wrong-arity.fail.input.txt new file mode 100644 index 00000000000..888ca760b48 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/wrong-arity.fail.input.txt @@ -0,0 +1,5 @@ +use test::add-debug-header + +REMAP { + test::add-debug-header("arg1", "arg2"); +} diff --git a/tools/hrw4u/tests/data/procedures/wrong-namespace.fail.error.txt b/tools/hrw4u/tests/data/procedures/wrong-namespace.fail.error.txt new file mode 100644 index 00000000000..1aca10b0238 --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/wrong-namespace.fail.error.txt @@ -0,0 +1 @@ +does not match namespace 'test' diff --git a/tools/hrw4u/tests/data/procedures/wrong-namespace.fail.input.txt b/tools/hrw4u/tests/data/procedures/wrong-namespace.fail.input.txt new file mode 100644 index 00000000000..c27adf8dfbf --- /dev/null +++ b/tools/hrw4u/tests/data/procedures/wrong-namespace.fail.input.txt @@ -0,0 +1,5 @@ +use test::wrong-namespace + +REMAP { + wrong::namespace(); +} diff --git a/tools/hrw4u/tests/test_procedures.py b/tools/hrw4u/tests/test_procedures.py new file mode 100644 index 00000000000..1a4a622cc53 --- /dev/null +++ b/tools/hrw4u/tests/test_procedures.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from pathlib import Path + +import pytest +import utils + + +@pytest.mark.procedures +@pytest.mark.parametrize("input_file,output_file", utils.collect_output_test_files("procedures", "hrw4u")) +def test_output_matches(input_file: Path, output_file: Path) -> None: + """Test that procedure expansion and hrw4u output matches expected output.""" + utils.run_procedure_output_test(input_file, output_file) + + +@pytest.mark.procedures +@pytest.mark.invalid +@pytest.mark.parametrize("input_file", utils.collect_failing_inputs("procedures")) +def test_invalid_inputs_fail(input_file: Path) -> None: + """Test that invalid procedure inputs produce expected errors.""" + utils.run_procedure_failing_test(input_file) + + +@pytest.mark.procedures +@pytest.mark.parametrize("input_file,flatten_file", utils.collect_flatten_test_files("procedures")) +def test_flatten_matches(input_file: Path, flatten_file: Path) -> None: + """Test that flatten mode produces expected self-contained hrw4u output.""" + utils.run_procedure_flatten_test(input_file, flatten_file) + + +@pytest.mark.procedures +@pytest.mark.parametrize("input_file,output_file", utils.collect_output_test_files("procedures", "hrw4u")) +def test_flatten_roundtrip(input_file: Path, output_file: Path) -> None: + """Test that flattened hrw4u compiles to the same header_rewrite as the original.""" + utils.run_procedure_flatten_roundtrip_test(input_file, output_file) diff --git a/tools/hrw4u/tests/test_units.py b/tools/hrw4u/tests/test_units.py index cf5d7d0f2eb..0cd7e6dd8b6 100644 --- a/tools/hrw4u/tests/test_units.py +++ b/tools/hrw4u/tests/test_units.py @@ -14,15 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Unit tests for internal methods and helper functions. -This module tests individual methods and components that might not be -fully exercised by the integration tests, providing focused testing -for internal implementation details. -""" - -from hrw4u.errors import ErrorCollector, Hrw4uSyntaxError, SymbolResolutionError +from hrw4u.errors import ErrorCollector, Hrw4uSyntaxError, SymbolResolutionError, humanize_error_message +from hrw4u.common import create_parse_tree +from hrw4u.hrw4uLexer import hrw4uLexer +from hrw4u.hrw4uParser import hrw4uParser from hrw4u.visitor import HRW4UVisitor from hrw4u.validation import Validator import pytest @@ -31,47 +27,38 @@ class TestHRW4UVisitorUnits: - """Unit tests for HRW4U internal methods.""" def setup_method(self): - """Set up test fixtures.""" - self.visitor = HRW4UVisitor(None, None, None) + self.visitor = HRW4UVisitor() def test_parse_function_args_empty(self): - """Test _parse_function_args with empty input.""" assert self.visitor._parse_function_args('') == [] assert self.visitor._parse_function_args(' ') == [] def test_parse_function_args_simple(self): - """Test _parse_function_args with simple arguments.""" assert self.visitor._parse_function_args('arg1') == ['arg1'] assert self.visitor._parse_function_args('arg1, arg2') == ['arg1', 'arg2'] assert self.visitor._parse_function_args('arg1, arg2, arg3') == ['arg1', 'arg2', 'arg3'] def test_parse_function_args_quoted_commas(self): - """Test _parse_function_args with quotes containing commas.""" assert self.visitor._parse_function_args('"arg1, with comma", arg2') == ['"arg1, with comma"', 'arg2'] assert self.visitor._parse_function_args('arg1, "arg2, with comma"') == ['arg1', '"arg2, with comma"'] assert self.visitor._parse_function_args('"first, comma", "second, comma"') == ['"first, comma"', '"second, comma"'] def test_parse_function_args_single_quotes(self): - """Test _parse_function_args with single quotes containing commas.""" assert self.visitor._parse_function_args("'arg1, with comma', arg2") == ["'arg1, with comma'", 'arg2'] assert self.visitor._parse_function_args("arg1, 'arg2, with comma'") == ['arg1', "'arg2, with comma'"] def test_parse_function_args_nested_functions(self): - """Test _parse_function_args with nested function calls.""" assert self.visitor._parse_function_args('func(a,b), arg2') == ['func(a,b)', 'arg2'] assert self.visitor._parse_function_args('arg1, func(a,b)') == ['arg1', 'func(a,b)'] assert self.visitor._parse_function_args('func1(a,b), func2(c,d)') == ['func1(a,b)', 'func2(c,d)'] def test_parse_function_args_deeply_nested(self): - """Test _parse_function_args with deeply nested parentheses.""" assert self.visitor._parse_function_args('func(nested(a,b),c), arg2') == ['func(nested(a,b),c)', 'arg2'] assert self.visitor._parse_function_args('outer(inner(deep(x,y),z),w), final') == ['outer(inner(deep(x,y),z),w)', 'final'] def test_parse_function_args_mixed_complex(self): - """Test _parse_function_args with complex mixed cases.""" complex_arg = 'func("arg1, with comma", nested_func(a,b), "arg3")' assert self.visitor._parse_function_args(complex_arg) == [complex_arg] assert self.visitor._parse_function_args('"quoted, arg", func(a,b), normal_arg') == [ @@ -79,33 +66,27 @@ def test_parse_function_args_mixed_complex(self): ] def test_parse_function_args_whitespace_handling(self): - """Test _parse_function_args with various whitespace patterns.""" assert self.visitor._parse_function_args(' arg1 , arg2 ') == ['arg1', 'arg2'] assert self.visitor._parse_function_args('func( a , b ), arg2') == ['func( a , b )', 'arg2'] assert self.visitor._parse_function_args('\targ1\t,\targ2\t') == ['arg1', 'arg2'] def test_parse_function_args_escaped_quotes(self): - """Test _parse_function_args with escaped quotes (basic test).""" - # Note: The current implementation doesn't handle escaped quotes perfectly, + # The current implementation doesn't handle escaped quotes perfectly, # but this documents the current behavior result = self.visitor._parse_function_args('"arg with \\" quote", arg2') def test_parse_function_args_edge_cases(self): - """Test _parse_function_args with edge cases.""" assert self.visitor._parse_function_args('func(a,b, arg2') == ['func(a,b, arg2'] assert self.visitor._parse_function_args('arg1,, arg2') == ['arg1', '', 'arg2'] assert self.visitor._parse_function_args(',,,') == ['', '', ''] class TestErrorCollectorUnits: - """Unit tests for ErrorCollector internal methods.""" def setup_method(self): - """Set up test fixtures.""" self.error_collector = ErrorCollector() def test_error_collector_basic(self): - """Test basic ErrorCollector functionality.""" assert not self.error_collector.has_errors() test_error = Hrw4uSyntaxError("test.hrw4u", 1, 0, "Test error", "test line") @@ -117,7 +98,6 @@ def test_error_collector_basic(self): assert "Found 1 error:" in error_summary def test_error_collector_multiple_errors(self): - """Test ErrorCollector with multiple errors.""" error1 = Hrw4uSyntaxError("test1.hrw4u", 1, 0, "Error 1", "line 1") error2 = Hrw4uSyntaxError("test2.hrw4u", 2, 5, "Error 2", "line 2") error3 = Hrw4uSyntaxError("test3.hrw4u", 3, 10, "Error 3", "line 3") @@ -136,10 +116,8 @@ def test_error_collector_multiple_errors(self): class TestValidationUnits: - """Unit tests for validation functions.""" def test_http_header_name_valid_standard(self): - """Test http_header_name with valid standard RFC 7230 header names.""" validator = Validator.http_header_name() valid_names = [ @@ -147,27 +125,25 @@ def test_http_header_name_valid_standard(self): "X-Custom-Header", "User-Agent", "Accept-Encoding", - "X_Custom_Header", # Underscores allowed - "X~Custom~Header", # Tildes allowed - "X^Custom^Header", # Carets allowed - "X|Custom|Header", # Pipes allowed - "X!Custom!Header", # Exclamation marks allowed - "X#Custom#Header", # Hash marks allowed - "X$Custom$Header", # Dollar signs allowed - "X%Custom%Header", # Percent signs allowed - "X&Custom&Header", # Ampersands allowed - "X'Custom'Header", # Single quotes allowed - "X*Custom*Header", # Asterisks allowed - "X+Custom+Header", # Plus signs allowed - "X`Custom`Header", # Backticks allowed + "X_Custom_Header", + "X~Custom~Header", + "X^Custom^Header", + "X|Custom|Header", + "X!Custom!Header", + "X#Custom#Header", + "X$Custom$Header", + "X%Custom%Header", + "X&Custom&Header", + "X'Custom'Header", + "X*Custom*Header", + "X+Custom+Header", + "X`Custom`Header", ] for name in valid_names: - # Should not raise an exception validator(name) def test_http_header_name_valid_ats_internal(self): - """Test http_header_name with valid ATS internal headers (@ prefix).""" validator = Validator.http_header_name() valid_ats_names = [ @@ -178,38 +154,36 @@ def test_http_header_name_valid_ats_internal(self): ] for name in valid_ats_names: - # Should not raise an exception validator(name) def test_http_header_name_invalid(self): - """Test http_header_name with invalid header names.""" validator = Validator.http_header_name() invalid_names = [ - "", # Empty name - "@", # Just @ alone - "Content Type", # Space not allowed - "Content\tType", # Tab not allowed - "Content\nType", # Newline not allowed - "Content(Type)", # Parentheses not allowed - "Content[Type]", # Brackets not allowed - "Content{Type}", # Braces not allowed - "Content", # Angle brackets not allowed - "Content@Type", # @ not allowed in middle - "Content,Type", # Comma not allowed - "Content;Type", # Semicolon not allowed - "Content:Type", # Colon not allowed - "Content=Type", # Equals not allowed - "Content?Type", # Question mark not allowed - "Content/Type", # Forward slash not allowed - "Content\\Type", # Backslash not allowed - "Content\"Type\"", # Quotes not allowed - "@Content@Type", # @ not allowed after first position - "X-@Header", # @ not allowed after first position - "headers.X-Match", # Dots not allowed (hrw4u restriction) - "X.Custom.Header", # Dots not allowed (hrw4u restriction) - "@Custom.Header", # Dots not allowed even in ATS headers - "header.X-Foo", # Dots not allowed (the specific case mentioned) + "", + "@", + "Content Type", + "Content\tType", + "Content\nType", + "Content(Type)", + "Content[Type]", + "Content{Type}", + "Content", + "Content@Type", + "Content,Type", + "Content;Type", + "Content:Type", + "Content=Type", + "Content?Type", + "Content/Type", + "Content\\Type", + "Content\"Type\"", + "@Content@Type", + "X-@Header", + "headers.X-Match", + "X.Custom.Header", + "@Custom.Header", + "header.X-Foo", ] for name in invalid_names: @@ -217,43 +191,40 @@ def test_http_header_name_invalid(self): validator(name) def test_http_token_valid(self): - """Test http_token with valid tokens.""" validator = Validator.http_token() valid_tokens = [ "Content-Type", "simple_token", "Token123", - "!#$%&'*+-.^_`|~", # All allowed special chars + "!#$%&'*+-.^_`|~", "Mixed123.Token-Name_Test", ] for token in valid_tokens: - # Should not raise an exception validator(token) def test_http_token_invalid(self): - """Test http_token with invalid tokens.""" validator = Validator.http_token() invalid_tokens = [ - "", # Empty - "token with space", # Space not allowed - "token\ttab", # Tab not allowed - "token\nnewline", # Newline not allowed - "token(paren)", # Parentheses not allowed - "token[bracket]", # Brackets not allowed - "token{brace}", # Braces not allowed - "token", # Angle brackets not allowed - "token@at", # @ not allowed in http_token - "token,comma", # Comma not allowed - "token;semicolon", # Semicolon not allowed - "token:colon", # Colon not allowed - "token=equals", # Equals not allowed - "token?question", # Question mark not allowed - "token/slash", # Forward slash not allowed - "token\\backslash", # Backslash not allowed - "token\"quote\"", # Quotes not allowed + "", + "token with space", + "token\ttab", + "token\nnewline", + "token(paren)", + "token[bracket]", + "token{brace}", + "token", + "token@at", + "token,comma", + "token;semicolon", + "token:colon", + "token=equals", + "token?question", + "token/slash", + "token\\backslash", + "token\"quote\"", ] for token in invalid_tokens: @@ -261,19 +232,15 @@ def test_http_token_invalid(self): validator(token) def test_regex_validator_factory(self): - """Test the unified regex_validator factory method.""" import re - # Create a custom validator for testing - test_pattern = re.compile(r'^[A-Z]+$') # Only uppercase letters + test_pattern = re.compile(r'^[A-Z]+$') validator = Validator.regex_validator(test_pattern, "Must be uppercase letters only") - # Valid cases validator("ABC") validator("HELLO") validator("TEST") - # Invalid cases with pytest.raises(SymbolResolutionError, match="Must be uppercase letters only"): validator("abc") @@ -284,5 +251,74 @@ def test_regex_validator_factory(self): validator("TEST123") +class TestHumanizeErrorMessage: + + def test_replaces_qualified_ident(self): + msg = "mismatched input 'Apple' expecting QUALIFIED_IDENT" + result = humanize_error_message(msg) + assert "QUALIFIED_IDENT" not in result + assert "qualified name" in result + + def test_replaces_punctuation_tokens(self): + msg = "mismatched input '{' expecting LPAREN" + result = humanize_error_message(msg) + assert "LPAREN" not in result + assert "'('" in result + + def test_replaces_multiple_tokens(self): + msg = "expecting {IDENT, QUALIFIED_IDENT, SEMICOLON}" + result = humanize_error_message(msg) + assert "IDENT" not in result + assert "QUALIFIED_IDENT" not in result + assert "SEMICOLON" not in result + + def test_preserves_normal_text(self): + msg = "unknown procedure 'test::foo': not loaded via 'use'" + result = humanize_error_message(msg) + assert result == msg + + def test_no_partial_word_replacement(self): + msg = "some IDENTIFIER_LIKE text" + result = humanize_error_message(msg) + assert result == msg + + +class TestParseErrorMessages: + """Verify that ANTLR parse errors use human-friendly token names.""" + + def _collect_errors(self, text: str) -> list[str]: + collector = ErrorCollector(max_errors=10) + try: + create_parse_tree(text, "", hrw4uLexer, hrw4uParser, "test", collect_errors=True, max_errors=10) + except Exception: + pass + return [str(e) for e in collector.errors] + + def _first_error(self, text: str) -> str: + _, _, collector = create_parse_tree(text, "", hrw4uLexer, hrw4uParser, "test", collect_errors=True, max_errors=10) + assert collector and collector.has_errors(), f"Expected parse errors for: {text!r}" + return str(collector.errors[0]) + + def test_single_colon_in_procedure(self): + error = self._first_error("procedure Apple:Roles() { }") + assert "QUALIFIED_IDENT" not in error + assert "qualified name" in error + + def test_missing_lparen(self): + error = self._first_error("procedure ns::foo { }") + assert "LPAREN" not in error + assert "'('" in error + + def test_unqualified_procedure_name(self): + error = self._first_error("procedure nocolon() { }") + assert "QUALIFIED_IDENT" not in error + assert "qualified name" in error + + def test_bad_use_directive(self): + error = self._first_error("use Foo") + assert "QUALIFIED_IDENT" not in error + assert "qualified name" in error + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tools/hrw4u/tests/utils.py b/tools/hrw4u/tests/utils.py index fac320fbcd0..eefbd7249f5 100644 --- a/tools/hrw4u/tests/utils.py +++ b/tools/hrw4u/tests/utils.py @@ -32,20 +32,7 @@ from u4wrh.u4wrhParser import u4wrhParser from u4wrh.hrw_visitor import HRWInverseVisitor -# Try to import structured error type, fall back to generic Exception if not available -try: - from src.errors import Hrw4uSyntaxError -except ImportError: - # Fallback: define a minimal interface for structured error detection - class Hrw4uSyntaxError(Exception): - - def __init__(self, filename: str, line: int, column: int, message: str, source_line: str = ""): - super().__init__(message) - self.filename = filename - self.line = line - self.column = column - self.source_line = source_line - +from hrw4u.errors import Hrw4uSyntaxError __all__: Final[list[str]] = [ "collect_output_test_files", @@ -56,11 +43,15 @@ def __init__(self, filename: str, line: int, column: int, message: str, source_l "run_failing_test", "run_reverse_test", "run_bulk_test", + "run_procedure_output_test", + "run_procedure_flatten_test", + "run_procedure_failing_test", + "run_procedure_flatten_roundtrip_test", + "collect_flatten_test_files", ] def parse_input_text(text: str) -> tuple[hrw4uParser, hrw4uParser.ProgramContext]: - """Parse hrw4u input text and return parser and AST.""" lexer = hrw4uLexer(InputStream(text)) stream = CommonTokenStream(lexer) parser = hrw4uParser(stream) @@ -69,7 +60,6 @@ def parse_input_text(text: str) -> tuple[hrw4uParser, hrw4uParser.ProgramContext def _read_exceptions(base_dir: Path) -> dict[str, str]: - """Read exceptions.txt file and return test -> direction mapping.""" exceptions_file = base_dir / "exceptions.txt" exceptions = {} @@ -89,14 +79,10 @@ def _read_exceptions(base_dir: Path) -> dict[str, str]: def collect_output_test_files(group: str, direction: str = "hrw4u") -> Iterator[pytest.param]: - """ - Collect test files for output validation tests. - """ base_dir = Path("tests/data") / group exceptions = _read_exceptions(base_dir) for input_file in base_dir.glob("*.input.txt"): - # Skip failure test cases here; those are handled separately if ".fail." in input_file.name: continue @@ -104,7 +90,6 @@ def collect_output_test_files(group: str, direction: str = "hrw4u") -> Iterator[ output_file = base.with_suffix('.output.txt') test_id = base.name - # Check if this test has direction restrictions if test_id in exceptions: test_direction = exceptions[test_id] if direction != "both" and direction != test_direction: @@ -114,12 +99,6 @@ def collect_output_test_files(group: str, direction: str = "hrw4u") -> Iterator[ def collect_ast_test_files(group: str) -> Iterator[pytest.param]: - """ - Collect test files for AST validation tests. - - AST tests always run in hrw4u direction only since they validate - the parse tree structure of the input format. - """ base_dir = Path("tests/data") / group for input_file in base_dir.glob("*.input.txt"): @@ -135,9 +114,6 @@ def collect_ast_test_files(group: str) -> Iterator[pytest.param]: def collect_failing_inputs(group: str) -> Iterator[pytest.param]: - """ - Collect test files for failure validation tests. - """ base_dir = Path("tests/data") / group for input_file in base_dir.glob("*.fail.input.txt"): test_id = input_file.stem @@ -145,7 +121,6 @@ def collect_failing_inputs(group: str) -> Iterator[pytest.param]: def run_output_test(input_file: Path, output_file: Path) -> None: - """Run output validation test comparing generated output with expected.""" input_text = input_file.read_text() parser, tree = parse_input_text(input_text) visitor = HRW4UVisitor() @@ -155,7 +130,6 @@ def run_output_test(input_file: Path, output_file: Path) -> None: def run_ast_test(input_file: Path, ast_file: Path) -> None: - """Run AST validation test comparing generated AST with expected.""" input_text = input_file.read_text() parser, tree = parse_input_text(input_text) actual_ast = tree.toStringTree(recog=parser).strip() @@ -164,7 +138,6 @@ def run_ast_test(input_file: Path, ast_file: Path) -> None: def run_failing_test(input_file: Path) -> None: - """Run failure validation test ensuring input produces expected error with structured validation.""" text = input_file.read_text() parser, tree = parse_input_text(text) visitor = HRW4UVisitor(filename=str(input_file)) @@ -182,14 +155,11 @@ def run_failing_test(input_file: Path) -> None: actual_exception = exc_info.value actual_error_str = str(actual_exception).strip() - # Parse expected error for structured validation expected_fields = _parse_error_file(expected_error_content) if expected_fields and isinstance(actual_exception, Hrw4uSyntaxError): - # Assert structured fields when available _assert_structured_error_fields(actual_exception, expected_fields, input_file) else: - # Fallback to substring matching for legacy files or non-structured exceptions assert expected_error_content in actual_error_str, ( f"Error mismatch for {input_file}\n" f"Expected error (partial match):\n{expected_error_content}\n\n" @@ -197,19 +167,11 @@ def run_failing_test(input_file: Path) -> None: def _parse_error_file(error_content: str) -> dict[str, str | int] | None: - """ - Parse structured error file content to extract filename, line, column, and message. - - Expected format: filename:line:column: error: message - Returns None if parsing fails (fallback to substring matching). - """ lines = error_content.strip().split('\n') if not lines: return None first_line = lines[0].strip() - - # Regex to parse: filename:line:column: error: message error_pattern = re.compile(r'^(.+):(\d+):(\d+):\s*error:\s*(.+)$') match = error_pattern.match(first_line) @@ -226,9 +188,6 @@ def _parse_error_file(error_content: str) -> dict[str, str | int] | None: def _assert_structured_error_fields( actual_exception: Hrw4uSyntaxError, expected_fields: dict[str, str | int], input_file: Path) -> None: - """Assert that structured exception fields match expected values.""" - - # Assert filename (normalize paths for comparison) expected_filename = str(Path(expected_fields['filename']).resolve()) actual_filename = str(Path(actual_exception.filename).resolve()) assert actual_filename == expected_filename, ( @@ -236,19 +195,16 @@ def _assert_structured_error_fields( f"Expected: {expected_filename}\n" f"Actual: {actual_filename}") - # Assert line number assert actual_exception.line == expected_fields['line'], ( f"Line number mismatch for {input_file}\n" f"Expected: {expected_fields['line']}\n" f"Actual: {actual_exception.line}") - # Assert column number assert actual_exception.column == expected_fields['column'], ( f"Column number mismatch for {input_file}\n" f"Expected: {expected_fields['column']}\n" f"Actual: {actual_exception.column}") - # Assert error message (allow partial match for flexibility) expected_message = expected_fields['message'] actual_full_error = str(actual_exception) assert expected_message in actual_full_error, ( @@ -258,7 +214,6 @@ def _assert_structured_error_fields( def run_reverse_test(input_file: Path, output_file: Path) -> None: - """Run u4wrh on output.txt and compare with input.txt (round-trip test).""" output_text = output_file.read_text() lexer = u4wrhLexer(InputStream(output_text)) stream = CommonTokenStream(lexer) @@ -271,63 +226,49 @@ def run_reverse_test(input_file: Path, output_file: Path) -> None: def create_output_test(group: str): - """Create a parametrized output test function for a specific group.""" import pytest @pytest.mark.parametrize("input_file,output_file", collect_output_test_files(group, "hrw4u")) def test_output_matches(input_file: Path, output_file: Path) -> None: - f"""Test that hrw4u output matches expected output for {group} test cases.""" run_output_test(input_file, output_file) return test_output_matches def create_ast_test(group: str): - """Create a parametrized AST test function for a specific group.""" import pytest @pytest.mark.ast @pytest.mark.parametrize("input_file,ast_file", collect_ast_test_files(group)) def test_ast_matches(input_file: Path, ast_file: Path) -> None: - f"""Test that AST structure matches expected AST for {group} test cases.""" run_ast_test(input_file, ast_file) return test_ast_matches def create_invalid_test(group: str): - """Create a parametrized invalid input test function for a specific group.""" import pytest @pytest.mark.invalid @pytest.mark.parametrize("input_file", collect_failing_inputs(group)) def test_invalid_inputs_fail(input_file: Path) -> None: - f"""Test that invalid {group} inputs produce expected errors.""" run_failing_test(input_file) return test_invalid_inputs_fail def create_reverse_test(group: str): - """Create a parametrized reverse test function for a specific group.""" import pytest @pytest.mark.reverse @pytest.mark.parametrize("input_file,output_file", collect_output_test_files(group, "u4wrh")) def test_reverse_conversion(input_file: Path, output_file: Path) -> None: - f"""Test that u4wrh reverse conversion produces original hrw4u for {group} test cases.""" run_reverse_test(input_file, output_file) return test_reverse_conversion def run_bulk_test(group: str) -> None: - """ - Run bulk compilation test for a specific test group. - - Collects all .input.txt files in the group, runs hrw4u with bulk - input:output pairs, and compares each output with expected .output.txt. - """ base_dir = Path("tests/data") / group exceptions = _read_exceptions(base_dir) @@ -385,3 +326,80 @@ def run_bulk_test(group: str) -> None: f"Bulk output mismatch for {input_file.name}\n" f"Expected:\n{expected_output}\n\n" f"Actual:\n{actual_output}") + + +def _procs_dir(input_file: Path) -> Path: + return input_file.parent / 'procs' + + +def collect_flatten_test_files(group: str) -> Iterator[pytest.param]: + base_dir = Path("tests/data") / group + + for input_file in base_dir.glob("*.input.txt"): + if ".fail." in input_file.name: + continue + + base = input_file.with_suffix('') + flatten_file = base.with_suffix('.flatten.txt') + test_id = base.name + + if flatten_file.exists(): + yield pytest.param(input_file, flatten_file, id=test_id) + + +def run_procedure_output_test(input_file: Path, output_file: Path) -> None: + procs_dir = _procs_dir(input_file) + input_text = input_file.read_text() + parser, tree = parse_input_text(input_text) + visitor = HRW4UVisitor(filename=str(input_file), proc_search_paths=[procs_dir]) + actual_output = "\n".join(visitor.visit(tree)).strip() + expected_output = output_file.read_text().strip() + assert actual_output == expected_output, f"Output mismatch in {input_file}" + + +def run_procedure_flatten_test(input_file: Path, flatten_file: Path) -> None: + procs_dir = _procs_dir(input_file) + input_text = input_file.read_text() + parser, tree = parse_input_text(input_text) + visitor = HRW4UVisitor(filename=str(input_file), proc_search_paths=[procs_dir]) + actual_output = "\n".join(visitor.flatten(tree, input_text)).strip() + expected_output = flatten_file.read_text().strip() + assert actual_output == expected_output, f"Flatten mismatch in {input_file}" + + +def run_procedure_flatten_roundtrip_test(input_file: Path, output_file: Path) -> None: + """Verify that flattened output compiles to the same header_rewrite as the original.""" + procs_dir = _procs_dir(input_file) + input_text = input_file.read_text() + parser, tree = parse_input_text(input_text) + visitor = HRW4UVisitor(filename=str(input_file), proc_search_paths=[procs_dir]) + flattened = "\n".join(visitor.flatten(tree, input_text)) + + # Compile the flattened output (no procedures needed — it's self-contained) + parser2, tree2 = parse_input_text(flattened) + visitor2 = HRW4UVisitor(filename=str(input_file)) + actual_output = "\n".join(visitor2.visit(tree2)).strip() + expected_output = output_file.read_text().strip() + assert actual_output == expected_output, ( + f"Flatten roundtrip mismatch in {input_file}\n" + f"Flattened hrw4u compiles to different output than original") + + +def run_procedure_failing_test(input_file: Path) -> None: + procs_dir = _procs_dir(input_file) + text = input_file.read_text() + + error_file = input_file.with_name(input_file.name.replace(".fail.input.txt", ".fail.error.txt")) + if not error_file.exists(): + raise RuntimeError(f"Missing expected error file: {error_file}") + + expected_error = error_file.read_text().strip() + + with pytest.raises(Exception) as exc_info: + parser, tree = parse_input_text(text) + HRW4UVisitor(filename=str(input_file), proc_search_paths=[procs_dir]).visit(tree) + + assert expected_error in str(exc_info.value), ( + f"Error mismatch for {input_file}\n" + f"Expected (substring): {expected_error!r}\n" + f"Actual: {str(exc_info.value)!r}") From fec579d297ce70e01ecb4fd9ebd1cbe3c7644793 Mon Sep 17 00:00:00 2001 From: Leif Hedstrom Date: Wed, 4 Mar 2026 18:42:39 -0700 Subject: [PATCH 2/2] Address Copilot review comments - Validate that required params precede optional params in procedure declarations; raises a clear error at declaration time. - Extract source line from input stream in _get_value_text errors so the error output includes the offending line and caret position. - Use SystemDefaults.INDENT_SPACES instead of hardcoded 4-space literals in flatten() and _flatten_conditional(). - Remove dead _collect_errors() in TestParseErrorMessages; _first_error() is the working equivalent used by all tests. --- tools/hrw4u/src/visitor.py | 19 ++++++++++++++++--- tools/hrw4u/tests/test_units.py | 8 -------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tools/hrw4u/src/visitor.py b/tools/hrw4u/src/visitor.py index bd8104d2111..f6fdf651304 100644 --- a/tools/hrw4u/src/visitor.py +++ b/tools/hrw4u/src/visitor.py @@ -237,8 +237,13 @@ def _get_value_text(self, val_ctx) -> str: if val_ctx.paramRef(): name = val_ctx.paramRef().IDENT().getText() if name not in self._proc_bindings: + try: + source_line = val_ctx.start.getInputStream().strdata.splitlines()[val_ctx.start.line - 1] + except Exception: + source_line = "" raise Hrw4uSyntaxError( - self.filename, val_ctx.start.line, val_ctx.start.column, f"'${name}' used outside procedure context", "") + self.filename, val_ctx.start.line, val_ctx.start.column, f"'${name}' used outside procedure context", + source_line) return self._proc_bindings[name] return val_ctx.getText() @@ -547,7 +552,7 @@ def _flatten_items(self, items, indent: str, source_text: str, bindings: dict[st def _flatten_conditional(self, cond_ctx, indent: str, source_text: str, bindings: dict[str, str]) -> list[str]: """Flatten a conditional block, expanding proc calls within its branches.""" lines: list[str] = [] - inner_indent = indent + " " + inner_indent = indent + (" " * SystemDefaults.INDENT_SPACES) if_ctx = cond_ctx.ifStatement() cond_text = self._get_source_text(if_ctx.condition(), source_text) @@ -573,7 +578,7 @@ def flatten(self, ctx, source_text: str = "") -> list[str]: if not source_text: source_text = ctx.start.source[1].getText(0, ctx.start.source[1].size - 1) self._source_text = source_text - indent = " " * 4 + indent = " " * SystemDefaults.INDENT_SPACES # Phase 1: Load all procedures (use directives + local procedure declarations) for item in ctx.programItem(): @@ -640,6 +645,14 @@ def visitProcedureDecl(self, ctx) -> None: self.filename, ctx.start.line, ctx.start.column, f"procedure '{name}' already declared in {existing.source_file}", "") params = self._collect_proc_params(ctx.paramList()) if ctx.paramList() else [] + seen_default = False + for p in params: + if p.default_ctx is None and seen_default: + raise Hrw4uSyntaxError( + self.filename, ctx.start.line, ctx.start.column, + f"procedure '{name}': required parameter '${p.name}' must not follow an optional parameter", "") + if p.default_ctx is not None: + seen_default = True self._proc_registry[name] = ProcSig(name, params, ctx.block(), self.filename, self._source_text) def visitProgram(self, ctx) -> list[str]: diff --git a/tools/hrw4u/tests/test_units.py b/tools/hrw4u/tests/test_units.py index 0cd7e6dd8b6..37bccd5cf73 100644 --- a/tools/hrw4u/tests/test_units.py +++ b/tools/hrw4u/tests/test_units.py @@ -286,14 +286,6 @@ def test_no_partial_word_replacement(self): class TestParseErrorMessages: """Verify that ANTLR parse errors use human-friendly token names.""" - def _collect_errors(self, text: str) -> list[str]: - collector = ErrorCollector(max_errors=10) - try: - create_parse_tree(text, "", hrw4uLexer, hrw4uParser, "test", collect_errors=True, max_errors=10) - except Exception: - pass - return [str(e) for e in collector.errors] - def _first_error(self, text: str) -> str: _, _, collector = create_parse_tree(text, "", hrw4uLexer, hrw4uParser, "test", collect_errors=True, max_errors=10) assert collector and collector.has_errors(), f"Expected parse errors for: {text!r}"