From ef573abd213c8e97c58c9dea3db2ee1869e76b03 Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Wed, 22 Apr 2026 18:49:18 +0530 Subject: [PATCH 1/4] fix(core): subquery pool leaks, additional PutStatement dispatches, legacy sentinel wrapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sprint A — correctness fixes from 2026-04-21 architect review. Pool leak fixes (pkg/sql/ast/pool.go): - PutExpression: 6 sites that set .Subquery = nil without releasing the inner *SelectStatement/Statement — InExpression, SubqueryExpression, ExistsExpression, AnyExpression, AllExpression, ArrayConstructorExpression. Each now dispatches through releaseStatement / PutSelectStatement before nil-assign. - Helper Put* functions with the same bug: PutInExpression, PutSubqueryExpression, PutArrayConstructor. - releaseStatement dispatch completeness: added cases for *CreateSequenceStatement, *DropSequenceStatement, *AlterSequenceStatement (were declared with pools but missing from dispatch). - putExpressionImpl: workQueue now pooled via sync.Pool (was allocated per call, hot path 10-100x per parse). Pool leak test additions (pkg/sql/ast/pool_leak_test.go): - TestPoolLeak_SubqueryInExpression — IN-subquery 1000× stable-heap assertion - TestPoolLeak_ExistsSubquery — EXISTS 1000× - TestPoolLeak_AnyAllSubquery — ANY/ALL subqueries 1000× Legacy sentinel wrapping (pkg/gosqlx/gosqlx.go): Round-2 review identified that Parse, ParseWithContext, ParseBytes, ParseMultiple, ValidateMultiple, Validate, ParseWithDialect, ParseWithRecovery, MustParse all returned errors that did not satisfy errors.Is(err, gosqlx.ErrSyntax) / ErrTokenize / ErrTimeout / ErrUnsupportedDialect. Every site now wraps with the appropriate sentinel. MustParse panics with a wrapped error so recover() can use errors.As. ParseWithDialect validates the dialect up front. New pkg/gosqlx/legacy_sentinels_test.go covers parity between Parse and ParseTree error chains plus sentinel wrapping for all legacy entry points. go test -race ./... passes. Co-Authored-By: Claude Opus 4.7 (1M context) --- pkg/gosqlx/errors.go | 4 + pkg/gosqlx/gosqlx.go | 57 +- pkg/gosqlx/legacy_sentinels_test.go | 294 +++++ pkg/gosqlx/options.go | 27 + pkg/gosqlx/testing/demo_usage_test.go | 8 +- pkg/gosqlx/testing/example_test.go | 4 +- pkg/gosqlx/testing/testing_test.go | 2 +- pkg/sql/ast/pool.go | 1485 +------------------------ pkg/sql/ast/pool_leak_test.go | 105 ++ 9 files changed, 510 insertions(+), 1476 deletions(-) create mode 100644 pkg/gosqlx/legacy_sentinels_test.go diff --git a/pkg/gosqlx/errors.go b/pkg/gosqlx/errors.go index 3f64632c..b9b251d8 100644 --- a/pkg/gosqlx/errors.go +++ b/pkg/gosqlx/errors.go @@ -46,4 +46,8 @@ var ( // ErrUnsupportedDialect indicates the dialect supplied via WithDialect is // not recognized by the underlying keywords package. ErrUnsupportedDialect = errors.New("gosqlx: unsupported dialect") + // ErrTooLarge is returned when input exceeds the configured maximum byte + // size (see WithMaxBytes). Callers can test errors.Is(err, ErrTooLarge) to + // distinguish cap-enforcement failures from read/parse errors. + ErrTooLarge = errors.New("gosqlx: input too large") ) diff --git a/pkg/gosqlx/gosqlx.go b/pkg/gosqlx/gosqlx.go index 24f5109f..ea944dcd 100644 --- a/pkg/gosqlx/gosqlx.go +++ b/pkg/gosqlx/gosqlx.go @@ -107,7 +107,7 @@ func Parse(sql string) (*ast.AST, error) { // Step 2: Tokenize SQL tokens, err := tkz.Tokenize([]byte(sql)) if err != nil { - return nil, fmt.Errorf("tokenization failed: %w", err) + return nil, fmt.Errorf("%w: %w", ErrTokenize, err) } // Step 3: Parse to AST directly from model tokens @@ -116,7 +116,7 @@ func Parse(sql string) (*ast.AST, error) { astNode, err := p.ParseFromModelTokens(tokens) if err != nil { - return nil, fmt.Errorf("parsing failed: %w", err) + return nil, fmt.Errorf("%w: %w", ErrSyntax, err) } return astNode, nil @@ -184,7 +184,7 @@ func Parse(sql string) (*ast.AST, error) { func ParseWithContext(ctx context.Context, sql string) (*ast.AST, error) { // Check context before starting if err := ctx.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("%w: %w", ErrTimeout, err) } // Step 1: Get tokenizer from pool @@ -194,7 +194,10 @@ func ParseWithContext(ctx context.Context, sql string) (*ast.AST, error) { // Step 2: Tokenize SQL with context support tokens, err := tkz.TokenizeContext(ctx, []byte(sql)) if err != nil { - return nil, fmt.Errorf("tokenization failed: %w", err) + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, fmt.Errorf("%w: %w", ErrTimeout, ctxErr) + } + return nil, fmt.Errorf("%w: %w", ErrTokenize, err) } // Step 3: Parse to AST with context support @@ -203,7 +206,10 @@ func ParseWithContext(ctx context.Context, sql string) (*ast.AST, error) { astNode, err := p.ParseContextFromModelTokens(ctx, tokens) if err != nil { - return nil, fmt.Errorf("parsing failed: %w", err) + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, fmt.Errorf("%w: %w", ErrTimeout, ctxErr) + } + return nil, fmt.Errorf("%w: %w", ErrSyntax, err) } return astNode, nil @@ -242,13 +248,13 @@ func ParseWithTimeout(sql string, timeout time.Duration) (*ast.AST, error) { func Validate(sql string) error { // Reject empty/whitespace-only input if len(strings.TrimSpace(sql)) == 0 { - return fmt.Errorf("invalid SQL: empty input") + return fmt.Errorf("%w: empty input", ErrSyntax) } // Use the dedicated validation fast-path that avoids building a full AST err := parser.ValidateBytes([]byte(sql)) if err != nil { - return fmt.Errorf("invalid SQL: %w", err) + return fmt.Errorf("%w: %w", ErrSyntax, err) } return nil @@ -270,7 +276,7 @@ func ParseBytes(sql []byte) (*ast.AST, error) { tokens, err := tkz.Tokenize(sql) if err != nil { - return nil, fmt.Errorf("tokenization failed: %w", err) + return nil, fmt.Errorf("%w: %w", ErrTokenize, err) } p := parser.GetParser() @@ -278,7 +284,7 @@ func ParseBytes(sql []byte) (*ast.AST, error) { astNode, err := p.ParseFromModelTokens(tokens) if err != nil { - return nil, fmt.Errorf("parsing failed: %w", err) + return nil, fmt.Errorf("%w: %w", ErrSyntax, err) } return astNode, nil @@ -296,7 +302,7 @@ func ParseBytes(sql []byte) (*ast.AST, error) { func MustParse(sql string) *ast.AST { astNode, err := Parse(sql) if err != nil { - panic(fmt.Sprintf("gosqlx.MustParse: %v", err)) + panic(fmt.Errorf("gosqlx.MustParse: %w", err)) } return astNode } @@ -393,13 +399,13 @@ func ParseMultiple(queries []string) ([]*ast.AST, error) { // Tokenize tokens, err := tkz.Tokenize([]byte(sql)) if err != nil { - return nil, fmt.Errorf("query %d: tokenization failed: %w", i, err) + return nil, fmt.Errorf("query %d: %w: %w", i, ErrTokenize, err) } // Parse directly from model tokens astNode, err := p.ParseFromModelTokens(tokens) if err != nil { - return nil, fmt.Errorf("query %d: parsing failed: %w", i, err) + return nil, fmt.Errorf("query %d: %w: %w", i, ErrSyntax, err) } results = append(results, astNode) @@ -436,13 +442,13 @@ func ValidateMultiple(queries []string) error { // Tokenize tokens, err := tkz.Tokenize([]byte(sql)) if err != nil { - return fmt.Errorf("query %d: %w", i, err) + return fmt.Errorf("query %d: %w: %w", i, ErrTokenize, err) } // Parse directly from model tokens _, err = p.ParseFromModelTokens(tokens) if err != nil { - return fmt.Errorf("query %d: %w", i, err) + return fmt.Errorf("query %d: %w: %w", i, ErrSyntax, err) } } @@ -599,13 +605,21 @@ func ParseWithRecovery(sql string) ([]ast.Statement, []error) { tokens, err := tkz.Tokenize([]byte(sql)) if err != nil { - return nil, []error{fmt.Errorf("tokenization failed: %w", err)} + return nil, []error{fmt.Errorf("%w: %w", ErrTokenize, err)} } p := parser.GetParser() defer parser.PutParser(p) - return p.ParseWithRecoveryFromModelTokens(tokens) + stmts, recoveryErrs := p.ParseWithRecoveryFromModelTokens(tokens) + if len(recoveryErrs) > 0 { + wrapped := make([]error, len(recoveryErrs)) + for i, e := range recoveryErrs { + wrapped[i] = fmt.Errorf("%w: %w", ErrSyntax, e) + } + return stmts, wrapped + } + return stmts, nil } // ParseWithDialect tokenizes and parses SQL using a specific SQL dialect for @@ -638,8 +652,17 @@ func ParseWithRecovery(sql string) ([]ast.Statement, []error) { // ast, err := gosqlx.ParseWithDialect(sql, keywords.DialectPostgreSQL) // // Returns an error if the dialect is unknown or if SQL is syntactically invalid. +// Errors are wrapped with gosqlx sentinel errors (ErrUnsupportedDialect, ErrSyntax, +// ErrTokenize) so callers can match via errors.Is. func ParseWithDialect(sql string, dialect keywords.SQLDialect) (*ast.AST, error) { - return parser.ParseWithDialect(sql, dialect) + if !keywords.IsValidDialect(string(dialect)) { + return nil, fmt.Errorf("%w: %q", ErrUnsupportedDialect, dialect) + } + astNode, err := parser.ParseWithDialect(sql, dialect) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrSyntax, err) + } + return astNode, nil } // Normalize parses sql, replaces all literal values (strings, numbers, booleans, diff --git a/pkg/gosqlx/legacy_sentinels_test.go b/pkg/gosqlx/legacy_sentinels_test.go new file mode 100644 index 00000000..0be5edd9 --- /dev/null +++ b/pkg/gosqlx/legacy_sentinels_test.go @@ -0,0 +1,294 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gosqlx + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" +) + +// invalidSyntaxSQL is deliberately malformed in a way the parser rejects +// after tokenization succeeds. Used to exercise the ErrSyntax wrapping path. +const invalidSyntaxSQL = "SELECT FROM WHERE" + +// invalidTokenizeSQL is an unterminated string literal — the tokenizer +// reports this before parsing begins, so we get an ErrTokenize wrap. +const invalidTokenizeSQL = "SELECT 'unterminated" + +// TestLegacy_ErrSyntax_Wrapping asserts every legacy Parse-family function +// wraps parser/syntax failures with ErrSyntax so callers can match on it. +func TestLegacy_ErrSyntax_Wrapping(t *testing.T) { + cases := []struct { + name string + run func() error + }{ + {"Parse", func() error { + _, err := Parse(invalidSyntaxSQL) + return err + }}, + {"ParseWithContext", func() error { + _, err := ParseWithContext(context.Background(), invalidSyntaxSQL) + return err + }}, + {"ParseWithTimeout", func() error { + _, err := ParseWithTimeout(invalidSyntaxSQL, time.Second) + return err + }}, + {"ParseBytes", func() error { + _, err := ParseBytes([]byte(invalidSyntaxSQL)) + return err + }}, + {"ParseMultiple", func() error { + _, err := ParseMultiple([]string{invalidSyntaxSQL}) + return err + }}, + {"Validate", func() error { + return Validate(invalidSyntaxSQL) + }}, + {"ValidateMultiple", func() error { + return ValidateMultiple([]string{invalidSyntaxSQL}) + }}, + {"Format", func() error { + _, err := Format(invalidSyntaxSQL, DefaultFormatOptions()) + return err + }}, + {"ParseWithDialect", func() error { + _, err := ParseWithDialect(invalidSyntaxSQL, keywords.DialectPostgreSQL) + return err + }}, + {"ParseWithRecovery", func() error { + _, errs := ParseWithRecovery(invalidSyntaxSQL) + if len(errs) == 0 { + return nil + } + return errs[0] + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.run() + if err == nil { + t.Fatalf("%s(%q) returned nil error; expected ErrSyntax-wrapped", tc.name, invalidSyntaxSQL) + } + if !errors.Is(err, ErrSyntax) { + t.Errorf("errors.Is(err, ErrSyntax) = false for %s; err = %v", tc.name, err) + } + }) + } +} + +// TestLegacy_ErrTokenize_Wrapping asserts tokenization failures are wrapped +// with ErrTokenize. Validate is intentionally excluded here because its +// fast-path surfaces the same failure under ErrSyntax (parser-level fallthrough). +func TestLegacy_ErrTokenize_Wrapping(t *testing.T) { + cases := []struct { + name string + run func() error + }{ + {"Parse", func() error { + _, err := Parse(invalidTokenizeSQL) + return err + }}, + {"ParseWithContext", func() error { + _, err := ParseWithContext(context.Background(), invalidTokenizeSQL) + return err + }}, + {"ParseWithTimeout", func() error { + _, err := ParseWithTimeout(invalidTokenizeSQL, time.Second) + return err + }}, + {"ParseBytes", func() error { + _, err := ParseBytes([]byte(invalidTokenizeSQL)) + return err + }}, + {"ParseMultiple", func() error { + _, err := ParseMultiple([]string{invalidTokenizeSQL}) + return err + }}, + {"ValidateMultiple", func() error { + return ValidateMultiple([]string{invalidTokenizeSQL}) + }}, + {"ParseWithRecovery", func() error { + _, errs := ParseWithRecovery(invalidTokenizeSQL) + if len(errs) == 0 { + return nil + } + return errs[0] + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.run() + if err == nil { + t.Fatalf("%s(%q) returned nil error; expected ErrTokenize-wrapped", tc.name, invalidTokenizeSQL) + } + // Some functions (e.g. ParseWithRecovery, ParseMultiple with a + // tokenizer that surfaces the failure during parsing) may classify + // under ErrSyntax instead. Accept either to avoid coupling tests + // to the exact failure layer — the important invariant is that + // ONE of the two sentinels always matches. + if !errors.Is(err, ErrTokenize) && !errors.Is(err, ErrSyntax) { + t.Errorf("neither ErrTokenize nor ErrSyntax matched for %s; err = %v", tc.name, err) + } + }) + } +} + +// TestLegacy_ErrTimeout_Wrapping asserts context-deadline failures are +// wrapped with ErrTimeout. +func TestLegacy_ErrTimeout_Wrapping(t *testing.T) { + // Build an already-expired context so the function fails fast on the + // context check without having to race against parsing time. + expired, cancel := context.WithDeadline(context.Background(), time.Unix(0, 0)) + defer cancel() + + cases := []struct { + name string + run func() error + }{ + {"ParseWithContext", func() error { + _, err := ParseWithContext(expired, "SELECT 1") + return err + }}, + {"ParseWithTimeout", func() error { + // Nanosecond timeout effectively expires immediately. + _, err := ParseWithTimeout("SELECT 1", time.Nanosecond) + return err + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.run() + if err == nil { + t.Fatalf("%s returned nil error; expected ErrTimeout-wrapped", tc.name) + } + if !errors.Is(err, ErrTimeout) { + t.Errorf("errors.Is(err, ErrTimeout) = false for %s; err = %v", tc.name, err) + } + // The underlying context.DeadlineExceeded should also remain + // reachable through the wrap chain. + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("errors.Is(err, context.DeadlineExceeded) = false for %s; err = %v", tc.name, err) + } + }) + } +} + +// TestLegacy_ErrUnsupportedDialect asserts ParseWithDialect returns +// ErrUnsupportedDialect for an unknown dialect name. +func TestLegacy_ErrUnsupportedDialect(t *testing.T) { + _, err := ParseWithDialect("SELECT 1", keywords.SQLDialect("fakedialect")) + if err == nil { + t.Fatalf("ParseWithDialect(fakedialect) returned nil error") + } + if !errors.Is(err, ErrUnsupportedDialect) { + t.Errorf("errors.Is(err, ErrUnsupportedDialect) = false; err = %v", err) + } + // The error message should mention the offending dialect for debuggability. + if !strings.Contains(err.Error(), "fakedialect") { + t.Errorf("error message does not mention dialect: %q", err.Error()) + } +} + +// TestLegacy_MustParse_PanicsWithWrappedError asserts MustParse panics with +// an error value (not a string) so recover() sites can use errors.Is/As to +// classify the failure. +func TestLegacy_MustParse_PanicsWithWrappedError(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("MustParse did not panic on invalid SQL") + } + err, ok := r.(error) + if !ok { + t.Fatalf("MustParse panicked with non-error %T: %v", r, r) + } + if !errors.Is(err, ErrSyntax) { + t.Errorf("panic value does not wrap ErrSyntax: %v", err) + } + }() + _ = MustParse(invalidSyntaxSQL) +} + +// TestLegacy_ParseTree_ParityWithParse asserts that the same invalid SQL +// produces the same sentinel match through both the legacy Parse entry point +// and the new ParseTree entry point. This protects against divergence where +// one surface wraps and the other does not. +func TestLegacy_ParseTree_ParityWithParse(t *testing.T) { + cases := []struct { + name string + sql string + sentinel error + }{ + {"syntax", invalidSyntaxSQL, ErrSyntax}, + {"tokenize", invalidTokenizeSQL, ErrTokenize}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, legacyErr := Parse(tc.sql) + _, treeErr := ParseTree(context.Background(), tc.sql) + if legacyErr == nil || treeErr == nil { + t.Fatalf("expected both Parse and ParseTree to fail; legacy=%v tree=%v", legacyErr, treeErr) + } + // Parity: if the sentinel matches on one surface, it must match on + // the other. Accept ErrSyntax as a broader fallback for the + // tokenize case because one of the two surfaces may classify at + // the parser level (see TestLegacy_ErrTokenize_Wrapping). + legacyMatch := errors.Is(legacyErr, tc.sentinel) || errors.Is(legacyErr, ErrSyntax) + treeMatch := errors.Is(treeErr, tc.sentinel) || errors.Is(treeErr, ErrSyntax) + if legacyMatch != treeMatch { + t.Errorf("sentinel parity broken for %q: legacy=%v tree=%v", tc.sql, legacyErr, treeErr) + } + }) + } +} + +// TestLegacy_ValidSQL_NoError is a smoke test: wrapping must not break the +// happy path. Valid SQL through every legacy function must still return nil. +func TestLegacy_ValidSQL_NoError(t *testing.T) { + const okSQL = "SELECT 1" + if _, err := Parse(okSQL); err != nil { + t.Errorf("Parse(valid) = %v", err) + } + if _, err := ParseWithContext(context.Background(), okSQL); err != nil { + t.Errorf("ParseWithContext(valid) = %v", err) + } + if _, err := ParseWithTimeout(okSQL, time.Second); err != nil { + t.Errorf("ParseWithTimeout(valid) = %v", err) + } + if _, err := ParseBytes([]byte(okSQL)); err != nil { + t.Errorf("ParseBytes(valid) = %v", err) + } + if _, err := ParseMultiple([]string{okSQL, okSQL}); err != nil { + t.Errorf("ParseMultiple(valid) = %v", err) + } + if err := Validate(okSQL); err != nil { + t.Errorf("Validate(valid) = %v", err) + } + if err := ValidateMultiple([]string{okSQL}); err != nil { + t.Errorf("ValidateMultiple(valid) = %v", err) + } + if _, err := Format(okSQL, DefaultFormatOptions()); err != nil { + t.Errorf("Format(valid) = %v", err) + } + if _, err := ParseWithDialect(okSQL, keywords.DialectPostgreSQL); err != nil { + t.Errorf("ParseWithDialect(valid) = %v", err) + } +} diff --git a/pkg/gosqlx/options.go b/pkg/gosqlx/options.go index f4a43a7a..f582e220 100644 --- a/pkg/gosqlx/options.go +++ b/pkg/gosqlx/options.go @@ -55,6 +55,12 @@ type parseOptions struct { // recover enables error-recovery parsing — returns partial results and all // collected diagnostics rather than stopping at the first error. recover bool + + // maxBytes caps the number of bytes ParseReader / ParseReaderMultiple will + // read from an io.Reader. Zero means unbounded (backward-compatible default). + // Inputs that exceed the cap are rejected with ErrTooLarge before any + // parsing work begins. + maxBytes int64 } // defaultParseOptions returns the baseline configuration used when no options @@ -132,6 +138,27 @@ func WithTimeout(d time.Duration) Option { } } +// WithMaxBytes caps the number of bytes ParseReader and ParseReaderMultiple +// will read from an io.Reader. A zero or negative value disables the cap and +// preserves the original unbounded behaviour (the default). +// +// When the cap is exceeded the reader entry points abort before any parsing +// work is attempted and return an error that satisfies +// errors.Is(err, ErrTooLarge). The reader is drained only up to maxBytes+1 +// bytes, so hostile 100 MB inputs will not cause a ~2x allocation spike. +// +// Example — reject inputs larger than 1 MiB: +// +// tree, err := gosqlx.ParseReader(ctx, r, gosqlx.WithMaxBytes(1<<20)) +// if errors.Is(err, gosqlx.ErrTooLarge) { +// // surface a 413-style error to the caller +// } +func WithMaxBytes(n int64) Option { + return func(o *parseOptions) { + o.maxBytes = n + } +} + // WithRecovery enables error-recovery parsing. When set, the parser // synchronizes after errors and continues, returning partial statements and // the full list of diagnostics rather than failing at the first error. diff --git a/pkg/gosqlx/testing/demo_usage_test.go b/pkg/gosqlx/testing/demo_usage_test.go index 33fe6167..d0284b72 100644 --- a/pkg/gosqlx/testing/demo_usage_test.go +++ b/pkg/gosqlx/testing/demo_usage_test.go @@ -109,14 +109,16 @@ func TestDemo_StatementTypes(t *testing.T) { // Demo: Testing error conditions func TestDemo_ErrorTesting(t *testing.T) { - // Test that specific errors are produced + // Test that specific errors are produced. + // Post-v1.15 errors are wrapped with the gosqlx.ErrSyntax sentinel, which + // renders as "syntax error" in the message chain. gosqlxtesting.AssertErrorContains(t, "SELECT FROM WHERE", - "parsing") + "syntax error") gosqlxtesting.AssertErrorContains(t, "INVALID SYNTAX HERE", - "parsing") + "syntax error") } // Demo: Using RequireParse for custom assertions diff --git a/pkg/gosqlx/testing/example_test.go b/pkg/gosqlx/testing/example_test.go index 7f3d081f..c785bb62 100644 --- a/pkg/gosqlx/testing/example_test.go +++ b/pkg/gosqlx/testing/example_test.go @@ -157,7 +157,7 @@ func ExampleAssertErrorContains() { // Test that specific error messages are produced gosqlxtesting.AssertErrorContains(t, "SELECT FROM WHERE", - "parsing") + "syntax error") // Test for tokenization errors gosqlxtesting.AssertErrorContains(t, @@ -206,7 +206,7 @@ func Example_comprehensiveTest() { // Test invalid query variations gosqlxtesting.AssertInvalidSQL(t, "SELECT FROM users WHERE") - gosqlxtesting.AssertErrorContains(t, "SELECT * FROM", "parsing") + gosqlxtesting.AssertErrorContains(t, "SELECT * FROM", "syntax error") } // Example_windowFunctions demonstrates testing window function queries. diff --git a/pkg/gosqlx/testing/testing_test.go b/pkg/gosqlx/testing/testing_test.go index 06edbc7c..ae5f4635 100644 --- a/pkg/gosqlx/testing/testing_test.go +++ b/pkg/gosqlx/testing/testing_test.go @@ -399,7 +399,7 @@ func TestAssertErrorContains_Matching(t *testing.T) { mockT := &mockTestingT{} sql := "SELECT FROM WHERE" - result := AssertErrorContains(mockT, sql, "parsing") + result := AssertErrorContains(mockT, sql, "syntax error") if !result { t.Error("AssertErrorContains should return true for matching error") diff --git a/pkg/sql/ast/pool.go b/pkg/sql/ast/pool.go index f341fc24..c1666c1b 100644 --- a/pkg/sql/ast/pool.go +++ b/pkg/sql/ast/pool.go @@ -403,6 +403,24 @@ var ( alterSequencePool = sync.Pool{ New: func() interface{} { return &AlterSequenceStatement{} }, } + + // putExpressionWorkQueuePool recycles the iterative work-queue slice used + // by putExpressionImpl. Pre-fix, putExpressionImpl allocated a fresh + // []Expression with cap 32 on every call — that fires 10-100× per parse + // in hot paths (complex SELECTs, deep expression trees), contributing + // measurable alloc-rate and GC pressure to an otherwise zero-copy hot + // path. Pooling the queue reclaims the allocation. + // + // Storing a *[]Expression (not []Expression) avoids the slice-header + // boxing allocation that happens when you store a slice in an interface. + // Callers must write the mutated slice header back to the pointer + // before Put so subsequent Get sees the grown capacity. + putExpressionWorkQueuePool = sync.Pool{ + New: func() interface{} { + s := make([]Expression, 0, 32) + return &s + }, + } ) // NewAST retrieves a new AST container from the pool. @@ -579,1343 +597,23 @@ func releaseStatement(stmt Statement) { PutReplaceStatement(s) case *AlterStatement: PutAlterStatement(s) + // Sequence statements are pooled via NewXxx/ReleaseXxx helpers. + // Without a dispatch here, a CTE or subquery that contained one + // would silently leak it (the stmt type audit found these three + // pooled but un-dispatched; see architect review sprint 2). + case *CreateSequenceStatement: + ReleaseCreateSequenceStatement(s) + case *DropSequenceStatement: + ReleaseDropSequenceStatement(s) + case *AlterSequenceStatement: + ReleaseAlterSequenceStatement(s) + // NOTE: *PragmaStatement is NOT pooled (no sync.Pool declared); + // intentionally no-op. Same for dml.go's *Select/*Insert/*Update/ + // *Delete (legacy unpooled duplicates) — they'd be GC'd naturally. + // If those types ever gain pools, add cases here. } } -// GetInsertStatement gets an InsertStatement from the pool -func GetInsertStatement() *InsertStatement { - return insertStmtPool.Get().(*InsertStatement) -} - -// PutInsertStatement returns an InsertStatement to the pool. -// -// Releases every pooled Expression/Statement reachable from the InsertStatement: -// - With (CTEs + nested statements + scalar CTE expressions) -// - Columns -// - Output (SQL Server OUTPUT clause) -// - Values (all rows, all cells) -// - Query (INSERT ... SELECT — the nested QueryExpression) -// - Returning -// - OnConflict.Target, OnConflict.Action.DoUpdate (Column, Value), OnConflict.Action.Where -// - OnDuplicateKey.Updates (Column, Value) -func PutInsertStatement(stmt *InsertStatement) { - if stmt == nil { - return - } - - // ── WITH clause / CTEs ──────────────────────────────────────────── - if stmt.With != nil { - for _, cte := range stmt.With.CTEs { - if cte == nil { - continue - } - releaseStatement(cte.Statement) - cte.Statement = nil - PutExpression(cte.ScalarExpr) - cte.ScalarExpr = nil - } - stmt.With.CTEs = nil - stmt.With = nil - } - - // ── Column list ─────────────────────────────────────────────────── - for i := range stmt.Columns { - PutExpression(stmt.Columns[i]) - stmt.Columns[i] = nil - } - stmt.Columns = stmt.Columns[:0] - - // ── OUTPUT clause (SQL Server) ──────────────────────────────────── - for i := range stmt.Output { - PutExpression(stmt.Output[i]) - stmt.Output[i] = nil - } - stmt.Output = stmt.Output[:0] - - // ── VALUES (multi-row) ──────────────────────────────────────────── - for i := range stmt.Values { - for j := range stmt.Values[i] { - PutExpression(stmt.Values[i][j]) - stmt.Values[i][j] = nil - } - stmt.Values[i] = stmt.Values[i][:0] - } - stmt.Values = stmt.Values[:0] - - // ── Query (INSERT ... SELECT) ───────────────────────────────────── - if stmt.Query != nil { - // Query is a QueryExpression (Statement); dispatch via releaseStatement. - releaseStatement(stmt.Query) - stmt.Query = nil - } - - // ── RETURNING ────────────────────────────────────────────────────── - for i := range stmt.Returning { - PutExpression(stmt.Returning[i]) - stmt.Returning[i] = nil - } - stmt.Returning = stmt.Returning[:0] - - // ── ON CONFLICT (PostgreSQL) ────────────────────────────────────── - if stmt.OnConflict != nil { - for i := range stmt.OnConflict.Target { - PutExpression(stmt.OnConflict.Target[i]) - stmt.OnConflict.Target[i] = nil - } - stmt.OnConflict.Target = nil - for i := range stmt.OnConflict.Action.DoUpdate { - PutExpression(stmt.OnConflict.Action.DoUpdate[i].Column) - PutExpression(stmt.OnConflict.Action.DoUpdate[i].Value) - stmt.OnConflict.Action.DoUpdate[i].Column = nil - stmt.OnConflict.Action.DoUpdate[i].Value = nil - } - stmt.OnConflict.Action.DoUpdate = nil - PutExpression(stmt.OnConflict.Action.Where) - stmt.OnConflict.Action.Where = nil - stmt.OnConflict = nil - } - - // ── ON DUPLICATE KEY UPDATE (MySQL) ─────────────────────────────── - if stmt.OnDuplicateKey != nil { - for i := range stmt.OnDuplicateKey.Updates { - PutExpression(stmt.OnDuplicateKey.Updates[i].Column) - PutExpression(stmt.OnDuplicateKey.Updates[i].Value) - stmt.OnDuplicateKey.Updates[i].Column = nil - stmt.OnDuplicateKey.Updates[i].Value = nil - } - stmt.OnDuplicateKey.Updates = nil - stmt.OnDuplicateKey = nil - } - - stmt.TableName = "" - - // Return to pool - insertStmtPool.Put(stmt) -} - -// GetUpdateStatement gets an UpdateStatement from the pool -func GetUpdateStatement() *UpdateStatement { - return updateStmtPool.Get().(*UpdateStatement) -} - -// PutUpdateStatement returns an UpdateStatement to the pool. -// -// Releases every pooled Expression/Statement reachable from the UpdateStatement: -// - With (CTEs + nested statements + scalar CTE expressions) -// - Assignments (Column, Value) -// - From (TableReference.Subquery, TableFunc, Pivot, MatchRecognize, TimeTravel, ForSystemTime) -// - Where -// - Returning -func PutUpdateStatement(stmt *UpdateStatement) { - if stmt == nil { - return - } - - // ── WITH clause / CTEs ──────────────────────────────────────────── - if stmt.With != nil { - for _, cte := range stmt.With.CTEs { - if cte == nil { - continue - } - releaseStatement(cte.Statement) - cte.Statement = nil - PutExpression(cte.ScalarExpr) - cte.ScalarExpr = nil - } - stmt.With.CTEs = nil - stmt.With = nil - } - - // ── SET assignments ─────────────────────────────────────────────── - for i := range stmt.Assignments { - PutExpression(stmt.Assignments[i].Column) - PutExpression(stmt.Assignments[i].Value) - stmt.Assignments[i].Column = nil - stmt.Assignments[i].Value = nil - } - stmt.Assignments = stmt.Assignments[:0] - - // ── FROM table references ───────────────────────────────────────── - for i := range stmt.From { - releaseTableReference(&stmt.From[i]) - } - stmt.From = stmt.From[:0] - - // ── WHERE ────────────────────────────────────────────────────────── - PutExpression(stmt.Where) - stmt.Where = nil - - // ── RETURNING ────────────────────────────────────────────────────── - for i := range stmt.Returning { - PutExpression(stmt.Returning[i]) - stmt.Returning[i] = nil - } - stmt.Returning = stmt.Returning[:0] - - // ── Scalars ──────────────────────────────────────────────────────── - stmt.TableName = "" - stmt.Alias = "" - - // Return to pool - updateStmtPool.Put(stmt) -} - -// GetDeleteStatement gets a DeleteStatement from the pool -func GetDeleteStatement() *DeleteStatement { - return deleteStmtPool.Get().(*DeleteStatement) -} - -// PutDeleteStatement returns a DeleteStatement to the pool. -// -// Releases every pooled Expression/Statement reachable from the DeleteStatement: -// - With (CTEs + nested statements + scalar CTE expressions) -// - Using (TableReference subqueries, TableFunc, Pivot, MatchRecognize, TimeTravel, ForSystemTime) -// - Where -// - Returning -func PutDeleteStatement(stmt *DeleteStatement) { - if stmt == nil { - return - } - - // ── WITH clause / CTEs ──────────────────────────────────────────── - if stmt.With != nil { - for _, cte := range stmt.With.CTEs { - if cte == nil { - continue - } - releaseStatement(cte.Statement) - cte.Statement = nil - PutExpression(cte.ScalarExpr) - cte.ScalarExpr = nil - } - stmt.With.CTEs = nil - stmt.With = nil - } - - // ── USING table references (PostgreSQL) ─────────────────────────── - for i := range stmt.Using { - releaseTableReference(&stmt.Using[i]) - } - stmt.Using = stmt.Using[:0] - - // ── WHERE ────────────────────────────────────────────────────────── - PutExpression(stmt.Where) - stmt.Where = nil - - // ── RETURNING ────────────────────────────────────────────────────── - for i := range stmt.Returning { - PutExpression(stmt.Returning[i]) - stmt.Returning[i] = nil - } - stmt.Returning = stmt.Returning[:0] - - // ── Scalars ──────────────────────────────────────────────────────── - stmt.TableName = "" - stmt.Alias = "" - - // Return to pool - deleteStmtPool.Put(stmt) -} - -// GetUpdateExpression gets an UpdateExpression from the pool -func GetUpdateExpression() *UpdateExpression { - return updateExprPool.Get().(*UpdateExpression) -} - -// PutUpdateExpression returns an UpdateExpression to the pool -func PutUpdateExpression(expr *UpdateExpression) { - if expr == nil { - return - } - - // Clean up expressions - PutExpression(expr.Column) - PutExpression(expr.Value) - - // Reset fields - expr.Column = nil - expr.Value = nil - - // Return to pool - updateExprPool.Put(expr) -} - -// GetSelectStatement gets a SelectStatement from the pool -func GetSelectStatement() *SelectStatement { - stmt := selectStmtPool.Get().(*SelectStatement) - stmt.Columns = stmt.Columns[:0] - stmt.OrderBy = stmt.OrderBy[:0] - return stmt -} - -// PutSelectStatement returns a SelectStatement to the pool. -// -// Uses iterative cleanup via PutExpression to handle deeply nested expressions. -// This function MUST release every pooled Expression/Node reachable from the -// SelectStatement; missing fields cause silent pool leaks that defeat the -// 60-80% memory reduction target and degrade hit-rate below 95%. -// -// Coverage (v1.14.0+ — comprehensive audit): -// - With (CTEs + their nested statements + scalar CTE expressions) -// - Top.Count -// - DistinctOnColumns -// - Columns -// - From (TableReference.Subquery, TableFunc, Pivot.AggregateFunction, MatchRecognize) -// - Joins (Left/Right TableRefs, Condition) -// - ArrayJoin (element Exprs) -// - PrewhereClause -// - Sample (no Expressions, but zeroed for hygiene) -// - Where -// - GroupBy -// - Having -// - Qualify -// - StartWith / ConnectBy.Condition -// - Windows (PartitionBy + OrderBy expressions + FrameClause bounds) -// - OrderBy -// - Fetch / For (no Expression children, just zero) -// - Limit / Offset (*int — no release needed) -func PutSelectStatement(stmt *SelectStatement) { - if stmt == nil { - return - } - - // ── WITH clause / CTEs ──────────────────────────────────────────── - if stmt.With != nil { - for _, cte := range stmt.With.CTEs { - if cte == nil { - continue - } - releaseStatement(cte.Statement) - cte.Statement = nil - PutExpression(cte.ScalarExpr) - cte.ScalarExpr = nil - } - stmt.With.CTEs = nil - stmt.With = nil - } - - // ── TOP clause ───────────────────────────────────────────────────── - if stmt.Top != nil { - PutExpression(stmt.Top.Count) - stmt.Top.Count = nil - stmt.Top = nil - } - - // ── DISTINCT ON columns ──────────────────────────────────────────── - for i := range stmt.DistinctOnColumns { - PutExpression(stmt.DistinctOnColumns[i]) - stmt.DistinctOnColumns[i] = nil - } - stmt.DistinctOnColumns = stmt.DistinctOnColumns[:0] - - // ── SELECT list columns ──────────────────────────────────────────── - for i := range stmt.Columns { - PutExpression(stmt.Columns[i]) - stmt.Columns[i] = nil - } - stmt.Columns = stmt.Columns[:0] - - // ── FROM table references (Subquery, TableFunc, Pivot, MatchRecognize) ─ - for i := range stmt.From { - releaseTableReference(&stmt.From[i]) - } - stmt.From = stmt.From[:0] - - // ── JOINs ────────────────────────────────────────────────────────── - for i := range stmt.Joins { - releaseTableReference(&stmt.Joins[i].Left) - releaseTableReference(&stmt.Joins[i].Right) - PutExpression(stmt.Joins[i].Condition) - stmt.Joins[i].Condition = nil - stmt.Joins[i].Type = "" - } - stmt.Joins = stmt.Joins[:0] - - // ── ARRAY JOIN (ClickHouse) ──────────────────────────────────────── - if stmt.ArrayJoin != nil { - for i := range stmt.ArrayJoin.Elements { - PutExpression(stmt.ArrayJoin.Elements[i].Expr) - stmt.ArrayJoin.Elements[i].Expr = nil - stmt.ArrayJoin.Elements[i].Alias = "" - } - stmt.ArrayJoin.Elements = nil - stmt.ArrayJoin = nil - } - - // ── PREWHERE / WHERE / HAVING / QUALIFY / START WITH ─────────────── - PutExpression(stmt.PrewhereClause) - stmt.PrewhereClause = nil - PutExpression(stmt.Where) - stmt.Where = nil - PutExpression(stmt.Having) - stmt.Having = nil - PutExpression(stmt.Qualify) - stmt.Qualify = nil - PutExpression(stmt.StartWith) - stmt.StartWith = nil - - // ── CONNECT BY ───────────────────────────────────────────────────── - if stmt.ConnectBy != nil { - PutExpression(stmt.ConnectBy.Condition) - stmt.ConnectBy.Condition = nil - stmt.ConnectBy = nil - } - - // ── SAMPLE (no expression children, just drop) ───────────────────── - stmt.Sample = nil - - // ── GROUP BY ─────────────────────────────────────────────────────── - for i := range stmt.GroupBy { - PutExpression(stmt.GroupBy[i]) - stmt.GroupBy[i] = nil - } - stmt.GroupBy = stmt.GroupBy[:0] - - // ── WINDOWS (PartitionBy, OrderBy, FrameClause bounds) ───────────── - for i := range stmt.Windows { - w := &stmt.Windows[i] - for j := range w.PartitionBy { - PutExpression(w.PartitionBy[j]) - w.PartitionBy[j] = nil - } - w.PartitionBy = w.PartitionBy[:0] - for j := range w.OrderBy { - PutExpression(w.OrderBy[j].Expression) - w.OrderBy[j].Expression = nil - } - w.OrderBy = w.OrderBy[:0] - if w.FrameClause != nil { - PutExpression(w.FrameClause.Start.Value) - w.FrameClause.Start.Value = nil - if w.FrameClause.End != nil { - PutExpression(w.FrameClause.End.Value) - w.FrameClause.End.Value = nil - w.FrameClause.End = nil - } - w.FrameClause = nil - } - w.Name = "" - } - stmt.Windows = stmt.Windows[:0] - - // ── ORDER BY ─────────────────────────────────────────────────────── - for i := range stmt.OrderBy { - PutExpression(stmt.OrderBy[i].Expression) - stmt.OrderBy[i].Expression = nil - } - stmt.OrderBy = stmt.OrderBy[:0] - - // ── LIMIT / OFFSET (*int - no Expression) ────────────────────────── - stmt.Limit = nil - stmt.Offset = nil - - // ── FETCH / FOR (no Expression children) ─────────────────────────── - stmt.Fetch = nil - stmt.For = nil - - // ── Scalars ──────────────────────────────────────────────────────── - stmt.TableName = "" - stmt.Distinct = false - - // Return to pool - selectStmtPool.Put(stmt) -} - -// releaseTableReference releases all pooled Expression/Statement references -// reachable from a TableReference. Zero-copies the TableReference back to a -// clean state suitable for pool reuse. -func releaseTableReference(tr *TableReference) { - if tr == nil { - return - } - // Subquery is itself a *SelectStatement — recurse through the statement - // dispatcher to release every nested pool reference. - if tr.Subquery != nil { - PutSelectStatement(tr.Subquery) - tr.Subquery = nil - } - // TableFunc is a *FunctionCall — release as expression. - if tr.TableFunc != nil { - PutExpression(tr.TableFunc) - tr.TableFunc = nil - } - // Pivot.AggregateFunction is an Expression. - if tr.Pivot != nil { - PutExpression(tr.Pivot.AggregateFunction) - tr.Pivot.AggregateFunction = nil - tr.Pivot = nil - } - // Unpivot holds only strings — drop the struct. - tr.Unpivot = nil - // MatchRecognize carries PartitionBy / OrderBy / Measures / Definitions. - if tr.MatchRecognize != nil { - mr := tr.MatchRecognize - for i := range mr.PartitionBy { - PutExpression(mr.PartitionBy[i]) - mr.PartitionBy[i] = nil - } - mr.PartitionBy = mr.PartitionBy[:0] - for i := range mr.OrderBy { - PutExpression(mr.OrderBy[i].Expression) - mr.OrderBy[i].Expression = nil - } - mr.OrderBy = mr.OrderBy[:0] - for i := range mr.Measures { - PutExpression(mr.Measures[i].Expr) - mr.Measures[i].Expr = nil - mr.Measures[i].Alias = "" - } - mr.Measures = mr.Measures[:0] - for i := range mr.Definitions { - PutExpression(mr.Definitions[i].Condition) - mr.Definitions[i].Condition = nil - mr.Definitions[i].Name = "" - } - mr.Definitions = mr.Definitions[:0] - tr.MatchRecognize = nil - } - // TimeTravel carries Named map of Expressions + Chained clauses. - if tr.TimeTravel != nil { - releaseTimeTravelClause(tr.TimeTravel) - tr.TimeTravel = nil - } - // ForSystemTime carries Point/Start/End expressions. - if tr.ForSystemTime != nil { - PutExpression(tr.ForSystemTime.Point) - PutExpression(tr.ForSystemTime.Start) - PutExpression(tr.ForSystemTime.End) - tr.ForSystemTime.Point = nil - tr.ForSystemTime.Start = nil - tr.ForSystemTime.End = nil - tr.ForSystemTime = nil - } - tr.Name = "" - tr.Alias = "" - tr.Lateral = false - tr.Final = false - tr.TableHints = nil -} - -// GetIdentifier gets an Identifier from the pool -func GetIdentifier() *Identifier { - return identifierPool.Get().(*Identifier) -} - -// PutIdentifier returns an Identifier to the pool -func PutIdentifier(ident *Identifier) { - if ident == nil { - return - } - ident.Name = "" - identifierPool.Put(ident) -} - -// GetBinaryExpression gets a BinaryExpression from the pool -func GetBinaryExpression() *BinaryExpression { - return binaryExprPool.Get().(*BinaryExpression) -} - -// PutBinaryExpression returns a BinaryExpression to the pool -func PutBinaryExpression(expr *BinaryExpression) { - if expr == nil { - return - } - PutExpression(expr.Left) - PutExpression(expr.Right) - expr.Left = nil - expr.Right = nil - expr.Operator = "" - binaryExprPool.Put(expr) -} - -// GetExpressionSlice gets a slice of Expression from the pool -func GetExpressionSlice() *[]Expression { - slice := exprSlicePool.Get().(*[]Expression) - *slice = (*slice)[:0] - return slice -} - -// PutExpressionSlice returns a slice of Expression to the pool -func PutExpressionSlice(slice *[]Expression) { - if slice == nil { - return - } - for i := range *slice { - PutExpression((*slice)[i]) - (*slice)[i] = nil - } - exprSlicePool.Put(slice) -} - -// GetLiteralValue gets a LiteralValue from the pool -func GetLiteralValue() *LiteralValue { - return literalValuePool.Get().(*LiteralValue) -} - -// PutLiteralValue returns a LiteralValue to the pool -func PutLiteralValue(lit *LiteralValue) { - if lit == nil { - return - } - - // Reset fields (Value is interface{}, use nil as zero value) - lit.Value = nil - lit.Type = "" - - // Return to pool - literalValuePool.Put(lit) -} - -// PutExpression returns any Expression to the appropriate pool with iterative cleanup. -// -// PutExpression is the primary function for returning expression nodes to their -// respective pools. It handles all expression types and uses iterative cleanup -// to prevent stack overflow with deeply nested expression trees. -// -// Key Features: -// - Supports all expression types (30+ pooled types) -// - Iterative cleanup algorithm (no recursion limits) -// - Prevents stack overflow for deeply nested expressions -// - Work queue size limits (MaxWorkQueueSize = 1000) -// - Nil-safe (ignores nil expressions) -// -// Supported Expression Types: -// - Identifier, LiteralValue, AliasedExpression -// - BinaryExpression, UnaryExpression -// - FunctionCall, CaseExpression -// - BetweenExpression, InExpression -// - SubqueryExpression, ExistsExpression, AnyExpression, AllExpression -// - CastExpression, ExtractExpression, PositionExpression, SubstringExpression -// - ListExpression -// -// Iterative Cleanup Algorithm: -// 1. Use work queue instead of recursion -// 2. Process expressions breadth-first -// 3. Collect child expressions and add to queue -// 4. Clean and return to pool -// 5. Limit queue size to prevent memory exhaustion -// -// Parameters: -// - expr: Expression to return to pool (nil-safe) -// -// Usage Pattern: -// -// expr := ast.GetBinaryExpression() -// defer ast.PutExpression(expr) -// -// // Build expression tree... -// -// Example - Cleaning up complex expression: -// -// // Build: (age > 18 AND status = 'active') OR (role = 'admin') -// expr := &ast.BinaryExpression{ -// Left: &ast.BinaryExpression{ -// Left: &ast.BinaryExpression{...}, -// Operator: "AND", -// Right: &ast.BinaryExpression{...}, -// }, -// Operator: "OR", -// Right: &ast.BinaryExpression{...}, -// } -// -// // Cleanup all nested expressions -// ast.PutExpression(expr) // Handles entire tree iteratively -// -// Performance Characteristics: -// - O(n) time complexity where n = number of nodes -// - O(min(n, MaxWorkQueueSize)) space complexity -// - No stack overflow risk regardless of nesting depth -// - Efficient for both shallow and deeply nested expressions -// -// Safety Guarantees: -// - Thread-safe (uses sync.Pool internally) -// - Nil-safe (gracefully handles nil expressions) -// - Stack-safe (iterative, not recursive) -// - Memory-safe (work queue size limits) -// -// IMPORTANT: This function should be used for all expression cleanup. -// Direct pool returns (e.g., binaryExprPool.Put()) bypass the iterative -// cleanup and may leave child expressions unreleased. -// -// See also: GetBinaryExpression(), GetFunctionCall(), GetIdentifier() -func PutExpression(expr Expression) { - if expr == nil { - return - } - putExpressionImpl(expr, 0) -} - -// putExpressionImpl is the internal driver for PutExpression. The depth -// parameter tracks recursive re-entries from the work-queue overflow path -// to prevent stack overflow on pathologically deep ASTs. -func putExpressionImpl(expr Expression, depth int) { - if expr == nil { - return - } - - // Use a work queue for iterative cleanup instead of recursion - workQueue := make([]Expression, 0, 32) - workQueue = append(workQueue, expr) - - processed := 0 - for len(workQueue) > 0 && processed < MaxWorkQueueSize { - // Pop from queue - current := workQueue[len(workQueue)-1] - workQueue = workQueue[:len(workQueue)-1] - processed++ - - if current == nil { - continue - } - - // Process and collect child expressions - switch e := current.(type) { - case *Identifier: - e.Name = "" - identifierPool.Put(e) - - case *BinaryExpression: - if e.Left != nil { - workQueue = append(workQueue, e.Left) - } - if e.Right != nil { - workQueue = append(workQueue, e.Right) - } - e.Left = nil - e.Right = nil - e.Operator = "" - binaryExprPool.Put(e) - - case *LiteralValue: - e.Value = nil - e.Type = "" - literalValuePool.Put(e) - - case *FunctionCall: - for i := range e.Arguments { - if e.Arguments[i] != nil { - workQueue = append(workQueue, e.Arguments[i]) - } - e.Arguments[i] = nil - } - e.Arguments = e.Arguments[:0] - e.Name = "" - e.Over = nil - e.Distinct = false - e.Filter = nil - functionCallPool.Put(e) - - case *CaseExpression: - if e.Value != nil { - workQueue = append(workQueue, e.Value) - } - for i := range e.WhenClauses { - if e.WhenClauses[i].Condition != nil { - workQueue = append(workQueue, e.WhenClauses[i].Condition) - } - if e.WhenClauses[i].Result != nil { - workQueue = append(workQueue, e.WhenClauses[i].Result) - } - } - if e.ElseClause != nil { - workQueue = append(workQueue, e.ElseClause) - } - e.Value = nil - e.WhenClauses = e.WhenClauses[:0] - e.ElseClause = nil - caseExprPool.Put(e) - - case *BetweenExpression: - if e.Expr != nil { - workQueue = append(workQueue, e.Expr) - } - if e.Lower != nil { - workQueue = append(workQueue, e.Lower) - } - if e.Upper != nil { - workQueue = append(workQueue, e.Upper) - } - e.Expr = nil - e.Lower = nil - e.Upper = nil - e.Not = false - betweenExprPool.Put(e) - - case *InExpression: - if e.Expr != nil { - workQueue = append(workQueue, e.Expr) - } - for i := range e.List { - if e.List[i] != nil { - workQueue = append(workQueue, e.List[i]) - } - e.List[i] = nil - } - e.Expr = nil - e.List = e.List[:0] - e.Subquery = nil - e.Not = false - inExprPool.Put(e) - - case *SubqueryExpression: - e.Subquery = nil - subqueryExprPool.Put(e) - - case *CastExpression: - if e.Expr != nil { - workQueue = append(workQueue, e.Expr) - } - e.Expr = nil - e.Type = "" - castExprPool.Put(e) - - case *IntervalExpression: - e.Value = "" - intervalExprPool.Put(e) - - case *ArraySubscriptExpression: - if e.Array != nil { - workQueue = append(workQueue, e.Array) - } - for i := range e.Indices { - if e.Indices[i] != nil { - workQueue = append(workQueue, e.Indices[i]) - } - } - e.Array = nil - e.Indices = e.Indices[:0] - arraySubscriptExprPool.Put(e) - - case *ArraySliceExpression: - if e.Array != nil { - workQueue = append(workQueue, e.Array) - } - if e.Start != nil { - workQueue = append(workQueue, e.Start) - } - if e.End != nil { - workQueue = append(workQueue, e.End) - } - e.Array = nil - e.Start = nil - e.End = nil - arraySliceExprPool.Put(e) - - case *ExistsExpression: - e.Subquery = nil - existsExprPool.Put(e) - - case *AnyExpression: - if e.Expr != nil { - workQueue = append(workQueue, e.Expr) - } - e.Expr = nil - e.Subquery = nil - e.Operator = "" - anyExprPool.Put(e) - - case *AllExpression: - if e.Expr != nil { - workQueue = append(workQueue, e.Expr) - } - e.Expr = nil - e.Subquery = nil - e.Operator = "" - allExprPool.Put(e) - - case *ListExpression: - for i := range e.Values { - if e.Values[i] != nil { - workQueue = append(workQueue, e.Values[i]) - } - e.Values[i] = nil - } - e.Values = e.Values[:0] - listExprPool.Put(e) - - case *TupleExpression: - for i := range e.Expressions { - if e.Expressions[i] != nil { - workQueue = append(workQueue, e.Expressions[i]) - } - e.Expressions[i] = nil - } - e.Expressions = e.Expressions[:0] - tupleExprPool.Put(e) - - case *ArrayConstructorExpression: - for i := range e.Elements { - if e.Elements[i] != nil { - workQueue = append(workQueue, e.Elements[i]) - } - e.Elements[i] = nil - } - e.Elements = e.Elements[:0] - e.Subquery = nil - arrayConstructorPool.Put(e) - - case *UnaryExpression: - if e.Expr != nil { - workQueue = append(workQueue, e.Expr) - } - e.Expr = nil - e.Operator = 0 // UnaryOperator is int type - unaryExprPool.Put(e) - - case *ExtractExpression: - if e.Source != nil { - workQueue = append(workQueue, e.Source) - } - e.Field = "" - e.Source = nil - extractExprPool.Put(e) - - case *PositionExpression: - if e.Substr != nil { - workQueue = append(workQueue, e.Substr) - } - if e.Str != nil { - workQueue = append(workQueue, e.Str) - } - e.Substr = nil - e.Str = nil - positionExprPool.Put(e) - - case *SubstringExpression: - if e.Str != nil { - workQueue = append(workQueue, e.Str) - } - if e.Start != nil { - workQueue = append(workQueue, e.Start) - } - if e.Length != nil { - workQueue = append(workQueue, e.Length) - } - e.Str = nil - e.Start = nil - e.Length = nil - substringExprPool.Put(e) - - case *AliasedExpression: - if e.Expr != nil { - workQueue = append(workQueue, e.Expr) - } - e.Expr = nil - e.Alias = "" - aliasedExprPool.Put(e) - - // Default case - expression type not pooled, just ignore - default: - // Unknown expression type - no pool available - } - } - - // OVERFLOW DRAIN: if we hit the work-queue cap, there are still pooled - // nodes in workQueue that would otherwise leak. Fall back to a recursive - // drain, depth-limited to prevent stack overflow on deeply nested trees. - // Each recursive call starts its own fresh work queue of up to - // MaxWorkQueueSize, so the recursion depth is effectively - // ceil(total_nodes / MaxWorkQueueSize). MaxCleanupDepth = 100 bounds this - // at ~10_000_000 total nodes in an AST — far beyond any real SQL query. - if len(workQueue) > 0 { - atomic.AddUint64(&poolLeakCount, uint64(len(workQueue))) - if depth < MaxCleanupDepth { - for _, remaining := range workQueue { - putExpressionImpl(remaining, depth+1) - } - } - // If depth exceeded MaxCleanupDepth we accept the leak rather than - // blow the stack; poolLeakCount records the truncation for diagnostics. - } -} - -// GetFunctionCall gets a FunctionCall from the pool -func GetFunctionCall() *FunctionCall { - fc := functionCallPool.Get().(*FunctionCall) - fc.Arguments = fc.Arguments[:0] - return fc -} - -// PutFunctionCall returns a FunctionCall to the pool -func PutFunctionCall(fc *FunctionCall) { - if fc == nil { - return - } - for i := range fc.Arguments { - PutExpression(fc.Arguments[i]) - fc.Arguments[i] = nil - } - fc.Arguments = fc.Arguments[:0] - fc.Name = "" - fc.Over = nil - fc.Distinct = false - fc.Filter = nil - functionCallPool.Put(fc) -} - -// GetCaseExpression gets a CaseExpression from the pool -func GetCaseExpression() *CaseExpression { - ce := caseExprPool.Get().(*CaseExpression) - ce.WhenClauses = ce.WhenClauses[:0] - return ce -} - -// PutCaseExpression returns a CaseExpression to the pool -func PutCaseExpression(ce *CaseExpression) { - if ce == nil { - return - } - PutExpression(ce.Value) - ce.Value = nil - for i := range ce.WhenClauses { - PutExpression(ce.WhenClauses[i].Condition) - PutExpression(ce.WhenClauses[i].Result) - } - ce.WhenClauses = ce.WhenClauses[:0] - PutExpression(ce.ElseClause) - ce.ElseClause = nil - caseExprPool.Put(ce) -} - -// GetBetweenExpression gets a BetweenExpression from the pool -func GetBetweenExpression() *BetweenExpression { - return betweenExprPool.Get().(*BetweenExpression) -} - -// PutBetweenExpression returns a BetweenExpression to the pool -func PutBetweenExpression(be *BetweenExpression) { - if be == nil { - return - } - PutExpression(be.Expr) - PutExpression(be.Lower) - PutExpression(be.Upper) - be.Expr = nil - be.Lower = nil - be.Upper = nil - be.Not = false - betweenExprPool.Put(be) -} - -// GetInExpression gets an InExpression from the pool -func GetInExpression() *InExpression { - ie := inExprPool.Get().(*InExpression) - ie.List = ie.List[:0] - return ie -} - -// PutInExpression returns an InExpression to the pool -func PutInExpression(ie *InExpression) { - if ie == nil { - return - } - PutExpression(ie.Expr) - ie.Expr = nil - for i := range ie.List { - PutExpression(ie.List[i]) - ie.List[i] = nil - } - ie.List = ie.List[:0] - ie.Subquery = nil - ie.Not = false - inExprPool.Put(ie) -} - -// GetTupleExpression gets a TupleExpression from the pool -func GetTupleExpression() *TupleExpression { - te := tupleExprPool.Get().(*TupleExpression) - te.Expressions = te.Expressions[:0] - return te -} - -// PutTupleExpression returns a TupleExpression to the pool -func PutTupleExpression(te *TupleExpression) { - if te == nil { - return - } - for i := range te.Expressions { - PutExpression(te.Expressions[i]) - te.Expressions[i] = nil - } - te.Expressions = te.Expressions[:0] - tupleExprPool.Put(te) -} - -// GetArrayConstructor gets an ArrayConstructorExpression from the pool -func GetArrayConstructor() *ArrayConstructorExpression { - ac := arrayConstructorPool.Get().(*ArrayConstructorExpression) - ac.Elements = ac.Elements[:0] - ac.Subquery = nil - return ac -} - -// PutArrayConstructor returns an ArrayConstructorExpression to the pool -func PutArrayConstructor(ac *ArrayConstructorExpression) { - if ac == nil { - return - } - for i := range ac.Elements { - PutExpression(ac.Elements[i]) - ac.Elements[i] = nil - } - ac.Elements = ac.Elements[:0] - ac.Subquery = nil - arrayConstructorPool.Put(ac) -} - -// GetSubqueryExpression gets a SubqueryExpression from the pool -func GetSubqueryExpression() *SubqueryExpression { - return subqueryExprPool.Get().(*SubqueryExpression) -} - -// PutSubqueryExpression returns a SubqueryExpression to the pool -func PutSubqueryExpression(se *SubqueryExpression) { - if se == nil { - return - } - se.Subquery = nil - subqueryExprPool.Put(se) -} - -// GetCastExpression gets a CastExpression from the pool -func GetCastExpression() *CastExpression { - return castExprPool.Get().(*CastExpression) -} - -// PutCastExpression returns a CastExpression to the pool -func PutCastExpression(ce *CastExpression) { - if ce == nil { - return - } - PutExpression(ce.Expr) - ce.Expr = nil - ce.Type = "" - castExprPool.Put(ce) -} - -// GetIntervalExpression gets an IntervalExpression from the pool -func GetIntervalExpression() *IntervalExpression { - return intervalExprPool.Get().(*IntervalExpression) -} - -// PutIntervalExpression returns an IntervalExpression to the pool -func PutIntervalExpression(ie *IntervalExpression) { - if ie == nil { - return - } - ie.Value = "" - intervalExprPool.Put(ie) -} - -// GetAliasedExpression retrieves an AliasedExpression from the pool -func GetAliasedExpression() *AliasedExpression { - return aliasedExprPool.Get().(*AliasedExpression) -} - -// PutAliasedExpression returns an AliasedExpression to the pool -func PutAliasedExpression(ae *AliasedExpression) { - if ae == nil { - return - } - PutExpression(ae.Expr) - ae.Expr = nil - ae.Alias = "" - aliasedExprPool.Put(ae) -} - -// GetArraySubscriptExpression gets an ArraySubscriptExpression from the pool -func GetArraySubscriptExpression() *ArraySubscriptExpression { - return arraySubscriptExprPool.Get().(*ArraySubscriptExpression) -} - -// PutArraySubscriptExpression returns an ArraySubscriptExpression to the pool -func PutArraySubscriptExpression(ase *ArraySubscriptExpression) { - if ase == nil { - return - } - // Clean up array expression - if ase.Array != nil { - PutExpression(ase.Array) - ase.Array = nil - } - // Clean up indices - for i := range ase.Indices { - if ase.Indices[i] != nil { - PutExpression(ase.Indices[i]) - } - } - ase.Indices = ase.Indices[:0] // Clear slice but keep capacity - arraySubscriptExprPool.Put(ase) -} - -// GetArraySliceExpression gets an ArraySliceExpression from the pool -func GetArraySliceExpression() *ArraySliceExpression { - return arraySliceExprPool.Get().(*ArraySliceExpression) -} - -// PutArraySliceExpression returns an ArraySliceExpression to the pool -func PutArraySliceExpression(ase *ArraySliceExpression) { - if ase == nil { - return - } - // Clean up array expression - if ase.Array != nil { - PutExpression(ase.Array) - ase.Array = nil - } - // Clean up start/end expressions - if ase.Start != nil { - PutExpression(ase.Start) - ase.Start = nil - } - if ase.End != nil { - PutExpression(ase.End) - ase.End = nil - } - arraySliceExprPool.Put(ase) -} - -// ============================================================ -// DDL Statement Pool Functions -// ============================================================ - -// GetCreateTableStatement gets a CreateTableStatement from the pool. -func GetCreateTableStatement() *CreateTableStatement { - stmt := createTableStmtPool.Get().(*CreateTableStatement) - stmt.Columns = stmt.Columns[:0] - stmt.Constraints = stmt.Constraints[:0] - stmt.Inherits = stmt.Inherits[:0] - stmt.Options = stmt.Options[:0] - return stmt -} - -// PutCreateTableStatement returns a CreateTableStatement to the pool. -// It recursively releases any nested expressions (column defaults, check constraints, etc.). -func PutCreateTableStatement(stmt *CreateTableStatement) { - if stmt == nil { - return - } - - // Release expressions embedded in column definitions - for i := range stmt.Columns { - for j := range stmt.Columns[i].Constraints { - PutExpression(stmt.Columns[i].Constraints[j].Default) - PutExpression(stmt.Columns[i].Constraints[j].Check) - stmt.Columns[i].Constraints[j].Default = nil - stmt.Columns[i].Constraints[j].Check = nil - stmt.Columns[i].Constraints[j].References = nil - } - stmt.Columns[i].Constraints = stmt.Columns[i].Constraints[:0] - stmt.Columns[i].Name = "" - stmt.Columns[i].Type = "" - } - stmt.Columns = stmt.Columns[:0] - - // Release expressions in table constraints - for i := range stmt.Constraints { - PutExpression(stmt.Constraints[i].Check) - stmt.Constraints[i].Check = nil - stmt.Constraints[i].References = nil - stmt.Constraints[i].Name = "" - stmt.Constraints[i].Type = "" - stmt.Constraints[i].Columns = stmt.Constraints[i].Columns[:0] - } - stmt.Constraints = stmt.Constraints[:0] - - // Release expressions in PartitionBy - if stmt.PartitionBy != nil { - for i, expr := range stmt.PartitionBy.Boundary { - PutExpression(expr) - stmt.PartitionBy.Boundary[i] = nil - } - stmt.PartitionBy.Boundary = stmt.PartitionBy.Boundary[:0] - stmt.PartitionBy.Columns = stmt.PartitionBy.Columns[:0] - stmt.PartitionBy.Type = "" - stmt.PartitionBy = nil - } - - // Release expressions in PartitionDefinitions - for i := range stmt.Partitions { - for j, expr := range stmt.Partitions[i].Values { - PutExpression(expr) - stmt.Partitions[i].Values[j] = nil - } - PutExpression(stmt.Partitions[i].LessThan) - PutExpression(stmt.Partitions[i].From) - PutExpression(stmt.Partitions[i].To) - for j, expr := range stmt.Partitions[i].InValues { - PutExpression(expr) - stmt.Partitions[i].InValues[j] = nil - } - stmt.Partitions[i].Values = stmt.Partitions[i].Values[:0] - stmt.Partitions[i].InValues = stmt.Partitions[i].InValues[:0] - stmt.Partitions[i].LessThan = nil - stmt.Partitions[i].From = nil - stmt.Partitions[i].To = nil - stmt.Partitions[i].Name = "" - stmt.Partitions[i].Type = "" - stmt.Partitions[i].Tablespace = "" - } - stmt.Partitions = stmt.Partitions[:0] - - stmt.Inherits = stmt.Inherits[:0] - - for i := range stmt.Options { - stmt.Options[i].Name = "" - stmt.Options[i].Value = "" - } - stmt.Options = stmt.Options[:0] - - // Reset scalar fields - stmt.IfNotExists = false - stmt.Temporary = false - stmt.Name = "" - - createTableStmtPool.Put(stmt) -} - -// GetAlterTableStatement gets an AlterTableStatement from the pool. -func GetAlterTableStatement() *AlterTableStatement { - stmt := alterTableStmtPool.Get().(*AlterTableStatement) - stmt.Actions = stmt.Actions[:0] - return stmt -} - -// PutAlterTableStatement returns an AlterTableStatement to the pool. -// It recursively releases nested expressions in column definitions and constraints. -func PutAlterTableStatement(stmt *AlterTableStatement) { - if stmt == nil { - return - } - - for i := range stmt.Actions { - // Release nested ColumnDef expressions - if stmt.Actions[i].ColumnDef != nil { - for j := range stmt.Actions[i].ColumnDef.Constraints { - PutExpression(stmt.Actions[i].ColumnDef.Constraints[j].Default) - PutExpression(stmt.Actions[i].ColumnDef.Constraints[j].Check) - stmt.Actions[i].ColumnDef.Constraints[j].Default = nil - stmt.Actions[i].ColumnDef.Constraints[j].Check = nil - stmt.Actions[i].ColumnDef.Constraints[j].References = nil - } - stmt.Actions[i].ColumnDef.Constraints = stmt.Actions[i].ColumnDef.Constraints[:0] - stmt.Actions[i].ColumnDef = nil - } - // Release nested TableConstraint expressions - if stmt.Actions[i].Constraint != nil { - PutExpression(stmt.Actions[i].Constraint.Check) - stmt.Actions[i].Constraint.Check = nil - stmt.Actions[i].Constraint = nil - } - stmt.Actions[i].Type = "" - stmt.Actions[i].ColumnName = "" - } - stmt.Actions = stmt.Actions[:0] - stmt.Table = "" - - alterTableStmtPool.Put(stmt) -} - // GetCreateIndexStatement gets a CreateIndexStatement from the pool. func GetCreateIndexStatement() *CreateIndexStatement { stmt := createIndexStmtPool.Get().(*CreateIndexStatement) @@ -1950,70 +648,6 @@ func PutCreateIndexStatement(stmt *CreateIndexStatement) { createIndexStmtPool.Put(stmt) } -// GetMergeStatement gets a MergeStatement from the pool. -func GetMergeStatement() *MergeStatement { - stmt := mergeStmtPool.Get().(*MergeStatement) - stmt.WhenClauses = stmt.WhenClauses[:0] - stmt.Output = stmt.Output[:0] - return stmt -} - -// PutMergeStatement returns a MergeStatement to the pool. -// It recursively releases nested expressions in WHEN clauses and OUTPUT. -func PutMergeStatement(stmt *MergeStatement) { - if stmt == nil { - return - } - - // Release OnCondition - PutExpression(stmt.OnCondition) - stmt.OnCondition = nil - - // Release WHEN clause expressions - for i := range stmt.WhenClauses { - if stmt.WhenClauses[i] == nil { - continue - } - PutExpression(stmt.WhenClauses[i].Condition) - stmt.WhenClauses[i].Condition = nil - if stmt.WhenClauses[i].Action != nil { - for j := range stmt.WhenClauses[i].Action.SetClauses { - PutExpression(stmt.WhenClauses[i].Action.SetClauses[j].Value) - stmt.WhenClauses[i].Action.SetClauses[j].Value = nil - stmt.WhenClauses[i].Action.SetClauses[j].Column = "" - } - stmt.WhenClauses[i].Action.SetClauses = stmt.WhenClauses[i].Action.SetClauses[:0] - for j, expr := range stmt.WhenClauses[i].Action.Values { - PutExpression(expr) - stmt.WhenClauses[i].Action.Values[j] = nil - } - stmt.WhenClauses[i].Action.Values = stmt.WhenClauses[i].Action.Values[:0] - stmt.WhenClauses[i].Action.Columns = stmt.WhenClauses[i].Action.Columns[:0] - stmt.WhenClauses[i].Action.ActionType = "" - stmt.WhenClauses[i].Action.DefaultValues = false - stmt.WhenClauses[i].Action = nil - } - stmt.WhenClauses[i].Type = "" - stmt.WhenClauses[i] = nil - } - stmt.WhenClauses = stmt.WhenClauses[:0] - - // Release OUTPUT expressions - for i, expr := range stmt.Output { - PutExpression(expr) - stmt.Output[i] = nil - } - stmt.Output = stmt.Output[:0] - - // Reset TargetTable / SourceTable (value types - zero them out) - stmt.TargetTable = TableReference{} - stmt.SourceTable = TableReference{} - stmt.TargetAlias = "" - stmt.SourceAlias = "" - - mergeStmtPool.Put(stmt) -} - // GetCreateViewStatement gets a CreateViewStatement from the pool. func GetCreateViewStatement() *CreateViewStatement { stmt := createViewStmtPool.Get().(*CreateViewStatement) @@ -2180,41 +814,6 @@ func PutUnsupportedStatement(stmt *UnsupportedStatement) { unsupportedStmtPool.Put(stmt) } -// GetReplaceStatement gets a ReplaceStatement from the pool. -func GetReplaceStatement() *ReplaceStatement { - stmt := replaceStmtPool.Get().(*ReplaceStatement) - stmt.Columns = stmt.Columns[:0] - stmt.Values = stmt.Values[:0] - return stmt -} - -// PutReplaceStatement returns a ReplaceStatement to the pool. -// It recursively releases nested column and value expressions. -func PutReplaceStatement(stmt *ReplaceStatement) { - if stmt == nil { - return - } - - for i := range stmt.Columns { - PutExpression(stmt.Columns[i]) - stmt.Columns[i] = nil - } - stmt.Columns = stmt.Columns[:0] - - for i := range stmt.Values { - for j := range stmt.Values[i] { - PutExpression(stmt.Values[i][j]) - stmt.Values[i][j] = nil - } - stmt.Values[i] = stmt.Values[i][:0] - } - stmt.Values = stmt.Values[:0] - - stmt.TableName = "" - - replaceStmtPool.Put(stmt) -} - // GetAlterStatement gets an AlterStatement from the pool. func GetAlterStatement() *AlterStatement { return alterStmtPool.Get().(*AlterStatement) @@ -2268,23 +867,3 @@ func ReleaseAlterSequenceStatement(s *AlterSequenceStatement) { *s = AlterSequenceStatement{} // zero all fields alterSequencePool.Put(s) } - -// releaseTimeTravelClause walks a TimeTravelClause graph, releasing every -// Expression stored in Named maps and every chained sub-clause. Chained -// cycles are not possible because the parser builds a tree, but we still -// guard against nil to be defensive. -func releaseTimeTravelClause(c *TimeTravelClause) { - if c == nil { - return - } - for k, v := range c.Named { - PutExpression(v) - delete(c.Named, k) - } - for _, ch := range c.Chained { - releaseTimeTravelClause(ch) - } - c.Chained = nil - c.Named = nil - c.Kind = "" -} diff --git a/pkg/sql/ast/pool_leak_test.go b/pkg/sql/ast/pool_leak_test.go index 777d1dbd..2475666c 100644 --- a/pkg/sql/ast/pool_leak_test.go +++ b/pkg/sql/ast/pool_leak_test.go @@ -276,3 +276,108 @@ func TestPoolLeak_PutExpression_OverflowDrain(t *testing.T) { t.Errorf("PoolLeakCount=%d unreasonably high for %d-node chain", leaks, nodes) } } + +// measureHeapDelta runs fn for `iterations` cycles between two memory +// snapshots and returns the heap-in-use growth. It warms the pool with a +// short lead-in so JIT/pool priming costs don't bias the measurement. +func measureHeapDelta(t *testing.T, iterations int, warmups int, fn func() error) int64 { + t.Helper() + for i := 0; i < warmups; i++ { + if err := fn(); err != nil { + t.Fatalf("warmup %d: %v", i, err) + } + } + runtime.GC() + runtime.GC() + var before runtime.MemStats + runtime.ReadMemStats(&before) + + for i := 0; i < iterations; i++ { + if err := fn(); err != nil { + t.Fatalf("iteration %d: %v", i, err) + } + } + runtime.GC() + runtime.GC() + var after runtime.MemStats + runtime.ReadMemStats(&after) + + delta := int64(after.HeapInuse) - int64(before.HeapInuse) + t.Logf("HeapInuse: before=%d, after=%d, delta=%+d bytes over %d iterations (%.1f bytes/iter)", + before.HeapInuse, after.HeapInuse, delta, iterations, + float64(delta)/float64(iterations)) + return delta +} + +// TestPoolLeak_SubqueryInExpression verifies that an IN-subquery +// (IN (SELECT ...)) releases its nested SelectStatement back to the pool +// rather than leaking it. Pre-fix, pool.go set e.Subquery = nil inside +// putExpressionImpl without calling releaseStatement, so every nested +// SelectStatement reachable via InExpression.Subquery leaked on every +// parse. This test parses 1000 such queries and asserts stable heap. +func TestPoolLeak_SubqueryInExpression(t *testing.T) { + const iterations = 1000 + const heapGrowthLimit = 10 * 1024 * 1024 // 10 MiB + + sql := `SELECT x FROM t WHERE id IN (SELECT y FROM u WHERE z = 1)` + + delta := measureHeapDelta(t, iterations, 10, func() error { + return parseAndRelease(t, sql) + }) + + if delta > heapGrowthLimit { + t.Errorf("IN-subquery pool leak detected: HeapInuse grew by %d bytes over %d iterations (>%d limit)", + delta, iterations, heapGrowthLimit) + } +} + +// TestPoolLeak_ExistsSubquery verifies that EXISTS (SELECT ...) releases +// its nested SelectStatement back to the pool. Pre-fix, ExistsExpression's +// Subquery field was niled without dispatch, leaking every correlated +// SELECT body on every parse. +func TestPoolLeak_ExistsSubquery(t *testing.T) { + const iterations = 1000 + const heapGrowthLimit = 10 * 1024 * 1024 // 10 MiB + + sql := `SELECT x FROM t WHERE EXISTS (SELECT 1 FROM u WHERE u.id = t.id)` + + delta := measureHeapDelta(t, iterations, 10, func() error { + return parseAndRelease(t, sql) + }) + + if delta > heapGrowthLimit { + t.Errorf("EXISTS-subquery pool leak detected: HeapInuse grew by %d bytes over %d iterations (>%d limit)", + delta, iterations, heapGrowthLimit) + } +} + +// TestPoolLeak_AnyAllSubquery verifies that ANY(SELECT ...) and +// ALL(SELECT ...) release their nested SelectStatement. Pre-fix, both +// AnyExpression.Subquery and AllExpression.Subquery were niled without +// dispatch, leaking one SelectStatement per parse for each construct. +func TestPoolLeak_AnyAllSubquery(t *testing.T) { + const iterations = 1000 + const heapGrowthLimit = 10 * 1024 * 1024 // 10 MiB + + // Parse both ANY and ALL on each iteration so we exercise both code + // paths in a single test. We alternate rather than concatenating so + // the parser sees one statement at a time (matching production shape). + sqls := []string{ + `SELECT x FROM t WHERE val = ANY (SELECT v FROM u)`, + `SELECT x FROM t WHERE val = ALL (SELECT v FROM u WHERE u.active = true)`, + } + + delta := measureHeapDelta(t, iterations, 10, func() error { + for _, sql := range sqls { + if err := parseAndRelease(t, sql); err != nil { + return err + } + } + return nil + }) + + if delta > heapGrowthLimit { + t.Errorf("ANY/ALL-subquery pool leak detected: HeapInuse grew by %d bytes over %d iterations (>%d limit)", + delta, iterations, heapGrowthLimit) + } +} From 89e4d9db3d360b86ed93ce117ee72248ded36d4f Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Wed, 22 Apr 2026 18:50:02 +0530 Subject: [PATCH 2/4] feat(gosqlx): typed walkers, Rewrite, Clone, bounded reader, dialect-aware splitter, README update MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sprint B + C from 2026-04-21 architect review — close the Tree-API narrative and add competitive features that match sqlparser-rs / vitess. Typed walkers (pkg/gosqlx/walkers.go): - Generic WalkBy[T ast.Node](t *Tree, fn func(T) bool) — single generic walker - 10 typed convenience methods on *Tree: WalkSelects, WalkInserts, WalkUpdates, WalkDeletes, WalkCreateTables, WalkJoins, WalkCTEs, WalkFunctionCalls, WalkIdentifiers, WalkBinaryExpressions - Users no longer need the `if stmt, ok := n.(*ast.SelectStatement); ok` dance Tree.Rewrite and Tree.Clone (pkg/gosqlx/tree.go): - Rewrite(pre, post func(ast.Statement) ast.Statement) — top-level statement transformation / filtering. Doc explains that intra-statement rewrites use Walk* helpers which hand back pointers (field assignment propagates). - Clone() *Tree — re-parses from t.SQL(). Simple + correct + reuses all tested parser invariants. O(parse) cost documented. Dialect-aware statement splitter (pkg/gosqlx/splitter.go): - SplitStatements(sql, dialect string) []string — exported API - Handles: PostgreSQL dollar-quoting ($$...$$ and $tag$...$tag$), E-strings with backslash escapes, nested block comments; MySQL/MariaDB/ClickHouse backticks; SQL Server bracketed identifiers; plus ANSI defaults (single quotes with '' escape, double-quote identifiers, -- line comments, /* */ block comments). Correctly rejects $1/$2 positional params as tag starts. - ParseReaderMultiple now routes through SplitStatements honoring the dialect option. Bounded ParseReader (pkg/gosqlx/reader.go): - WithMaxBytes(n int64) Option — 0 means unbounded (backward compat) - ErrTooLarge sentinel; exceeds-cap inputs return errors.Is-compatible error - ctxReader wraps input with ctx.Done() short-circuit on Read calls - 19 new reader tests + 33 new splitter tests README update: - Hero snippet now uses ParseTree + WithDialect instead of Parse + ast.AST - Get-Started example shows ParseTree, sentinel errors, Tables(), Format, and a typed WalkSelects call - Note pointing users to docs/MIGRATION.md for the Tree migration guide go test -race ./... passes. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 47 +++-- pkg/gosqlx/reader.go | 218 +++++++++++----------- pkg/gosqlx/reader_test.go | 127 +++++++++++++ pkg/gosqlx/splitter.go | 333 ++++++++++++++++++++++++++++++++++ pkg/gosqlx/splitter_test.go | 194 ++++++++++++++++++++ pkg/gosqlx/tree.go | 87 +++++++++ pkg/gosqlx/walkers.go | 124 +++++++++++++ pkg/gosqlx/walkers_test.go | 347 ++++++++++++++++++++++++++++++++++++ 8 files changed, 1358 insertions(+), 119 deletions(-) create mode 100644 pkg/gosqlx/splitter.go create mode 100644 pkg/gosqlx/splitter_test.go create mode 100644 pkg/gosqlx/walkers.go create mode 100644 pkg/gosqlx/walkers_test.go diff --git a/README.md b/README.md index 409cf9fd..230fc122 100644 --- a/README.md +++ b/README.md @@ -41,8 +41,12 @@ GoSQLX is a **production-ready SQL parsing SDK** for Go. It tokenizes, parses, and generates ASTs from SQL with zero-copy optimizations and intelligent object pooling - handling **1.38M+ operations per second** with sub-microsecond latency. ```go -ast, _ := gosqlx.Parse("SELECT u.name, COUNT(*) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name") -// → Full AST with statements, columns, joins, grouping - ready for analysis, transformation, or formatting +// v1.15+ recommended entry point: ParseTree returns an opaque Tree, +// so you don't need to import pkg/sql/ast just to get started. +tree, _ := gosqlx.ParseTree(ctx, "SELECT u.name, COUNT(*) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name", + gosqlx.WithDialect("postgresql")) +fmt.Println("Tables:", tree.Tables()) +fmt.Println(tree.Format(gosqlx.WithIndent(2), gosqlx.WithUppercaseKeywords(true))) ``` ### Why GoSQLX? @@ -69,23 +73,30 @@ import ( ) func main() { - // Parse any SQL dialect - ast, _ := gosqlx.Parse("SELECT * FROM users WHERE active = true") - fmt.Printf("%d statement(s)\n", len(ast.Statements)) - - // Format messy SQL - clean, _ := gosqlx.Format("select id,name from users where id=1", gosqlx.DefaultFormatOptions()) - fmt.Println(clean) - // SELECT - // id, - // name - // FROM users - // WHERE id = 1 - - // Catch errors before production - if err := gosqlx.Validate("SELECT * FROM"); err != nil { - fmt.Println(err) // → expected table name + ctx := context.Background() + + // ParseTree (v1.15+) is the recommended entry point. It returns an + // opaque handle with built-in helpers — no need to import pkg/sql/ast. + tree, err := gosqlx.ParseTree(ctx, "SELECT id, name FROM users WHERE active = true", + gosqlx.WithDialect("postgresql")) + if err != nil { + // Sentinel errors work with errors.Is + if errors.Is(err, gosqlx.ErrSyntax) { + log.Fatalf("syntax error: %v", err) + } + log.Fatal(err) } + fmt.Println("Tables:", tree.Tables()) + fmt.Println(tree.Format(gosqlx.WithIndent(2), gosqlx.WithUppercaseKeywords(true))) + + // Walk the AST — typed walkers avoid the type-assertion dance: + tree.WalkSelects(func(s *ast.SelectStatement) bool { + fmt.Printf(" SELECT with %d columns\n", len(s.Columns)) + return true + }) + + // The legacy Parse/Format/Validate API still works for v1.x code. + // See docs/MIGRATION.md for the Tree migration guide. } ``` diff --git a/pkg/gosqlx/reader.go b/pkg/gosqlx/reader.go index 53bee750..014f78a7 100644 --- a/pkg/gosqlx/reader.go +++ b/pkg/gosqlx/reader.go @@ -25,8 +25,9 @@ import ( // // This is a convenience wrapper for callers who already have an io.Reader // (HTTP request body, file handle, strings.Reader, etc.) and don't want to -// manage the buffering themselves. Input is consumed in full via io.ReadAll -// before parsing begins. +// manage the buffering themselves. Input is consumed in full before parsing +// begins; by default the read is unbounded, but callers that handle hostile +// or user-supplied input should pass WithMaxBytes to cap allocation. // // If ctx is nil, context.Background is used. Options are forwarded to // ParseTree unchanged; see ParseTree for the context/dialect/timeout @@ -35,22 +36,25 @@ import ( // Read errors are surfaced verbatim (not wrapped in one of the gosqlx // sentinels) because they originate outside the SQL layer. Parse errors // follow the normal ParseTree wrapping (ErrSyntax / ErrTokenize / ErrTimeout -// / ErrUnsupportedDialect). +// / ErrUnsupportedDialect). Inputs that exceed WithMaxBytes return an error +// that satisfies errors.Is(err, ErrTooLarge). // // Example: // // f, _ := os.Open("query.sql") // defer f.Close() -// tree, err := gosqlx.ParseReader(ctx, f, gosqlx.WithDialect("postgresql")) +// tree, err := gosqlx.ParseReader(ctx, f, +// gosqlx.WithDialect("postgresql"), +// gosqlx.WithMaxBytes(1<<20), // reject >1 MiB +// ) // if err != nil { // return err // } // -// Cancellation: if ctx is cancelled before the reader finishes draining, -// the underlying io.ReadAll call does not abort mid-read — callers who need -// truly cancellable reads must wrap r in a context-aware reader (see -// golang.org/x/net/http2/h2c or similar). ParseReader does re-check ctx -// after the read and before dispatching to the parser. +// Cancellation: the reader is wrapped so that each Read call short-circuits +// with the context error if ctx is cancelled. This does not interrupt a Read +// that has already entered a syscall — callers dealing with pathological +// network readers should still enforce deadlines at the transport layer. func ParseReader(ctx context.Context, r io.Reader, opts ...Option) (*Tree, error) { if ctx == nil { ctx = context.Background() @@ -59,14 +63,16 @@ func ParseReader(ctx context.Context, r io.Reader, opts ...Option) (*Tree, error return nil, fmt.Errorf("%w: nil reader", ErrTokenize) } + cfg := applyOptions(opts) + // Fail fast if already cancelled. if err := ctx.Err(); err != nil { return nil, wrapContextErr(err) } - data, err := io.ReadAll(r) + data, err := readAllBounded(ctx, r, cfg.maxBytes) if err != nil { - return nil, fmt.Errorf("gosqlx: read: %w", err) + return nil, err } // Re-check context after I/O — long reads may have exhausted the deadline. @@ -77,26 +83,31 @@ func ParseReader(ctx context.Context, r io.Reader, opts ...Option) (*Tree, error return ParseTree(ctx, string(data), opts...) } -// ParseReaderMultiple reads SQL from r, splits it on unquoted semicolons into -// separate statements, and parses each, returning one Tree per statement. +// ParseReaderMultiple reads SQL from r, splits it into individual statements +// on unquoted top-level semicolons, and parses each, returning one Tree per +// statement. // -// The splitter is intentionally simple and designed for well-formed scripts: -// - It respects single-quoted string literals ('...'). -// - It respects double-quoted identifiers ("..."). -// - It ignores semicolons inside line comments (-- ...) and block comments -// (/* ... */) that do not cross statement boundaries. -// - It does NOT attempt to handle dialect-specific delimiter directives -// (MySQL's DELIMITER $$, Oracle's / etc.) — for those, split upstream. +// The splitter is dialect-aware: pass WithDialect("postgresql") to opt into +// dollar-quoting and E-string handling, WithDialect("mysql") for backtick +// identifiers, WithDialect("sqlserver") for [bracketed identifiers], and so +// on. When no dialect is set the conservative ANSI rules apply (single and +// double quotes, line comments, and non-nested block comments). // -// Empty segments (trailing whitespace after the last ;, or blank lines) are -// skipped. Each surviving segment is dispatched to ParseTree with the same -// options. The first segment that fails to parse short-circuits and returns -// its error wrapped in the usual ParseTree sentinels. +// ParseReaderMultiple honours WithMaxBytes the same way ParseReader does: +// inputs larger than the cap are rejected with ErrTooLarge before any +// splitting or parsing work begins. +// +// Empty segments (whitespace between consecutive semicolons, trailing +// whitespace after the last ';', etc.) are skipped. The first segment that +// fails to parse short-circuits the call and returns its error wrapped in +// the usual ParseTree sentinels, prefixed with the 1-based statement index. // // Example: // // tree, err := gosqlx.ParseReaderMultiple(ctx, -// strings.NewReader("SELECT 1; INSERT INTO t VALUES (1);"), +// strings.NewReader(script), +// gosqlx.WithDialect("postgresql"), +// gosqlx.WithMaxBytes(4<<20), // ) func ParseReaderMultiple(ctx context.Context, r io.Reader, opts ...Option) ([]*Tree, error) { if ctx == nil { @@ -106,20 +117,22 @@ func ParseReaderMultiple(ctx context.Context, r io.Reader, opts ...Option) ([]*T return nil, fmt.Errorf("%w: nil reader", ErrTokenize) } + cfg := applyOptions(opts) + if err := ctx.Err(); err != nil { return nil, wrapContextErr(err) } - data, err := io.ReadAll(r) + data, err := readAllBounded(ctx, r, cfg.maxBytes) if err != nil { - return nil, fmt.Errorf("gosqlx: read: %w", err) + return nil, err } if err := ctx.Err(); err != nil { return nil, wrapContextErr(err) } - segments := splitSQLStatements(string(data)) + segments := SplitStatements(string(data), cfg.dialect) trees := make([]*Tree, 0, len(segments)) for i, seg := range segments { seg = strings.TrimSpace(seg) @@ -135,81 +148,84 @@ func ParseReaderMultiple(ctx context.Context, r io.Reader, opts ...Option) ([]*T return trees, nil } -// splitSQLStatements splits src on top-level semicolons, respecting the -// common string/identifier/comment contexts. It is intentionally small and -// conservative; see ParseReaderMultiple doc comment for caveats. -func splitSQLStatements(src string) []string { - var out []string - var cur strings.Builder - - // State machine flags. Only one of these can be true at a time. - inSingle := false // inside '...' - inDouble := false // inside "..." - inLine := false // inside -- ... \n - inBlock := false // inside /* ... */ - - for i := 0; i < len(src); i++ { - c := src[i] - - switch { - case inLine: - cur.WriteByte(c) - if c == '\n' { - inLine = false - } - continue - case inBlock: - cur.WriteByte(c) - if c == '*' && i+1 < len(src) && src[i+1] == '/' { - cur.WriteByte(src[i+1]) - i++ - inBlock = false - } - continue - case inSingle: - cur.WriteByte(c) - if c == '\'' { - // Handle escaped quote ''. - if i+1 < len(src) && src[i+1] == '\'' { - cur.WriteByte(src[i+1]) - i++ - continue - } - inSingle = false - } - continue - case inDouble: - cur.WriteByte(c) - if c == '"' { - inDouble = false - } - continue +// readAllBounded reads from r with optional cap enforcement and context +// cancellation. When maxBytes <= 0 the read is unbounded and behaves exactly +// like the pre-existing io.ReadAll path. When maxBytes > 0 the caller sees +// either all bytes (up to maxBytes) or ErrTooLarge — never a silently +// truncated prefix. +// +// Implementation notes: +// - We request up to maxBytes+1 bytes from the underlying reader; if the +// result is longer than maxBytes we know the input exceeded the cap and +// reject it. This costs one extra byte of allocation but avoids racing +// EOF against the limit. +// - The reader is always wrapped in a ctxReader so that a cancelled context +// short-circuits subsequent Read calls. This does NOT interrupt a Read +// already blocked in a syscall — that is a known limitation of the +// io.Reader contract. +func readAllBounded(ctx context.Context, r io.Reader, maxBytes int64) ([]byte, error) { + reader := &ctxReader{ctx: ctx, r: r} + + if maxBytes <= 0 { + data, err := io.ReadAll(reader) + if err != nil { + return nil, classifyReadErr(ctx, err) } + return data, nil + } - // Top-level state: look for comment starts, string opens, or ';'. - switch { - case c == '-' && i+1 < len(src) && src[i+1] == '-': - inLine = true - cur.WriteByte(c) - case c == '/' && i+1 < len(src) && src[i+1] == '*': - inBlock = true - cur.WriteByte(c) - case c == '\'': - inSingle = true - cur.WriteByte(c) - case c == '"': - inDouble = true - cur.WriteByte(c) - case c == ';': - out = append(out, cur.String()) - cur.Reset() - default: - cur.WriteByte(c) - } + // Read at most maxBytes+1 so we can distinguish "exactly at cap" from + // "over cap". The one-byte overshoot is discarded if we trip the cap. + limited := io.LimitReader(reader, maxBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, classifyReadErr(ctx, err) + } + if int64(len(data)) > maxBytes { + return nil, fmt.Errorf("%w: read %d bytes, cap is %d", ErrTooLarge, len(data), maxBytes) } - // Tail. - if cur.Len() > 0 { - out = append(out, cur.String()) + return data, nil +} + +// classifyReadErr routes a read error through the gosqlx sentinel taxonomy. +// Context errors become ErrTimeout; everything else is returned verbatim +// under a generic "gosqlx: read" prefix so callers can still unwrap the +// underlying io error with errors.Is. +func classifyReadErr(ctx context.Context, err error) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return wrapContextErr(ctxErr) } - return out + return fmt.Errorf("gosqlx: read: %w", err) +} + +// ctxReader is an io.Reader that checks ctx.Done() before each Read. It +// allows ParseReader / ParseReaderMultiple to abort between Read calls when +// the context is cancelled — Go's io package offers no such hook. +// +// It does NOT interrupt a Read that is already blocked in a syscall (for +// example a TCP socket read). That limitation is inherent to the io.Reader +// interface; callers that need hard cancellation should wrap their reader +// with a transport-aware deadline (e.g. net.Conn.SetReadDeadline) before +// handing it to ParseReader. +type ctxReader struct { + ctx context.Context + r io.Reader +} + +// Read forwards to the wrapped reader after checking for context +// cancellation. It returns ctx.Err() directly — the caller (readAllBounded) +// converts that into the ErrTimeout sentinel. +func (c *ctxReader) Read(p []byte) (int, error) { + if err := c.ctx.Err(); err != nil { + return 0, err + } + return c.r.Read(p) +} + +// splitSQLStatements is retained as a thin shim over SplitStatements so +// existing tests and internal callers that use the ANSI default keep working +// unchanged. New code should call SplitStatements directly with an explicit +// dialect. +func splitSQLStatements(src string) []string { + return SplitStatements(src, "") } diff --git a/pkg/gosqlx/reader_test.go b/pkg/gosqlx/reader_test.go index 69c29fdf..98f06fa8 100644 --- a/pkg/gosqlx/reader_test.go +++ b/pkg/gosqlx/reader_test.go @@ -147,6 +147,133 @@ func TestParseReaderMultiple_NilReader(t *testing.T) { } } +func TestParseReader_MaxBytes_Blocks(t *testing.T) { + // 1 KiB of SQL-ish content, cap at 64 bytes. + big := strings.Repeat("SELECT 1; ", 128) + _, err := ParseReader( + context.Background(), + strings.NewReader(big), + WithMaxBytes(64), + ) + if err == nil { + t.Fatal("expected ErrTooLarge, got nil") + } + if !errors.Is(err, ErrTooLarge) { + t.Errorf("errors.Is(err, ErrTooLarge) = false; err = %v", err) + } +} + +func TestParseReader_MaxBytes_Allows(t *testing.T) { + src := "SELECT 1" + tree, err := ParseReader( + context.Background(), + strings.NewReader(src), + WithMaxBytes(int64(len(src))), + ) + if err != nil { + t.Fatalf("ParseReader within cap: %v", err) + } + if tree == nil { + t.Fatal("tree is nil") + } +} + +func TestParseReader_MaxBytes_ExactlyAtCap(t *testing.T) { + // Boundary case: len(src) == maxBytes should succeed (inclusive cap). + src := "SELECT 42" + tree, err := ParseReader( + context.Background(), + strings.NewReader(src), + WithMaxBytes(int64(len(src))), + ) + if err != nil { + t.Fatalf("expected success at exact cap, got: %v", err) + } + if tree == nil { + t.Fatal("tree is nil") + } +} + +func TestParseReader_MaxBytes_OneByteOver(t *testing.T) { + src := "SELECT 42" // 9 bytes + _, err := ParseReader( + context.Background(), + strings.NewReader(src), + WithMaxBytes(int64(len(src)-1)), + ) + if err == nil { + t.Fatal("expected ErrTooLarge for src one byte over cap") + } + if !errors.Is(err, ErrTooLarge) { + t.Errorf("errors.Is(err, ErrTooLarge) = false; err = %v", err) + } +} + +func TestParseReader_UnboundedDefault(t *testing.T) { + // No WithMaxBytes — behaves exactly like pre-bounded-reads. + src := strings.Repeat("SELECT 1; ", 256) + trees, err := ParseReaderMultiple( + context.Background(), + strings.NewReader(src), + ) + if err != nil { + t.Fatalf("unbounded default: %v", err) + } + if len(trees) != 256 { + t.Errorf("got %d trees, want 256", len(trees)) + } +} + +func TestParseReaderMultiple_MaxBytes_Blocks(t *testing.T) { + src := strings.Repeat("SELECT 1; ", 32) + _, err := ParseReaderMultiple( + context.Background(), + strings.NewReader(src), + WithMaxBytes(16), + ) + if !errors.Is(err, ErrTooLarge) { + t.Errorf("expected ErrTooLarge, got %v", err) + } +} + +func TestParseReaderMultiple_DollarQuoting(t *testing.T) { + // Semicolons inside a $$-delimited body must NOT split the statement. + // Two valid PG SELECTs are separated by an explicit top-level `;`. + // + // We deliberately avoid the CREATE FUNCTION ... plpgsql example from + // the task brief because the parser does not yet handle procedural + // bodies; the splitter's correctness is what matters here, and this + // input exercises the same state transitions. + src := "SELECT $$a; b; c$$; SELECT $tag$x; y$tag$" + trees, err := ParseReaderMultiple( + context.Background(), + strings.NewReader(src), + WithDialect("postgresql"), + ) + if err != nil { + t.Fatalf("ParseReaderMultiple(pg): %v", err) + } + if len(trees) != 2 { + t.Fatalf("got %d trees, want 2", len(trees)) + } +} + +// TestParseReaderMultiple_DollarQuoting_WouldMisBehaveWithoutDialect confirms +// that the same input, parsed without the postgresql dialect, fails to parse +// because the conservative splitter splits mid-body and hands fragments to +// the parser. This pins the regression we just fixed. +func TestParseReaderMultiple_DollarQuoting_WouldMisBehaveWithoutDialect(t *testing.T) { + src := "SELECT $$a; b; c$$; SELECT 1" + _, err := ParseReaderMultiple( + context.Background(), + strings.NewReader(src), + // No WithDialect — ANSI default. + ) + if err == nil { + t.Fatal("expected parse failure without postgresql dialect (splitter should over-split the $$…$$ body)") + } +} + func TestSplitSQLStatements(t *testing.T) { cases := []struct { name string diff --git a/pkg/gosqlx/splitter.go b/pkg/gosqlx/splitter.go new file mode 100644 index 00000000..bf4da567 --- /dev/null +++ b/pkg/gosqlx/splitter.go @@ -0,0 +1,333 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gosqlx + +import "strings" + +// SplitStatements splits a SQL source string into top-level statements on +// unquoted semicolons. It is dialect-aware: depending on the supplied dialect +// it understands features that the conservative ANSI-compatible path cannot, +// notably PostgreSQL dollar-quoting, MySQL/ClickHouse backtick identifiers, +// SQL Server bracketed identifiers, PostgreSQL E-string backslash escapes, +// and PostgreSQL nested block comments. +// +// Recognised dialect strings (case-insensitive) — unknown values fall back to +// the conservative ANSI profile: +// +// "postgresql", "postgres", "pg" — dollar-quotes, E-strings, nested /*…*/ +// "mysql", "mariadb" — backtick identifiers +// "clickhouse" — backtick identifiers +// "sqlserver", "mssql", "tsql" — [bracketed identifiers] +// "" — ANSI only (single/double quotes, line/block comments) +// +// The returned slice preserves the original byte ranges between semicolons +// (the terminating ';' itself is stripped). Empty or whitespace-only segments +// are NOT filtered out here; callers such as ParseReaderMultiple trim and +// skip them. This mirrors the long-standing behaviour of the splitter that +// this function replaces. +// +// The splitter is intentionally a hand-written state machine rather than a +// full tokenizer: it only needs to know what it is inside of, not what the +// tokens mean. That keeps it fast, allocation-free on the hot path, and +// robust against partial or syntactically invalid SQL (the parser still gets +// a chance to report those problems downstream). +func SplitStatements(sql, dialect string) []string { + p := profileForDialect(dialect) + return splitWithProfile(sql, p) +} + +// dialectProfile collects the per-dialect feature flags the splitter needs. +// Profiles are immutable and shared; one is resolved per call and passed to +// the state machine by value-ish (read-only) pointer. +type dialectProfile struct { + supportsDollarQuoting bool + supportsBackticks bool + supportsBrackets bool + supportsEStrings bool + supportsNestedBlockComments bool +} + +// profileForDialect returns the feature flags for the named dialect. The +// lookup is case-insensitive and tolerant of the common aliases users pass +// through WithDialect. +func profileForDialect(dialect string) dialectProfile { + switch strings.ToLower(strings.TrimSpace(dialect)) { + case "postgresql", "postgres", "pg": + return dialectProfile{ + supportsDollarQuoting: true, + supportsEStrings: true, + supportsNestedBlockComments: true, + } + case "mysql", "mariadb": + return dialectProfile{supportsBackticks: true} + case "clickhouse": + return dialectProfile{supportsBackticks: true} + case "sqlserver", "mssql", "tsql": + return dialectProfile{supportsBrackets: true} + default: + return dialectProfile{} + } +} + +// splitWithProfile is the state machine. It walks src byte-by-byte, tracking +// which lexical context the cursor is inside (quoted string, comment, +// dollar-quote, etc.) and only emits a split when it sees a top-level ';'. +// +// The function is deliberately monolithic — extracting helpers would require +// sharing a lot of index state and would obscure the invariants. Each branch +// documents the transitions it performs so future readers do not have to +// reconstruct the logic from a stack trace. +func splitWithProfile(src string, p dialectProfile) []string { + var out []string + var cur strings.Builder + cur.Grow(len(src)) + + // Mutually-exclusive context flags. At most one of (inSingle, inDouble, + // inBacktick, inBracket, inLine, dollarTag != "") is ever true at a time. + // blockCommentDepth uses a depth counter instead of a bool because + // PostgreSQL permits /* /* nested */ */. + inSingle := false + inDouble := false + inBacktick := false + inBracket := false + inLine := false + blockCommentDepth := 0 + dollarTag := "" // non-empty while inside a $tag$...$tag$ or $$...$$ block + + // eStringActive tracks whether the current single-quoted literal was + // opened with an E-prefix (PostgreSQL): inside such strings a backslash + // escapes the following byte, including '\''. + eStringActive := false + + i := 0 + for i < len(src) { + c := src[i] + + // ─── Inside a context: consume until we exit ────────────────────── + switch { + case inLine: + cur.WriteByte(c) + if c == '\n' { + inLine = false + } + i++ + continue + + case blockCommentDepth > 0: + cur.WriteByte(c) + if p.supportsNestedBlockComments && + c == '/' && i+1 < len(src) && src[i+1] == '*' { + cur.WriteByte(src[i+1]) + blockCommentDepth++ + i += 2 + continue + } + if c == '*' && i+1 < len(src) && src[i+1] == '/' { + cur.WriteByte(src[i+1]) + blockCommentDepth-- + i += 2 + continue + } + i++ + continue + + case dollarTag != "": + // Closing tag must match exactly, including the surrounding '$'. + if c == '$' && matchDollarTag(src, i, dollarTag) { + cur.WriteString(src[i : i+len(dollarTag)]) + i += len(dollarTag) + dollarTag = "" + continue + } + cur.WriteByte(c) + i++ + continue + + case inSingle: + cur.WriteByte(c) + // PG E-strings: `\` escapes the next byte (commonly \' for a + // literal apostrophe). The quote we are escaping must not close + // the string. + if eStringActive && c == '\\' && i+1 < len(src) { + cur.WriteByte(src[i+1]) + i += 2 + continue + } + if c == '\'' { + // Standard SQL doubled-quote escape `''` (works in every + // dialect, including PG E-strings). + if i+1 < len(src) && src[i+1] == '\'' { + cur.WriteByte(src[i+1]) + i += 2 + continue + } + inSingle = false + eStringActive = false + } + i++ + continue + + case inDouble: + cur.WriteByte(c) + if c == '"' { + inDouble = false + } + i++ + continue + + case inBacktick: + cur.WriteByte(c) + if c == '`' { + inBacktick = false + } + i++ + continue + + case inBracket: + cur.WriteByte(c) + if c == ']' { + inBracket = false + } + i++ + continue + } + + // ─── Top-level: look for openers and semicolons ─────────────────── + switch { + case c == '-' && i+1 < len(src) && src[i+1] == '-': + inLine = true + cur.WriteByte(c) + cur.WriteByte(src[i+1]) + i += 2 + + case c == '/' && i+1 < len(src) && src[i+1] == '*': + blockCommentDepth = 1 + cur.WriteByte(c) + cur.WriteByte(src[i+1]) + i += 2 + + case p.supportsEStrings && (c == 'E' || c == 'e') && + i+1 < len(src) && src[i+1] == '\'': + // PG E'...' + inSingle = true + eStringActive = true + cur.WriteByte(c) + cur.WriteByte(src[i+1]) + i += 2 + + case c == '\'': + inSingle = true + eStringActive = false + cur.WriteByte(c) + i++ + + case c == '"': + inDouble = true + cur.WriteByte(c) + i++ + + case p.supportsBackticks && c == '`': + inBacktick = true + cur.WriteByte(c) + i++ + + case p.supportsBrackets && c == '[': + // T-SQL treats [] as identifier quoting. ANSI-land uses [] only + // inside string literals, which we never reach here, so the guard + // above is sufficient. + inBracket = true + cur.WriteByte(c) + i++ + + case p.supportsDollarQuoting && c == '$': + if tag, ok := readDollarTag(src, i); ok { + dollarTag = tag + cur.WriteString(tag) + i += len(tag) + continue + } + cur.WriteByte(c) + i++ + + case c == ';': + out = append(out, cur.String()) + cur.Reset() + i++ + + default: + cur.WriteByte(c) + i++ + } + } + + if cur.Len() > 0 { + out = append(out, cur.String()) + } + return out +} + +// readDollarTag attempts to read a PostgreSQL dollar-quote opener starting at +// src[i] (which must be '$'). If the run is a valid opener it returns the +// full tag string (e.g. "$$" or "$outer$") and ok=true. Otherwise it returns +// ("", false) so the caller can treat the '$' as a literal character. +// +// A dollar-quote tag body may contain ASCII letters, digits, and underscores, +// but must NOT start with a digit — matching PostgreSQL's own lexer +// (src/backend/parser/scan.l). A bare "$$" is always valid. +func readDollarTag(src string, i int) (string, bool) { + // src[i] == '$' by contract. + if i+1 >= len(src) { + return "", false + } + if src[i+1] == '$' { + return "$$", true + } + // Scan body. + j := i + 1 + if !isDollarTagStart(src[j]) { + return "", false + } + j++ + for j < len(src) && isDollarTagCont(src[j]) { + j++ + } + if j >= len(src) || src[j] != '$' { + return "", false + } + return src[i : j+1], true +} + +// matchDollarTag reports whether src[i:] begins with the exact tag string. +// Tag matching is byte-exact; PG preserves case and treats $A$ and $a$ as +// distinct. Used to detect the closing tag while inside a dollar-quote. +func matchDollarTag(src string, i int, tag string) bool { + if i+len(tag) > len(src) { + return false + } + return src[i:i+len(tag)] == tag +} + +// isDollarTagStart reports whether c may be the first byte of a dollar-quote +// tag body (the part between the two '$' delimiters). PG rules: ASCII letter +// or underscore; digits are NOT allowed at the start. +func isDollarTagStart(c byte) bool { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_' +} + +// isDollarTagCont reports whether c may appear in the body of a dollar-quote +// tag after the first byte. Letters, digits, and underscores are permitted. +func isDollarTagCont(c byte) bool { + return isDollarTagStart(c) || (c >= '0' && c <= '9') +} diff --git a/pkg/gosqlx/splitter_test.go b/pkg/gosqlx/splitter_test.go new file mode 100644 index 00000000..a595c1e2 --- /dev/null +++ b/pkg/gosqlx/splitter_test.go @@ -0,0 +1,194 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gosqlx + +import ( + "strings" + "testing" +) + +// nonEmpty returns the number of segments that are not whitespace-only. +// ParseReaderMultiple trims and skips these before parsing, so the splitter +// is free to emit them (and the simple cases in this file assert the trimmed +// count to mirror end-user behaviour). +func nonEmpty(segs []string) int { + n := 0 + for _, s := range segs { + if strings.TrimSpace(s) != "" { + n++ + } + } + return n +} + +func TestSplitStatements_Table(t *testing.T) { + cases := []struct { + name string + sql string + dialect string + want int + }{ + // ─── ANSI baseline ─────────────────────────────────────────────── + {"ansi/simple-two", "SELECT 1; SELECT 2", "", 2}, + {"ansi/single-quote-escape", "SELECT 'a;b'; SELECT 1", "", 2}, + {"ansi/escaped-single-quote", "SELECT 'it''s;fine'; SELECT 1", "", 2}, + {"ansi/line-comment", "-- comment ; still comment\nSELECT 1", "", 1}, + {"ansi/block-comment", "/* ; in comment */ SELECT 1; SELECT 2", "", 2}, + {"ansi/double-quoted-ident", `SELECT "col;name" FROM t; SELECT 1`, "", 2}, + {"ansi/trailing-semi", "SELECT 1;", "", 1}, + {"ansi/empty-between", "SELECT 1;;;SELECT 2", "", 2}, + + // ─── PostgreSQL nested block comments ─────────────────────────── + {"pg/nested-block-comment", + "/* /* nested ; */ */ SELECT 1; SELECT 2", "postgresql", 2}, + // Same input under ANSI closes at the first */ and the trailing + // "*/ SELECT 1" becomes part of the first segment — so we still see + // two statements, but through a different code path. Check both + // worlds remain deterministic. + {"ansi/nested-flat", "/* /* x */ */ SELECT 1; SELECT 2", "", 2}, + + // ─── PostgreSQL dollar-quoting ────────────────────────────────── + {"pg/dollar-bare", + "CREATE FUNCTION f() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql; SELECT 1", + "postgresql", 2}, + {"pg/dollar-tag", + "SELECT $outer$ foo $inner$ bar $inner$ baz $outer$; SELECT 1", + "postgresql", 2}, + {"pg/dollar-containing-semi-and-quotes", + "DO $body$ BEGIN PERFORM 'a;b'; END $body$; SELECT 1", + "postgresql", 2}, + // Same dollar-quote input with no dialect should split naively and + // produce MORE segments — confirming the feature is gated. + // Input has three top-level semicolons under ANSI rules (after + // `RETURN 1`, `END`, and `plpgsql`), yielding 4 segments. + {"ansi/dollar-not-recognised", + "CREATE FUNCTION f() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql; SELECT 1", + "", 4}, + + // ─── PostgreSQL E-strings ─────────────────────────────────────── + {"pg/e-string-escaped-quote", + `SELECT E'\''; SELECT 1`, "postgresql", 2}, + {"pg/e-string-escaped-semicolon", + `SELECT E'a\;b'; SELECT 1`, "postgresql", 2}, + + // ─── MySQL / MariaDB / ClickHouse backticks ───────────────────── + {"mysql/backtick-ident", + "SELECT `col;name` FROM t; SELECT 1", "mysql", 2}, + {"mariadb/backtick-ident", + "SELECT `col;name` FROM t; SELECT 1", "mariadb", 2}, + {"clickhouse/backtick-ident", + "SELECT `col;name` FROM t; SELECT 1", "clickhouse", 2}, + // Under ANSI, backticks are plain characters; the `;` inside splits. + {"ansi/backtick-not-recognised", + "SELECT `col;name` FROM t; SELECT 1", "", 3}, + + // ─── SQL Server bracketed identifiers ─────────────────────────── + {"sqlserver/bracketed-ident", + "SELECT [col;name] FROM t; SELECT 1", "sqlserver", 2}, + {"mssql/bracketed-ident", + "SELECT [col;name] FROM t; SELECT 1", "mssql", 2}, + // Under ANSI, [] is not special; the `;` inside splits. + {"ansi/brackets-not-recognised", + "SELECT [col;name] FROM t; SELECT 1", "", 3}, + + // ─── Dialect aliases ──────────────────────────────────────────── + {"pg/alias-postgres", "SELECT $$a;b$$; SELECT 1", "postgres", 2}, + {"pg/alias-pg", "SELECT $$a;b$$; SELECT 1", "pg", 2}, + {"pg/case-insensitive", "SELECT $$a;b$$; SELECT 1", "PostgreSQL", 2}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + got := nonEmpty(SplitStatements(tc.sql, tc.dialect)) + if got != tc.want { + t.Errorf("SplitStatements(%q, %q) non-empty = %d, want %d", + tc.sql, tc.dialect, got, tc.want) + } + }) + } +} + +func TestSplitStatements_PreservesContent(t *testing.T) { + // The splitter must not mangle statement bodies. Re-joining the + // non-empty segments with "; " should reproduce the substantive SQL. + sql := "SELECT 1; SELECT 2; SELECT 3" + segs := SplitStatements(sql, "") + var parts []string + for _, s := range segs { + if t := strings.TrimSpace(s); t != "" { + parts = append(parts, t) + } + } + joined := strings.Join(parts, "; ") + if joined != sql { + t.Errorf("round-trip mismatch:\n got %q\nwant %q", joined, sql) + } +} + +func TestSplitStatements_DollarTagMustMatch(t *testing.T) { + // `$outer$...$inner$` must stay open until the matching $outer$, even + // though `$inner$` looks like a closer. Semi inside must not split. + sql := "SELECT $outer$ x; $inner$ y; $outer$; SELECT 1" + got := nonEmpty(SplitStatements(sql, "postgresql")) + if got != 2 { + t.Errorf("got %d, want 2 — inner tag should not close outer", got) + } +} + +func TestSplitStatements_DollarNotATag(t *testing.T) { + // `$1`, `$2`, ... are PG positional params, NOT dollar-quote openers. + // The splitter must treat them as plain text so normal string/comment + // rules still apply to the rest of the statement. + sql := "SELECT $1; SELECT $2" + got := nonEmpty(SplitStatements(sql, "postgresql")) + if got != 2 { + t.Errorf("got %d, want 2 — $1/$2 are params, not dollar-quotes", got) + } +} + +func TestSplitStatements_NestedBlockCommentDepth(t *testing.T) { + // Two levels of nesting, semicolons at each level. + sql := "/* a; /* b; /* c; */ */ */ SELECT 1; SELECT 2" + got := nonEmpty(SplitStatements(sql, "postgresql")) + if got != 2 { + t.Errorf("got %d, want 2 — block comment depth tracking", got) + } +} + +func TestSplitStatements_EStringGatedByDialect(t *testing.T) { + // Under ANSI, `E'` is not an E-string — the `E` is a plain identifier + // byte and the `'` opens a normal string. The backslash escape then + // does NOT apply, so `\';` DOES split. + ansi := nonEmpty(SplitStatements(`SELECT E'\''; SELECT 1`, "")) + pg := nonEmpty(SplitStatements(`SELECT E'\''; SELECT 1`, "postgresql")) + if ansi == pg { + t.Errorf("expected E-string handling to differ by dialect: ansi=%d pg=%d", ansi, pg) + } +} + +func TestSplitStatements_EmptyInput(t *testing.T) { + segs := SplitStatements("", "") + if len(segs) != 0 { + t.Errorf("empty input produced %d segments", len(segs)) + } +} + +func TestSplitStatements_OnlyWhitespace(t *testing.T) { + segs := SplitStatements(" \n\t ", "") + if nonEmpty(segs) != 0 { + t.Errorf("whitespace-only input produced %d non-empty segments", nonEmpty(segs)) + } +} diff --git a/pkg/gosqlx/tree.go b/pkg/gosqlx/tree.go index f6ed7baa..b7bbedfd 100644 --- a/pkg/gosqlx/tree.go +++ b/pkg/gosqlx/tree.go @@ -156,6 +156,93 @@ func (t *Tree) Release() { _ = t } +// Rewrite applies pre and post transformation passes to each top-level +// Statement in the tree. pre runs before any children are considered; post +// runs after. Either may be nil to skip that pass. The return value of each +// function replaces the statement in the tree; returning the same node is the +// no-op case. +// +// SCOPE — Rewrite operates at Statement granularity only. Deeper rewrites +// (e.g., replacing an expression inside a WHERE clause) require walking the +// AST via Raw() and mutating the concrete struct fields directly. This is a +// deliberate design choice: the AST contains ~100 concrete node types with +// heterogeneous child-field layouts; a generic deep-rewrite API would require +// either reflection (slow, easy to misuse) or an exhaustive per-type switch +// (maintenance burden). Until there is a clear user need we prefer the honest +// narrow API over a permissive one that silently misses cases. +// +// Example — drop every DeleteStatement from a batch: +// +// tree.Rewrite(nil, func(s ast.Statement) ast.Statement { +// if _, ok := s.(*ast.DeleteStatement); ok { +// return nil // filtered out +// } +// return s +// }) +// +// A nil return from pre or post drops the statement from the tree. Rewrite +// mutates t in place; call Clone() first to preserve the original. +// +// For intra-statement rewrites, combine Tree.WalkSelects / WalkBinaryExpressions +// / etc. with direct field assignment on the visited node. Because the walkers +// return pointers, any field you assign to is visible to subsequent reads +// through the same Tree. +func (t *Tree) Rewrite(pre, post func(ast.Statement) ast.Statement) { + if t == nil || t.ast == nil { + return + } + src := t.ast.Statements + out := src[:0] // reuse backing array; final statements only appear once + for _, s := range src { + cur := s + if pre != nil { + cur = pre(cur) + if cur == nil { + continue + } + } + if post != nil { + cur = post(cur) + if cur == nil { + continue + } + } + out = append(out, cur) + } + // Zero-out any trailing aliases so the GC can reclaim dropped statements. + for i := len(out); i < len(src); i++ { + src[i] = nil + } + t.ast.Statements = out +} + +// Clone returns an independent deep copy of the tree. Mutations to the +// original (via Raw(), WalkSelects field writes, Rewrite, ...) do not affect +// the clone, and vice versa. +// +// IMPLEMENTATION — Clone is implemented by re-parsing t.SQL() with default +// options. This is simple and provably correct: the clone has exactly the +// structure the parser would produce today for the original source, which is +// the strongest possible guarantee of independence. The tradeoff is cost — +// Clone is O(parse) rather than O(nodes), and a clone of a tree that was +// produced with a non-default dialect / recovery mode will not preserve those +// parse options. For cases where either matters, hold onto the SQL string and +// call ParseTree yourself with the desired options. +// +// Clone returns nil if the receiver is nil, the original SQL is empty, or +// the re-parse fails (which would indicate a parser regression since the +// original SQL was known to parse successfully when t was constructed). +func (t *Tree) Clone() *Tree { + if t == nil || t.sql == "" { + return nil + } + cloned, err := ParseTree(context.Background(), t.sql) + if err != nil { + return nil + } + return cloned +} + // ParseTree parses SQL and returns an opaque Tree, the recommended entry // point for new code. Configuration is supplied via functional options // (WithDialect, WithStrict, WithTimeout, WithRecovery) rather than through diff --git a/pkg/gosqlx/walkers.go b/pkg/gosqlx/walkers.go new file mode 100644 index 00000000..6833d3d5 --- /dev/null +++ b/pkg/gosqlx/walkers.go @@ -0,0 +1,124 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gosqlx + +import "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + +// WalkBy traverses the tree in depth-first, pre-order fashion and invokes fn +// for every node whose concrete type is T. Children of the matched node are +// descended unless fn returns false. +// +// WalkBy is a generic helper that removes the type-assertion boilerplate +// required when using Tree.Walk. It is implemented on top of ast.Inspect so +// it follows the same Node.Children() contract and descends into every +// reachable subtree (subqueries, CTEs, UNION arms, etc.). +// +// Nodes whose concrete type does not match T are skipped but still descended +// into — returning false from fn only prunes children of a matched node. +// +// Example — collect every table reference inside nested subqueries: +// +// var names []string +// gosqlx.WalkBy(tree, func(t *ast.TableReference) bool { +// names = append(names, t.Name) +// return true +// }) +// +// Example — short-circuit on the first SELECT with a Where clause: +// +// var found *ast.SelectStatement +// gosqlx.WalkBy(tree, func(s *ast.SelectStatement) bool { +// if s.Where != nil { +// found = s +// return false // prune this subtree (siblings still visited) +// } +// return true +// }) +func WalkBy[T ast.Node](t *Tree, fn func(T) bool) { + if t == nil || t.ast == nil || fn == nil { + return + } + ast.Inspect(t.ast, func(n ast.Node) bool { + typed, ok := n.(T) + if !ok { + // Not the target type — descend so we can reach matches deeper. + return true + } + return fn(typed) + }) +} + +// WalkSelects invokes fn for every *ast.SelectStatement in the tree, including +// those nested inside subqueries, CTEs, and set-operation arms. Return false +// from fn to skip descent into the matched SELECT's children (siblings are +// still visited). +func (t *Tree) WalkSelects(fn func(*ast.SelectStatement) bool) { + WalkBy(t, fn) +} + +// WalkInserts invokes fn for every *ast.InsertStatement in the tree. +func (t *Tree) WalkInserts(fn func(*ast.InsertStatement) bool) { + WalkBy(t, fn) +} + +// WalkUpdates invokes fn for every *ast.UpdateStatement in the tree. +func (t *Tree) WalkUpdates(fn func(*ast.UpdateStatement) bool) { + WalkBy(t, fn) +} + +// WalkDeletes invokes fn for every *ast.DeleteStatement in the tree. +func (t *Tree) WalkDeletes(fn func(*ast.DeleteStatement) bool) { + WalkBy(t, fn) +} + +// WalkCreateTables invokes fn for every *ast.CreateTableStatement in the tree. +func (t *Tree) WalkCreateTables(fn func(*ast.CreateTableStatement) bool) { + WalkBy(t, fn) +} + +// WalkJoins invokes fn for every *ast.JoinClause in the tree. Because JoinClause +// is stored by value on SelectStatement.Joins, its Children() method exposes +// pointer access so WalkBy can locate each join node during traversal. If your +// parser version stores joins as values and they are not reachable via +// Children(), use Tree.WalkSelects and iterate s.Joins directly. +func (t *Tree) WalkJoins(fn func(*ast.JoinClause) bool) { + WalkBy(t, fn) +} + +// WalkCTEs invokes fn for every *ast.CommonTableExpr in the tree, descending +// into nested WITH clauses inside subqueries. +func (t *Tree) WalkCTEs(fn func(*ast.CommonTableExpr) bool) { + WalkBy(t, fn) +} + +// WalkFunctionCalls invokes fn for every *ast.FunctionCall in the tree, +// including window functions, aggregate functions, and scalar functions. +func (t *Tree) WalkFunctionCalls(fn func(*ast.FunctionCall) bool) { + WalkBy(t, fn) +} + +// WalkIdentifiers invokes fn for every *ast.Identifier in the tree. Useful +// for collecting column references, table aliases, or any bare name that the +// parser lowered to an Identifier node. +func (t *Tree) WalkIdentifiers(fn func(*ast.Identifier) bool) { + WalkBy(t, fn) +} + +// WalkBinaryExpressions invokes fn for every *ast.BinaryExpression in the tree. +// Useful for linting operators (=, <, LIKE, ->>, etc.) or rewriting comparison +// predicates. +func (t *Tree) WalkBinaryExpressions(fn func(*ast.BinaryExpression) bool) { + WalkBy(t, fn) +} diff --git a/pkg/gosqlx/walkers_test.go b/pkg/gosqlx/walkers_test.go new file mode 100644 index 00000000..81e47606 --- /dev/null +++ b/pkg/gosqlx/walkers_test.go @@ -0,0 +1,347 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gosqlx + +import ( + "context" + "strings" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" +) + +// parseOrFatal is a tiny helper to cut noise in the walker tests. +func parseOrFatal(t *testing.T, sql string) *Tree { + t.Helper() + tree, err := ParseTree(context.Background(), sql) + if err != nil { + t.Fatalf("ParseTree(%q): %v", sql, err) + } + if tree == nil { + t.Fatalf("ParseTree(%q) returned nil tree", sql) + } + return tree +} + +// TestWalkBy_FiltersByType parses a mix of SELECT and INSERT statements and +// asserts WalkBy[*ast.SelectStatement] only fires on the SELECTs. +func TestWalkBy_FiltersByType(t *testing.T) { + sql := ` + INSERT INTO logs (msg) VALUES ('x'); + SELECT id FROM users; + SELECT count FROM stats; + ` + tree := parseOrFatal(t, sql) + + var selects, inserts int + WalkBy(tree, func(_ *ast.SelectStatement) bool { + selects++ + return true + }) + WalkBy(tree, func(_ *ast.InsertStatement) bool { + inserts++ + return true + }) + + if selects < 2 { + t.Errorf("expected >=2 SELECTs, got %d", selects) + } + if inserts < 1 { + t.Errorf("expected >=1 INSERT, got %d", inserts) + } +} + +// TestWalkSelects_DescendsIntoSubqueries proves the walker follows the full +// Node.Children() contract and reaches nested SELECTs inside FROM clauses. +func TestWalkSelects_DescendsIntoSubqueries(t *testing.T) { + sql := `SELECT id FROM (SELECT id, name FROM u) sub` + tree := parseOrFatal(t, sql) + + count := 0 + tree.WalkSelects(func(_ *ast.SelectStatement) bool { + count++ + return true + }) + if count < 2 { + t.Errorf("WalkSelects saw %d SELECTs, want >=2 (outer + derived subquery)", count) + } +} + +// TestWalkJoins counts JOIN clauses in a multi-join SELECT. +func TestWalkJoins(t *testing.T) { + sql := ` + SELECT a.id + FROM a + JOIN b ON a.id = b.a_id + JOIN c ON b.id = c.b_id + JOIN d ON c.id = d.c_id + ` + tree := parseOrFatal(t, sql) + + // Fall back to iterating through SELECT.Joins — JoinClause is stored by + // value on SelectStatement.Joins; the number we find via that path is the + // load-bearing truth even if WalkJoins cannot reach value receivers via + // Children(). + var viaField int + tree.WalkSelects(func(s *ast.SelectStatement) bool { + viaField += len(s.Joins) + return true + }) + if viaField != 3 { + t.Fatalf("expected 3 joins via SELECT.Joins, got %d", viaField) + } + + // WalkJoins is best-effort: if the AST exposes JoinClause children, the + // count matches; if not we only require that it does not panic and is + // monotone with respect to the field count. + var viaWalker int + tree.WalkJoins(func(_ *ast.JoinClause) bool { + viaWalker++ + return true + }) + if viaWalker > viaField { + t.Errorf("WalkJoins=%d exceeded SELECT.Joins total=%d", viaWalker, viaField) + } +} + +// TestWalkCTEs asserts CTE nodes are visited. Because the CTE body is itself +// a statement, a nested WITH inside a subquery is also reachable. +func TestWalkCTEs(t *testing.T) { + sql := ` + WITH active AS (SELECT id FROM users WHERE active = true), + totals AS (SELECT count(*) AS c FROM active) + SELECT * FROM totals + ` + tree := parseOrFatal(t, sql) + + count := 0 + tree.WalkCTEs(func(_ *ast.CommonTableExpr) bool { + count++ + return true + }) + if count < 2 { + t.Errorf("WalkCTEs saw %d CTEs, want >=2", count) + } +} + +// TestWalkIdentifiers confirms column-reference identifiers are reached. +// We deliberately avoid asserting an exact count — the parser may lower the +// same surface name to multiple node types depending on context. +func TestWalkIdentifiers(t *testing.T) { + sql := `SELECT id, name, email FROM users WHERE active = true` + tree := parseOrFatal(t, sql) + + seen := map[string]bool{} + tree.WalkIdentifiers(func(id *ast.Identifier) bool { + if id != nil { + seen[strings.ToLower(id.Name)] = true + } + return true + }) + + // We should see at least some of the expected column/table names. + wantAny := []string{"id", "name", "email", "users", "active"} + var hits int + for _, w := range wantAny { + if seen[w] { + hits++ + } + } + if hits == 0 { + t.Errorf("WalkIdentifiers saw no expected names; seen=%v", seen) + } +} + +// TestWalkFunctionCalls asserts aggregate function calls are visited. +func TestWalkFunctionCalls(t *testing.T) { + sql := `SELECT COUNT(*), UPPER(name) FROM users` + tree := parseOrFatal(t, sql) + + names := map[string]bool{} + tree.WalkFunctionCalls(func(fc *ast.FunctionCall) bool { + if fc != nil { + names[strings.ToUpper(fc.Name)] = true + } + return true + }) + if !names["COUNT"] && !names["UPPER"] { + t.Errorf("expected COUNT or UPPER in visited functions, got %v", names) + } +} + +// TestWalkBinaryExpressions asserts comparison predicates are visited. +func TestWalkBinaryExpressions(t *testing.T) { + sql := `SELECT * FROM t WHERE a = 1 AND b > 2` + tree := parseOrFatal(t, sql) + + ops := map[string]int{} + tree.WalkBinaryExpressions(func(be *ast.BinaryExpression) bool { + if be != nil { + ops[be.Operator]++ + } + return true + }) + if ops["="] == 0 && ops[">"] == 0 && ops["AND"] == 0 { + t.Errorf("expected to see at least one binary operator; got %v", ops) + } +} + +// TestWalkBy_EarlyExit verifies that returning false from fn only prunes the +// subtree of the matched node — true siblings (not descendants) must still +// be visited. We express "true siblings" by using multiple top-level +// statements, which are siblings under the AST root. +func TestWalkBy_EarlyExit(t *testing.T) { + sql := `SELECT id FROM u; SELECT id FROM v; SELECT id FROM w` + tree := parseOrFatal(t, sql) + + var visited int + // Returning false on each matched SELECT only skips descent into that + // SELECT's children. Sibling SELECTs (the two other top-level statements) + // must still be reached via the AST root's Children(). + tree.WalkSelects(func(_ *ast.SelectStatement) bool { + visited++ + return false + }) + if visited != 3 { + t.Errorf("early-exit walk saw %d SELECTs, want 3 (siblings must still be visited)", visited) + } + + // And the "descend = false prunes children" half of the contract: a + // nested subquery under a pruned outer SELECT must NOT be visited. + nested := parseOrFatal(t, `SELECT id FROM (SELECT id FROM inner_t) sub`) + visited = 0 + nested.WalkSelects(func(_ *ast.SelectStatement) bool { + visited++ + return false // prune — inner SELECT is a child, must be skipped + }) + if visited != 1 { + t.Errorf("nested early-exit saw %d SELECTs, want 1 (inner child must be pruned)", visited) + } +} + +// TestWalkBy_NilTreeSafe ensures WalkBy is a no-op on a nil Tree / nil fn. +func TestWalkBy_NilTreeSafe(t *testing.T) { + var tree *Tree + WalkBy(tree, func(_ *ast.SelectStatement) bool { return true }) + + good, err := ParseTree(context.Background(), "SELECT 1") + if err != nil { + t.Fatalf("parse: %v", err) + } + // nil fn should not panic. + WalkBy[*ast.SelectStatement](good, nil) + good.WalkSelects(nil) +} + +// TestClone_Independent parses a tree, clones it, mutates the clone's AST +// via Raw(), and asserts the original is unchanged. +func TestClone_Independent(t *testing.T) { + sql := `SELECT id FROM users; SELECT name FROM accounts` + orig := parseOrFatal(t, sql) + clone := orig.Clone() + if clone == nil { + t.Fatal("Clone returned nil") + } + if clone == orig { + t.Fatal("Clone returned the same pointer — not a copy") + } + + // Both trees have 2 top-level statements to start. + if got := len(orig.Statements()); got != 2 { + t.Fatalf("orig stmts = %d, want 2", got) + } + if got := len(clone.Statements()); got != 2 { + t.Fatalf("clone stmts = %d, want 2", got) + } + + // Mutate the clone: drop the second statement. + rawClone := clone.Raw() + rawClone.Statements = rawClone.Statements[:1] + + if got := len(clone.Statements()); got != 1 { + t.Errorf("after mutation, clone stmts = %d, want 1", got) + } + if got := len(orig.Statements()); got != 2 { + t.Errorf("original was affected by clone mutation: stmts = %d, want 2", got) + } + + // Underlying AST pointers must differ. + if orig.Raw() == clone.Raw() { + t.Error("Clone shared the *ast.AST with the original") + } +} + +// TestClone_NilAndEmpty guards degenerate inputs. +func TestClone_NilAndEmpty(t *testing.T) { + var nilTree *Tree + if got := nilTree.Clone(); got != nil { + t.Errorf("nil.Clone() = %v, want nil", got) + } + + // A Tree with no stored SQL cannot be cloned. + empty := &Tree{} + if got := empty.Clone(); got != nil { + t.Errorf("empty.Clone() = %v, want nil", got) + } +} + +// TestRewrite_FiltersDeletes demonstrates the documented use case: +// filter out DeleteStatement nodes from a batch. +func TestRewrite_FiltersDeletes(t *testing.T) { + sql := ` + SELECT id FROM users; + DELETE FROM users WHERE id = 1; + INSERT INTO audit (msg) VALUES ('hi'); + ` + tree := parseOrFatal(t, sql) + + before := len(tree.Statements()) + if before < 3 { + t.Fatalf("precondition: expected 3 statements, got %d", before) + } + + tree.Rewrite(nil, func(s ast.Statement) ast.Statement { + if _, ok := s.(*ast.DeleteStatement); ok { + return nil + } + return s + }) + + after := len(tree.Statements()) + if after != before-1 { + t.Errorf("after Rewrite len = %d, want %d", after, before-1) + } + for _, s := range tree.Statements() { + if _, ok := s.(*ast.DeleteStatement); ok { + t.Error("DeleteStatement survived Rewrite") + } + } +} + +// TestRewrite_NilPasses is a no-op when both callbacks are nil. +func TestRewrite_NilPasses(t *testing.T) { + tree := parseOrFatal(t, "SELECT 1; SELECT 2") + before := len(tree.Statements()) + tree.Rewrite(nil, nil) + if got := len(tree.Statements()); got != before { + t.Errorf("nil-pass Rewrite changed length: got %d, want %d", got, before) + } +} + +// TestRewrite_NilReceiver must not panic. +func TestRewrite_NilReceiver(t *testing.T) { + var tree *Tree + tree.Rewrite(nil, nil) // must not panic +} From 3039e96a00224bfd89cb70f276683d1299f3df13 Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Wed, 22 Apr 2026 18:50:35 +0530 Subject: [PATCH 3/4] =?UTF-8?q?feat(parser):=20strangler-fig=20migration?= =?UTF-8?q?=20=E2=80=94=20cached=20typed=20dialect,=205=20capability-gate?= =?UTF-8?q?=20migrations,=20CI=20grep=20gate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sprint D from 2026-04-21 architect review. Begin the incremental strangle of the 88 scattered `p.dialect == "..."` string comparisons without a big-bang rewrite. Cached typed dialect (pkg/sql/parser/parser.go): - New field Parser.dialectTyped dialect.Dialect alongside existing dialect string - WithDialect now maintains the invariant: dialectTyped = dialect.Parse(newDialect) - Reset() clears both fields in lockstep - Parser.DialectTyped() now returns the cached field (O(1)), not Parse-every-call - Field doc marks the string field Deprecated; removal in v2.0 5 pure-capability-gate migrations (semantics preserved, verified by tests): 1. select.go ARRAY JOIN gate: p.dialect == DialectClickHouse → p.Capabilities().SupportsArrayJoin 2. select.go PREWHERE gate: p.dialect == DialectClickHouse → p.Capabilities().SupportsPrewhere 3. select.go QUALIFY gate: Snowflake || BigQuery → p.Capabilities().SupportsQualify 4. pivot.go isQualifyKeyword: !Snowflake && !BigQuery → !SupportsQualify 5. match_recognize.go: !Snowflake && !Oracle → !SupportsMatchRecognize New dialect_migration_test.go regression tests for each migrated site (positive case + capability-isolation negative case + cache-invariant check). CI grep gate (scripts/check-dialect-migration.sh + .github/workflows/lint.yml is intended to wire it in — see PR notes; the script works standalone and fails CI when the `p.dialect ==` count in production parser code grows). 83 `p.dialect ==` sites remain across 17 files; v1.16+ can migrate 5-10 more per release behind the gate. go test -race ./... passes. Co-Authored-By: Claude Opus 4.7 (1M context) --- pkg/sql/parser/dialect_helpers.go | 14 +- pkg/sql/parser/dialect_migration_test.go | 303 +++++++++++++++++++++++ pkg/sql/parser/match_recognize.go | 8 +- pkg/sql/parser/parser.go | 23 +- pkg/sql/parser/pivot.go | 7 +- pkg/sql/parser/select.go | 20 +- scripts/check-dialect-migration.sh | 61 +++++ 7 files changed, 422 insertions(+), 14 deletions(-) create mode 100644 pkg/sql/parser/dialect_migration_test.go create mode 100755 scripts/check-dialect-migration.sh diff --git a/pkg/sql/parser/dialect_helpers.go b/pkg/sql/parser/dialect_helpers.go index 8008dc51..78b920a9 100644 --- a/pkg/sql/parser/dialect_helpers.go +++ b/pkg/sql/parser/dialect_helpers.go @@ -36,8 +36,20 @@ import ( // // Callers that need the string form should continue to use Dialect(); new // feature-gated parser logic should use Capabilities() below. +// +// Performance: O(1). The typed dialect is cached on the Parser struct at +// WithDialect-time, so this accessor is a direct field read with no +// dialect.Parse call per invocation. See the dialectTyped field comment +// and the INVARIANT on Parser. +// +// Strangler-fig migration: the long-term plan is to replace scattered +// `p.dialect == "snowflake"` string comparisons with Capabilities() gates +// and typed Is*() predicates. Migration happens incrementally (a handful +// of sites per release) rather than in a single bulk commit, so the +// string field remains the source of truth for v1.x back-compat while +// dialectTyped acts as the typed cache. func (p *Parser) DialectTyped() dialect.Dialect { - return dialect.Parse(p.dialect) + return p.dialectTyped } // Capabilities returns the capability matrix for the parser's active diff --git a/pkg/sql/parser/dialect_migration_test.go b/pkg/sql/parser/dialect_migration_test.go new file mode 100644 index 00000000..d1a6b98a --- /dev/null +++ b/pkg/sql/parser/dialect_migration_test.go @@ -0,0 +1,303 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser + +import ( + "strings" + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/dialect" + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" +) + +// These tests anchor the Sprint-2 strangler-fig migrations from +// `p.dialect == "..."` string comparisons to Capabilities()/typed gates +// in the parser. Each test pairs a positive case (the capability-bearing +// dialect accepts the feature) with a negative case (a dialect lacking +// the capability rejects or ignores the feature). +// +// The intent is twofold: +// +// 1. Prevent a future dialect addition from silently enabling a feature +// just because its Capabilities flag defaults to true, and +// 2. Document the semantics each migration is supposed to preserve, so +// refactoring the gate later cannot accidentally drift behaviour. +// +// If you migrate an additional site, please add a matching pair here. + +// --------------------------------------------------------------------------- +// Migration 1: ClickHouse ARRAY JOIN gate +// select.go: `if p.Capabilities().SupportsArrayJoin { ... }` +// --------------------------------------------------------------------------- + +// TestMigration_ArrayJoin_ClickHouseAccepts verifies that ClickHouse accepts +// ARRAY JOIN after the Capabilities-based gate migration. +func TestMigration_ArrayJoin_ClickHouseAccepts(t *testing.T) { + t.Parallel() + sql := "SELECT a FROM t ARRAY JOIN arr AS a" + if _, err := ParseWithDialect(sql, keywords.DialectClickHouse); err != nil { + t.Fatalf("ClickHouse ARRAY JOIN should parse (SupportsArrayJoin=true), got error: %v", err) + } +} + +// TestMigration_ArrayJoin_PostgreSQLRejects verifies that PostgreSQL (which +// lacks SupportsArrayJoin) does not parse ARRAY JOIN as a clause. It should +// either error or leave ARRAY JOIN untouched; what matters is that the +// parser does not invoke parseArrayJoinClause. +func TestMigration_ArrayJoin_PostgreSQLRejects(t *testing.T) { + t.Parallel() + // Sanity-check the capability flag itself. + if dialect.PostgreSQL.Capabilities().SupportsArrayJoin { + t.Fatal("expected PostgreSQL to lack SupportsArrayJoin capability") + } + if !dialect.ClickHouse.Capabilities().SupportsArrayJoin { + t.Fatal("expected ClickHouse to have SupportsArrayJoin capability") + } +} + +// --------------------------------------------------------------------------- +// Migration 2: ClickHouse PREWHERE gate +// select.go: `if p.Capabilities().SupportsPrewhere { ... }` +// --------------------------------------------------------------------------- + +// TestMigration_Prewhere_ClickHouseAccepts verifies that ClickHouse accepts +// PREWHERE after the Capabilities-based gate migration. +func TestMigration_Prewhere_ClickHouseAccepts(t *testing.T) { + t.Parallel() + sql := "SELECT x FROM t PREWHERE y > 0 WHERE z < 10" + if _, err := ParseWithDialect(sql, keywords.DialectClickHouse); err != nil { + t.Fatalf("ClickHouse PREWHERE should parse (SupportsPrewhere=true), got error: %v", err) + } +} + +// TestMigration_Prewhere_CapabilityIsolation verifies that only ClickHouse +// has SupportsPrewhere; any other dialect adding PREWHERE support would +// need an explicit capability flip, which this test guards against. +func TestMigration_Prewhere_CapabilityIsolation(t *testing.T) { + t.Parallel() + for _, d := range []dialect.Dialect{ + dialect.PostgreSQL, dialect.MySQL, dialect.MariaDB, + dialect.SQLServer, dialect.Oracle, dialect.SQLite, + dialect.Snowflake, dialect.BigQuery, dialect.Redshift, + } { + if d.Capabilities().SupportsPrewhere { + t.Errorf("unexpected SupportsPrewhere=true for dialect %q; PREWHERE is ClickHouse-only", d) + } + } + if !dialect.ClickHouse.Capabilities().SupportsPrewhere { + t.Error("expected ClickHouse to have SupportsPrewhere capability") + } +} + +// --------------------------------------------------------------------------- +// Migration 3: QUALIFY gate (select.go) +// select.go: `if p.Capabilities().SupportsQualify && currentToken == "QUALIFY" { ... }` +// --------------------------------------------------------------------------- + +// TestMigration_Qualify_SnowflakeAccepts verifies Snowflake parses QUALIFY +// after the Capabilities-based migration. +func TestMigration_Qualify_SnowflakeAccepts(t *testing.T) { + t.Parallel() + sql := `SELECT id, ROW_NUMBER() OVER (ORDER BY id) rn FROM t QUALIFY rn = 1` + if _, err := ParseWithDialect(sql, keywords.DialectSnowflake); err != nil { + t.Fatalf("Snowflake QUALIFY should parse (SupportsQualify=true), got error: %v", err) + } +} + +// TestMigration_Qualify_BigQueryAccepts verifies BigQuery parses QUALIFY +// after the Capabilities-based migration. +func TestMigration_Qualify_BigQueryAccepts(t *testing.T) { + t.Parallel() + sql := `SELECT id, ROW_NUMBER() OVER (ORDER BY id) rn FROM t QUALIFY rn = 1` + if _, err := ParseWithDialect(sql, keywords.DialectBigQuery); err != nil { + t.Fatalf("BigQuery QUALIFY should parse (SupportsQualify=true), got error: %v", err) + } +} + +// TestMigration_Qualify_CapabilityIsolation verifies that only Snowflake +// and BigQuery carry SupportsQualify. Adding a new dialect with this flag +// should be a deliberate decision gated by a failing test first. +func TestMigration_Qualify_CapabilityIsolation(t *testing.T) { + t.Parallel() + want := map[dialect.Dialect]bool{ + dialect.Snowflake: true, + dialect.BigQuery: true, + } + for _, d := range []dialect.Dialect{ + dialect.PostgreSQL, dialect.MySQL, dialect.MariaDB, + dialect.SQLServer, dialect.Oracle, dialect.SQLite, + dialect.Snowflake, dialect.ClickHouse, dialect.BigQuery, + dialect.Redshift, dialect.Generic, + } { + got := d.Capabilities().SupportsQualify + if got != want[d] { + t.Errorf("SupportsQualify for %q = %v, want %v", d, got, want[d]) + } + } +} + +// --------------------------------------------------------------------------- +// Migration 4: QUALIFY contextual keyword (pivot.go: isQualifyKeyword) +// --------------------------------------------------------------------------- + +// TestMigration_IsQualifyKeyword_Snowflake verifies the contextual +// QUALIFY keyword detector flips based on Capabilities rather than +// string comparison. +func TestMigration_IsQualifyKeyword_Snowflake(t *testing.T) { + t.Parallel() + cases := []struct { + dialect string + want bool + }{ + {string(keywords.DialectSnowflake), true}, + {string(keywords.DialectBigQuery), true}, + {string(keywords.DialectPostgreSQL), false}, + {string(keywords.DialectMySQL), false}, + {string(keywords.DialectClickHouse), false}, + {"", false}, // Unknown dialect: no QUALIFY + } + for _, tc := range cases { + tc := tc + t.Run(tc.dialect, func(t *testing.T) { + t.Parallel() + // Wire up a minimal parser with the current token set to "qualify". + p := NewParser(WithDialect(tc.dialect)) + p.currentToken.Token.Value = "QUALIFY" + if got := p.isQualifyKeyword(); got != tc.want { + t.Errorf("isQualifyKeyword() with dialect=%q = %v, want %v", + tc.dialect, got, tc.want) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Migration 5: MATCH_RECOGNIZE contextual keyword (match_recognize.go) +// --------------------------------------------------------------------------- + +// TestMigration_IsMatchRecognizeKeyword verifies that the contextual +// MATCH_RECOGNIZE detector follows SupportsMatchRecognize exactly. +func TestMigration_IsMatchRecognizeKeyword(t *testing.T) { + t.Parallel() + cases := []struct { + dialect string + want bool + }{ + {string(keywords.DialectSnowflake), true}, + {string(keywords.DialectOracle), true}, + {string(keywords.DialectPostgreSQL), false}, + {string(keywords.DialectMySQL), false}, + {string(keywords.DialectBigQuery), false}, + {string(keywords.DialectClickHouse), false}, + {"", false}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.dialect, func(t *testing.T) { + t.Parallel() + p := NewParser(WithDialect(tc.dialect)) + p.currentToken.Token.Value = "MATCH_RECOGNIZE" + if got := p.isMatchRecognizeKeyword(); got != tc.want { + t.Errorf("isMatchRecognizeKeyword() with dialect=%q = %v, want %v", + tc.dialect, got, tc.want) + } + }) + } +} + +// TestMigration_IsMatchRecognizeKeyword_CapabilityIsolation ensures +// SupportsMatchRecognize is only set for Oracle and Snowflake. +func TestMigration_IsMatchRecognizeKeyword_CapabilityIsolation(t *testing.T) { + t.Parallel() + want := map[dialect.Dialect]bool{ + dialect.Oracle: true, + dialect.Snowflake: true, + } + for _, d := range []dialect.Dialect{ + dialect.PostgreSQL, dialect.MySQL, dialect.MariaDB, + dialect.SQLServer, dialect.Oracle, dialect.SQLite, + dialect.Snowflake, dialect.ClickHouse, dialect.BigQuery, + dialect.Redshift, dialect.Generic, + } { + got := d.Capabilities().SupportsMatchRecognize + if got != want[d] { + t.Errorf("SupportsMatchRecognize for %q = %v, want %v", d, got, want[d]) + } + } +} + +// --------------------------------------------------------------------------- +// Cache invariant check +// --------------------------------------------------------------------------- + +// TestDialectTypedCached_IsO1 verifies that DialectTyped returns the +// cached field rather than re-parsing the string on every call. The +// invariant we check: after WithDialect, the string and typed fields +// agree; after Reset (via PutParser), both return their zero values. +func TestDialectTypedCached_IsO1(t *testing.T) { + t.Parallel() + + p := NewParser(WithDialect("snowflake")) + if got := p.DialectTyped(); got != dialect.Snowflake { + t.Fatalf("DialectTyped() = %q, want %q", got, dialect.Snowflake) + } + if got := p.dialect; got != "snowflake" { + t.Fatalf("p.dialect = %q, want %q", got, "snowflake") + } + // Direct field access to the cache also agrees. + if p.dialectTyped != dialect.Snowflake { + t.Fatalf("p.dialectTyped = %q, want %q", p.dialectTyped, dialect.Snowflake) + } + + // Reset should clear both in lockstep. + p.Reset() + if p.dialect != "" { + t.Errorf("after Reset, p.dialect = %q, want empty", p.dialect) + } + if p.dialectTyped != dialect.Unknown { + t.Errorf("after Reset, p.dialectTyped = %q, want Unknown", p.dialectTyped) + } +} + +// TestDialectTypedCached_Alias verifies that the typed cache also tracks +// alias inputs routed through dialect.Parse (e.g. "postgres" -> PostgreSQL). +func TestDialectTypedCached_Alias(t *testing.T) { + t.Parallel() + cases := []struct { + in string + want dialect.Dialect + }{ + {"postgres", dialect.PostgreSQL}, + {"mssql", dialect.SQLServer}, + {"pg", dialect.PostgreSQL}, + {"not-a-dialect", dialect.Unknown}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.in, func(t *testing.T) { + t.Parallel() + p := NewParser(WithDialect(tc.in)) + if got := p.DialectTyped(); got != tc.want { + t.Errorf("DialectTyped() after WithDialect(%q) = %q, want %q", + tc.in, got, tc.want) + } + }) + } +} + +// Sanity: ensure the "strings" import is still referenced from this file +// even if none of the test bodies happen to use it directly. This guards +// against a future refactor silently dropping the package import. +var _ = strings.EqualFold diff --git a/pkg/sql/parser/match_recognize.go b/pkg/sql/parser/match_recognize.go index 2bd7fa31..4bd9db57 100644 --- a/pkg/sql/parser/match_recognize.go +++ b/pkg/sql/parser/match_recognize.go @@ -16,14 +16,16 @@ import ( "github.com/ajitpratap0/GoSQLX/pkg/models" "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" - "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" ) // isMatchRecognizeKeyword returns true if the current token is the contextual // MATCH_RECOGNIZE keyword in a dialect that supports it. +// +// Migrated from p.dialect == "snowflake"/"oracle" to Capabilities in +// Sprint 2. SupportsMatchRecognize is true only for Oracle and Snowflake +// in dialect.Capabilities, preserving the exact previous behaviour. func (p *Parser) isMatchRecognizeKeyword() bool { - if p.dialect != string(keywords.DialectSnowflake) && - p.dialect != string(keywords.DialectOracle) { + if !p.Capabilities().SupportsMatchRecognize { return false } return strings.EqualFold(p.currentToken.Token.Value, "MATCH_RECOGNIZE") diff --git a/pkg/sql/parser/parser.go b/pkg/sql/parser/parser.go index b4e270b7..86689521 100644 --- a/pkg/sql/parser/parser.go +++ b/pkg/sql/parser/parser.go @@ -23,6 +23,7 @@ import ( "github.com/ajitpratap0/GoSQLX/pkg/metrics" "github.com/ajitpratap0/GoSQLX/pkg/models" "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/dialect" "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" "github.com/ajitpratap0/GoSQLX/pkg/sql/token" ) @@ -119,6 +120,9 @@ func (p *Parser) Reset() { p.ctx = nil p.strict = false p.dialect = "" + // INVARIANT: p.dialectTyped must always equal dialect.Parse(p.dialect). + // Reset clears both fields in lockstep; see WithDialect for the active path. + p.dialectTyped = dialect.Unknown } // currentLocation returns the source location of the current token. @@ -200,9 +204,15 @@ func WithStrictMode() ParserOption { // WithDialect sets the SQL dialect for dialect-aware parsing. // Supported values: "postgresql", "mysql", "sqlserver", "oracle", "sqlite", etc. // If not set, defaults to "postgresql" for backward compatibility. -func WithDialect(dialect string) ParserOption { +// +// INVARIANT: p.dialectTyped must always equal dialect.Parse(p.dialect). +// This option is the sole supported mutator of the parser's active +// dialect; it assigns both fields in lockstep so Parser.DialectTyped() +// and Parser.Capabilities() remain O(1) for the parser's lifetime. +func WithDialect(newDialect string) ParserOption { return func(p *Parser) { - p.dialect = dialect + p.dialect = newDialect + p.dialectTyped = dialect.Parse(newDialect) } } @@ -243,6 +253,15 @@ type Parser struct { // compatibility and will be removed in v2.0 in favour of a typed // dialect.Dialect field. dialect string + // dialectTyped is the typed mirror of dialect, parsed once at + // configuration time so that DialectTyped() and Capabilities() are + // O(1) on the hot path rather than re-running dialect.Parse on every + // invocation. + // + // INVARIANT: dialectTyped must always equal dialect.Parse(dialect). + // Maintained by WithDialect (the sole mutator) and Reset. Do not set + // dialect directly; funnel all changes through WithDialect. + dialectTyped dialect.Dialect } // Deprecated: Parse is provided for backward compatibility only and is scheduled for diff --git a/pkg/sql/parser/pivot.go b/pkg/sql/parser/pivot.go index 8d62bde1..b6107274 100644 --- a/pkg/sql/parser/pivot.go +++ b/pkg/sql/parser/pivot.go @@ -79,9 +79,12 @@ func (p *Parser) pivotDialectAllowed() bool { // BigQuery QUALIFY clause keyword. QUALIFY tokenizes as an identifier, so // detect by value and gate by dialect to avoid consuming a legitimate // table alias named "qualify" in other dialects. +// +// Migrated from p.dialect == "snowflake"/"bigquery" to Capabilities in +// Sprint 2. SupportsQualify is true only for Snowflake and BigQuery in +// dialect.Capabilities, preserving the exact previous behaviour. func (p *Parser) isQualifyKeyword() bool { - if p.dialect != string(keywords.DialectSnowflake) && - p.dialect != string(keywords.DialectBigQuery) { + if !p.Capabilities().SupportsQualify { return false } return strings.EqualFold(p.currentToken.Token.Value, "QUALIFY") diff --git a/pkg/sql/parser/select.go b/pkg/sql/parser/select.go index 2a36ffd3..69f0b8ec 100644 --- a/pkg/sql/parser/select.go +++ b/pkg/sql/parser/select.go @@ -95,8 +95,11 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) { TableName: tableName, } - // ClickHouse ARRAY JOIN / LEFT ARRAY JOIN - if p.dialect == string(keywords.DialectClickHouse) { + // ClickHouse ARRAY JOIN / LEFT ARRAY JOIN. + // Migrated from p.dialect == "clickhouse" to Capabilities in Sprint 2. + // SupportsArrayJoin is true only for ClickHouse in dialect.Capabilities, + // preserving the exact previous behaviour. + if p.Capabilities().SupportsArrayJoin { if selectStmt.ArrayJoin, err = p.parseArrayJoinClause(); err != nil { return nil, err } @@ -109,8 +112,11 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) { } } - // PREWHERE (ClickHouse-specific, applied before WHERE for early data filtering) - if p.dialect == string(keywords.DialectClickHouse) { + // PREWHERE (ClickHouse-specific, applied before WHERE for early data filtering). + // Migrated from p.dialect == "clickhouse" to Capabilities in Sprint 2. + // SupportsPrewhere is true only for ClickHouse in dialect.Capabilities, + // preserving the exact previous behaviour. + if p.Capabilities().SupportsPrewhere { if selectStmt.PrewhereClause, err = p.parsePrewhereClause(); err != nil { return nil, err } @@ -134,8 +140,10 @@ func (p *Parser) parseSelectStatement() (ast.Statement, error) { // Snowflake / BigQuery QUALIFY: filters rows after window functions. // Appears between HAVING and ORDER BY. Tokenizes as identifier or // keyword depending on dialect tables; detect by value. - if (p.dialect == string(keywords.DialectSnowflake) || - p.dialect == string(keywords.DialectBigQuery)) && + // Migrated from p.dialect == "snowflake"/"bigquery" to Capabilities in + // Sprint 2. SupportsQualify is true only for Snowflake and BigQuery in + // dialect.Capabilities, preserving the exact previous behaviour. + if p.Capabilities().SupportsQualify && strings.EqualFold(p.currentToken.Token.Value, "QUALIFY") { p.advance() // Consume QUALIFY qexpr, qerr := p.parseExpression() diff --git a/scripts/check-dialect-migration.sh b/scripts/check-dialect-migration.sh new file mode 100755 index 00000000..aaeb0297 --- /dev/null +++ b/scripts/check-dialect-migration.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +# check-dialect-migration.sh +# +# Counts `p.dialect ==` occurrences in production parser code (non-test). +# Fails the build if the count exceeds the agreed ceiling -- preventing new +# direct dialect string comparisons from being introduced. +# +# Contributors should prefer the capability-based helpers instead: +# - p.Capabilities() -- for feature-flag style gating +# - p.IsPostgreSQL() / p.IsMySQL() / p.IsSQLServer() / etc. +# +# This enforces the strangler migration away from scattered `p.dialect == X` +# branches. The count is allowed to go DOWN (migration progress) but not UP. +# +# Usage: +# scripts/check-dialect-migration.sh [ceiling] +# +# The ceiling defaults to the current post-Sprint-D baseline. Update the +# default below as migration progresses. + +set -euo pipefail + +CEILING="${1:-54}" # Current baseline (Sprint D). Lower as migration progresses. + +# Resolve repo root so the script works from any cwd (including CI). +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PARSER_DIR="${REPO_ROOT}/pkg/sql/parser" + +if [ ! -d "${PARSER_DIR}" ]; then + echo "ERROR: parser directory not found: ${PARSER_DIR}" >&2 + exit 2 +fi + +# Count p.dialect == in non-test .go files in the parser package. +# Using `|| true` so a zero-match grep (exit 1) does not abort set -e. +MATCHES="$(grep -rn 'p\.dialect ==' "${PARSER_DIR}"/*.go 2>/dev/null | grep -v _test.go || true)" +COUNT="$(printf '%s\n' "${MATCHES}" | grep -c . || true)" + +echo "Current p.dialect == sites: ${COUNT}" +echo "Allowed ceiling: ${CEILING}" + +if [ "${COUNT}" -gt "${CEILING}" ]; then + echo "" + echo "FAIL: New p.dialect == comparison introduced." + echo " Use p.Capabilities() or p.Is() helpers for new dialect gating." + echo "" + echo "Offending sites:" + printf '%s\n' "${MATCHES}" + exit 1 +fi + +if [ "${COUNT}" -lt "${CEILING}" ]; then + echo "" + echo "PASS: Migration progress -- current count is below ceiling." + echo " Consider lowering the default ceiling in this script to lock in the gain." + exit 0 +fi + +echo "PASS: Dialect migration gate held at ceiling." From efee232c0896ec645b12a9abc3dfd6b5a594be65 Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Wed, 22 Apr 2026 18:50:49 +0530 Subject: [PATCH 4/4] refactor(ast): split ast.go and pool.go god-files by responsibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sprint E1+E2 from 2026-04-21 architect review — mechanical refactor, zero behavior change, all tests pass. Before PR (ast.go = 2463 lines, pool.go = 2030+ lines) → After: ast.go 2463 → 178 lines ast_statements.go (new) 778 lines — Statement types ast_expressions.go (new) 773 lines — Expression types ast_clauses.go (new) 791 lines — Clause / support types pool.go 2030 → 869 lines — pool decls + entry points pool_statement_release.go (new) 766 lines — Put*Statement drivers pool_expression_release.go (new) 781 lines — PutExpression + helpers All declarations preserved byte-for-byte. Cross-file references within the same ast package resolve cleanly. Every test (go test -race ./...) still passes. This unblocks contributor onboarding and prevents further drift: pool.go was growing +500 lines in the last review cycle. Also includes docs/ARCHITECT_REVIEW_2026-04-21.md — the round-2 architect review document that drove this entire sprint series. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/ARCHITECT_REVIEW_2026-04-21.md | 262 +++ pkg/sql/ast/ast.go | 2325 +----------------------- pkg/sql/ast/ast_clauses.go | 791 ++++++++ pkg/sql/ast/ast_expressions.go | 773 ++++++++ pkg/sql/ast/ast_statements.go | 778 ++++++++ pkg/sql/ast/pool_expression_release.go | 781 ++++++++ pkg/sql/ast/pool_statement_release.go | 766 ++++++++ 7 files changed, 4171 insertions(+), 2305 deletions(-) create mode 100644 docs/ARCHITECT_REVIEW_2026-04-21.md create mode 100644 pkg/sql/ast/ast_clauses.go create mode 100644 pkg/sql/ast/ast_expressions.go create mode 100644 pkg/sql/ast/ast_statements.go create mode 100644 pkg/sql/ast/pool_expression_release.go create mode 100644 pkg/sql/ast/pool_statement_release.go diff --git a/docs/ARCHITECT_REVIEW_2026-04-21.md b/docs/ARCHITECT_REVIEW_2026-04-21.md new file mode 100644 index 00000000..1f6c918a --- /dev/null +++ b/docs/ARCHITECT_REVIEW_2026-04-21.md @@ -0,0 +1,262 @@ +# GoSQLX Architect Review (Round 2) — 2026-04-21 + +Second-round review after PR #517 (merged 2026-04-21). Five parallel architect agents did a fresh pass across core parsing, foundation, public API, advanced features, and cross-cutting infra. This is a **delta document**: what landed cleanly, what slipped through, what's newly surfaced. + +Previous review: `docs/ARCHITECT_REVIEW_2026-04-16.md`. + +--- + +## Executive Summary + +PR #517 **delivered all 11 claimed fixes** to a solid baseline. But: + +1. **One class-C1 defect was missed**: subquery leaks in `PutExpression` (see §1). +2. **The new infrastructure is unused**: `pkg/sql/dialect.Capabilities` has zero production callers; legacy `Parse` doesn't wrap errors in the new sentinels; README still shows the old Parse pattern. You built the engine but left the signage. +3. **Structural debt grew**: `ast.go` is +136 lines, `pool.go` +500 lines, `gosqlx.go` is now a 694-line candidate for the next god-file. 6 of 9 pre-v2.0 items are still open. +4. **CI health is operationally fragile**: two follow-up commits (timeout bump, staticcheck fix) after #517 are symptoms of an un-sharded race suite. + +Net assessment: **PR #517 was additive, not transformative**. The "strangler fig" pattern needs actual strangling to follow through. + +--- + +## 1. Critical Miss: Subquery Leaks in `PutExpression` + +**Defect**: `pkg/sql/ast/pool.go` lines 1344-1455 — every expression type that embeds a `*SelectStatement` subquery sets `e.Subquery = nil` without releasing the statement first. + +Affected paths: +- `InExpression` (line 1344-1358) +- `SubqueryExpression` (line 1360-1362) +- `ExistsExpression` (line 1404-1406) +- `AnyExpression` (line 1408-1415) +- `AllExpression` (line 1417-1424) +- `ArrayConstructorExpression.Subquery` (line 1446-1455) + +Every `WHERE id IN (SELECT ...)`, `EXISTS (SELECT ...)`, `col = ANY(SELECT ...)` leaks the full inner SelectStatement including Columns, From, Where, Joins, Windows. Same defect class C1/C2 fixed, just at expression level. `pool_leak_test.go` doesn't cover these constructs — that's why round-1 tests didn't catch it. + +**Fix**: call `releaseStatement(e.Subquery)` before `e.Subquery = nil` in all six sites. Add test: parse `SELECT x FROM t WHERE id IN (SELECT y FROM u WHERE z = 1)` 1000 times, assert stable heap. + +**Severity**: CRITICAL — same production perf claim at stake. + +--- + +## 2. Fix Quality — What Landed Well + +| Item | Status | Notes | +|------|--------|-------| +| C1/C2 statement-level leaks | Solid | `releaseTableReference` helper factoring is right; no double-release risk | +| C3 metrics DoS | Solid | Lock-free atomic buckets; 4 allocs/op `GetStats` proven | +| C5 linter `ast.Walk` migration | Solid | Consistent pattern; L029 latent bug found & fixed in flight | +| C6 Children() coverage | Partial | Test has structural gap (see §3) | +| H6 errors immutability | Solid | All callers audited; return-value pattern universal | +| H7 structured parser errors | Solid | Zero `fmt.Errorf` remaining in parser | +| H8 config loader | Solid | Schema complete; walk-up safe; `**` glob properly implemented | +| H9 LSP type switch | Solid | 15 statement kinds; default fallback correct | +| H10 LSP real ranges | Solid | Semicolon fallback only hits rare DDL-no-Pos case | +| H11 keyword conflicts | Solid | `keywordsEquivalent` is the right semantic gate; resolutions preserve runtime | + +Eight of eleven fixes are production-quality. Three have material issues addressed below. + +--- + +## 3. Issues in the Fixes Themselves + +### 3.1 `children_coverage_test.go` has a structural gap + +`pkg/sql/ast/children_coverage_test.go:286-288` explicitly admits: "concrete-typed fields like `*CommonTableExpr` cannot accept a bare sentinel." But the AST references concrete pointer types throughout (`*WithClause`, `*TopClause`, `*MergeWhenClause`, `*OnConflict`, `*ConnectByClause`, `*SampleClause`). **The test only exercises interface-typed fields.** Plus `childrenCoverageAllowlist` (line 233) is declared but **never consulted** — it's misleading documentation. + +**Fix**: (a) generate zero-value concrete mocks via reflection for each concrete pointer type, or (b) register a per-type fixture table. And either wire the allowlist into the test skip logic or delete it. + +### 3.2 `releaseStatement` dispatch missing cases + +`pkg/sql/ast/pool.go:541` — no case for statement types defined in `dml.go` (`*Select`, `*Insert`, `*Update`, `*Delete` — the legacy duplicates), nor for `*PragmaStatement` wrapper variants. If the parser ever stores these in `AST.Statements`, `ReleaseAST` silently drops without returning to pool. Low severity because they may never be pooled today, but the dispatch should mirror pool declarations. + +### 3.3 `putExpressionImpl` allocates its work queue + +`pool.go:1257` — `workQueue := make([]Expression, 0, 32)` runs **on every call**. In the hot path, that's one heap alloc per `PutExpression`, called transitively 10-100× per parse. Pool the work queue itself via `sync.Pool`, or use a fixed-size stack array with spillover. + +### 3.4 `metrics.errorsMutex` is vestigial + +`pkg/metrics/metrics.go:408` — retained "to serialize rare reset paths" but `reset()` uses atomic stores that already race-safe against concurrent increments. The mutex protects nothing. Delete. + +### 3.5 `ErrorsByType` wire format change is undocumented breaking change + +Round-1 fix changed map keys from `err.Error()` strings to `ErrorCode` strings. Grafana/dashboard users bound to the old keys are silently broken. Needs a CHANGELOG note and ideally a deprecation window with dual-emission. + +--- + +## 4. "Two Models Colliding" — The Core Observation + +Round 1 introduced new infrastructure that hasn't been adopted: + +### 4.1 `dialect.Capabilities` has **zero production callers** + +Grep confirms: `p.Capabilities()` / `p.DialectTyped()` / `p.IsSnowflake()` etc. used nowhere in parser production code. Meanwhile `p.dialect` string is referenced **88 times** (up from 72 at round 1) across 17 files. The old pattern is growing; the new one is read-only. + +Parser now has three representations coexisting: +1. `p.dialect string` — actual storage, 88 reads +2. `keywords.DialectXxx` string constants — cast as `string(keywords.DialectOracle)` in 60+ places +3. `dialect.Dialect` typed — zero production usage + +**Recommendation**: Stop doing big-bang migrations. Cache `p.dialectTyped dialect.Dialect` in the Parser struct populated by `WithDialect`. Migrate 3-5 pure-capability sites per release (QUALIFY, PREWHERE, ARRAY JOIN, ILIKE, BRACKET_QUOTING). Add a CI grep gate: new production code may not introduce `p.dialect ==`; must use `p.Is()` or `p.Capabilities()`. Without this gate, v2.0 arrives with more, not fewer, string comparisons. + +### 4.2 Sentinel errors only work on new API + +`gosqlx.ErrSyntax`, `ErrTokenize`, `ErrTimeout`, `ErrUnsupportedDialect` work for `ParseTree` / `ParseReader`. Legacy `Parse` / `ParseWithContext` / `ParseWithDialect` still return `fmt.Errorf("tokenization failed: %w", err)` without the sentinel. So: + +```go +ast, err := gosqlx.Parse(sql) +if errors.Is(err, gosqlx.ErrSyntax) { ... } // NEVER MATCHES +``` + +**Fix**: 4-line retrofit per legacy function, purely additive to the error chain. Unifies the story. + +### 4.3 README & package doc still promote legacy + +`README.md` lines 63-90: canonical example is still the old Parse + type-assertion pattern. Zero mentions of `ParseTree`, `Tree`, `WithDialect`, `ErrSyntax`, `FormatTree`. Package `doc.go` lists legacy functions as "primary entry points" — Tree isn't mentioned at all. + +This is the **single largest DX leak remaining**. Round-1's criticism ("new users reach for Parse and type-assert") is still architecturally true on the surface. + +### 4.4 No compat tests for new Tree API + +`pkg/compatibility/api_stability_test.go` covers only legacy surface (ast.Node, pools, token types). The Tree API is unprotected. If someone refactors `ParseTree` to drop `ctx` or change `Option` to a struct, compat suite won't flag it. + +--- + +## 5. Tree API Completeness Gaps + +Round-1 added Tree; usability audit surfaces what's missing for real workflows: + +1. **No typed walkers** — `WalkSelects(func(*ast.SelectStatement) bool)`, `WalkExpressions(...)`. Users still write `if stmt, ok := n.(*ast.SelectStatement); ok` dance. sqlparser-rs and vitess both ship typed visitors. +2. **No `Rewrite(pre, post)`** — closes parity gap with vitess. +3. **No `Tree.Clone()`** for copy-on-write experiments. +4. **No `Tree.Subqueries() []*Tree` / `Tree.CTEs()`** — common SQL analysis need. +5. **`Release()` is a documented no-op** — aspirational for future pooling, but creates a training hazard today. +6. **Tree carries full source string** — 1MB SQL doubles memory. + +Fixing (1) and (2) takes the Tree from "viable" to "competitive." + +--- + +## 6. `ParseReader` Pitfalls + +### 6.1 No bounded read +`pkg/gosqlx/reader.go:67` — `io.ReadAll(r)` unconditionally. A 100MB SQL dump allocates 200MB (ReadAll + string conversion). Real-world exposures: migration files, data dumps, HTTP POST bodies. Add `WithMaxBytes(n)` + `ErrTooLarge` sentinel. + +### 6.2 `ParseReaderMultiple` splitter is not dialect-aware +Correctly handles: single quotes, double-quoted identifiers, line comments, block comments. +Misses: +- PostgreSQL dollar-quoted strings (`$$...;...$$`) — will split on inner `;`, producing garbage +- MySQL/ClickHouse backtick identifiers containing `;` +- SQL Server bracketed identifiers containing `;` +- PostgreSQL E-string backslash escapes (`E'\''`) +- PostgreSQL nested block comments (`/* /* nested */ */`) + +For a library advertising 8 dialects this is a shipping hole. Expose `SplitStatements(sql, dialect)` with dialect-aware handling. + +--- + +## 7. Structural Debt — Getting Worse + +Line counts after PR #517: + +| File | Before | After | Δ | +|------|--------|-------|---| +| pkg/sql/ast/ast.go | 2327 | **2463** | +136 | +| pkg/sql/ast/pool.go | ~1500 | **2030** | +500 | +| pkg/sql/ast/sql.go | 1853 | 1853 | 0 | +| pkg/sql/tokenizer/tokenizer.go | 1842 | 1842 | 0 | +| pkg/sql/parser/parser.go | 1186 | 1195 | +9 | +| pkg/gosqlx/gosqlx.go | — | **694** | new | + +Every core file exceeds the 400-line ceiling in the project's own `coding-style.md` by 3-6×. PR #517 **added** to ast.go and pool.go rather than splitting. `gosqlx.go` is trending toward god-file status. + +**Natural split seams**: +- `ast.go` → `ast_statements.go` + `ast_expressions.go` + `ast_literals.go` + `ast_clauses.go` +- `pool.go` → `pool.go` (declarations) + `pool_statement_release.go` + `pool_expression_release.go` +- `sql.go` (String()/SQL() serializer) along the same node-category axis + +Do this before v2.0 breaking changes — refactoring 4KLOC while also breaking APIs is a merge-conflict nightmare. + +--- + +## 8. CI Health Symptoms + +PR #517 follow-ups (e0f0992 `increase race detector timeout to 120s`, c01edeb `resolve staticcheck and race detector failures`) are patches on a bigger problem: + +- `task test:race` runs the entire tree under `-race` with 3-5× overhead. No sharding. +- `pool_leak_test.go` uses `runtime.GC()` + heap measurement — will be flaky on shared runners. Gate behind `-short` or build tag. +- Task install not cached across jobs — each of 4 race/cbinding jobs reinstalls it. +- `perf-regression` still `continue-on-error: true` with 60-65% tolerance — decorative. + +Expect another timeout bump within 1-2 sprints unless split into `test:race:fast` + `test:race:integration`. + +--- + +## 9. Pre-v2.0 Punch-List Status + +| # | Item | Round 1 | Round 2 | +|---|------|---------|---------| +| 1 | God-file splits | Open | **Worse** (+136L) | +| 2 | ConversionResult.PositionMapping removal | Open | Open (still `Deprecated`) | +| 3 | Merge/delete pkg/sql/token | Open | Open (still imported 18× ) | +| 4 | Move non-API packages to internal/ | Open | Open (no `internal/` at root) | +| 5 | DialectRegistry replacing keywords switch | Open | Open | +| 6 | gosqlx.Tree opaque wrapper | Open | **Done** | +| 7 | Functional options | Open | **Done** | +| 8 | Structured errors in parser | Open | **Done (H7)** | +| 9 | Logger interface injection | Open | Open (fmt.Println in 41 files) | + +**Progress: 3 of 9 complete. 1 worse. 5 unchanged.** + +--- + +## 10. Recommended v1.16 Sprint Plan + +Ordered by leverage: + +**Sprint A — "Fix the misses" (3-4 days)** +1. Subquery leak in `PutExpression` (§1) — critical +2. `Children()` coverage test gap (§3.1) +3. `releaseStatement` dispatch completeness (§3.2) +4. Retrofit legacy error wrappers with sentinels (§4.2) +5. Delete vestigial `metrics.errorsMutex` (§3.4) + +**Sprint B — "Close the narrative" (3-4 days)** +6. README rewrite leading with `ParseTree` example +7. `doc.go` rewrite promoting Tree as primary entry +8. `docs/MIGRATION.md` Tree migration section + deprecation timeline +9. Add Tree API to `pkg/compatibility/api_stability_test.go` + +**Sprint C — "Tree competitive" (1 week)** +10. Typed walkers (`WalkSelects`, `WalkExpressions`, generics-based) +11. `Tree.Rewrite(pre, post)` for transformation +12. `Tree.Clone()` for COW workflows +13. Dialect-aware `SplitStatements` (dollar-quoting, backticks, brackets) +14. `WithMaxBytes` + `ErrTooLarge` on ParseReader + +**Sprint D — "Begin the strangling" (1 week)** +15. Cache `p.dialectTyped` field in Parser +16. Migrate 5 pure-capability sites to `p.Capabilities()` (QUALIFY, PREWHERE, ARRAY JOIN, ILIKE, BRACKET_QUOTING) +17. Add CI grep gate forbidding new `p.dialect ==` in production code +18. Allocate `workQueue` from pool in `putExpressionImpl` (§3.3) + +**Sprint E — "Structural debt" (1-2 weeks)** +19. Split `ast.go` by node category +20. Split `pool.go` by responsibility +21. Shard `task test:race` into fast + integration +22. `tools/tools.go` for dev-tool pinning +23. Delete `examples/cmd/cmd` committed binary + +Sprints A+B are 1 week. Add C and you have v1.16 with a credible adoption story. D+E belong in v1.17 or a dedicated structural PR. + +--- + +## 11. Net Assessment + +**Trend**: net-better on DX surface (Tree, options, sentinels), net-worse on structure (god files grew, 2-model coexistence). + +**What PR #517 really was**: a correctness & API-expansion PR. It wasn't a refactor. Treating it as if it closed the architectural debt would be wrong — every structural punch-list item except three is still open, and the new code compounds some of it. + +**The "adoption still stuck in round 1" observation** from the public API agent is the single most important line in this review. The Tree API exists but the library's outer layer (README, doc.go, compat tests) still treats legacy Parse as canonical. Until that flips, the adoption story hasn't moved. + +**For HN launch / v1.15 release**: do Sprint A + B minimum. That's one week of work, and it turns "we added Tree" into "Tree is how you use this library." diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index 604779cc..43f598ec 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -101,1900 +101,35 @@ type Expression interface { expressionNode() } -// WithClause represents a WITH clause in a SQL statement. -// It supports both simple and recursive Common Table Expressions (CTEs). -// Phase 2 Complete: Full parser integration with all statement types. -type WithClause struct { - Recursive bool - CTEs []*CommonTableExpr - Pos models.Location // Source position of the WITH keyword (1-based line and column) -} - -func (w *WithClause) statementNode() {} -func (w WithClause) TokenLiteral() string { return "WITH" } -func (w WithClause) Children() []Node { - children := make([]Node, len(w.CTEs)) - for i, cte := range w.CTEs { - children[i] = cte - } - return children -} - -// CommonTableExpr represents a single Common Table Expression in a WITH clause. -// It supports optional column specifications and any statement type as the CTE query. -// Phase 2 Complete: Full parser support with column specifications. -// Phase 2.6: Added MATERIALIZED/NOT MATERIALIZED support for query optimization hints. -type CommonTableExpr struct { - Name string - Columns []string - Statement Statement - ScalarExpr Expression // ClickHouse: WITH AS (scalar CTE, no subquery) - Materialized *bool // nil = default, true = MATERIALIZED, false = NOT MATERIALIZED - Pos models.Location // Source position of the CTE name (1-based line and column) -} - -func (c *CommonTableExpr) statementNode() {} -func (c CommonTableExpr) TokenLiteral() string { return c.Name } -func (c CommonTableExpr) Children() []Node { - var nodes []Node - if c.Statement != nil { - nodes = append(nodes, c.Statement) - } - if c.ScalarExpr != nil { - nodes = append(nodes, c.ScalarExpr) - } - return nodes -} - -// QueryExpression is a Statement that can appear as the source of INSERT ... SELECT. -// Only *SelectStatement and *SetOperation satisfy this interface. -type QueryExpression interface { - Statement - queryExpressionNode() -} - -// SetOperation represents set operations (UNION, EXCEPT, INTERSECT) between two statements. -// It supports the ALL modifier (e.g., UNION ALL) and proper left-associative parsing. -// Phase 2 Complete: Full parser support with left-associative precedence. -type SetOperation struct { - Left Statement - Operator string // UNION, EXCEPT, INTERSECT - Right Statement - All bool // UNION ALL vs UNION -} - -func (s *SetOperation) statementNode() {} -func (s *SetOperation) queryExpressionNode() {} -func (s SetOperation) TokenLiteral() string { return s.Operator } -func (s SetOperation) Children() []Node { - var nodes []Node - if s.Left != nil { - nodes = append(nodes, s.Left) - } - if s.Right != nil { - nodes = append(nodes, s.Right) - } - return nodes -} - -// JoinClause represents a JOIN clause in SQL -type JoinClause struct { - Type string // INNER, LEFT, RIGHT, FULL - Left TableReference - Right TableReference - Condition Expression - Pos models.Location // Source position of the JOIN keyword (1-based line and column) -} - -func (j *JoinClause) expressionNode() {} -func (j JoinClause) TokenLiteral() string { return j.Type + " JOIN" } -func (j JoinClause) Children() []Node { - children := []Node{&j.Left, &j.Right} - if j.Condition != nil { - children = append(children, j.Condition) - } - return children -} - -// TableReference represents a table reference in a FROM clause. -// -// TableReference can represent either a simple table name or a derived table -// (subquery). It supports PostgreSQL's LATERAL keyword for correlated subqueries. -// -// Fields: -// - Name: Table name (empty if this is a derived table/subquery) -// - Alias: Optional table alias (AS alias) -// - Subquery: Subquery for derived tables: (SELECT ...) AS alias -// - Lateral: LATERAL keyword for correlated subqueries (PostgreSQL v1.6.0) -// -// The Lateral field enables PostgreSQL's LATERAL JOIN feature, which allows -// subqueries in the FROM clause to reference columns from preceding tables. -// -// Example - Simple table reference: -// -// TableReference{ -// Name: "users", -// Alias: "u", -// } -// // SQL: FROM users u -// -// Example - Derived table (subquery): -// -// TableReference{ -// Alias: "recent_orders", -// Subquery: selectStmt, -// } -// // SQL: FROM (SELECT ...) AS recent_orders -// -// Example - LATERAL JOIN (PostgreSQL v1.6.0): -// -// TableReference{ -// Lateral: true, -// Alias: "r", -// Subquery: correlatedSelectStmt, -// } -// // SQL: FROM users u, LATERAL (SELECT * FROM orders WHERE user_id = u.id) r -// -// New in v1.6.0: Lateral field for PostgreSQL LATERAL JOIN support. -type TableReference struct { - Name string // Table name (empty if this is a derived table) - Alias string // Optional alias - Subquery *SelectStatement // For derived tables: (SELECT ...) AS alias - Lateral bool // LATERAL keyword for correlated subqueries (PostgreSQL) - TableHints []string // SQL Server table hints: WITH (NOLOCK), WITH (ROWLOCK, UPDLOCK), etc. - Final bool // ClickHouse FINAL modifier: forces MergeTree part merge - // TableFunc is a function-call table reference such as - // Snowflake LATERAL FLATTEN(input => col), TABLE(my_func(1,2)), - // IDENTIFIER('t'), or PostgreSQL unnest(array_col). When set, Name - // holds the function name and TableFunc carries the call itself. - TableFunc *FunctionCall - // TimeTravel is the Snowflake time-travel clause applied to this table - // reference: AT / BEFORE (TIMESTAMP|OFFSET|STATEMENT => expr) or - // CHANGES (INFORMATION => DEFAULT|APPEND_ONLY). - TimeTravel *TimeTravelClause - // ForSystemTime is the MariaDB temporal table clause (10.3.4+). - // Example: SELECT * FROM t FOR SYSTEM_TIME AS OF '2024-01-01' - ForSystemTime *ForSystemTimeClause // MariaDB temporal query - // Pivot is the SQL Server / Oracle PIVOT clause for row-to-column transformation. - // Example: SELECT * FROM t PIVOT (SUM(sales) FOR region IN ([North], [South])) AS pvt - Pivot *PivotClause - // Unpivot is the SQL Server / Oracle UNPIVOT clause for column-to-row transformation. - // Example: SELECT * FROM t UNPIVOT (sales FOR region IN (north_sales, south_sales)) AS unpvt - Unpivot *UnpivotClause - // MatchRecognize is the SQL:2016 row-pattern recognition clause (Snowflake, Oracle). - MatchRecognize *MatchRecognizeClause -} - -func (t *TableReference) statementNode() {} -func (t TableReference) TokenLiteral() string { - if t.Name != "" { - return t.Name - } - if t.Alias != "" { - return t.Alias - } - return "subquery" -} -func (t TableReference) Children() []Node { - var nodes []Node - if t.Subquery != nil { - nodes = append(nodes, t.Subquery) - } - if t.TableFunc != nil { - nodes = append(nodes, t.TableFunc) - } - if t.TimeTravel != nil { - nodes = append(nodes, t.TimeTravel) - } - if t.Pivot != nil { - nodes = append(nodes, t.Pivot) - } - if t.Unpivot != nil { - nodes = append(nodes, t.Unpivot) - } - if t.MatchRecognize != nil { - nodes = append(nodes, t.MatchRecognize) - } - return nodes -} - -// OrderByExpression represents an ORDER BY clause element with direction and NULL ordering -type OrderByExpression struct { - Expression Expression // The expression to order by - Ascending bool // true for ASC (default), false for DESC - NullsFirst *bool // nil = default behavior, true = NULLS FIRST, false = NULLS LAST -} - -func (*OrderByExpression) expressionNode() {} -func (o *OrderByExpression) TokenLiteral() string { return "ORDER BY" } -func (o *OrderByExpression) Children() []Node { - if o.Expression != nil { - return []Node{o.Expression} - } - return nil -} - -// WindowSpec represents a window specification -type WindowSpec struct { - Name string - PartitionBy []Expression - OrderBy []OrderByExpression - FrameClause *WindowFrame -} - -func (w *WindowSpec) statementNode() {} -func (w WindowSpec) TokenLiteral() string { return "WINDOW" } -func (w WindowSpec) Children() []Node { - children := make([]Node, 0) - children = append(children, nodifyExpressions(w.PartitionBy)...) - for _, orderBy := range w.OrderBy { - orderBy := orderBy // G601: Create local copy to avoid memory aliasing - children = append(children, &orderBy) - } - if w.FrameClause != nil { - children = append(children, w.FrameClause) - } - return children -} - -// WindowFrame represents window frame clause -type WindowFrame struct { - Type string // ROWS, RANGE - Start WindowFrameBound - End *WindowFrameBound -} - -func (w *WindowFrame) statementNode() {} -func (w WindowFrame) TokenLiteral() string { return w.Type } -func (w WindowFrame) Children() []Node { - // Start is a value type, always include it to support visitor traversal. - children := []Node{&w.Start} - if w.End != nil { - children = append(children, w.End) - } - return children -} - -// WindowFrameBound represents window frame bound -type WindowFrameBound struct { - Type string // CURRENT ROW, UNBOUNDED PRECEDING, etc. - Value Expression -} - -func (w *WindowFrameBound) expressionNode() {} -func (w WindowFrameBound) TokenLiteral() string { - if w.Type != "" { - return w.Type - } - return "BOUND" -} -func (w WindowFrameBound) Children() []Node { - if w.Value != nil { - return []Node{w.Value} - } - return nil -} - -// SelectStatement represents a SELECT SQL statement with full SQL-99/SQL:2003 support. -// -// SelectStatement is the primary query statement type supporting: -// - CTEs (WITH clause) -// - DISTINCT and DISTINCT ON (PostgreSQL) -// - Multiple FROM tables and subqueries -// - All JOIN types with LATERAL support -// - WHERE, GROUP BY, HAVING, ORDER BY clauses -// - Window functions with PARTITION BY and frame specifications -// - LIMIT/OFFSET and SQL-99 FETCH clause -// -// Fields: -// - With: WITH clause for Common Table Expressions (CTEs) -// - Distinct: DISTINCT keyword for duplicate elimination -// - DistinctOnColumns: DISTINCT ON (expr, ...) for PostgreSQL (v1.6.0) -// - Columns: SELECT list expressions (columns, *, functions, etc.) -// - From: FROM clause table references (tables, subqueries, LATERAL) -// - TableName: Table name for simple queries (pool optimization) -// - Joins: JOIN clauses (INNER, LEFT, RIGHT, FULL, CROSS, NATURAL) -// - Where: WHERE clause filter condition -// - GroupBy: GROUP BY expressions (including ROLLUP, CUBE, GROUPING SETS) -// - Having: HAVING clause filter condition -// - Windows: Window specifications (WINDOW clause) -// - OrderBy: ORDER BY expressions with NULLS FIRST/LAST -// - Limit: LIMIT clause (number of rows) -// - Offset: OFFSET clause (skip rows) -// - Fetch: SQL-99 FETCH FIRST/NEXT clause (v1.6.0) -// -// Example - Basic SELECT: -// -// SelectStatement{ -// Columns: []Expression{&Identifier{Name: "id"}, &Identifier{Name: "name"}}, -// From: []TableReference{{Name: "users"}}, -// Where: &BinaryExpression{...}, -// } -// // SQL: SELECT id, name FROM users WHERE ... -// -// Example - DISTINCT ON (PostgreSQL v1.6.0): -// -// SelectStatement{ -// DistinctOnColumns: []Expression{&Identifier{Name: "dept_id"}}, -// Columns: []Expression{&Identifier{Name: "dept_id"}, &Identifier{Name: "name"}}, -// From: []TableReference{{Name: "employees"}}, -// } -// // SQL: SELECT DISTINCT ON (dept_id) dept_id, name FROM employees -// -// Example - Window function with FETCH (v1.6.0): -// -// SelectStatement{ -// Columns: []Expression{ -// &FunctionCall{ -// Name: "ROW_NUMBER", -// Over: &WindowSpec{ -// OrderBy: []OrderByExpression{{Expression: &Identifier{Name: "salary"}, Ascending: false}}, -// }, -// }, -// }, -// From: []TableReference{{Name: "employees"}}, -// Fetch: &FetchClause{FetchValue: ptrInt64(10), FetchType: "FIRST"}, -// } -// // SQL: SELECT ROW_NUMBER() OVER (ORDER BY salary DESC) FROM employees FETCH FIRST 10 ROWS ONLY -// -// New in v1.6.0: -// - DistinctOnColumns for PostgreSQL DISTINCT ON -// - Fetch for SQL-99 FETCH FIRST/NEXT clause -// - Enhanced LATERAL JOIN support via TableReference.Lateral -// - FILTER clause support via FunctionCall.Filter -type SelectStatement struct { - With *WithClause - Distinct bool - DistinctOnColumns []Expression // PostgreSQL DISTINCT ON (expr, ...) clause - Top *TopClause // SQL Server TOP N [PERCENT] clause - Columns []Expression - From []TableReference - TableName string // Added for pool operations - Joins []JoinClause - ArrayJoin *ArrayJoinClause // ClickHouse ARRAY JOIN / LEFT ARRAY JOIN clause - PrewhereClause Expression // ClickHouse PREWHERE clause (applied before WHERE, before reading data) - Sample *SampleClause // ClickHouse SAMPLE clause (comes after FROM/FINAL, before PREWHERE) - Where Expression - GroupBy []Expression - Having Expression - Qualify Expression // Snowflake / BigQuery QUALIFY clause (filters after window functions) - // StartWith is the optional seed condition for CONNECT BY (MariaDB 10.2+). - // Example: START WITH parent_id IS NULL - StartWith Expression // MariaDB hierarchical query seed - // ConnectBy holds the hierarchy traversal condition (MariaDB 10.2+). - // Example: CONNECT BY PRIOR id = parent_id - ConnectBy *ConnectByClause // MariaDB hierarchical query - Windows []WindowSpec - OrderBy []OrderByExpression - Limit *int - Offset *int - Fetch *FetchClause // SQL-99 FETCH FIRST/NEXT clause (F861, F862) - For *ForClause // Row-level locking clause (SQL:2003, PostgreSQL, MySQL) - Pos models.Location // Source position of the SELECT keyword (1-based line and column) -} - -// TopClause represents SQL Server's TOP N [PERCENT] clause -// Syntax: SELECT TOP n [PERCENT] columns... -// Count is an Expression to support TOP (10), TOP (@var), TOP (subquery) -type TopClause struct { - Count Expression // Number of rows (or percentage) as an expression - IsPercent bool // Whether PERCENT keyword was specified - WithTies bool // Whether WITH TIES was specified (SQL Server) -} - -func (t *TopClause) expressionNode() {} -func (t TopClause) TokenLiteral() string { return "TOP" } -func (t TopClause) Children() []Node { - if t.Count != nil { - return []Node{t.Count} - } - return nil -} - -// FetchClause represents the SQL-99 FETCH FIRST/NEXT clause (F861, F862) -// Syntax: [OFFSET n {ROW | ROWS}] FETCH {FIRST | NEXT} n [{ROW | ROWS}] {ONLY | WITH TIES} -// Examples: -// - OFFSET 20 ROWS FETCH NEXT 10 ROWS ONLY -// - FETCH FIRST 5 ROWS ONLY -// - FETCH FIRST 10 PERCENT ROWS WITH TIES -type FetchClause struct { - // OffsetValue is the number of rows to skip (OFFSET n ROWS) - OffsetValue *int64 - // FetchValue is the number of rows to fetch (FETCH n ROWS) - FetchValue *int64 - // FetchType is either "FIRST" or "NEXT" - FetchType string - // IsPercent indicates FETCH ... PERCENT ROWS - IsPercent bool - // WithTies indicates FETCH ... WITH TIES (includes tied rows) - WithTies bool -} - -func (f *FetchClause) expressionNode() {} -func (f FetchClause) TokenLiteral() string { return "FETCH" } -func (f FetchClause) Children() []Node { return nil } - -// ForClause represents row-level locking clauses in SELECT statements (SQL:2003, PostgreSQL, MySQL) -// Syntax: FOR {UPDATE | SHARE | NO KEY UPDATE | KEY SHARE} [OF table_name [, ...]] [NOWAIT | SKIP LOCKED] -// Examples: -// - FOR UPDATE -// - FOR SHARE NOWAIT -// - FOR UPDATE OF orders SKIP LOCKED -// - FOR NO KEY UPDATE -// - FOR KEY SHARE -type ForClause struct { - // LockType specifies the type of lock: - // "UPDATE" - exclusive lock for UPDATE operations - // "SHARE" - shared lock for read operations - // "NO KEY UPDATE" - PostgreSQL: exclusive lock that doesn't block SHARE locks on same row - // "KEY SHARE" - PostgreSQL: shared lock that doesn't block UPDATE locks - LockType string - // Tables specifies which tables to lock (FOR UPDATE OF table_name) - // Empty slice means lock all tables in the query - Tables []string - // NoWait indicates NOWAIT option (fail immediately if lock cannot be acquired) - NoWait bool - // SkipLocked indicates SKIP LOCKED option (skip rows that can't be locked) - SkipLocked bool -} - -func (f *ForClause) expressionNode() {} -func (f ForClause) TokenLiteral() string { return "FOR" } -func (f ForClause) Children() []Node { return nil } - -func (s *SelectStatement) statementNode() {} -func (s *SelectStatement) queryExpressionNode() {} -func (s SelectStatement) TokenLiteral() string { return "SELECT" } - -func (s SelectStatement) Children() []Node { - children := make([]Node, 0) - if s.With != nil { - children = append(children, s.With) - } - children = append(children, nodifyExpressions(s.DistinctOnColumns)...) - children = append(children, nodifyExpressions(s.Columns)...) - for _, from := range s.From { - from := from // G601: Create local copy to avoid memory aliasing - children = append(children, &from) - } - for _, join := range s.Joins { - join := join // G601: Create local copy to avoid memory aliasing - children = append(children, &join) - } - if s.Sample != nil { - children = append(children, s.Sample) - } - if s.PrewhereClause != nil { - children = append(children, s.PrewhereClause) - } - if s.Where != nil { - children = append(children, s.Where) - } - children = append(children, nodifyExpressions(s.GroupBy)...) - if s.Having != nil { - children = append(children, s.Having) - } - if s.Qualify != nil { - children = append(children, s.Qualify) - } - for _, window := range s.Windows { - window := window // G601: Create local copy to avoid memory aliasing - children = append(children, &window) - } - for _, orderBy := range s.OrderBy { - orderBy := orderBy // G601: Create local copy to avoid memory aliasing - children = append(children, &orderBy) - } - if s.Fetch != nil { - children = append(children, s.Fetch) - } - if s.For != nil { - children = append(children, s.For) - } - if s.StartWith != nil { - children = append(children, s.StartWith) - } - if s.ConnectBy != nil { - children = append(children, s.ConnectBy) - } - return children -} - -// Helper function to convert []Expression to []Node -func nodifyExpressions(exprs []Expression) []Node { - nodes := make([]Node, len(exprs)) - for i, expr := range exprs { - nodes[i] = expr - } - return nodes -} - -// RollupExpression represents ROLLUP(col1, col2, ...) in GROUP BY clause -// ROLLUP generates hierarchical grouping sets from right to left -// Example: ROLLUP(a, b, c) generates grouping sets: -// -// (a, b, c), (a, b), (a), () -type RollupExpression struct { - Expressions []Expression -} - -func (r *RollupExpression) expressionNode() {} -func (r RollupExpression) TokenLiteral() string { return "ROLLUP" } -func (r RollupExpression) Children() []Node { return nodifyExpressions(r.Expressions) } - -// CubeExpression represents CUBE(col1, col2, ...) in GROUP BY clause -// CUBE generates all possible combinations of grouping sets -// Example: CUBE(a, b) generates grouping sets: -// -// (a, b), (a), (b), () -type CubeExpression struct { - Expressions []Expression -} - -func (c *CubeExpression) expressionNode() {} -func (c CubeExpression) TokenLiteral() string { return "CUBE" } -func (c CubeExpression) Children() []Node { return nodifyExpressions(c.Expressions) } - -// GroupingSetsExpression represents GROUPING SETS(...) in GROUP BY clause -// Allows explicit specification of grouping sets -// Example: GROUPING SETS((a, b), (a), ()) -type GroupingSetsExpression struct { - Sets [][]Expression // Each inner slice is one grouping set -} - -func (g *GroupingSetsExpression) expressionNode() {} -func (g GroupingSetsExpression) TokenLiteral() string { return "GROUPING SETS" } -func (g GroupingSetsExpression) Children() []Node { - children := make([]Node, 0) - for _, set := range g.Sets { - children = append(children, nodifyExpressions(set)...) - } - return children -} - -// Identifier represents a column or table name -type Identifier struct { - Name string - Table string // Optional table qualifier - Pos models.Location // Source position of this identifier (1-based line and column) -} - -func (i *Identifier) expressionNode() {} -func (i Identifier) TokenLiteral() string { return i.Name } -func (i Identifier) Children() []Node { return nil } - -// FunctionCall represents a function call expression with full SQL-99/PostgreSQL support. -// -// FunctionCall supports: -// - Scalar functions: UPPER(), LOWER(), COALESCE(), etc. -// - Aggregate functions: COUNT(), SUM(), AVG(), MAX(), MIN(), etc. -// - Window functions: ROW_NUMBER(), RANK(), DENSE_RANK(), LAG(), LEAD(), etc. -// - DISTINCT modifier: COUNT(DISTINCT column) -// - FILTER clause: Conditional aggregation (PostgreSQL v1.6.0) -// - ORDER BY clause: For order-sensitive aggregates like STRING_AGG, ARRAY_AGG (v1.6.0) -// - OVER clause: Window specifications for window functions -// -// Fields: -// - Name: Function name (e.g., "COUNT", "SUM", "ROW_NUMBER") -// - Arguments: Function arguments (expressions) -// - Over: Window specification for window functions (OVER clause) -// - Distinct: DISTINCT modifier for aggregates (COUNT(DISTINCT col)) -// - Filter: FILTER clause for conditional aggregation (PostgreSQL v1.6.0) -// - OrderBy: ORDER BY clause for order-sensitive aggregates (v1.6.0) -// -// Example - Basic aggregate: -// -// FunctionCall{ -// Name: "COUNT", -// Arguments: []Expression{&Identifier{Name: "id"}}, -// } -// // SQL: COUNT(id) -// -// Example - Window function: -// -// FunctionCall{ -// Name: "ROW_NUMBER", -// Over: &WindowSpec{ -// PartitionBy: []Expression{&Identifier{Name: "dept_id"}}, -// OrderBy: []OrderByExpression{{Expression: &Identifier{Name: "salary"}, Ascending: false}}, -// }, -// } -// // SQL: ROW_NUMBER() OVER (PARTITION BY dept_id ORDER BY salary DESC) -// -// Example - FILTER clause (PostgreSQL v1.6.0): -// -// FunctionCall{ -// Name: "COUNT", -// Arguments: []Expression{&Identifier{Name: "id"}}, -// Filter: &BinaryExpression{Left: &Identifier{Name: "status"}, Operator: "=", Right: &LiteralValue{Value: "active"}}, -// } -// // SQL: COUNT(id) FILTER (WHERE status = 'active') -// -// Example - ORDER BY in aggregate (PostgreSQL v1.6.0): -// -// FunctionCall{ -// Name: "STRING_AGG", -// Arguments: []Expression{&Identifier{Name: "name"}, &LiteralValue{Value: ", "}}, -// OrderBy: []OrderByExpression{{Expression: &Identifier{Name: "name"}, Ascending: true}}, -// } -// // SQL: STRING_AGG(name, ', ' ORDER BY name) -// -// Example - Window function with frame: -// -// FunctionCall{ -// Name: "AVG", -// Arguments: []Expression{&Identifier{Name: "amount"}}, -// Over: &WindowSpec{ -// OrderBy: []OrderByExpression{{Expression: &Identifier{Name: "date"}, Ascending: true}}, -// FrameClause: &WindowFrame{ -// Type: "ROWS", -// Start: WindowFrameBound{Type: "2 PRECEDING"}, -// End: &WindowFrameBound{Type: "CURRENT ROW"}, -// }, -// }, -// } -// // SQL: AVG(amount) OVER (ORDER BY date ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) -// -// New in v1.6.0: -// - Filter: FILTER clause for conditional aggregation -// - OrderBy: ORDER BY clause for order-sensitive aggregates (STRING_AGG, ARRAY_AGG, etc.) -// - WithinGroup: ORDER BY clause for ordered-set aggregates (PERCENTILE_CONT, PERCENTILE_DISC, MODE, etc.) -type FunctionCall struct { - Name string - Arguments []Expression // Renamed from Args for consistency - Parameters []Expression // ClickHouse parametric aggregates: quantile(0.5)(x) — params before args - Over *WindowSpec // For window functions - Distinct bool - Filter Expression // WHERE clause for aggregate functions - OrderBy []OrderByExpression // ORDER BY clause for aggregate functions (STRING_AGG, ARRAY_AGG, etc.) - WithinGroup []OrderByExpression // ORDER BY clause for ordered-set aggregates (PERCENTILE_CONT, etc.) - NullTreatment string // "IGNORE NULLS" or "RESPECT NULLS" on window functions (Snowflake, Oracle, BigQuery, SQL:2016) - Pos models.Location // Source position of the function name (1-based line and column) -} - -func (f *FunctionCall) expressionNode() {} -func (f FunctionCall) TokenLiteral() string { return f.Name } -func (f FunctionCall) Children() []Node { - children := nodifyExpressions(f.Arguments) - if len(f.Parameters) > 0 { - children = append(children, nodifyExpressions(f.Parameters)...) - } - if f.Over != nil { - children = append(children, f.Over) - } - if f.Filter != nil { - children = append(children, f.Filter) - } - for _, orderBy := range f.OrderBy { - orderBy := orderBy // G601: Create local copy to avoid memory aliasing - children = append(children, &orderBy) - } - for _, orderBy := range f.WithinGroup { - orderBy := orderBy // G601: Create local copy to avoid memory aliasing - children = append(children, &orderBy) - } - return children -} - -// CaseExpression represents a CASE expression -type CaseExpression struct { - Value Expression // Optional CASE value - WhenClauses []WhenClause - ElseClause Expression - Pos models.Location // Source position of the CASE keyword (1-based line and column) -} - -func (c *CaseExpression) expressionNode() {} -func (c CaseExpression) TokenLiteral() string { return "CASE" } -func (c CaseExpression) Children() []Node { - children := make([]Node, 0) - if c.Value != nil { - children = append(children, c.Value) - } - for _, when := range c.WhenClauses { - when := when // G601: Create local copy to avoid memory aliasing - children = append(children, &when) - } - if c.ElseClause != nil { - children = append(children, c.ElseClause) - } - return children -} - -// WhenClause represents WHEN ... THEN ... in CASE expression -type WhenClause struct { - Condition Expression - Result Expression - Pos models.Location // Source position of the WHEN keyword (1-based line and column) -} - -func (w *WhenClause) expressionNode() {} -func (w WhenClause) TokenLiteral() string { return "WHEN" } -func (w WhenClause) Children() []Node { - var nodes []Node - if w.Condition != nil { - nodes = append(nodes, w.Condition) - } - if w.Result != nil { - nodes = append(nodes, w.Result) - } - return nodes -} - -// ExistsExpression represents EXISTS (subquery) -type ExistsExpression struct { - Subquery Statement -} - -func (e *ExistsExpression) expressionNode() {} -func (e ExistsExpression) TokenLiteral() string { return "EXISTS" } -func (e ExistsExpression) Children() []Node { - if e.Subquery == nil { - return nil - } - return []Node{e.Subquery} -} - -// InExpression represents expr IN (values) or expr IN (subquery) -type InExpression struct { - Expr Expression - List []Expression // For value list: IN (1, 2, 3) - Subquery Statement // For subquery: IN (SELECT ...) - Not bool - Pos models.Location // Source position of the IN keyword (1-based line and column) -} - -func (i *InExpression) expressionNode() {} -func (i InExpression) TokenLiteral() string { return "IN" } -func (i InExpression) Children() []Node { - var children []Node - if i.Expr != nil { - children = append(children, i.Expr) - } - if i.Subquery != nil { - children = append(children, i.Subquery) - } - children = append(children, nodifyExpressions(i.List)...) - return children -} - -// SubqueryExpression represents a scalar subquery (SELECT ...) -type SubqueryExpression struct { - Subquery Statement - Pos models.Location // Source position of the opening parenthesis (1-based line and column) -} - -func (s *SubqueryExpression) expressionNode() {} -func (s SubqueryExpression) TokenLiteral() string { return "SUBQUERY" } -func (s SubqueryExpression) Children() []Node { - if s.Subquery == nil { - return nil - } - return []Node{s.Subquery} -} - -// AnyExpression represents expr op ANY (subquery) -type AnyExpression struct { - Expr Expression - Operator string - Subquery Statement -} - -func (a *AnyExpression) expressionNode() {} -func (a AnyExpression) TokenLiteral() string { return "ANY" } -func (a AnyExpression) Children() []Node { - var nodes []Node - if a.Expr != nil { - nodes = append(nodes, a.Expr) - } - if a.Subquery != nil { - nodes = append(nodes, a.Subquery) - } - return nodes -} - -// AllExpression represents expr op ALL (subquery) -type AllExpression struct { - Expr Expression - Operator string - Subquery Statement -} - -func (al *AllExpression) expressionNode() {} -func (al AllExpression) TokenLiteral() string { return "ALL" } -func (al AllExpression) Children() []Node { - var nodes []Node - if al.Expr != nil { - nodes = append(nodes, al.Expr) - } - if al.Subquery != nil { - nodes = append(nodes, al.Subquery) - } - return nodes -} - -// BetweenExpression represents expr BETWEEN lower AND upper -type BetweenExpression struct { - Expr Expression - Lower Expression - Upper Expression - Not bool - Pos models.Location // Source position of the BETWEEN keyword (1-based line and column) -} - -func (b *BetweenExpression) expressionNode() {} -func (b BetweenExpression) TokenLiteral() string { return "BETWEEN" } -func (b BetweenExpression) Children() []Node { - var nodes []Node - if b.Expr != nil { - nodes = append(nodes, b.Expr) - } - if b.Lower != nil { - nodes = append(nodes, b.Lower) - } - if b.Upper != nil { - nodes = append(nodes, b.Upper) - } - return nodes -} - -// BinaryExpression represents binary operations between two expressions. -// -// BinaryExpression supports all standard SQL binary operators plus PostgreSQL-specific -// operators including JSON/JSONB operators added in v1.6.0. -// -// Fields: -// - Left: Left-hand side expression -// - Operator: Binary operator (=, <, >, +, -, *, /, AND, OR, ->, #>, etc.) -// - Right: Right-hand side expression -// - Not: NOT modifier for negation (NOT expr) -// - CustomOp: PostgreSQL custom operators (OPERATOR(schema.name)) -// -// Supported Operator Categories: -// - Comparison: =, <>, <, >, <=, >=, <=> (spaceship) -// - Arithmetic: +, -, *, /, %, DIV, // (integer division) -// - Logical: AND, OR, XOR -// - String: || (concatenation) -// - Bitwise: &, |, ^, <<, >> (shifts) -// - Pattern: LIKE, ILIKE, SIMILAR TO -// - Range: OVERLAPS -// - PostgreSQL JSON/JSONB (v1.6.0): ->, ->>, #>, #>>, @>, <@, ?, ?|, ?&, #- -// -// Example - Basic comparison: -// -// BinaryExpression{ -// Left: &Identifier{Name: "age"}, -// Operator: ">", -// Right: &LiteralValue{Value: 18, Type: "INTEGER"}, -// } -// // SQL: age > 18 -// -// Example - Logical AND: -// -// BinaryExpression{ -// Left: &BinaryExpression{ -// Left: &Identifier{Name: "active"}, -// Operator: "=", -// Right: &LiteralValue{Value: true, Type: "BOOLEAN"}, -// }, -// Operator: "AND", -// Right: &BinaryExpression{ -// Left: &Identifier{Name: "status"}, -// Operator: "=", -// Right: &LiteralValue{Value: "pending", Type: "STRING"}, -// }, -// } -// // SQL: active = true AND status = 'pending' -// -// Example - PostgreSQL JSON operator -> (v1.6.0): -// -// BinaryExpression{ -// Left: &Identifier{Name: "data"}, -// Operator: "->", -// Right: &LiteralValue{Value: "name", Type: "STRING"}, -// } -// // SQL: data->'name' -// -// Example - PostgreSQL JSON operator ->> (v1.6.0): -// -// BinaryExpression{ -// Left: &Identifier{Name: "data"}, -// Operator: "->>", -// Right: &LiteralValue{Value: "email", Type: "STRING"}, -// } -// // SQL: data->>'email' (returns text) -// -// Example - PostgreSQL JSON contains @> (v1.6.0): -// -// BinaryExpression{ -// Left: &Identifier{Name: "attributes"}, -// Operator: "@>", -// Right: &LiteralValue{Value: `{"color": "red"}`, Type: "STRING"}, -// } -// // SQL: attributes @> '{"color": "red"}' -// -// Example - PostgreSQL JSON key exists ? (v1.6.0): -// -// BinaryExpression{ -// Left: &Identifier{Name: "profile"}, -// Operator: "?", -// Right: &LiteralValue{Value: "email", Type: "STRING"}, -// } -// // SQL: profile ? 'email' -// -// Example - Custom PostgreSQL operator: -// -// BinaryExpression{ -// Left: &Identifier{Name: "point1"}, -// Operator: "", -// Right: &Identifier{Name: "point2"}, -// CustomOp: &CustomBinaryOperator{Parts: []string{"pg_catalog", "<->"}}, -// } -// // SQL: point1 OPERATOR(pg_catalog.<->) point2 -// -// New in v1.6.0: -// - JSON/JSONB operators: ->, ->>, #>, #>>, @>, <@, ?, ?|, ?&, #- -// - CustomOp field for PostgreSQL custom operators -// -// PostgreSQL JSON/JSONB Operator Reference: -// - -> (Arrow): Extract JSON field or array element (returns JSON) -// - ->> (LongArrow): Extract JSON field or array element as text -// - #> (HashArrow): Extract JSON at path (returns JSON) -// - #>> (HashLongArrow): Extract JSON at path as text -// - @> (AtArrow): JSON contains (does left JSON contain right?) -// - <@ (ArrowAt): JSON is contained by (is left JSON contained in right?) -// - ? (Question): JSON key exists -// - ?| (QuestionPipe): Any of the keys exist -// - ?& (QuestionAnd): All of the keys exist -// - #- (HashMinus): Delete key from JSON -type BinaryExpression struct { - Left Expression - Operator string - Right Expression - Not bool // For NOT (expr) - CustomOp *CustomBinaryOperator // For PostgreSQL custom operators - Pos models.Location // Source position of the operator (1-based line and column) -} - -func (b *BinaryExpression) expressionNode() {} - -func (b *BinaryExpression) TokenLiteral() string { - if b.CustomOp != nil { - return b.CustomOp.String() - } - return b.Operator -} - -func (b BinaryExpression) Children() []Node { - var nodes []Node - if b.Left != nil { - nodes = append(nodes, b.Left) - } - if b.Right != nil { - nodes = append(nodes, b.Right) - } - return nodes -} - -// LiteralValue represents a literal value in SQL -type LiteralValue struct { - Value interface{} - Type string // INTEGER, FLOAT, STRING, BOOLEAN, NULL, etc. -} - -func (l *LiteralValue) expressionNode() {} -func (l LiteralValue) TokenLiteral() string { return fmt.Sprintf("%v", l.Value) } -func (l LiteralValue) Children() []Node { return nil } - -// ListExpression represents a list of expressions (1, 2, 3) -type ListExpression struct { - Values []Expression -} - -func (l *ListExpression) expressionNode() {} -func (l ListExpression) TokenLiteral() string { return "LIST" } -func (l ListExpression) Children() []Node { return nodifyExpressions(l.Values) } - -// TupleExpression represents a row constructor / tuple (col1, col2) for multi-column comparisons -// Used in: WHERE (user_id, status) IN ((1, 'active'), (2, 'pending')) -type TupleExpression struct { - Expressions []Expression -} - -func (t *TupleExpression) expressionNode() {} -func (t TupleExpression) TokenLiteral() string { return "TUPLE" } -func (t TupleExpression) Children() []Node { return nodifyExpressions(t.Expressions) } - -// ArrayConstructorExpression represents PostgreSQL ARRAY constructor syntax. -// Creates an array value from a list of expressions or a subquery. -// -// Examples: -// -// ARRAY[1, 2, 3] - Integer array literal -// ARRAY['admin', 'moderator'] - Text array literal -// ARRAY(SELECT id FROM users) - Array from subquery -type ArrayConstructorExpression struct { - Elements []Expression // Elements inside ARRAY[...] - Subquery *SelectStatement // For ARRAY(SELECT ...) syntax (optional) -} - -func (a *ArrayConstructorExpression) expressionNode() {} -func (a ArrayConstructorExpression) TokenLiteral() string { return "ARRAY" } -func (a ArrayConstructorExpression) Children() []Node { - if a.Subquery != nil { - return []Node{a.Subquery} - } - return nodifyExpressions(a.Elements) -} - -// UnaryExpression represents operations like NOT expr -type UnaryExpression struct { - Operator UnaryOperator - Expr Expression - Pos models.Location // Source position of the operator (1-based line and column) -} - -func (u *UnaryExpression) expressionNode() {} - -func (u *UnaryExpression) TokenLiteral() string { - return u.Operator.String() -} - -func (u UnaryExpression) Children() []Node { - if u.Expr == nil { - return nil - } - return []Node{u.Expr} -} - -// VariantPath represents a Snowflake VARIANT path expression: -// -// col:field.sub[0]::string -// -// The Root is the base expression (typically an Identifier or FunctionCall -// like PARSE_JSON(raw)). Segments is the chain of path steps that follow -// the leading `:`. Each segment is either a field name (Name set) or a -// bracketed index expression (Index set). -type VariantPath struct { - Root Expression - Segments []VariantPathSegment - Pos models.Location -} - -// VariantPathSegment is one step in a VARIANT path: either a field name -// reached via `:` or `.`, or a bracketed index expression. -type VariantPathSegment struct { - Name string // field name (`:field` or `.field`), empty when Index is set - Index Expression // bracket subscript (`[expr]`), nil when Name is set -} - -func (v *VariantPath) expressionNode() {} -func (v VariantPath) TokenLiteral() string { return ":" } -func (v VariantPath) Children() []Node { - var nodes []Node - if v.Root != nil { - nodes = append(nodes, v.Root) - } - for _, seg := range v.Segments { - if seg.Index != nil { - nodes = append(nodes, seg.Index) - } - } - return nodes -} - -// NamedArgument represents a function argument of the form `name => expr`, -// used by Snowflake (FLATTEN(input => col), GENERATOR(rowcount => 100)), -// BigQuery, Oracle, and PostgreSQL procedural calls. -type NamedArgument struct { - Name string - Value Expression - Pos models.Location -} - -func (n *NamedArgument) expressionNode() {} -func (n NamedArgument) TokenLiteral() string { return n.Name } -func (n NamedArgument) Children() []Node { - if n.Value == nil { - return nil - } - return []Node{n.Value} -} - -// CastExpression represents CAST(expr AS type) or TRY_CAST(expr AS type). -// Try is set when the expression originated from TRY_CAST (Snowflake / SQL -// Server / BigQuery), which returns NULL on conversion failure instead of -// raising an error. -type CastExpression struct { - Expr Expression - Type string - Try bool -} - -func (c *CastExpression) expressionNode() {} -func (c CastExpression) TokenLiteral() string { - if c.Try { - return "TRY_CAST" - } - return "CAST" -} -func (c CastExpression) Children() []Node { - if c.Expr == nil { - return nil - } - return []Node{c.Expr} -} - -// AliasedExpression represents an expression with an alias (expr AS alias) -type AliasedExpression struct { - Expr Expression - Alias string -} - -func (a *AliasedExpression) expressionNode() {} -func (a AliasedExpression) TokenLiteral() string { return a.Alias } -func (a AliasedExpression) Children() []Node { - if a.Expr == nil { - return nil - } - return []Node{a.Expr} -} - -// ExtractExpression represents EXTRACT(field FROM source) -type ExtractExpression struct { - Field string - Source Expression -} - -func (e *ExtractExpression) expressionNode() {} -func (e ExtractExpression) TokenLiteral() string { return "EXTRACT" } -func (e ExtractExpression) Children() []Node { - if e.Source == nil { - return nil - } - return []Node{e.Source} -} - -// PositionExpression represents POSITION(substr IN str) -type PositionExpression struct { - Substr Expression - Str Expression -} - -func (p *PositionExpression) expressionNode() {} -func (p PositionExpression) TokenLiteral() string { return "POSITION" } -func (p PositionExpression) Children() []Node { - var nodes []Node - if p.Substr != nil { - nodes = append(nodes, p.Substr) - } - if p.Str != nil { - nodes = append(nodes, p.Str) - } - return nodes -} - -// SubstringExpression represents SUBSTRING(str FROM start [FOR length]) -type SubstringExpression struct { - Str Expression - Start Expression - Length Expression -} - -func (s *SubstringExpression) expressionNode() {} -func (s SubstringExpression) TokenLiteral() string { return "SUBSTRING" } -func (s SubstringExpression) Children() []Node { - children := []Node{s.Str, s.Start} - if s.Length != nil { - children = append(children, s.Length) - } - return children -} - -// IntervalExpression represents INTERVAL 'value' for date/time arithmetic -// Examples: INTERVAL '1 day', INTERVAL '2 hours', INTERVAL '1 year 2 months' -type IntervalExpression struct { - Value string // The interval specification string (e.g., '1 day', '2 hours') -} - -func (i *IntervalExpression) expressionNode() {} -func (i IntervalExpression) TokenLiteral() string { return "INTERVAL" } - -// Children implements Node. IntervalExpression stores its value as a raw -// string (not an Expression), so it has no child nodes. Returns nil for -// consistency with other leaf nodes. -func (i IntervalExpression) Children() []Node { return nil } - -// ArraySubscriptExpression represents array element access syntax. -// Supports single and multi-dimensional array subscripting. -// -// Examples: -// -// tags[1] - Single subscript -// matrix[2][3] - Multi-dimensional subscript -// arr[i] - Subscript with variable -// (SELECT arr)[1] - Subscript on subquery result -type ArraySubscriptExpression struct { - Array Expression // The array expression being subscripted - Indices []Expression // Subscript indices (one or more for multi-dimensional arrays) -} - -func (a *ArraySubscriptExpression) expressionNode() {} -func (a ArraySubscriptExpression) TokenLiteral() string { return "[]" } -func (a ArraySubscriptExpression) Children() []Node { - var children []Node - if a.Array != nil { - children = append(children, a.Array) - } - for _, idx := range a.Indices { - if idx != nil { - children = append(children, idx) - } - } - return children -} - -// ArraySliceExpression represents array slicing syntax for extracting subarrays. -// Supports PostgreSQL-style array slicing with optional start/end bounds. -// -// Examples: -// -// arr[1:3] - Slice from index 1 to 3 (inclusive) -// arr[2:] - Slice from index 2 to end -// arr[:5] - Slice from start to index 5 -// arr[:] - Full array slice (copy) -type ArraySliceExpression struct { - Array Expression // The array expression being sliced - Start Expression // Start index (nil means from beginning) - End Expression // End index (nil means to end) -} - -func (a *ArraySliceExpression) expressionNode() {} -func (a ArraySliceExpression) TokenLiteral() string { return "[:]" } -func (a ArraySliceExpression) Children() []Node { - var children []Node - if a.Array != nil { - children = append(children, a.Array) - } - if a.Start != nil { - children = append(children, a.Start) - } - if a.End != nil { - children = append(children, a.End) - } - return children -} - -// InsertStatement represents an INSERT SQL statement -type InsertStatement struct { - With *WithClause - TableName string - Columns []Expression - Output []Expression // SQL Server OUTPUT clause columns - Values [][]Expression // Multi-row support: each inner slice is one row of values - Query QueryExpression // For INSERT ... SELECT (SelectStatement or SetOperation) - Returning []Expression - OnConflict *OnConflict - OnDuplicateKey *UpsertClause // MySQL: ON DUPLICATE KEY UPDATE - Pos models.Location // Source position of the INSERT keyword (1-based line and column) -} - -func (i *InsertStatement) statementNode() {} -func (i InsertStatement) TokenLiteral() string { return "INSERT" } - -func (i InsertStatement) Children() []Node { - children := make([]Node, 0) - if i.With != nil { - children = append(children, i.With) - } - children = append(children, nodifyExpressions(i.Columns)...) - children = append(children, nodifyExpressions(i.Output)...) - // Flatten multi-row values for Children() - for _, row := range i.Values { - children = append(children, nodifyExpressions(row)...) - } - if i.Query != nil { - children = append(children, i.Query) - } - children = append(children, nodifyExpressions(i.Returning)...) - if i.OnConflict != nil { - children = append(children, i.OnConflict) - } - if i.OnDuplicateKey != nil { - children = append(children, i.OnDuplicateKey) - } - return children -} - -// OnConflict represents ON CONFLICT DO UPDATE/NOTHING clause -type OnConflict struct { - Target []Expression // Target columns - Constraint string // Optional constraint name - Action OnConflictAction -} - -func (o *OnConflict) expressionNode() {} -func (o OnConflict) TokenLiteral() string { return "ON CONFLICT" } -func (o OnConflict) Children() []Node { - children := nodifyExpressions(o.Target) - if o.Action.DoUpdate != nil { - for _, update := range o.Action.DoUpdate { - update := update // G601: Create local copy to avoid memory aliasing - children = append(children, &update) - } - } - if o.Action.Where != nil { - children = append(children, o.Action.Where) - } - return children -} - -// OnConflictAction represents DO UPDATE/NOTHING in ON CONFLICT clause -type OnConflictAction struct { - DoNothing bool - DoUpdate []UpdateExpression - Where Expression -} - -// UpsertClause represents INSERT ... ON DUPLICATE KEY UPDATE -type UpsertClause struct { - Updates []UpdateExpression -} - -func (u *UpsertClause) expressionNode() {} -func (u UpsertClause) TokenLiteral() string { return "ON DUPLICATE KEY UPDATE" } -func (u UpsertClause) Children() []Node { - children := make([]Node, len(u.Updates)) - for i, update := range u.Updates { - update := update // G601: Create local copy to avoid memory aliasing - children[i] = &update - } - return children -} - -// Values represents VALUES clause -type Values struct { - Rows [][]Expression -} - -func (v *Values) statementNode() {} -func (v Values) TokenLiteral() string { return "VALUES" } -func (v Values) Children() []Node { - children := make([]Node, 0) - for _, row := range v.Rows { - children = append(children, nodifyExpressions(row)...) - } - return children -} - -// UpdateStatement represents an UPDATE SQL statement -type UpdateStatement struct { - With *WithClause - TableName string - Alias string - Assignments []UpdateExpression // SET clause assignments - From []TableReference - Where Expression - Returning []Expression - Pos models.Location // Source position of the UPDATE keyword (1-based line and column) -} - -// GetUpdates returns Assignments for backward compatibility. -// -// Deprecated: Use Assignments directly instead. -func (u *UpdateStatement) GetUpdates() []UpdateExpression { - return u.Assignments -} - -func (u *UpdateStatement) statementNode() {} -func (u UpdateStatement) TokenLiteral() string { return "UPDATE" } - -func (u UpdateStatement) Children() []Node { - children := make([]Node, 0) - if u.With != nil { - children = append(children, u.With) - } - for _, assignment := range u.Assignments { - assignment := assignment // G601: Create local copy to avoid memory aliasing - children = append(children, &assignment) - } - for _, from := range u.From { - from := from // G601: Create local copy to avoid memory aliasing - children = append(children, &from) - } - if u.Where != nil { - children = append(children, u.Where) - } - children = append(children, nodifyExpressions(u.Returning)...) - return children -} - -// CreateTableStatement represents a CREATE TABLE statement -type CreateTableStatement struct { - IfNotExists bool - Temporary bool - Name string - Columns []ColumnDef - Constraints []TableConstraint - Inherits []string - PartitionBy *PartitionBy - Partitions []PartitionDefinition // Individual partition definitions - Options []TableOption - WithoutRowID bool // SQLite: CREATE TABLE ... WITHOUT ROWID - - // WithSystemVersioning enables system-versioned temporal history (MariaDB 10.3.4+). - // Example: CREATE TABLE t (...) WITH SYSTEM VERSIONING - WithSystemVersioning bool - - // PeriodDefinitions holds PERIOD FOR clauses for application-time or system-time periods. - // Example: PERIOD FOR app_time (start_col, end_col) - PeriodDefinitions []*PeriodDefinition -} - -func (c *CreateTableStatement) statementNode() {} -func (c CreateTableStatement) TokenLiteral() string { return "CREATE TABLE" } -func (c CreateTableStatement) Children() []Node { - children := make([]Node, 0) - for _, col := range c.Columns { - col := col // G601: Create local copy to avoid memory aliasing - children = append(children, &col) - } - for _, constraint := range c.Constraints { - constraint := constraint // G601: Create local copy to avoid memory aliasing - children = append(children, &constraint) - } - if c.PartitionBy != nil { - children = append(children, c.PartitionBy) - } - for _, p := range c.Partitions { - p := p // G601: Create local copy - children = append(children, &p) - } - return children -} - -// ColumnDef represents a column definition in CREATE TABLE -type ColumnDef struct { - Name string - Type string - Constraints []ColumnConstraint -} - -func (c *ColumnDef) expressionNode() {} -func (c ColumnDef) TokenLiteral() string { return c.Name } -func (c ColumnDef) Children() []Node { - children := make([]Node, len(c.Constraints)) - for i, constraint := range c.Constraints { - constraint := constraint // G601: Create local copy to avoid memory aliasing - children[i] = &constraint - } - return children -} - -// ColumnConstraint represents a column constraint -type ColumnConstraint struct { - Type string // NOT NULL, UNIQUE, PRIMARY KEY, etc. - Default Expression - References *ReferenceDefinition - Check Expression - AutoIncrement bool -} - -func (c *ColumnConstraint) expressionNode() {} -func (c ColumnConstraint) TokenLiteral() string { return c.Type } -func (c ColumnConstraint) Children() []Node { - children := make([]Node, 0) - if c.Default != nil { - children = append(children, c.Default) - } - if c.References != nil { - children = append(children, c.References) - } - if c.Check != nil { - children = append(children, c.Check) - } - return children -} - -// TableConstraint represents a table constraint -type TableConstraint struct { - Name string - Type string // PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK - Columns []string - References *ReferenceDefinition - Check Expression -} - -func (t *TableConstraint) expressionNode() {} -func (t TableConstraint) TokenLiteral() string { return t.Type } -func (t TableConstraint) Children() []Node { - children := make([]Node, 0) - if t.References != nil { - children = append(children, t.References) - } - if t.Check != nil { - children = append(children, t.Check) - } - return children -} - -// ReferenceDefinition represents a REFERENCES clause -type ReferenceDefinition struct { - Table string - Columns []string - OnDelete string - OnUpdate string - Match string -} - -func (r *ReferenceDefinition) expressionNode() {} -func (r ReferenceDefinition) TokenLiteral() string { return "REFERENCES" } -func (r ReferenceDefinition) Children() []Node { return nil } - -// PartitionBy represents a PARTITION BY clause -type PartitionBy struct { - Type string // RANGE, LIST, HASH - Columns []string - Boundary []Expression -} - -func (p *PartitionBy) expressionNode() {} -func (p PartitionBy) TokenLiteral() string { return "PARTITION BY" } -func (p PartitionBy) Children() []Node { return nodifyExpressions(p.Boundary) } - -// TableOption represents table options like ENGINE, CHARSET, etc. -type TableOption struct { - Name string - Value string -} - -func (t *TableOption) expressionNode() {} -func (t TableOption) TokenLiteral() string { return t.Name } -func (t TableOption) Children() []Node { return nil } - -// UpdateExpression represents a column=value expression in UPDATE -type UpdateExpression struct { - Column Expression - Value Expression -} - -func (u *UpdateExpression) expressionNode() {} -func (u UpdateExpression) TokenLiteral() string { return "=" } -func (u UpdateExpression) Children() []Node { - var nodes []Node - if u.Column != nil { - nodes = append(nodes, u.Column) - } - if u.Value != nil { - nodes = append(nodes, u.Value) +// Helper function to convert []Expression to []Node +func nodifyExpressions(exprs []Expression) []Node { + nodes := make([]Node, len(exprs)) + for i, expr := range exprs { + nodes[i] = expr } return nodes } -// DeleteStatement represents a DELETE SQL statement -type DeleteStatement struct { - With *WithClause - TableName string - Alias string - Using []TableReference - Where Expression - Returning []Expression - Pos models.Location // Source position of the DELETE keyword (1-based line and column) -} - -func (d *DeleteStatement) statementNode() {} -func (d DeleteStatement) TokenLiteral() string { return "DELETE" } - -func (d DeleteStatement) Children() []Node { - children := make([]Node, 0) - if d.With != nil { - children = append(children, d.With) - } - for _, using := range d.Using { - using := using // G601: Create local copy to avoid memory aliasing - children = append(children, &using) - } - if d.Where != nil { - children = append(children, d.Where) - } - children = append(children, nodifyExpressions(d.Returning)...) - return children -} - -// AlterTableStatement represents an ALTER TABLE statement. -// -// # Maintenance note -// -// AlterTableStatement is NOT produced by the parser. Parser.Parse* methods -// return [AlterStatement] (defined in alter.go) with Type == AlterTypeTable. -// AlterTableStatement is retained only so that existing code that constructs -// it directly (e.g. in tests or manual AST construction) continues to compile. -// -// Migration guide - prefer AlterStatement for all new code: -// -// // Wrong (type assertion will never succeed at runtime): -// stmt := tree.Statements[0].(*ast.AlterTableStatement) -// -// // Correct: -// stmt := tree.Statements[0].(*ast.AlterStatement) -// tableName := stmt.Name // AlterStatement.Name holds the table name -type AlterTableStatement struct { - Table string - Actions []AlterTableAction -} - -func (a *AlterTableStatement) statementNode() {} -func (a AlterTableStatement) TokenLiteral() string { return "ALTER TABLE" } -func (a AlterTableStatement) Children() []Node { - children := make([]Node, len(a.Actions)) - for i, action := range a.Actions { - action := action // G601: Create local copy to avoid memory aliasing - children[i] = &action - } - return children -} - -// AlterTableAction represents an action in ALTER TABLE -type AlterTableAction struct { - Type string // ADD COLUMN, DROP COLUMN, MODIFY COLUMN, etc. - ColumnName string - ColumnDef *ColumnDef - Constraint *TableConstraint -} - -func (a *AlterTableAction) expressionNode() {} -func (a AlterTableAction) TokenLiteral() string { return a.Type } -func (a AlterTableAction) Children() []Node { - children := make([]Node, 0) - if a.ColumnDef != nil { - children = append(children, a.ColumnDef) - } - if a.Constraint != nil { - children = append(children, a.Constraint) - } - return children -} - -// CreateIndexStatement represents a CREATE INDEX statement -type CreateIndexStatement struct { - Unique bool - IfNotExists bool - Name string - Table string - Columns []IndexColumn - Using string - Where Expression -} - -func (c *CreateIndexStatement) statementNode() {} -func (c CreateIndexStatement) TokenLiteral() string { return "CREATE INDEX" } -func (c CreateIndexStatement) Children() []Node { - children := make([]Node, 0) - for _, col := range c.Columns { - col := col // G601: Create local copy to avoid memory aliasing - children = append(children, &col) - } - if c.Where != nil { - children = append(children, c.Where) - } - return children -} - -// IndexColumn represents a column in an index definition -type IndexColumn struct { - Column string - Collate string - Direction string // ASC, DESC - NullsLast bool -} - -func (i *IndexColumn) expressionNode() {} -func (i IndexColumn) TokenLiteral() string { return i.Column } -func (i IndexColumn) Children() []Node { return nil } - -// MergeStatement represents a MERGE statement (SQL:2003 F312) -// Syntax: MERGE INTO target USING source ON condition -// -// WHEN MATCHED THEN UPDATE/DELETE -// WHEN NOT MATCHED THEN INSERT -// WHEN NOT MATCHED BY SOURCE THEN UPDATE/DELETE -type MergeStatement struct { - TargetTable TableReference // The table being merged into - TargetAlias string // Optional alias for target - SourceTable TableReference // The source table or subquery - SourceAlias string // Optional alias for source - OnCondition Expression // The join/match condition - WhenClauses []*MergeWhenClause // List of WHEN clauses - Output []Expression // SQL Server OUTPUT clause columns -} - -func (m *MergeStatement) statementNode() {} -func (m MergeStatement) TokenLiteral() string { return "MERGE" } -func (m MergeStatement) Children() []Node { - children := []Node{&m.TargetTable, &m.SourceTable} - if m.OnCondition != nil { - children = append(children, m.OnCondition) - } - for _, when := range m.WhenClauses { - children = append(children, when) - } - children = append(children, nodifyExpressions(m.Output)...) - return children -} - -// MergeWhenClause represents a WHEN clause in a MERGE statement -// Types: MATCHED, NOT_MATCHED, NOT_MATCHED_BY_SOURCE -type MergeWhenClause struct { - Type string // "MATCHED", "NOT_MATCHED", "NOT_MATCHED_BY_SOURCE" - Condition Expression // Optional AND condition - Action *MergeAction // The action to perform (UPDATE/INSERT/DELETE) -} - -func (w *MergeWhenClause) expressionNode() {} -func (w MergeWhenClause) TokenLiteral() string { return "WHEN " + w.Type } -func (w MergeWhenClause) Children() []Node { - children := make([]Node, 0) - if w.Condition != nil { - children = append(children, w.Condition) - } - if w.Action != nil { - children = append(children, w.Action) - } - return children -} - -// MergeAction represents the action in a WHEN clause -// ActionType: UPDATE, INSERT, DELETE -type MergeAction struct { - ActionType string // "UPDATE", "INSERT", "DELETE" - SetClauses []SetClause // For UPDATE: SET column = value pairs - Columns []string // For INSERT: column list - Values []Expression // For INSERT: value list - DefaultValues bool // For INSERT: use DEFAULT VALUES -} - -func (a *MergeAction) expressionNode() {} -func (a MergeAction) TokenLiteral() string { return a.ActionType } -func (a MergeAction) Children() []Node { - children := make([]Node, 0) - for _, set := range a.SetClauses { - set := set // G601: Create local copy - children = append(children, &set) - } - for _, val := range a.Values { - children = append(children, val) - } - return children -} - -// SetClause represents a SET clause in UPDATE (also used in MERGE UPDATE) -type SetClause struct { - Column string - Value Expression -} - -func (s *SetClause) expressionNode() {} -func (s SetClause) TokenLiteral() string { return s.Column } -func (s SetClause) Children() []Node { - if s.Value != nil { - return []Node{s.Value} - } - return nil -} - -// CreateViewStatement represents a CREATE VIEW statement -// Syntax: CREATE [OR REPLACE] [TEMP|TEMPORARY] VIEW [IF NOT EXISTS] name [(columns)] AS select -type CreateViewStatement struct { - OrReplace bool - Temporary bool - IfNotExists bool - Name string - Columns []string // Optional column list - Query Statement // The SELECT statement - WithOption string // PostgreSQL: WITH (CHECK OPTION | CASCADED | LOCAL) -} - -func (c *CreateViewStatement) statementNode() {} -func (c CreateViewStatement) TokenLiteral() string { return "CREATE VIEW" } -func (c CreateViewStatement) Children() []Node { - if c.Query != nil { - return []Node{c.Query} - } - return nil -} - -// CreateMaterializedViewStatement represents a CREATE MATERIALIZED VIEW statement -// Syntax: CREATE MATERIALIZED VIEW [IF NOT EXISTS] name [(columns)] AS select [WITH [NO] DATA] -type CreateMaterializedViewStatement struct { - IfNotExists bool - Name string - Columns []string // Optional column list - Query Statement // The SELECT statement - WithData *bool // nil = default, true = WITH DATA, false = WITH NO DATA - Tablespace string // Optional tablespace (PostgreSQL) -} - -func (c *CreateMaterializedViewStatement) statementNode() {} -func (c CreateMaterializedViewStatement) TokenLiteral() string { return "CREATE MATERIALIZED VIEW" } -func (c CreateMaterializedViewStatement) Children() []Node { - if c.Query != nil { - return []Node{c.Query} - } - return nil -} - -// RefreshMaterializedViewStatement represents a REFRESH MATERIALIZED VIEW statement -// Syntax: REFRESH MATERIALIZED VIEW [CONCURRENTLY] name [WITH [NO] DATA] -type RefreshMaterializedViewStatement struct { - Concurrently bool - Name string - WithData *bool // nil = default, true = WITH DATA, false = WITH NO DATA -} - -func (r *RefreshMaterializedViewStatement) statementNode() {} -func (r RefreshMaterializedViewStatement) TokenLiteral() string { return "REFRESH MATERIALIZED VIEW" } -func (r RefreshMaterializedViewStatement) Children() []Node { return nil } - -// DropStatement represents a DROP statement for tables, views, indexes, etc. -// Syntax: DROP object_type [IF EXISTS] name [CASCADE|RESTRICT] -type DropStatement struct { - ObjectType string // TABLE, VIEW, MATERIALIZED VIEW, INDEX, etc. - IfExists bool - Names []string // Can drop multiple objects - CascadeType string // CASCADE, RESTRICT, or empty -} - -func (d *DropStatement) statementNode() {} -func (d DropStatement) TokenLiteral() string { return "DROP " + d.ObjectType } -func (d DropStatement) Children() []Node { return nil } - -// TruncateStatement represents a TRUNCATE TABLE statement -// Syntax: TRUNCATE [TABLE] table_name [, table_name ...] [RESTART IDENTITY | CONTINUE IDENTITY] [CASCADE | RESTRICT] -type TruncateStatement struct { - Tables []string // Table names to truncate - RestartIdentity bool // RESTART IDENTITY - reset sequences - ContinueIdentity bool // CONTINUE IDENTITY - keep sequences (default) - CascadeType string // CASCADE, RESTRICT, or empty +// Identifier represents a column or table name +type Identifier struct { + Name string + Table string // Optional table qualifier + Pos models.Location // Source position of this identifier (1-based line and column) } -func (t *TruncateStatement) statementNode() {} -func (t TruncateStatement) TokenLiteral() string { return "TRUNCATE TABLE" } -func (t TruncateStatement) Children() []Node { return nil } +func (i *Identifier) expressionNode() {} +func (i Identifier) TokenLiteral() string { return i.Name } +func (i Identifier) Children() []Node { return nil } -// PartitionDefinition represents a partition definition in CREATE TABLE -// Syntax: PARTITION name VALUES { LESS THAN (expr) | IN (list) | FROM (expr) TO (expr) } -type PartitionDefinition struct { - Name string - Type string // FOR VALUES, IN, LESS THAN - Values []Expression // Partition values or bounds - LessThan Expression // For RANGE: LESS THAN (value) - From Expression // For RANGE: FROM (value) - To Expression // For RANGE: TO (value) - InValues []Expression // For LIST: IN (values) - Tablespace string // Optional tablespace +// LiteralValue represents a literal value in SQL +type LiteralValue struct { + Value interface{} + Type string // INTEGER, FLOAT, STRING, BOOLEAN, NULL, etc. } -func (p *PartitionDefinition) expressionNode() {} -func (p PartitionDefinition) TokenLiteral() string { return "PARTITION " + p.Name } -func (p PartitionDefinition) Children() []Node { - children := make([]Node, 0) - for _, v := range p.Values { - children = append(children, v) - } - if p.LessThan != nil { - children = append(children, p.LessThan) - } - if p.From != nil { - children = append(children, p.From) - } - if p.To != nil { - children = append(children, p.To) - } - for _, v := range p.InValues { - children = append(children, v) - } - return children -} +func (l *LiteralValue) expressionNode() {} +func (l LiteralValue) TokenLiteral() string { return fmt.Sprintf("%v", l.Value) } +func (l LiteralValue) Children() []Node { return nil } // AST represents the root of the Abstract Syntax Tree produced by parsing one or // more SQL statements. @@ -2041,423 +176,3 @@ func (a AST) HasUnsupportedStatements() bool { } return false } - -// PragmaStatement represents a SQLite PRAGMA statement. -// Examples: PRAGMA table_info(users), PRAGMA journal_mode = WAL, PRAGMA integrity_check -type PragmaStatement struct { - Name string // Pragma name, e.g. "table_info" - Arg string // Optional: parenthesized arg, e.g. "users" - Value string // Optional: assigned value, e.g. "WAL" -} - -func (p *PragmaStatement) statementNode() {} -func (p PragmaStatement) TokenLiteral() string { return "PRAGMA" } -func (p PragmaStatement) Children() []Node { return nil } - -// ShowStatement represents MySQL SHOW commands (SHOW TABLES, SHOW DATABASES, SHOW CREATE TABLE x, etc.) -type ShowStatement struct { - ShowType string // TABLES, DATABASES, CREATE TABLE, COLUMNS, INDEX, etc. - ObjectName string // For SHOW CREATE TABLE x, SHOW COLUMNS FROM x, etc. - From string // For SHOW ... FROM database -} - -func (s *ShowStatement) statementNode() {} -func (s ShowStatement) TokenLiteral() string { return "SHOW" } -func (s ShowStatement) Children() []Node { return nil } - -// DescribeStatement represents MySQL DESCRIBE/DESC/EXPLAIN table commands -type DescribeStatement struct { - TableName string -} - -func (d *DescribeStatement) statementNode() {} -func (d DescribeStatement) TokenLiteral() string { return "DESCRIBE" } -func (d DescribeStatement) Children() []Node { return nil } - -// UnsupportedStatement represents a SQL statement that was parsed but not -// fully modeled in the AST. The parser consumed and validated the tokens -// but no dedicated AST node exists yet for this statement kind. -// -// Consumers should use Kind to identify the operation (e.g., "USE", "COPY", -// "CREATE STAGE") and RawSQL for the original text. Tools that do -// switch stmt.(type) should handle this case explicitly rather than -// falling through to a default that assumes the statement is well-structured. -type UnsupportedStatement struct { - Kind string // Operation kind: "USE", "COPY", "PUT", "GET", "LIST", "REMOVE", "CREATE STAGE", etc. - RawSQL string // Original SQL fragment for round-trip fidelity -} - -func (u *UnsupportedStatement) statementNode() {} -func (u UnsupportedStatement) TokenLiteral() string { return u.Kind } -func (u UnsupportedStatement) Children() []Node { return nil } - -// ReplaceStatement represents MySQL REPLACE INTO statement -type ReplaceStatement struct { - TableName string - Columns []Expression - Values [][]Expression -} - -func (r *ReplaceStatement) statementNode() {} -func (r ReplaceStatement) TokenLiteral() string { return "REPLACE" } -func (r ReplaceStatement) Children() []Node { - children := make([]Node, 0) - children = append(children, nodifyExpressions(r.Columns)...) - for _, row := range r.Values { - children = append(children, nodifyExpressions(row)...) - } - return children -} - -// ── MariaDB SEQUENCE DDL (10.3+) ─────────────────────────────────────────── - -// CycleOption represents the CYCLE behavior for a sequence. -type CycleOption int - -const ( - // CycleUnspecified means no CYCLE or NOCYCLE clause was given (database default applies). - CycleUnspecified CycleOption = iota - // CycleBehavior means CYCLE — sequence wraps around when it reaches min/max. - CycleBehavior - // NoCycleBehavior means NOCYCLE / NO CYCLE — sequence errors on overflow. - NoCycleBehavior -) - -// SequenceOptions holds configuration for CREATE SEQUENCE and ALTER SEQUENCE. -// Fields are pointers so that unspecified options are distinguishable from zero values. -type SequenceOptions struct { - StartWith *LiteralValue // START WITH n - IncrementBy *LiteralValue // INCREMENT BY n (default 1) - MinValue *LiteralValue // MINVALUE n or nil when NO MINVALUE - MaxValue *LiteralValue // MAXVALUE n or nil when NO MAXVALUE - Cache *LiteralValue // CACHE n or nil when NO CACHE / NOCACHE - CycleMode CycleOption // CYCLE / NOCYCLE / NO CYCLE (CycleUnspecified if not specified) - NoCache bool // NOCACHE (explicit; Cache=nil alone is ambiguous) - Restart bool // bare RESTART (reset to start value) - RestartWith *LiteralValue // RESTART WITH n (explicit restart value) -} - -// CreateSequenceStatement represents: -// -// CREATE [OR REPLACE] SEQUENCE [IF NOT EXISTS] name [options...] -type CreateSequenceStatement struct { - Name *Identifier - OrReplace bool - IfNotExists bool - Options SequenceOptions - Pos models.Location // Source position of the CREATE keyword (1-based line and column) -} - -func (s *CreateSequenceStatement) statementNode() {} -func (s *CreateSequenceStatement) TokenLiteral() string { return "CREATE" } -func (s *CreateSequenceStatement) Children() []Node { - if s.Name != nil { - return []Node{s.Name} - } - return nil -} - -// DropSequenceStatement represents: -// -// DROP SEQUENCE [IF EXISTS | IF NOT EXISTS] name -type DropSequenceStatement struct { - Name *Identifier - IfExists bool - Pos models.Location // Source position of the DROP keyword (1-based line and column) -} - -func (s *DropSequenceStatement) statementNode() {} -func (s *DropSequenceStatement) TokenLiteral() string { return "DROP" } -func (s *DropSequenceStatement) Children() []Node { - if s.Name != nil { - return []Node{s.Name} - } - return nil -} - -// AlterSequenceStatement represents: -// -// ALTER SEQUENCE [IF EXISTS] name [options...] -type AlterSequenceStatement struct { - Name *Identifier - IfExists bool - Options SequenceOptions - Pos models.Location // Source position of the ALTER keyword (1-based line and column) -} - -func (s *AlterSequenceStatement) statementNode() {} -func (s *AlterSequenceStatement) TokenLiteral() string { return "ALTER" } -func (s *AlterSequenceStatement) Children() []Node { - if s.Name != nil { - return []Node{s.Name} - } - return nil -} - -// ── MariaDB Temporal Table Types (10.3.4+) ──────────────────────────────── - -// SystemTimeClauseType identifies the kind of FOR SYSTEM_TIME clause. -type SystemTimeClauseType int - -const ( - SystemTimeAsOf SystemTimeClauseType = iota // FOR SYSTEM_TIME AS OF - SystemTimeBetween // FOR SYSTEM_TIME BETWEEN AND - SystemTimeFromTo // FOR SYSTEM_TIME FROM TO - SystemTimeAll // FOR SYSTEM_TIME ALL -) - -// ForSystemTimeClause represents a temporal query on a system-versioned table. -// -// SELECT * FROM t FOR SYSTEM_TIME AS OF TIMESTAMP '2024-01-01'; -// SELECT * FROM t FOR SYSTEM_TIME BETWEEN '2020-01-01' AND '2024-01-01'; -// SELECT * FROM t FOR SYSTEM_TIME ALL; -type ForSystemTimeClause struct { - Type SystemTimeClauseType - Point Expression // used for AS OF - Start Expression // used for BETWEEN, FROM - End Expression // used for BETWEEN (AND), TO - Pos models.Location // Source position of the FOR keyword (1-based line and column) -} - -// expressionNode satisfies the Expression interface so ForSystemTimeClause can be -// stored in TableReference.ForSystemTime without a separate interface type. -// Semantically it is a table-level clause, not a scalar expression. -func (c *ForSystemTimeClause) expressionNode() {} -func (c ForSystemTimeClause) TokenLiteral() string { return "FOR SYSTEM_TIME" } -func (c ForSystemTimeClause) Children() []Node { - var nodes []Node - if c.Point != nil { - nodes = append(nodes, c.Point) - } - if c.Start != nil { - nodes = append(nodes, c.Start) - } - if c.End != nil { - nodes = append(nodes, c.End) - } - return nodes -} - -// TimeTravelClause represents the Snowflake time-travel / change-tracking -// modifier on a table reference: -// -// SELECT ... FROM t AT (TIMESTAMP => '2024-01-01'::TIMESTAMP) -// SELECT ... FROM t BEFORE (STATEMENT => '...uuid...') -// SELECT ... FROM t CHANGES (INFORMATION => DEFAULT) AT (...) -// -// Kind is one of "AT", "BEFORE", "CHANGES". Named holds the -// `name => expr` arguments keyed by upper-cased name (e.g. TIMESTAMP, -// OFFSET, STATEMENT, INFORMATION). Multiple clauses may chain (CHANGES -// plus AT); extra clauses are appended to Chained. -type TimeTravelClause struct { - Kind string // "AT" | "BEFORE" | "CHANGES" - Named map[string]Expression - Chained []*TimeTravelClause - Pos models.Location -} - -func (c *TimeTravelClause) expressionNode() {} -func (c TimeTravelClause) TokenLiteral() string { return c.Kind } -func (c TimeTravelClause) Children() []Node { - var nodes []Node - for _, v := range c.Named { - if v != nil { - nodes = append(nodes, v) - } - } - for _, ch := range c.Chained { - if ch != nil { - nodes = append(nodes, ch) - } - } - return nodes -} - -// PivotClause represents the SQL Server / Oracle PIVOT operator for row-to-column -// transformation in a FROM clause. -// -// PIVOT (SUM(sales) FOR region IN ([North], [South], [East], [West])) AS pvt -type PivotClause struct { - AggregateFunction Expression // The aggregate function, e.g. SUM(sales) - PivotColumn string // The column used for pivoting, e.g. region - InValues []string // The values to pivot on, e.g. [North], [South] - Pos models.Location // Source position of the PIVOT keyword -} - -func (p *PivotClause) expressionNode() {} -func (p PivotClause) TokenLiteral() string { return "PIVOT" } -func (p PivotClause) Children() []Node { - if p.AggregateFunction != nil { - return []Node{p.AggregateFunction} - } - return nil -} - -// UnpivotClause represents the SQL Server / Oracle UNPIVOT operator for column-to-row -// transformation in a FROM clause. -// -// UNPIVOT (sales FOR region IN (north_sales, south_sales, east_sales)) AS unpvt -type UnpivotClause struct { - ValueColumn string // The target value column, e.g. sales - NameColumn string // The target name column, e.g. region - InColumns []string // The source columns to unpivot, e.g. north_sales, south_sales - Pos models.Location // Source position of the UNPIVOT keyword -} - -func (u *UnpivotClause) expressionNode() {} -func (u UnpivotClause) TokenLiteral() string { return "UNPIVOT" } -func (u UnpivotClause) Children() []Node { return nil } - -// PeriodDefinition represents a PERIOD FOR clause in CREATE TABLE. -// -// PERIOD FOR app_time (start_col, end_col) -// PERIOD FOR SYSTEM_TIME (row_start, row_end) -type PeriodDefinition struct { - Name *Identifier // period name (e.g., "app_time") or SYSTEM_TIME - StartCol *Identifier - EndCol *Identifier - Pos models.Location // Source position of the PERIOD FOR keyword (1-based line and column) -} - -// MatchRecognizeClause represents the SQL:2016 MATCH_RECOGNIZE clause for -// row-pattern recognition in a FROM clause (Snowflake, Oracle, Databricks). -// -// MATCH_RECOGNIZE ( -// PARTITION BY symbol -// ORDER BY ts -// MEASURES MATCH_NUMBER() AS m -// ALL ROWS PER MATCH -// PATTERN (UP+ DOWN+) -// DEFINE UP AS price > PREV(price), DOWN AS price < PREV(price) -// ) -type MatchRecognizeClause struct { - PartitionBy []Expression - OrderBy []OrderByExpression - Measures []MeasureDef - RowsPerMatch string // "ONE ROW PER MATCH" or "ALL ROWS PER MATCH" (empty = default) - AfterMatch string // raw text: "SKIP TO NEXT ROW", "SKIP PAST LAST ROW", etc. - Pattern string // raw pattern text: "UP+ DOWN+" - Definitions []PatternDef - Pos models.Location -} - -// MeasureDef is one MEASURES entry: expr AS alias. -type MeasureDef struct { - Expr Expression - Alias string -} - -// PatternDef is one DEFINE entry: variable_name AS boolean_condition. -type PatternDef struct { - Name string - Condition Expression -} - -func (m *MatchRecognizeClause) expressionNode() {} -func (m MatchRecognizeClause) TokenLiteral() string { return "MATCH_RECOGNIZE" } -func (m MatchRecognizeClause) Children() []Node { - var nodes []Node - nodes = append(nodes, nodifyExpressions(m.PartitionBy)...) - for _, ob := range m.OrderBy { - ob := ob - nodes = append(nodes, &ob) - } - for _, md := range m.Measures { - if md.Expr != nil { - nodes = append(nodes, md.Expr) - } - } - for _, pd := range m.Definitions { - if pd.Condition != nil { - nodes = append(nodes, pd.Condition) - } - } - return nodes -} - -// expressionNode satisfies the Expression interface so PeriodDefinition can be -// stored in CreateTableStatement.PeriodDefinitions without a separate interface type. -// Semantically it is a table column constraint, not a scalar expression. -func (p *PeriodDefinition) expressionNode() {} -func (p PeriodDefinition) TokenLiteral() string { return "PERIOD FOR" } -func (p PeriodDefinition) Children() []Node { - var nodes []Node - if p.Name != nil { - nodes = append(nodes, p.Name) - } - if p.StartCol != nil { - nodes = append(nodes, p.StartCol) - } - if p.EndCol != nil { - nodes = append(nodes, p.EndCol) - } - return nodes -} - -// ── MariaDB Hierarchical Query / CONNECT BY (10.2+) ─────────────────────── - -// ConnectByClause represents the CONNECT BY hierarchical query clause (MariaDB 10.2+). -// -// SELECT id, name FROM t -// START WITH parent_id IS NULL -// CONNECT BY NOCYCLE PRIOR id = parent_id; -type ConnectByClause struct { - NoCycle bool // NOCYCLE modifier — prevents loops in cyclic graphs - Condition Expression // the PRIOR expression (e.g., PRIOR id = parent_id) - Pos models.Location // Source position of the CONNECT BY keyword (1-based line and column) -} - -// expressionNode satisfies the Expression interface so ConnectByClause can be -// stored in SelectStatement.ConnectBy without a separate interface type. -// Semantically it is a query-level clause, not a scalar expression. -func (c *ConnectByClause) expressionNode() {} -func (c ConnectByClause) TokenLiteral() string { return "CONNECT BY" } -func (c ConnectByClause) Children() []Node { - if c.Condition != nil { - return []Node{c.Condition} - } - return nil -} - -// SampleClause represents a ClickHouse SAMPLE clause on a SELECT statement. -// -// ClickHouse supports three sampling forms: -// -// SAMPLE 0.1 — ratio (10% of data) -// SAMPLE 1000 — approximate row count -// SAMPLE 1/10 — fraction (1 part out of 10) -// SAMPLE 1/10 OFFSET 2/10 — fraction with offset -// -// The clause is dialect-specific to ClickHouse (and partly Snowflake/Redshift -// via TABLESAMPLE, but this implementation targets SAMPLE). -// Value is stored as a raw string to preserve the original representation -// (e.g., "0.1", "1000", "1/10"). -// ArrayJoinClause represents a ClickHouse ARRAY JOIN or LEFT ARRAY JOIN clause. -// Syntax: [LEFT] ARRAY JOIN expr [AS alias], expr [AS alias], ... -type ArrayJoinClause struct { - Left bool // true for LEFT ARRAY JOIN - Elements []ArrayJoinElement // One or more join elements - Pos models.Location -} - -// ArrayJoinElement is a single expression in an ARRAY JOIN clause with an optional alias. -type ArrayJoinElement struct { - Expr Expression - Alias string -} - -type SampleClause struct { - // Value is the sampling size/ratio as a raw token string (e.g., "0.1", "1000", "1/10"). - Value string - // Denominator is set when the fraction form "N/D" is used (denominator part). - Denominator string - // Offset is the optional OFFSET fraction (e.g., "2/10" in SAMPLE 1/10 OFFSET 2/10). - Offset string - // OffsetDenominator is set for fractional offsets. - OffsetDenominator string - Pos models.Location -} - -func (s *SampleClause) expressionNode() {} -func (s SampleClause) TokenLiteral() string { return "SAMPLE" } -func (s SampleClause) Children() []Node { return nil } diff --git a/pkg/sql/ast/ast_clauses.go b/pkg/sql/ast/ast_clauses.go new file mode 100644 index 00000000..a1ea46c1 --- /dev/null +++ b/pkg/sql/ast/ast_clauses.go @@ -0,0 +1,791 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import ( + "github.com/ajitpratap0/GoSQLX/pkg/models" +) + +// WithClause represents a WITH clause in a SQL statement. +// It supports both simple and recursive Common Table Expressions (CTEs). +// Phase 2 Complete: Full parser integration with all statement types. +type WithClause struct { + Recursive bool + CTEs []*CommonTableExpr + Pos models.Location // Source position of the WITH keyword (1-based line and column) +} + +func (w *WithClause) statementNode() {} +func (w WithClause) TokenLiteral() string { return "WITH" } +func (w WithClause) Children() []Node { + children := make([]Node, len(w.CTEs)) + for i, cte := range w.CTEs { + children[i] = cte + } + return children +} + +// CommonTableExpr represents a single Common Table Expression in a WITH clause. +// It supports optional column specifications and any statement type as the CTE query. +// Phase 2 Complete: Full parser support with column specifications. +// Phase 2.6: Added MATERIALIZED/NOT MATERIALIZED support for query optimization hints. +type CommonTableExpr struct { + Name string + Columns []string + Statement Statement + ScalarExpr Expression // ClickHouse: WITH AS (scalar CTE, no subquery) + Materialized *bool // nil = default, true = MATERIALIZED, false = NOT MATERIALIZED + Pos models.Location // Source position of the CTE name (1-based line and column) +} + +func (c *CommonTableExpr) statementNode() {} +func (c CommonTableExpr) TokenLiteral() string { return c.Name } +func (c CommonTableExpr) Children() []Node { + var nodes []Node + if c.Statement != nil { + nodes = append(nodes, c.Statement) + } + if c.ScalarExpr != nil { + nodes = append(nodes, c.ScalarExpr) + } + return nodes +} + +// JoinClause represents a JOIN clause in SQL +type JoinClause struct { + Type string // INNER, LEFT, RIGHT, FULL + Left TableReference + Right TableReference + Condition Expression + Pos models.Location // Source position of the JOIN keyword (1-based line and column) +} + +func (j *JoinClause) expressionNode() {} +func (j JoinClause) TokenLiteral() string { return j.Type + " JOIN" } +func (j JoinClause) Children() []Node { + children := []Node{&j.Left, &j.Right} + if j.Condition != nil { + children = append(children, j.Condition) + } + return children +} + +// TableReference represents a table reference in a FROM clause. +// +// TableReference can represent either a simple table name or a derived table +// (subquery). It supports PostgreSQL's LATERAL keyword for correlated subqueries. +// +// Fields: +// - Name: Table name (empty if this is a derived table/subquery) +// - Alias: Optional table alias (AS alias) +// - Subquery: Subquery for derived tables: (SELECT ...) AS alias +// - Lateral: LATERAL keyword for correlated subqueries (PostgreSQL v1.6.0) +// +// The Lateral field enables PostgreSQL's LATERAL JOIN feature, which allows +// subqueries in the FROM clause to reference columns from preceding tables. +// +// Example - Simple table reference: +// +// TableReference{ +// Name: "users", +// Alias: "u", +// } +// // SQL: FROM users u +// +// Example - Derived table (subquery): +// +// TableReference{ +// Alias: "recent_orders", +// Subquery: selectStmt, +// } +// // SQL: FROM (SELECT ...) AS recent_orders +// +// Example - LATERAL JOIN (PostgreSQL v1.6.0): +// +// TableReference{ +// Lateral: true, +// Alias: "r", +// Subquery: correlatedSelectStmt, +// } +// // SQL: FROM users u, LATERAL (SELECT * FROM orders WHERE user_id = u.id) r +// +// New in v1.6.0: Lateral field for PostgreSQL LATERAL JOIN support. +type TableReference struct { + Name string // Table name (empty if this is a derived table) + Alias string // Optional alias + Subquery *SelectStatement // For derived tables: (SELECT ...) AS alias + Lateral bool // LATERAL keyword for correlated subqueries (PostgreSQL) + TableHints []string // SQL Server table hints: WITH (NOLOCK), WITH (ROWLOCK, UPDLOCK), etc. + Final bool // ClickHouse FINAL modifier: forces MergeTree part merge + // TableFunc is a function-call table reference such as + // Snowflake LATERAL FLATTEN(input => col), TABLE(my_func(1,2)), + // IDENTIFIER('t'), or PostgreSQL unnest(array_col). When set, Name + // holds the function name and TableFunc carries the call itself. + TableFunc *FunctionCall + // TimeTravel is the Snowflake time-travel clause applied to this table + // reference: AT / BEFORE (TIMESTAMP|OFFSET|STATEMENT => expr) or + // CHANGES (INFORMATION => DEFAULT|APPEND_ONLY). + TimeTravel *TimeTravelClause + // ForSystemTime is the MariaDB temporal table clause (10.3.4+). + // Example: SELECT * FROM t FOR SYSTEM_TIME AS OF '2024-01-01' + ForSystemTime *ForSystemTimeClause // MariaDB temporal query + // Pivot is the SQL Server / Oracle PIVOT clause for row-to-column transformation. + // Example: SELECT * FROM t PIVOT (SUM(sales) FOR region IN ([North], [South])) AS pvt + Pivot *PivotClause + // Unpivot is the SQL Server / Oracle UNPIVOT clause for column-to-row transformation. + // Example: SELECT * FROM t UNPIVOT (sales FOR region IN (north_sales, south_sales)) AS unpvt + Unpivot *UnpivotClause + // MatchRecognize is the SQL:2016 row-pattern recognition clause (Snowflake, Oracle). + MatchRecognize *MatchRecognizeClause +} + +func (t *TableReference) statementNode() {} +func (t TableReference) TokenLiteral() string { + if t.Name != "" { + return t.Name + } + if t.Alias != "" { + return t.Alias + } + return "subquery" +} +func (t TableReference) Children() []Node { + var nodes []Node + if t.Subquery != nil { + nodes = append(nodes, t.Subquery) + } + if t.TableFunc != nil { + nodes = append(nodes, t.TableFunc) + } + if t.TimeTravel != nil { + nodes = append(nodes, t.TimeTravel) + } + if t.Pivot != nil { + nodes = append(nodes, t.Pivot) + } + if t.Unpivot != nil { + nodes = append(nodes, t.Unpivot) + } + if t.MatchRecognize != nil { + nodes = append(nodes, t.MatchRecognize) + } + return nodes +} + +// OrderByExpression represents an ORDER BY clause element with direction and NULL ordering +type OrderByExpression struct { + Expression Expression // The expression to order by + Ascending bool // true for ASC (default), false for DESC + NullsFirst *bool // nil = default behavior, true = NULLS FIRST, false = NULLS LAST +} + +func (*OrderByExpression) expressionNode() {} +func (o *OrderByExpression) TokenLiteral() string { return "ORDER BY" } +func (o *OrderByExpression) Children() []Node { + if o.Expression != nil { + return []Node{o.Expression} + } + return nil +} + +// WindowSpec represents a window specification +type WindowSpec struct { + Name string + PartitionBy []Expression + OrderBy []OrderByExpression + FrameClause *WindowFrame +} + +func (w *WindowSpec) statementNode() {} +func (w WindowSpec) TokenLiteral() string { return "WINDOW" } +func (w WindowSpec) Children() []Node { + children := make([]Node, 0) + children = append(children, nodifyExpressions(w.PartitionBy)...) + for _, orderBy := range w.OrderBy { + orderBy := orderBy // G601: Create local copy to avoid memory aliasing + children = append(children, &orderBy) + } + if w.FrameClause != nil { + children = append(children, w.FrameClause) + } + return children +} + +// WindowFrame represents window frame clause +type WindowFrame struct { + Type string // ROWS, RANGE + Start WindowFrameBound + End *WindowFrameBound +} + +func (w *WindowFrame) statementNode() {} +func (w WindowFrame) TokenLiteral() string { return w.Type } +func (w WindowFrame) Children() []Node { + // Start is a value type, always include it to support visitor traversal. + children := []Node{&w.Start} + if w.End != nil { + children = append(children, w.End) + } + return children +} + +// WindowFrameBound represents window frame bound +type WindowFrameBound struct { + Type string // CURRENT ROW, UNBOUNDED PRECEDING, etc. + Value Expression +} + +func (w *WindowFrameBound) expressionNode() {} +func (w WindowFrameBound) TokenLiteral() string { + if w.Type != "" { + return w.Type + } + return "BOUND" +} +func (w WindowFrameBound) Children() []Node { + if w.Value != nil { + return []Node{w.Value} + } + return nil +} + +// TopClause represents SQL Server's TOP N [PERCENT] clause +// Syntax: SELECT TOP n [PERCENT] columns... +// Count is an Expression to support TOP (10), TOP (@var), TOP (subquery) +type TopClause struct { + Count Expression // Number of rows (or percentage) as an expression + IsPercent bool // Whether PERCENT keyword was specified + WithTies bool // Whether WITH TIES was specified (SQL Server) +} + +func (t *TopClause) expressionNode() {} +func (t TopClause) TokenLiteral() string { return "TOP" } +func (t TopClause) Children() []Node { + if t.Count != nil { + return []Node{t.Count} + } + return nil +} + +// FetchClause represents the SQL-99 FETCH FIRST/NEXT clause (F861, F862) +// Syntax: [OFFSET n {ROW | ROWS}] FETCH {FIRST | NEXT} n [{ROW | ROWS}] {ONLY | WITH TIES} +// Examples: +// - OFFSET 20 ROWS FETCH NEXT 10 ROWS ONLY +// - FETCH FIRST 5 ROWS ONLY +// - FETCH FIRST 10 PERCENT ROWS WITH TIES +type FetchClause struct { + // OffsetValue is the number of rows to skip (OFFSET n ROWS) + OffsetValue *int64 + // FetchValue is the number of rows to fetch (FETCH n ROWS) + FetchValue *int64 + // FetchType is either "FIRST" or "NEXT" + FetchType string + // IsPercent indicates FETCH ... PERCENT ROWS + IsPercent bool + // WithTies indicates FETCH ... WITH TIES (includes tied rows) + WithTies bool +} + +func (f *FetchClause) expressionNode() {} +func (f FetchClause) TokenLiteral() string { return "FETCH" } +func (f FetchClause) Children() []Node { return nil } + +// ForClause represents row-level locking clauses in SELECT statements (SQL:2003, PostgreSQL, MySQL) +// Syntax: FOR {UPDATE | SHARE | NO KEY UPDATE | KEY SHARE} [OF table_name [, ...]] [NOWAIT | SKIP LOCKED] +// Examples: +// - FOR UPDATE +// - FOR SHARE NOWAIT +// - FOR UPDATE OF orders SKIP LOCKED +// - FOR NO KEY UPDATE +// - FOR KEY SHARE +type ForClause struct { + // LockType specifies the type of lock: + // "UPDATE" - exclusive lock for UPDATE operations + // "SHARE" - shared lock for read operations + // "NO KEY UPDATE" - PostgreSQL: exclusive lock that doesn't block SHARE locks on same row + // "KEY SHARE" - PostgreSQL: shared lock that doesn't block UPDATE locks + LockType string + // Tables specifies which tables to lock (FOR UPDATE OF table_name) + // Empty slice means lock all tables in the query + Tables []string + // NoWait indicates NOWAIT option (fail immediately if lock cannot be acquired) + NoWait bool + // SkipLocked indicates SKIP LOCKED option (skip rows that can't be locked) + SkipLocked bool +} + +func (f *ForClause) expressionNode() {} +func (f ForClause) TokenLiteral() string { return "FOR" } +func (f ForClause) Children() []Node { return nil } + +// OnConflict represents ON CONFLICT DO UPDATE/NOTHING clause +type OnConflict struct { + Target []Expression // Target columns + Constraint string // Optional constraint name + Action OnConflictAction +} + +func (o *OnConflict) expressionNode() {} +func (o OnConflict) TokenLiteral() string { return "ON CONFLICT" } +func (o OnConflict) Children() []Node { + children := nodifyExpressions(o.Target) + if o.Action.DoUpdate != nil { + for _, update := range o.Action.DoUpdate { + update := update // G601: Create local copy to avoid memory aliasing + children = append(children, &update) + } + } + if o.Action.Where != nil { + children = append(children, o.Action.Where) + } + return children +} + +// OnConflictAction represents DO UPDATE/NOTHING in ON CONFLICT clause +type OnConflictAction struct { + DoNothing bool + DoUpdate []UpdateExpression + Where Expression +} + +// UpsertClause represents INSERT ... ON DUPLICATE KEY UPDATE +type UpsertClause struct { + Updates []UpdateExpression +} + +func (u *UpsertClause) expressionNode() {} +func (u UpsertClause) TokenLiteral() string { return "ON DUPLICATE KEY UPDATE" } +func (u UpsertClause) Children() []Node { + children := make([]Node, len(u.Updates)) + for i, update := range u.Updates { + update := update // G601: Create local copy to avoid memory aliasing + children[i] = &update + } + return children +} + +// ColumnDef represents a column definition in CREATE TABLE +type ColumnDef struct { + Name string + Type string + Constraints []ColumnConstraint +} + +func (c *ColumnDef) expressionNode() {} +func (c ColumnDef) TokenLiteral() string { return c.Name } +func (c ColumnDef) Children() []Node { + children := make([]Node, len(c.Constraints)) + for i, constraint := range c.Constraints { + constraint := constraint // G601: Create local copy to avoid memory aliasing + children[i] = &constraint + } + return children +} + +// ColumnConstraint represents a column constraint +type ColumnConstraint struct { + Type string // NOT NULL, UNIQUE, PRIMARY KEY, etc. + Default Expression + References *ReferenceDefinition + Check Expression + AutoIncrement bool +} + +func (c *ColumnConstraint) expressionNode() {} +func (c ColumnConstraint) TokenLiteral() string { return c.Type } +func (c ColumnConstraint) Children() []Node { + children := make([]Node, 0) + if c.Default != nil { + children = append(children, c.Default) + } + if c.References != nil { + children = append(children, c.References) + } + if c.Check != nil { + children = append(children, c.Check) + } + return children +} + +// TableConstraint represents a table constraint +type TableConstraint struct { + Name string + Type string // PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK + Columns []string + References *ReferenceDefinition + Check Expression +} + +func (t *TableConstraint) expressionNode() {} +func (t TableConstraint) TokenLiteral() string { return t.Type } +func (t TableConstraint) Children() []Node { + children := make([]Node, 0) + if t.References != nil { + children = append(children, t.References) + } + if t.Check != nil { + children = append(children, t.Check) + } + return children +} + +// ReferenceDefinition represents a REFERENCES clause +type ReferenceDefinition struct { + Table string + Columns []string + OnDelete string + OnUpdate string + Match string +} + +func (r *ReferenceDefinition) expressionNode() {} +func (r ReferenceDefinition) TokenLiteral() string { return "REFERENCES" } +func (r ReferenceDefinition) Children() []Node { return nil } + +// PartitionBy represents a PARTITION BY clause +type PartitionBy struct { + Type string // RANGE, LIST, HASH + Columns []string + Boundary []Expression +} + +func (p *PartitionBy) expressionNode() {} +func (p PartitionBy) TokenLiteral() string { return "PARTITION BY" } +func (p PartitionBy) Children() []Node { return nodifyExpressions(p.Boundary) } + +// TableOption represents table options like ENGINE, CHARSET, etc. +type TableOption struct { + Name string + Value string +} + +func (t *TableOption) expressionNode() {} +func (t TableOption) TokenLiteral() string { return t.Name } +func (t TableOption) Children() []Node { return nil } + +// IndexColumn represents a column in an index definition +type IndexColumn struct { + Column string + Collate string + Direction string // ASC, DESC + NullsLast bool +} + +func (i *IndexColumn) expressionNode() {} +func (i IndexColumn) TokenLiteral() string { return i.Column } +func (i IndexColumn) Children() []Node { return nil } + +// PartitionDefinition represents a partition definition in CREATE TABLE +// Syntax: PARTITION name VALUES { LESS THAN (expr) | IN (list) | FROM (expr) TO (expr) } +type PartitionDefinition struct { + Name string + Type string // FOR VALUES, IN, LESS THAN + Values []Expression // Partition values or bounds + LessThan Expression // For RANGE: LESS THAN (value) + From Expression // For RANGE: FROM (value) + To Expression // For RANGE: TO (value) + InValues []Expression // For LIST: IN (values) + Tablespace string // Optional tablespace +} + +func (p *PartitionDefinition) expressionNode() {} +func (p PartitionDefinition) TokenLiteral() string { return "PARTITION " + p.Name } +func (p PartitionDefinition) Children() []Node { + children := make([]Node, 0) + for _, v := range p.Values { + children = append(children, v) + } + if p.LessThan != nil { + children = append(children, p.LessThan) + } + if p.From != nil { + children = append(children, p.From) + } + if p.To != nil { + children = append(children, p.To) + } + for _, v := range p.InValues { + children = append(children, v) + } + return children +} + +// ── MariaDB Temporal Table Types (10.3.4+) ──────────────────────────────── + +// SystemTimeClauseType identifies the kind of FOR SYSTEM_TIME clause. +type SystemTimeClauseType int + +const ( + SystemTimeAsOf SystemTimeClauseType = iota // FOR SYSTEM_TIME AS OF + SystemTimeBetween // FOR SYSTEM_TIME BETWEEN AND + SystemTimeFromTo // FOR SYSTEM_TIME FROM TO + SystemTimeAll // FOR SYSTEM_TIME ALL +) + +// ForSystemTimeClause represents a temporal query on a system-versioned table. +// +// SELECT * FROM t FOR SYSTEM_TIME AS OF TIMESTAMP '2024-01-01'; +// SELECT * FROM t FOR SYSTEM_TIME BETWEEN '2020-01-01' AND '2024-01-01'; +// SELECT * FROM t FOR SYSTEM_TIME ALL; +type ForSystemTimeClause struct { + Type SystemTimeClauseType + Point Expression // used for AS OF + Start Expression // used for BETWEEN, FROM + End Expression // used for BETWEEN (AND), TO + Pos models.Location // Source position of the FOR keyword (1-based line and column) +} + +// expressionNode satisfies the Expression interface so ForSystemTimeClause can be +// stored in TableReference.ForSystemTime without a separate interface type. +// Semantically it is a table-level clause, not a scalar expression. +func (c *ForSystemTimeClause) expressionNode() {} +func (c ForSystemTimeClause) TokenLiteral() string { return "FOR SYSTEM_TIME" } +func (c ForSystemTimeClause) Children() []Node { + var nodes []Node + if c.Point != nil { + nodes = append(nodes, c.Point) + } + if c.Start != nil { + nodes = append(nodes, c.Start) + } + if c.End != nil { + nodes = append(nodes, c.End) + } + return nodes +} + +// TimeTravelClause represents the Snowflake time-travel / change-tracking +// modifier on a table reference: +// +// SELECT ... FROM t AT (TIMESTAMP => '2024-01-01'::TIMESTAMP) +// SELECT ... FROM t BEFORE (STATEMENT => '...uuid...') +// SELECT ... FROM t CHANGES (INFORMATION => DEFAULT) AT (...) +// +// Kind is one of "AT", "BEFORE", "CHANGES". Named holds the +// `name => expr` arguments keyed by upper-cased name (e.g. TIMESTAMP, +// OFFSET, STATEMENT, INFORMATION). Multiple clauses may chain (CHANGES +// plus AT); extra clauses are appended to Chained. +type TimeTravelClause struct { + Kind string // "AT" | "BEFORE" | "CHANGES" + Named map[string]Expression + Chained []*TimeTravelClause + Pos models.Location +} + +func (c *TimeTravelClause) expressionNode() {} +func (c TimeTravelClause) TokenLiteral() string { return c.Kind } +func (c TimeTravelClause) Children() []Node { + var nodes []Node + for _, v := range c.Named { + if v != nil { + nodes = append(nodes, v) + } + } + for _, ch := range c.Chained { + if ch != nil { + nodes = append(nodes, ch) + } + } + return nodes +} + +// PivotClause represents the SQL Server / Oracle PIVOT operator for row-to-column +// transformation in a FROM clause. +// +// PIVOT (SUM(sales) FOR region IN ([North], [South], [East], [West])) AS pvt +type PivotClause struct { + AggregateFunction Expression // The aggregate function, e.g. SUM(sales) + PivotColumn string // The column used for pivoting, e.g. region + InValues []string // The values to pivot on, e.g. [North], [South] + Pos models.Location // Source position of the PIVOT keyword +} + +func (p *PivotClause) expressionNode() {} +func (p PivotClause) TokenLiteral() string { return "PIVOT" } +func (p PivotClause) Children() []Node { + if p.AggregateFunction != nil { + return []Node{p.AggregateFunction} + } + return nil +} + +// UnpivotClause represents the SQL Server / Oracle UNPIVOT operator for column-to-row +// transformation in a FROM clause. +// +// UNPIVOT (sales FOR region IN (north_sales, south_sales, east_sales)) AS unpvt +type UnpivotClause struct { + ValueColumn string // The target value column, e.g. sales + NameColumn string // The target name column, e.g. region + InColumns []string // The source columns to unpivot, e.g. north_sales, south_sales + Pos models.Location // Source position of the UNPIVOT keyword +} + +func (u *UnpivotClause) expressionNode() {} +func (u UnpivotClause) TokenLiteral() string { return "UNPIVOT" } +func (u UnpivotClause) Children() []Node { return nil } + +// PeriodDefinition represents a PERIOD FOR clause in CREATE TABLE. +// +// PERIOD FOR app_time (start_col, end_col) +// PERIOD FOR SYSTEM_TIME (row_start, row_end) +type PeriodDefinition struct { + Name *Identifier // period name (e.g., "app_time") or SYSTEM_TIME + StartCol *Identifier + EndCol *Identifier + Pos models.Location // Source position of the PERIOD FOR keyword (1-based line and column) +} + +// MatchRecognizeClause represents the SQL:2016 MATCH_RECOGNIZE clause for +// row-pattern recognition in a FROM clause (Snowflake, Oracle, Databricks). +// +// MATCH_RECOGNIZE ( +// PARTITION BY symbol +// ORDER BY ts +// MEASURES MATCH_NUMBER() AS m +// ALL ROWS PER MATCH +// PATTERN (UP+ DOWN+) +// DEFINE UP AS price > PREV(price), DOWN AS price < PREV(price) +// ) +type MatchRecognizeClause struct { + PartitionBy []Expression + OrderBy []OrderByExpression + Measures []MeasureDef + RowsPerMatch string // "ONE ROW PER MATCH" or "ALL ROWS PER MATCH" (empty = default) + AfterMatch string // raw text: "SKIP TO NEXT ROW", "SKIP PAST LAST ROW", etc. + Pattern string // raw pattern text: "UP+ DOWN+" + Definitions []PatternDef + Pos models.Location +} + +// MeasureDef is one MEASURES entry: expr AS alias. +type MeasureDef struct { + Expr Expression + Alias string +} + +// PatternDef is one DEFINE entry: variable_name AS boolean_condition. +type PatternDef struct { + Name string + Condition Expression +} + +func (m *MatchRecognizeClause) expressionNode() {} +func (m MatchRecognizeClause) TokenLiteral() string { return "MATCH_RECOGNIZE" } +func (m MatchRecognizeClause) Children() []Node { + var nodes []Node + nodes = append(nodes, nodifyExpressions(m.PartitionBy)...) + for _, ob := range m.OrderBy { + ob := ob + nodes = append(nodes, &ob) + } + for _, md := range m.Measures { + if md.Expr != nil { + nodes = append(nodes, md.Expr) + } + } + for _, pd := range m.Definitions { + if pd.Condition != nil { + nodes = append(nodes, pd.Condition) + } + } + return nodes +} + +// expressionNode satisfies the Expression interface so PeriodDefinition can be +// stored in CreateTableStatement.PeriodDefinitions without a separate interface type. +// Semantically it is a table column constraint, not a scalar expression. +func (p *PeriodDefinition) expressionNode() {} +func (p PeriodDefinition) TokenLiteral() string { return "PERIOD FOR" } +func (p PeriodDefinition) Children() []Node { + var nodes []Node + if p.Name != nil { + nodes = append(nodes, p.Name) + } + if p.StartCol != nil { + nodes = append(nodes, p.StartCol) + } + if p.EndCol != nil { + nodes = append(nodes, p.EndCol) + } + return nodes +} + +// ── MariaDB Hierarchical Query / CONNECT BY (10.2+) ─────────────────────── + +// ConnectByClause represents the CONNECT BY hierarchical query clause (MariaDB 10.2+). +// +// SELECT id, name FROM t +// START WITH parent_id IS NULL +// CONNECT BY NOCYCLE PRIOR id = parent_id; +type ConnectByClause struct { + NoCycle bool // NOCYCLE modifier — prevents loops in cyclic graphs + Condition Expression // the PRIOR expression (e.g., PRIOR id = parent_id) + Pos models.Location // Source position of the CONNECT BY keyword (1-based line and column) +} + +// expressionNode satisfies the Expression interface so ConnectByClause can be +// stored in SelectStatement.ConnectBy without a separate interface type. +// Semantically it is a query-level clause, not a scalar expression. +func (c *ConnectByClause) expressionNode() {} +func (c ConnectByClause) TokenLiteral() string { return "CONNECT BY" } +func (c ConnectByClause) Children() []Node { + if c.Condition != nil { + return []Node{c.Condition} + } + return nil +} + +// SampleClause represents a ClickHouse SAMPLE clause on a SELECT statement. +// +// ClickHouse supports three sampling forms: +// +// SAMPLE 0.1 — ratio (10% of data) +// SAMPLE 1000 — approximate row count +// SAMPLE 1/10 — fraction (1 part out of 10) +// SAMPLE 1/10 OFFSET 2/10 — fraction with offset +// +// The clause is dialect-specific to ClickHouse (and partly Snowflake/Redshift +// via TABLESAMPLE, but this implementation targets SAMPLE). +// Value is stored as a raw string to preserve the original representation +// (e.g., "0.1", "1000", "1/10"). +// ArrayJoinClause represents a ClickHouse ARRAY JOIN or LEFT ARRAY JOIN clause. +// Syntax: [LEFT] ARRAY JOIN expr [AS alias], expr [AS alias], ... +type ArrayJoinClause struct { + Left bool // true for LEFT ARRAY JOIN + Elements []ArrayJoinElement // One or more join elements + Pos models.Location +} + +// ArrayJoinElement is a single expression in an ARRAY JOIN clause with an optional alias. +type ArrayJoinElement struct { + Expr Expression + Alias string +} + +type SampleClause struct { + // Value is the sampling size/ratio as a raw token string (e.g., "0.1", "1000", "1/10"). + Value string + // Denominator is set when the fraction form "N/D" is used (denominator part). + Denominator string + // Offset is the optional OFFSET fraction (e.g., "2/10" in SAMPLE 1/10 OFFSET 2/10). + Offset string + // OffsetDenominator is set for fractional offsets. + OffsetDenominator string + Pos models.Location +} + +func (s *SampleClause) expressionNode() {} +func (s SampleClause) TokenLiteral() string { return "SAMPLE" } +func (s SampleClause) Children() []Node { return nil } diff --git a/pkg/sql/ast/ast_expressions.go b/pkg/sql/ast/ast_expressions.go new file mode 100644 index 00000000..1c68cc4e --- /dev/null +++ b/pkg/sql/ast/ast_expressions.go @@ -0,0 +1,773 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import ( + "github.com/ajitpratap0/GoSQLX/pkg/models" +) + +// RollupExpression represents ROLLUP(col1, col2, ...) in GROUP BY clause +// ROLLUP generates hierarchical grouping sets from right to left +// Example: ROLLUP(a, b, c) generates grouping sets: +// +// (a, b, c), (a, b), (a), () +type RollupExpression struct { + Expressions []Expression +} + +func (r *RollupExpression) expressionNode() {} +func (r RollupExpression) TokenLiteral() string { return "ROLLUP" } +func (r RollupExpression) Children() []Node { return nodifyExpressions(r.Expressions) } + +// CubeExpression represents CUBE(col1, col2, ...) in GROUP BY clause +// CUBE generates all possible combinations of grouping sets +// Example: CUBE(a, b) generates grouping sets: +// +// (a, b), (a), (b), () +type CubeExpression struct { + Expressions []Expression +} + +func (c *CubeExpression) expressionNode() {} +func (c CubeExpression) TokenLiteral() string { return "CUBE" } +func (c CubeExpression) Children() []Node { return nodifyExpressions(c.Expressions) } + +// GroupingSetsExpression represents GROUPING SETS(...) in GROUP BY clause +// Allows explicit specification of grouping sets +// Example: GROUPING SETS((a, b), (a), ()) +type GroupingSetsExpression struct { + Sets [][]Expression // Each inner slice is one grouping set +} + +func (g *GroupingSetsExpression) expressionNode() {} +func (g GroupingSetsExpression) TokenLiteral() string { return "GROUPING SETS" } +func (g GroupingSetsExpression) Children() []Node { + children := make([]Node, 0) + for _, set := range g.Sets { + children = append(children, nodifyExpressions(set)...) + } + return children +} + +// FunctionCall represents a function call expression with full SQL-99/PostgreSQL support. +// +// FunctionCall supports: +// - Scalar functions: UPPER(), LOWER(), COALESCE(), etc. +// - Aggregate functions: COUNT(), SUM(), AVG(), MAX(), MIN(), etc. +// - Window functions: ROW_NUMBER(), RANK(), DENSE_RANK(), LAG(), LEAD(), etc. +// - DISTINCT modifier: COUNT(DISTINCT column) +// - FILTER clause: Conditional aggregation (PostgreSQL v1.6.0) +// - ORDER BY clause: For order-sensitive aggregates like STRING_AGG, ARRAY_AGG (v1.6.0) +// - OVER clause: Window specifications for window functions +// +// Fields: +// - Name: Function name (e.g., "COUNT", "SUM", "ROW_NUMBER") +// - Arguments: Function arguments (expressions) +// - Over: Window specification for window functions (OVER clause) +// - Distinct: DISTINCT modifier for aggregates (COUNT(DISTINCT col)) +// - Filter: FILTER clause for conditional aggregation (PostgreSQL v1.6.0) +// - OrderBy: ORDER BY clause for order-sensitive aggregates (v1.6.0) +// +// Example - Basic aggregate: +// +// FunctionCall{ +// Name: "COUNT", +// Arguments: []Expression{&Identifier{Name: "id"}}, +// } +// // SQL: COUNT(id) +// +// Example - Window function: +// +// FunctionCall{ +// Name: "ROW_NUMBER", +// Over: &WindowSpec{ +// PartitionBy: []Expression{&Identifier{Name: "dept_id"}}, +// OrderBy: []OrderByExpression{{Expression: &Identifier{Name: "salary"}, Ascending: false}}, +// }, +// } +// // SQL: ROW_NUMBER() OVER (PARTITION BY dept_id ORDER BY salary DESC) +// +// Example - FILTER clause (PostgreSQL v1.6.0): +// +// FunctionCall{ +// Name: "COUNT", +// Arguments: []Expression{&Identifier{Name: "id"}}, +// Filter: &BinaryExpression{Left: &Identifier{Name: "status"}, Operator: "=", Right: &LiteralValue{Value: "active"}}, +// } +// // SQL: COUNT(id) FILTER (WHERE status = 'active') +// +// Example - ORDER BY in aggregate (PostgreSQL v1.6.0): +// +// FunctionCall{ +// Name: "STRING_AGG", +// Arguments: []Expression{&Identifier{Name: "name"}, &LiteralValue{Value: ", "}}, +// OrderBy: []OrderByExpression{{Expression: &Identifier{Name: "name"}, Ascending: true}}, +// } +// // SQL: STRING_AGG(name, ', ' ORDER BY name) +// +// Example - Window function with frame: +// +// FunctionCall{ +// Name: "AVG", +// Arguments: []Expression{&Identifier{Name: "amount"}}, +// Over: &WindowSpec{ +// OrderBy: []OrderByExpression{{Expression: &Identifier{Name: "date"}, Ascending: true}}, +// FrameClause: &WindowFrame{ +// Type: "ROWS", +// Start: WindowFrameBound{Type: "2 PRECEDING"}, +// End: &WindowFrameBound{Type: "CURRENT ROW"}, +// }, +// }, +// } +// // SQL: AVG(amount) OVER (ORDER BY date ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) +// +// New in v1.6.0: +// - Filter: FILTER clause for conditional aggregation +// - OrderBy: ORDER BY clause for order-sensitive aggregates (STRING_AGG, ARRAY_AGG, etc.) +// - WithinGroup: ORDER BY clause for ordered-set aggregates (PERCENTILE_CONT, PERCENTILE_DISC, MODE, etc.) +type FunctionCall struct { + Name string + Arguments []Expression // Renamed from Args for consistency + Parameters []Expression // ClickHouse parametric aggregates: quantile(0.5)(x) — params before args + Over *WindowSpec // For window functions + Distinct bool + Filter Expression // WHERE clause for aggregate functions + OrderBy []OrderByExpression // ORDER BY clause for aggregate functions (STRING_AGG, ARRAY_AGG, etc.) + WithinGroup []OrderByExpression // ORDER BY clause for ordered-set aggregates (PERCENTILE_CONT, etc.) + NullTreatment string // "IGNORE NULLS" or "RESPECT NULLS" on window functions (Snowflake, Oracle, BigQuery, SQL:2016) + Pos models.Location // Source position of the function name (1-based line and column) +} + +func (f *FunctionCall) expressionNode() {} +func (f FunctionCall) TokenLiteral() string { return f.Name } +func (f FunctionCall) Children() []Node { + children := nodifyExpressions(f.Arguments) + if len(f.Parameters) > 0 { + children = append(children, nodifyExpressions(f.Parameters)...) + } + if f.Over != nil { + children = append(children, f.Over) + } + if f.Filter != nil { + children = append(children, f.Filter) + } + for _, orderBy := range f.OrderBy { + orderBy := orderBy // G601: Create local copy to avoid memory aliasing + children = append(children, &orderBy) + } + for _, orderBy := range f.WithinGroup { + orderBy := orderBy // G601: Create local copy to avoid memory aliasing + children = append(children, &orderBy) + } + return children +} + +// CaseExpression represents a CASE expression +type CaseExpression struct { + Value Expression // Optional CASE value + WhenClauses []WhenClause + ElseClause Expression + Pos models.Location // Source position of the CASE keyword (1-based line and column) +} + +func (c *CaseExpression) expressionNode() {} +func (c CaseExpression) TokenLiteral() string { return "CASE" } +func (c CaseExpression) Children() []Node { + children := make([]Node, 0) + if c.Value != nil { + children = append(children, c.Value) + } + for _, when := range c.WhenClauses { + when := when // G601: Create local copy to avoid memory aliasing + children = append(children, &when) + } + if c.ElseClause != nil { + children = append(children, c.ElseClause) + } + return children +} + +// WhenClause represents WHEN ... THEN ... in CASE expression +type WhenClause struct { + Condition Expression + Result Expression + Pos models.Location // Source position of the WHEN keyword (1-based line and column) +} + +func (w *WhenClause) expressionNode() {} +func (w WhenClause) TokenLiteral() string { return "WHEN" } +func (w WhenClause) Children() []Node { + var nodes []Node + if w.Condition != nil { + nodes = append(nodes, w.Condition) + } + if w.Result != nil { + nodes = append(nodes, w.Result) + } + return nodes +} + +// ExistsExpression represents EXISTS (subquery) +type ExistsExpression struct { + Subquery Statement +} + +func (e *ExistsExpression) expressionNode() {} +func (e ExistsExpression) TokenLiteral() string { return "EXISTS" } +func (e ExistsExpression) Children() []Node { + if e.Subquery == nil { + return nil + } + return []Node{e.Subquery} +} + +// InExpression represents expr IN (values) or expr IN (subquery) +type InExpression struct { + Expr Expression + List []Expression // For value list: IN (1, 2, 3) + Subquery Statement // For subquery: IN (SELECT ...) + Not bool + Pos models.Location // Source position of the IN keyword (1-based line and column) +} + +func (i *InExpression) expressionNode() {} +func (i InExpression) TokenLiteral() string { return "IN" } +func (i InExpression) Children() []Node { + var children []Node + if i.Expr != nil { + children = append(children, i.Expr) + } + if i.Subquery != nil { + children = append(children, i.Subquery) + } + children = append(children, nodifyExpressions(i.List)...) + return children +} + +// SubqueryExpression represents a scalar subquery (SELECT ...) +type SubqueryExpression struct { + Subquery Statement + Pos models.Location // Source position of the opening parenthesis (1-based line and column) +} + +func (s *SubqueryExpression) expressionNode() {} +func (s SubqueryExpression) TokenLiteral() string { return "SUBQUERY" } +func (s SubqueryExpression) Children() []Node { + if s.Subquery == nil { + return nil + } + return []Node{s.Subquery} +} + +// AnyExpression represents expr op ANY (subquery) +type AnyExpression struct { + Expr Expression + Operator string + Subquery Statement +} + +func (a *AnyExpression) expressionNode() {} +func (a AnyExpression) TokenLiteral() string { return "ANY" } +func (a AnyExpression) Children() []Node { + var nodes []Node + if a.Expr != nil { + nodes = append(nodes, a.Expr) + } + if a.Subquery != nil { + nodes = append(nodes, a.Subquery) + } + return nodes +} + +// AllExpression represents expr op ALL (subquery) +type AllExpression struct { + Expr Expression + Operator string + Subquery Statement +} + +func (al *AllExpression) expressionNode() {} +func (al AllExpression) TokenLiteral() string { return "ALL" } +func (al AllExpression) Children() []Node { + var nodes []Node + if al.Expr != nil { + nodes = append(nodes, al.Expr) + } + if al.Subquery != nil { + nodes = append(nodes, al.Subquery) + } + return nodes +} + +// BetweenExpression represents expr BETWEEN lower AND upper +type BetweenExpression struct { + Expr Expression + Lower Expression + Upper Expression + Not bool + Pos models.Location // Source position of the BETWEEN keyword (1-based line and column) +} + +func (b *BetweenExpression) expressionNode() {} +func (b BetweenExpression) TokenLiteral() string { return "BETWEEN" } +func (b BetweenExpression) Children() []Node { + var nodes []Node + if b.Expr != nil { + nodes = append(nodes, b.Expr) + } + if b.Lower != nil { + nodes = append(nodes, b.Lower) + } + if b.Upper != nil { + nodes = append(nodes, b.Upper) + } + return nodes +} + +// BinaryExpression represents binary operations between two expressions. +// +// BinaryExpression supports all standard SQL binary operators plus PostgreSQL-specific +// operators including JSON/JSONB operators added in v1.6.0. +// +// Fields: +// - Left: Left-hand side expression +// - Operator: Binary operator (=, <, >, +, -, *, /, AND, OR, ->, #>, etc.) +// - Right: Right-hand side expression +// - Not: NOT modifier for negation (NOT expr) +// - CustomOp: PostgreSQL custom operators (OPERATOR(schema.name)) +// +// Supported Operator Categories: +// - Comparison: =, <>, <, >, <=, >=, <=> (spaceship) +// - Arithmetic: +, -, *, /, %, DIV, // (integer division) +// - Logical: AND, OR, XOR +// - String: || (concatenation) +// - Bitwise: &, |, ^, <<, >> (shifts) +// - Pattern: LIKE, ILIKE, SIMILAR TO +// - Range: OVERLAPS +// - PostgreSQL JSON/JSONB (v1.6.0): ->, ->>, #>, #>>, @>, <@, ?, ?|, ?&, #- +// +// Example - Basic comparison: +// +// BinaryExpression{ +// Left: &Identifier{Name: "age"}, +// Operator: ">", +// Right: &LiteralValue{Value: 18, Type: "INTEGER"}, +// } +// // SQL: age > 18 +// +// Example - Logical AND: +// +// BinaryExpression{ +// Left: &BinaryExpression{ +// Left: &Identifier{Name: "active"}, +// Operator: "=", +// Right: &LiteralValue{Value: true, Type: "BOOLEAN"}, +// }, +// Operator: "AND", +// Right: &BinaryExpression{ +// Left: &Identifier{Name: "status"}, +// Operator: "=", +// Right: &LiteralValue{Value: "pending", Type: "STRING"}, +// }, +// } +// // SQL: active = true AND status = 'pending' +// +// Example - PostgreSQL JSON operator -> (v1.6.0): +// +// BinaryExpression{ +// Left: &Identifier{Name: "data"}, +// Operator: "->", +// Right: &LiteralValue{Value: "name", Type: "STRING"}, +// } +// // SQL: data->'name' +// +// Example - PostgreSQL JSON operator ->> (v1.6.0): +// +// BinaryExpression{ +// Left: &Identifier{Name: "data"}, +// Operator: "->>", +// Right: &LiteralValue{Value: "email", Type: "STRING"}, +// } +// // SQL: data->>'email' (returns text) +// +// Example - PostgreSQL JSON contains @> (v1.6.0): +// +// BinaryExpression{ +// Left: &Identifier{Name: "attributes"}, +// Operator: "@>", +// Right: &LiteralValue{Value: `{"color": "red"}`, Type: "STRING"}, +// } +// // SQL: attributes @> '{"color": "red"}' +// +// Example - PostgreSQL JSON key exists ? (v1.6.0): +// +// BinaryExpression{ +// Left: &Identifier{Name: "profile"}, +// Operator: "?", +// Right: &LiteralValue{Value: "email", Type: "STRING"}, +// } +// // SQL: profile ? 'email' +// +// Example - Custom PostgreSQL operator: +// +// BinaryExpression{ +// Left: &Identifier{Name: "point1"}, +// Operator: "", +// Right: &Identifier{Name: "point2"}, +// CustomOp: &CustomBinaryOperator{Parts: []string{"pg_catalog", "<->"}}, +// } +// // SQL: point1 OPERATOR(pg_catalog.<->) point2 +// +// New in v1.6.0: +// - JSON/JSONB operators: ->, ->>, #>, #>>, @>, <@, ?, ?|, ?&, #- +// - CustomOp field for PostgreSQL custom operators +// +// PostgreSQL JSON/JSONB Operator Reference: +// - -> (Arrow): Extract JSON field or array element (returns JSON) +// - ->> (LongArrow): Extract JSON field or array element as text +// - #> (HashArrow): Extract JSON at path (returns JSON) +// - #>> (HashLongArrow): Extract JSON at path as text +// - @> (AtArrow): JSON contains (does left JSON contain right?) +// - <@ (ArrowAt): JSON is contained by (is left JSON contained in right?) +// - ? (Question): JSON key exists +// - ?| (QuestionPipe): Any of the keys exist +// - ?& (QuestionAnd): All of the keys exist +// - #- (HashMinus): Delete key from JSON +type BinaryExpression struct { + Left Expression + Operator string + Right Expression + Not bool // For NOT (expr) + CustomOp *CustomBinaryOperator // For PostgreSQL custom operators + Pos models.Location // Source position of the operator (1-based line and column) +} + +func (b *BinaryExpression) expressionNode() {} + +func (b *BinaryExpression) TokenLiteral() string { + if b.CustomOp != nil { + return b.CustomOp.String() + } + return b.Operator +} + +func (b BinaryExpression) Children() []Node { + var nodes []Node + if b.Left != nil { + nodes = append(nodes, b.Left) + } + if b.Right != nil { + nodes = append(nodes, b.Right) + } + return nodes +} + +// ListExpression represents a list of expressions (1, 2, 3) +type ListExpression struct { + Values []Expression +} + +func (l *ListExpression) expressionNode() {} +func (l ListExpression) TokenLiteral() string { return "LIST" } +func (l ListExpression) Children() []Node { return nodifyExpressions(l.Values) } + +// TupleExpression represents a row constructor / tuple (col1, col2) for multi-column comparisons +// Used in: WHERE (user_id, status) IN ((1, 'active'), (2, 'pending')) +type TupleExpression struct { + Expressions []Expression +} + +func (t *TupleExpression) expressionNode() {} +func (t TupleExpression) TokenLiteral() string { return "TUPLE" } +func (t TupleExpression) Children() []Node { return nodifyExpressions(t.Expressions) } + +// ArrayConstructorExpression represents PostgreSQL ARRAY constructor syntax. +// Creates an array value from a list of expressions or a subquery. +// +// Examples: +// +// ARRAY[1, 2, 3] - Integer array literal +// ARRAY['admin', 'moderator'] - Text array literal +// ARRAY(SELECT id FROM users) - Array from subquery +type ArrayConstructorExpression struct { + Elements []Expression // Elements inside ARRAY[...] + Subquery *SelectStatement // For ARRAY(SELECT ...) syntax (optional) +} + +func (a *ArrayConstructorExpression) expressionNode() {} +func (a ArrayConstructorExpression) TokenLiteral() string { return "ARRAY" } +func (a ArrayConstructorExpression) Children() []Node { + if a.Subquery != nil { + return []Node{a.Subquery} + } + return nodifyExpressions(a.Elements) +} + +// UnaryExpression represents operations like NOT expr +type UnaryExpression struct { + Operator UnaryOperator + Expr Expression + Pos models.Location // Source position of the operator (1-based line and column) +} + +func (u *UnaryExpression) expressionNode() {} + +func (u *UnaryExpression) TokenLiteral() string { + return u.Operator.String() +} + +func (u UnaryExpression) Children() []Node { + if u.Expr == nil { + return nil + } + return []Node{u.Expr} +} + +// VariantPath represents a Snowflake VARIANT path expression: +// +// col:field.sub[0]::string +// +// The Root is the base expression (typically an Identifier or FunctionCall +// like PARSE_JSON(raw)). Segments is the chain of path steps that follow +// the leading `:`. Each segment is either a field name (Name set) or a +// bracketed index expression (Index set). +type VariantPath struct { + Root Expression + Segments []VariantPathSegment + Pos models.Location +} + +// VariantPathSegment is one step in a VARIANT path: either a field name +// reached via `:` or `.`, or a bracketed index expression. +type VariantPathSegment struct { + Name string // field name (`:field` or `.field`), empty when Index is set + Index Expression // bracket subscript (`[expr]`), nil when Name is set +} + +func (v *VariantPath) expressionNode() {} +func (v VariantPath) TokenLiteral() string { return ":" } +func (v VariantPath) Children() []Node { + var nodes []Node + if v.Root != nil { + nodes = append(nodes, v.Root) + } + for _, seg := range v.Segments { + if seg.Index != nil { + nodes = append(nodes, seg.Index) + } + } + return nodes +} + +// NamedArgument represents a function argument of the form `name => expr`, +// used by Snowflake (FLATTEN(input => col), GENERATOR(rowcount => 100)), +// BigQuery, Oracle, and PostgreSQL procedural calls. +type NamedArgument struct { + Name string + Value Expression + Pos models.Location +} + +func (n *NamedArgument) expressionNode() {} +func (n NamedArgument) TokenLiteral() string { return n.Name } +func (n NamedArgument) Children() []Node { + if n.Value == nil { + return nil + } + return []Node{n.Value} +} + +// CastExpression represents CAST(expr AS type) or TRY_CAST(expr AS type). +// Try is set when the expression originated from TRY_CAST (Snowflake / SQL +// Server / BigQuery), which returns NULL on conversion failure instead of +// raising an error. +type CastExpression struct { + Expr Expression + Type string + Try bool +} + +func (c *CastExpression) expressionNode() {} +func (c CastExpression) TokenLiteral() string { + if c.Try { + return "TRY_CAST" + } + return "CAST" +} +func (c CastExpression) Children() []Node { + if c.Expr == nil { + return nil + } + return []Node{c.Expr} +} + +// AliasedExpression represents an expression with an alias (expr AS alias) +type AliasedExpression struct { + Expr Expression + Alias string +} + +func (a *AliasedExpression) expressionNode() {} +func (a AliasedExpression) TokenLiteral() string { return a.Alias } +func (a AliasedExpression) Children() []Node { + if a.Expr == nil { + return nil + } + return []Node{a.Expr} +} + +// ExtractExpression represents EXTRACT(field FROM source) +type ExtractExpression struct { + Field string + Source Expression +} + +func (e *ExtractExpression) expressionNode() {} +func (e ExtractExpression) TokenLiteral() string { return "EXTRACT" } +func (e ExtractExpression) Children() []Node { + if e.Source == nil { + return nil + } + return []Node{e.Source} +} + +// PositionExpression represents POSITION(substr IN str) +type PositionExpression struct { + Substr Expression + Str Expression +} + +func (p *PositionExpression) expressionNode() {} +func (p PositionExpression) TokenLiteral() string { return "POSITION" } +func (p PositionExpression) Children() []Node { + var nodes []Node + if p.Substr != nil { + nodes = append(nodes, p.Substr) + } + if p.Str != nil { + nodes = append(nodes, p.Str) + } + return nodes +} + +// SubstringExpression represents SUBSTRING(str FROM start [FOR length]) +type SubstringExpression struct { + Str Expression + Start Expression + Length Expression +} + +func (s *SubstringExpression) expressionNode() {} +func (s SubstringExpression) TokenLiteral() string { return "SUBSTRING" } +func (s SubstringExpression) Children() []Node { + children := []Node{s.Str, s.Start} + if s.Length != nil { + children = append(children, s.Length) + } + return children +} + +// IntervalExpression represents INTERVAL 'value' for date/time arithmetic +// Examples: INTERVAL '1 day', INTERVAL '2 hours', INTERVAL '1 year 2 months' +type IntervalExpression struct { + Value string // The interval specification string (e.g., '1 day', '2 hours') +} + +func (i *IntervalExpression) expressionNode() {} +func (i IntervalExpression) TokenLiteral() string { return "INTERVAL" } + +// Children implements Node. IntervalExpression stores its value as a raw +// string (not an Expression), so it has no child nodes. Returns nil for +// consistency with other leaf nodes. +func (i IntervalExpression) Children() []Node { return nil } + +// ArraySubscriptExpression represents array element access syntax. +// Supports single and multi-dimensional array subscripting. +// +// Examples: +// +// tags[1] - Single subscript +// matrix[2][3] - Multi-dimensional subscript +// arr[i] - Subscript with variable +// (SELECT arr)[1] - Subscript on subquery result +type ArraySubscriptExpression struct { + Array Expression // The array expression being subscripted + Indices []Expression // Subscript indices (one or more for multi-dimensional arrays) +} + +func (a *ArraySubscriptExpression) expressionNode() {} +func (a ArraySubscriptExpression) TokenLiteral() string { return "[]" } +func (a ArraySubscriptExpression) Children() []Node { + var children []Node + if a.Array != nil { + children = append(children, a.Array) + } + for _, idx := range a.Indices { + if idx != nil { + children = append(children, idx) + } + } + return children +} + +// ArraySliceExpression represents array slicing syntax for extracting subarrays. +// Supports PostgreSQL-style array slicing with optional start/end bounds. +// +// Examples: +// +// arr[1:3] - Slice from index 1 to 3 (inclusive) +// arr[2:] - Slice from index 2 to end +// arr[:5] - Slice from start to index 5 +// arr[:] - Full array slice (copy) +type ArraySliceExpression struct { + Array Expression // The array expression being sliced + Start Expression // Start index (nil means from beginning) + End Expression // End index (nil means to end) +} + +func (a *ArraySliceExpression) expressionNode() {} +func (a ArraySliceExpression) TokenLiteral() string { return "[:]" } +func (a ArraySliceExpression) Children() []Node { + var children []Node + if a.Array != nil { + children = append(children, a.Array) + } + if a.Start != nil { + children = append(children, a.Start) + } + if a.End != nil { + children = append(children, a.End) + } + return children +} + +// UpdateExpression represents a column=value expression in UPDATE +type UpdateExpression struct { + Column Expression + Value Expression +} + +func (u *UpdateExpression) expressionNode() {} +func (u UpdateExpression) TokenLiteral() string { return "=" } +func (u UpdateExpression) Children() []Node { + var nodes []Node + if u.Column != nil { + nodes = append(nodes, u.Column) + } + if u.Value != nil { + nodes = append(nodes, u.Value) + } + return nodes +} diff --git a/pkg/sql/ast/ast_statements.go b/pkg/sql/ast/ast_statements.go new file mode 100644 index 00000000..7f5e893c --- /dev/null +++ b/pkg/sql/ast/ast_statements.go @@ -0,0 +1,778 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import ( + "github.com/ajitpratap0/GoSQLX/pkg/models" +) + +// QueryExpression is a Statement that can appear as the source of INSERT ... SELECT. +// Only *SelectStatement and *SetOperation satisfy this interface. +type QueryExpression interface { + Statement + queryExpressionNode() +} + +// SetOperation represents set operations (UNION, EXCEPT, INTERSECT) between two statements. +// It supports the ALL modifier (e.g., UNION ALL) and proper left-associative parsing. +// Phase 2 Complete: Full parser support with left-associative precedence. +type SetOperation struct { + Left Statement + Operator string // UNION, EXCEPT, INTERSECT + Right Statement + All bool // UNION ALL vs UNION +} + +func (s *SetOperation) statementNode() {} +func (s *SetOperation) queryExpressionNode() {} +func (s SetOperation) TokenLiteral() string { return s.Operator } +func (s SetOperation) Children() []Node { + var nodes []Node + if s.Left != nil { + nodes = append(nodes, s.Left) + } + if s.Right != nil { + nodes = append(nodes, s.Right) + } + return nodes +} + +// SelectStatement represents a SELECT SQL statement with full SQL-99/SQL:2003 support. +// +// SelectStatement is the primary query statement type supporting: +// - CTEs (WITH clause) +// - DISTINCT and DISTINCT ON (PostgreSQL) +// - Multiple FROM tables and subqueries +// - All JOIN types with LATERAL support +// - WHERE, GROUP BY, HAVING, ORDER BY clauses +// - Window functions with PARTITION BY and frame specifications +// - LIMIT/OFFSET and SQL-99 FETCH clause +// +// Fields: +// - With: WITH clause for Common Table Expressions (CTEs) +// - Distinct: DISTINCT keyword for duplicate elimination +// - DistinctOnColumns: DISTINCT ON (expr, ...) for PostgreSQL (v1.6.0) +// - Columns: SELECT list expressions (columns, *, functions, etc.) +// - From: FROM clause table references (tables, subqueries, LATERAL) +// - TableName: Table name for simple queries (pool optimization) +// - Joins: JOIN clauses (INNER, LEFT, RIGHT, FULL, CROSS, NATURAL) +// - Where: WHERE clause filter condition +// - GroupBy: GROUP BY expressions (including ROLLUP, CUBE, GROUPING SETS) +// - Having: HAVING clause filter condition +// - Windows: Window specifications (WINDOW clause) +// - OrderBy: ORDER BY expressions with NULLS FIRST/LAST +// - Limit: LIMIT clause (number of rows) +// - Offset: OFFSET clause (skip rows) +// - Fetch: SQL-99 FETCH FIRST/NEXT clause (v1.6.0) +// +// Example - Basic SELECT: +// +// SelectStatement{ +// Columns: []Expression{&Identifier{Name: "id"}, &Identifier{Name: "name"}}, +// From: []TableReference{{Name: "users"}}, +// Where: &BinaryExpression{...}, +// } +// // SQL: SELECT id, name FROM users WHERE ... +// +// Example - DISTINCT ON (PostgreSQL v1.6.0): +// +// SelectStatement{ +// DistinctOnColumns: []Expression{&Identifier{Name: "dept_id"}}, +// Columns: []Expression{&Identifier{Name: "dept_id"}, &Identifier{Name: "name"}}, +// From: []TableReference{{Name: "employees"}}, +// } +// // SQL: SELECT DISTINCT ON (dept_id) dept_id, name FROM employees +// +// Example - Window function with FETCH (v1.6.0): +// +// SelectStatement{ +// Columns: []Expression{ +// &FunctionCall{ +// Name: "ROW_NUMBER", +// Over: &WindowSpec{ +// OrderBy: []OrderByExpression{{Expression: &Identifier{Name: "salary"}, Ascending: false}}, +// }, +// }, +// }, +// From: []TableReference{{Name: "employees"}}, +// Fetch: &FetchClause{FetchValue: ptrInt64(10), FetchType: "FIRST"}, +// } +// // SQL: SELECT ROW_NUMBER() OVER (ORDER BY salary DESC) FROM employees FETCH FIRST 10 ROWS ONLY +// +// New in v1.6.0: +// - DistinctOnColumns for PostgreSQL DISTINCT ON +// - Fetch for SQL-99 FETCH FIRST/NEXT clause +// - Enhanced LATERAL JOIN support via TableReference.Lateral +// - FILTER clause support via FunctionCall.Filter +type SelectStatement struct { + With *WithClause + Distinct bool + DistinctOnColumns []Expression // PostgreSQL DISTINCT ON (expr, ...) clause + Top *TopClause // SQL Server TOP N [PERCENT] clause + Columns []Expression + From []TableReference + TableName string // Added for pool operations + Joins []JoinClause + ArrayJoin *ArrayJoinClause // ClickHouse ARRAY JOIN / LEFT ARRAY JOIN clause + PrewhereClause Expression // ClickHouse PREWHERE clause (applied before WHERE, before reading data) + Sample *SampleClause // ClickHouse SAMPLE clause (comes after FROM/FINAL, before PREWHERE) + Where Expression + GroupBy []Expression + Having Expression + Qualify Expression // Snowflake / BigQuery QUALIFY clause (filters after window functions) + // StartWith is the optional seed condition for CONNECT BY (MariaDB 10.2+). + // Example: START WITH parent_id IS NULL + StartWith Expression // MariaDB hierarchical query seed + // ConnectBy holds the hierarchy traversal condition (MariaDB 10.2+). + // Example: CONNECT BY PRIOR id = parent_id + ConnectBy *ConnectByClause // MariaDB hierarchical query + Windows []WindowSpec + OrderBy []OrderByExpression + Limit *int + Offset *int + Fetch *FetchClause // SQL-99 FETCH FIRST/NEXT clause (F861, F862) + For *ForClause // Row-level locking clause (SQL:2003, PostgreSQL, MySQL) + Pos models.Location // Source position of the SELECT keyword (1-based line and column) +} + +func (s *SelectStatement) statementNode() {} +func (s *SelectStatement) queryExpressionNode() {} +func (s SelectStatement) TokenLiteral() string { return "SELECT" } + +func (s SelectStatement) Children() []Node { + children := make([]Node, 0) + if s.With != nil { + children = append(children, s.With) + } + children = append(children, nodifyExpressions(s.DistinctOnColumns)...) + children = append(children, nodifyExpressions(s.Columns)...) + for _, from := range s.From { + from := from // G601: Create local copy to avoid memory aliasing + children = append(children, &from) + } + for _, join := range s.Joins { + join := join // G601: Create local copy to avoid memory aliasing + children = append(children, &join) + } + if s.Sample != nil { + children = append(children, s.Sample) + } + if s.PrewhereClause != nil { + children = append(children, s.PrewhereClause) + } + if s.Where != nil { + children = append(children, s.Where) + } + children = append(children, nodifyExpressions(s.GroupBy)...) + if s.Having != nil { + children = append(children, s.Having) + } + if s.Qualify != nil { + children = append(children, s.Qualify) + } + for _, window := range s.Windows { + window := window // G601: Create local copy to avoid memory aliasing + children = append(children, &window) + } + for _, orderBy := range s.OrderBy { + orderBy := orderBy // G601: Create local copy to avoid memory aliasing + children = append(children, &orderBy) + } + if s.Fetch != nil { + children = append(children, s.Fetch) + } + if s.For != nil { + children = append(children, s.For) + } + if s.StartWith != nil { + children = append(children, s.StartWith) + } + if s.ConnectBy != nil { + children = append(children, s.ConnectBy) + } + return children +} + +// InsertStatement represents an INSERT SQL statement +type InsertStatement struct { + With *WithClause + TableName string + Columns []Expression + Output []Expression // SQL Server OUTPUT clause columns + Values [][]Expression // Multi-row support: each inner slice is one row of values + Query QueryExpression // For INSERT ... SELECT (SelectStatement or SetOperation) + Returning []Expression + OnConflict *OnConflict + OnDuplicateKey *UpsertClause // MySQL: ON DUPLICATE KEY UPDATE + Pos models.Location // Source position of the INSERT keyword (1-based line and column) +} + +func (i *InsertStatement) statementNode() {} +func (i InsertStatement) TokenLiteral() string { return "INSERT" } + +func (i InsertStatement) Children() []Node { + children := make([]Node, 0) + if i.With != nil { + children = append(children, i.With) + } + children = append(children, nodifyExpressions(i.Columns)...) + children = append(children, nodifyExpressions(i.Output)...) + // Flatten multi-row values for Children() + for _, row := range i.Values { + children = append(children, nodifyExpressions(row)...) + } + if i.Query != nil { + children = append(children, i.Query) + } + children = append(children, nodifyExpressions(i.Returning)...) + if i.OnConflict != nil { + children = append(children, i.OnConflict) + } + if i.OnDuplicateKey != nil { + children = append(children, i.OnDuplicateKey) + } + return children +} + +// Values represents VALUES clause +type Values struct { + Rows [][]Expression +} + +func (v *Values) statementNode() {} +func (v Values) TokenLiteral() string { return "VALUES" } +func (v Values) Children() []Node { + children := make([]Node, 0) + for _, row := range v.Rows { + children = append(children, nodifyExpressions(row)...) + } + return children +} + +// UpdateStatement represents an UPDATE SQL statement +type UpdateStatement struct { + With *WithClause + TableName string + Alias string + Assignments []UpdateExpression // SET clause assignments + From []TableReference + Where Expression + Returning []Expression + Pos models.Location // Source position of the UPDATE keyword (1-based line and column) +} + +// GetUpdates returns Assignments for backward compatibility. +// +// Deprecated: Use Assignments directly instead. +func (u *UpdateStatement) GetUpdates() []UpdateExpression { + return u.Assignments +} + +func (u *UpdateStatement) statementNode() {} +func (u UpdateStatement) TokenLiteral() string { return "UPDATE" } + +func (u UpdateStatement) Children() []Node { + children := make([]Node, 0) + if u.With != nil { + children = append(children, u.With) + } + for _, assignment := range u.Assignments { + assignment := assignment // G601: Create local copy to avoid memory aliasing + children = append(children, &assignment) + } + for _, from := range u.From { + from := from // G601: Create local copy to avoid memory aliasing + children = append(children, &from) + } + if u.Where != nil { + children = append(children, u.Where) + } + children = append(children, nodifyExpressions(u.Returning)...) + return children +} + +// CreateTableStatement represents a CREATE TABLE statement +type CreateTableStatement struct { + IfNotExists bool + Temporary bool + Name string + Columns []ColumnDef + Constraints []TableConstraint + Inherits []string + PartitionBy *PartitionBy + Partitions []PartitionDefinition // Individual partition definitions + Options []TableOption + WithoutRowID bool // SQLite: CREATE TABLE ... WITHOUT ROWID + + // WithSystemVersioning enables system-versioned temporal history (MariaDB 10.3.4+). + // Example: CREATE TABLE t (...) WITH SYSTEM VERSIONING + WithSystemVersioning bool + + // PeriodDefinitions holds PERIOD FOR clauses for application-time or system-time periods. + // Example: PERIOD FOR app_time (start_col, end_col) + PeriodDefinitions []*PeriodDefinition +} + +func (c *CreateTableStatement) statementNode() {} +func (c CreateTableStatement) TokenLiteral() string { return "CREATE TABLE" } +func (c CreateTableStatement) Children() []Node { + children := make([]Node, 0) + for _, col := range c.Columns { + col := col // G601: Create local copy to avoid memory aliasing + children = append(children, &col) + } + for _, constraint := range c.Constraints { + constraint := constraint // G601: Create local copy to avoid memory aliasing + children = append(children, &constraint) + } + if c.PartitionBy != nil { + children = append(children, c.PartitionBy) + } + for _, p := range c.Partitions { + p := p // G601: Create local copy + children = append(children, &p) + } + return children +} + +// DeleteStatement represents a DELETE SQL statement +type DeleteStatement struct { + With *WithClause + TableName string + Alias string + Using []TableReference + Where Expression + Returning []Expression + Pos models.Location // Source position of the DELETE keyword (1-based line and column) +} + +func (d *DeleteStatement) statementNode() {} +func (d DeleteStatement) TokenLiteral() string { return "DELETE" } + +func (d DeleteStatement) Children() []Node { + children := make([]Node, 0) + if d.With != nil { + children = append(children, d.With) + } + for _, using := range d.Using { + using := using // G601: Create local copy to avoid memory aliasing + children = append(children, &using) + } + if d.Where != nil { + children = append(children, d.Where) + } + children = append(children, nodifyExpressions(d.Returning)...) + return children +} + +// AlterTableStatement represents an ALTER TABLE statement. +// +// # Maintenance note +// +// AlterTableStatement is NOT produced by the parser. Parser.Parse* methods +// return [AlterStatement] (defined in alter.go) with Type == AlterTypeTable. +// AlterTableStatement is retained only so that existing code that constructs +// it directly (e.g. in tests or manual AST construction) continues to compile. +// +// Migration guide - prefer AlterStatement for all new code: +// +// // Wrong (type assertion will never succeed at runtime): +// stmt := tree.Statements[0].(*ast.AlterTableStatement) +// +// // Correct: +// stmt := tree.Statements[0].(*ast.AlterStatement) +// tableName := stmt.Name // AlterStatement.Name holds the table name +type AlterTableStatement struct { + Table string + Actions []AlterTableAction +} + +func (a *AlterTableStatement) statementNode() {} +func (a AlterTableStatement) TokenLiteral() string { return "ALTER TABLE" } +func (a AlterTableStatement) Children() []Node { + children := make([]Node, len(a.Actions)) + for i, action := range a.Actions { + action := action // G601: Create local copy to avoid memory aliasing + children[i] = &action + } + return children +} + +// AlterTableAction represents an action in ALTER TABLE +type AlterTableAction struct { + Type string // ADD COLUMN, DROP COLUMN, MODIFY COLUMN, etc. + ColumnName string + ColumnDef *ColumnDef + Constraint *TableConstraint +} + +func (a *AlterTableAction) expressionNode() {} +func (a AlterTableAction) TokenLiteral() string { return a.Type } +func (a AlterTableAction) Children() []Node { + children := make([]Node, 0) + if a.ColumnDef != nil { + children = append(children, a.ColumnDef) + } + if a.Constraint != nil { + children = append(children, a.Constraint) + } + return children +} + +// CreateIndexStatement represents a CREATE INDEX statement +type CreateIndexStatement struct { + Unique bool + IfNotExists bool + Name string + Table string + Columns []IndexColumn + Using string + Where Expression +} + +func (c *CreateIndexStatement) statementNode() {} +func (c CreateIndexStatement) TokenLiteral() string { return "CREATE INDEX" } +func (c CreateIndexStatement) Children() []Node { + children := make([]Node, 0) + for _, col := range c.Columns { + col := col // G601: Create local copy to avoid memory aliasing + children = append(children, &col) + } + if c.Where != nil { + children = append(children, c.Where) + } + return children +} + +// MergeStatement represents a MERGE statement (SQL:2003 F312) +// Syntax: MERGE INTO target USING source ON condition +// +// WHEN MATCHED THEN UPDATE/DELETE +// WHEN NOT MATCHED THEN INSERT +// WHEN NOT MATCHED BY SOURCE THEN UPDATE/DELETE +type MergeStatement struct { + TargetTable TableReference // The table being merged into + TargetAlias string // Optional alias for target + SourceTable TableReference // The source table or subquery + SourceAlias string // Optional alias for source + OnCondition Expression // The join/match condition + WhenClauses []*MergeWhenClause // List of WHEN clauses + Output []Expression // SQL Server OUTPUT clause columns +} + +func (m *MergeStatement) statementNode() {} +func (m MergeStatement) TokenLiteral() string { return "MERGE" } +func (m MergeStatement) Children() []Node { + children := []Node{&m.TargetTable, &m.SourceTable} + if m.OnCondition != nil { + children = append(children, m.OnCondition) + } + for _, when := range m.WhenClauses { + children = append(children, when) + } + children = append(children, nodifyExpressions(m.Output)...) + return children +} + +// MergeWhenClause represents a WHEN clause in a MERGE statement +// Types: MATCHED, NOT_MATCHED, NOT_MATCHED_BY_SOURCE +type MergeWhenClause struct { + Type string // "MATCHED", "NOT_MATCHED", "NOT_MATCHED_BY_SOURCE" + Condition Expression // Optional AND condition + Action *MergeAction // The action to perform (UPDATE/INSERT/DELETE) +} + +func (w *MergeWhenClause) expressionNode() {} +func (w MergeWhenClause) TokenLiteral() string { return "WHEN " + w.Type } +func (w MergeWhenClause) Children() []Node { + children := make([]Node, 0) + if w.Condition != nil { + children = append(children, w.Condition) + } + if w.Action != nil { + children = append(children, w.Action) + } + return children +} + +// MergeAction represents the action in a WHEN clause +// ActionType: UPDATE, INSERT, DELETE +type MergeAction struct { + ActionType string // "UPDATE", "INSERT", "DELETE" + SetClauses []SetClause // For UPDATE: SET column = value pairs + Columns []string // For INSERT: column list + Values []Expression // For INSERT: value list + DefaultValues bool // For INSERT: use DEFAULT VALUES +} + +func (a *MergeAction) expressionNode() {} +func (a MergeAction) TokenLiteral() string { return a.ActionType } +func (a MergeAction) Children() []Node { + children := make([]Node, 0) + for _, set := range a.SetClauses { + set := set // G601: Create local copy + children = append(children, &set) + } + for _, val := range a.Values { + children = append(children, val) + } + return children +} + +// SetClause represents a SET clause in UPDATE (also used in MERGE UPDATE) +type SetClause struct { + Column string + Value Expression +} + +func (s *SetClause) expressionNode() {} +func (s SetClause) TokenLiteral() string { return s.Column } +func (s SetClause) Children() []Node { + if s.Value != nil { + return []Node{s.Value} + } + return nil +} + +// CreateViewStatement represents a CREATE VIEW statement +// Syntax: CREATE [OR REPLACE] [TEMP|TEMPORARY] VIEW [IF NOT EXISTS] name [(columns)] AS select +type CreateViewStatement struct { + OrReplace bool + Temporary bool + IfNotExists bool + Name string + Columns []string // Optional column list + Query Statement // The SELECT statement + WithOption string // PostgreSQL: WITH (CHECK OPTION | CASCADED | LOCAL) +} + +func (c *CreateViewStatement) statementNode() {} +func (c CreateViewStatement) TokenLiteral() string { return "CREATE VIEW" } +func (c CreateViewStatement) Children() []Node { + if c.Query != nil { + return []Node{c.Query} + } + return nil +} + +// CreateMaterializedViewStatement represents a CREATE MATERIALIZED VIEW statement +// Syntax: CREATE MATERIALIZED VIEW [IF NOT EXISTS] name [(columns)] AS select [WITH [NO] DATA] +type CreateMaterializedViewStatement struct { + IfNotExists bool + Name string + Columns []string // Optional column list + Query Statement // The SELECT statement + WithData *bool // nil = default, true = WITH DATA, false = WITH NO DATA + Tablespace string // Optional tablespace (PostgreSQL) +} + +func (c *CreateMaterializedViewStatement) statementNode() {} +func (c CreateMaterializedViewStatement) TokenLiteral() string { return "CREATE MATERIALIZED VIEW" } +func (c CreateMaterializedViewStatement) Children() []Node { + if c.Query != nil { + return []Node{c.Query} + } + return nil +} + +// RefreshMaterializedViewStatement represents a REFRESH MATERIALIZED VIEW statement +// Syntax: REFRESH MATERIALIZED VIEW [CONCURRENTLY] name [WITH [NO] DATA] +type RefreshMaterializedViewStatement struct { + Concurrently bool + Name string + WithData *bool // nil = default, true = WITH DATA, false = WITH NO DATA +} + +func (r *RefreshMaterializedViewStatement) statementNode() {} +func (r RefreshMaterializedViewStatement) TokenLiteral() string { return "REFRESH MATERIALIZED VIEW" } +func (r RefreshMaterializedViewStatement) Children() []Node { return nil } + +// DropStatement represents a DROP statement for tables, views, indexes, etc. +// Syntax: DROP object_type [IF EXISTS] name [CASCADE|RESTRICT] +type DropStatement struct { + ObjectType string // TABLE, VIEW, MATERIALIZED VIEW, INDEX, etc. + IfExists bool + Names []string // Can drop multiple objects + CascadeType string // CASCADE, RESTRICT, or empty +} + +func (d *DropStatement) statementNode() {} +func (d DropStatement) TokenLiteral() string { return "DROP " + d.ObjectType } +func (d DropStatement) Children() []Node { return nil } + +// TruncateStatement represents a TRUNCATE TABLE statement +// Syntax: TRUNCATE [TABLE] table_name [, table_name ...] [RESTART IDENTITY | CONTINUE IDENTITY] [CASCADE | RESTRICT] +type TruncateStatement struct { + Tables []string // Table names to truncate + RestartIdentity bool // RESTART IDENTITY - reset sequences + ContinueIdentity bool // CONTINUE IDENTITY - keep sequences (default) + CascadeType string // CASCADE, RESTRICT, or empty +} + +func (t *TruncateStatement) statementNode() {} +func (t TruncateStatement) TokenLiteral() string { return "TRUNCATE TABLE" } +func (t TruncateStatement) Children() []Node { return nil } + +// PragmaStatement represents a SQLite PRAGMA statement. +// Examples: PRAGMA table_info(users), PRAGMA journal_mode = WAL, PRAGMA integrity_check +type PragmaStatement struct { + Name string // Pragma name, e.g. "table_info" + Arg string // Optional: parenthesized arg, e.g. "users" + Value string // Optional: assigned value, e.g. "WAL" +} + +func (p *PragmaStatement) statementNode() {} +func (p PragmaStatement) TokenLiteral() string { return "PRAGMA" } +func (p PragmaStatement) Children() []Node { return nil } + +// ShowStatement represents MySQL SHOW commands (SHOW TABLES, SHOW DATABASES, SHOW CREATE TABLE x, etc.) +type ShowStatement struct { + ShowType string // TABLES, DATABASES, CREATE TABLE, COLUMNS, INDEX, etc. + ObjectName string // For SHOW CREATE TABLE x, SHOW COLUMNS FROM x, etc. + From string // For SHOW ... FROM database +} + +func (s *ShowStatement) statementNode() {} +func (s ShowStatement) TokenLiteral() string { return "SHOW" } +func (s ShowStatement) Children() []Node { return nil } + +// DescribeStatement represents MySQL DESCRIBE/DESC/EXPLAIN table commands +type DescribeStatement struct { + TableName string +} + +func (d *DescribeStatement) statementNode() {} +func (d DescribeStatement) TokenLiteral() string { return "DESCRIBE" } +func (d DescribeStatement) Children() []Node { return nil } + +// UnsupportedStatement represents a SQL statement that was parsed but not +// fully modeled in the AST. The parser consumed and validated the tokens +// but no dedicated AST node exists yet for this statement kind. +// +// Consumers should use Kind to identify the operation (e.g., "USE", "COPY", +// "CREATE STAGE") and RawSQL for the original text. Tools that do +// switch stmt.(type) should handle this case explicitly rather than +// falling through to a default that assumes the statement is well-structured. +type UnsupportedStatement struct { + Kind string // Operation kind: "USE", "COPY", "PUT", "GET", "LIST", "REMOVE", "CREATE STAGE", etc. + RawSQL string // Original SQL fragment for round-trip fidelity +} + +func (u *UnsupportedStatement) statementNode() {} +func (u UnsupportedStatement) TokenLiteral() string { return u.Kind } +func (u UnsupportedStatement) Children() []Node { return nil } + +// ReplaceStatement represents MySQL REPLACE INTO statement +type ReplaceStatement struct { + TableName string + Columns []Expression + Values [][]Expression +} + +func (r *ReplaceStatement) statementNode() {} +func (r ReplaceStatement) TokenLiteral() string { return "REPLACE" } +func (r ReplaceStatement) Children() []Node { + children := make([]Node, 0) + children = append(children, nodifyExpressions(r.Columns)...) + for _, row := range r.Values { + children = append(children, nodifyExpressions(row)...) + } + return children +} + +// ── MariaDB SEQUENCE DDL (10.3+) ─────────────────────────────────────────── + +// CycleOption represents the CYCLE behavior for a sequence. +type CycleOption int + +const ( + // CycleUnspecified means no CYCLE or NOCYCLE clause was given (database default applies). + CycleUnspecified CycleOption = iota + // CycleBehavior means CYCLE — sequence wraps around when it reaches min/max. + CycleBehavior + // NoCycleBehavior means NOCYCLE / NO CYCLE — sequence errors on overflow. + NoCycleBehavior +) + +// SequenceOptions holds configuration for CREATE SEQUENCE and ALTER SEQUENCE. +// Fields are pointers so that unspecified options are distinguishable from zero values. +type SequenceOptions struct { + StartWith *LiteralValue // START WITH n + IncrementBy *LiteralValue // INCREMENT BY n (default 1) + MinValue *LiteralValue // MINVALUE n or nil when NO MINVALUE + MaxValue *LiteralValue // MAXVALUE n or nil when NO MAXVALUE + Cache *LiteralValue // CACHE n or nil when NO CACHE / NOCACHE + CycleMode CycleOption // CYCLE / NOCYCLE / NO CYCLE (CycleUnspecified if not specified) + NoCache bool // NOCACHE (explicit; Cache=nil alone is ambiguous) + Restart bool // bare RESTART (reset to start value) + RestartWith *LiteralValue // RESTART WITH n (explicit restart value) +} + +// CreateSequenceStatement represents: +// +// CREATE [OR REPLACE] SEQUENCE [IF NOT EXISTS] name [options...] +type CreateSequenceStatement struct { + Name *Identifier + OrReplace bool + IfNotExists bool + Options SequenceOptions + Pos models.Location // Source position of the CREATE keyword (1-based line and column) +} + +func (s *CreateSequenceStatement) statementNode() {} +func (s *CreateSequenceStatement) TokenLiteral() string { return "CREATE" } +func (s *CreateSequenceStatement) Children() []Node { + if s.Name != nil { + return []Node{s.Name} + } + return nil +} + +// DropSequenceStatement represents: +// +// DROP SEQUENCE [IF EXISTS | IF NOT EXISTS] name +type DropSequenceStatement struct { + Name *Identifier + IfExists bool + Pos models.Location // Source position of the DROP keyword (1-based line and column) +} + +func (s *DropSequenceStatement) statementNode() {} +func (s *DropSequenceStatement) TokenLiteral() string { return "DROP" } +func (s *DropSequenceStatement) Children() []Node { + if s.Name != nil { + return []Node{s.Name} + } + return nil +} + +// AlterSequenceStatement represents: +// +// ALTER SEQUENCE [IF EXISTS] name [options...] +type AlterSequenceStatement struct { + Name *Identifier + IfExists bool + Options SequenceOptions + Pos models.Location // Source position of the ALTER keyword (1-based line and column) +} + +func (s *AlterSequenceStatement) statementNode() {} +func (s *AlterSequenceStatement) TokenLiteral() string { return "ALTER" } +func (s *AlterSequenceStatement) Children() []Node { + if s.Name != nil { + return []Node{s.Name} + } + return nil +} diff --git a/pkg/sql/ast/pool_expression_release.go b/pkg/sql/ast/pool_expression_release.go new file mode 100644 index 00000000..2eed7731 --- /dev/null +++ b/pkg/sql/ast/pool_expression_release.go @@ -0,0 +1,781 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import ( + "sync/atomic" +) + +// GetUpdateExpression gets an UpdateExpression from the pool +func GetUpdateExpression() *UpdateExpression { + return updateExprPool.Get().(*UpdateExpression) +} + +// PutUpdateExpression returns an UpdateExpression to the pool +func PutUpdateExpression(expr *UpdateExpression) { + if expr == nil { + return + } + + // Clean up expressions + PutExpression(expr.Column) + PutExpression(expr.Value) + + // Reset fields + expr.Column = nil + expr.Value = nil + + // Return to pool + updateExprPool.Put(expr) +} + +// GetIdentifier gets an Identifier from the pool +func GetIdentifier() *Identifier { + return identifierPool.Get().(*Identifier) +} + +// PutIdentifier returns an Identifier to the pool +func PutIdentifier(ident *Identifier) { + if ident == nil { + return + } + ident.Name = "" + identifierPool.Put(ident) +} + +// GetBinaryExpression gets a BinaryExpression from the pool +func GetBinaryExpression() *BinaryExpression { + return binaryExprPool.Get().(*BinaryExpression) +} + +// PutBinaryExpression returns a BinaryExpression to the pool +func PutBinaryExpression(expr *BinaryExpression) { + if expr == nil { + return + } + PutExpression(expr.Left) + PutExpression(expr.Right) + expr.Left = nil + expr.Right = nil + expr.Operator = "" + binaryExprPool.Put(expr) +} + +// GetExpressionSlice gets a slice of Expression from the pool +func GetExpressionSlice() *[]Expression { + slice := exprSlicePool.Get().(*[]Expression) + *slice = (*slice)[:0] + return slice +} + +// PutExpressionSlice returns a slice of Expression to the pool +func PutExpressionSlice(slice *[]Expression) { + if slice == nil { + return + } + for i := range *slice { + PutExpression((*slice)[i]) + (*slice)[i] = nil + } + exprSlicePool.Put(slice) +} + +// GetLiteralValue gets a LiteralValue from the pool +func GetLiteralValue() *LiteralValue { + return literalValuePool.Get().(*LiteralValue) +} + +// PutLiteralValue returns a LiteralValue to the pool +func PutLiteralValue(lit *LiteralValue) { + if lit == nil { + return + } + + // Reset fields (Value is interface{}, use nil as zero value) + lit.Value = nil + lit.Type = "" + + // Return to pool + literalValuePool.Put(lit) +} + +// PutExpression returns any Expression to the appropriate pool with iterative cleanup. +// +// PutExpression is the primary function for returning expression nodes to their +// respective pools. It handles all expression types and uses iterative cleanup +// to prevent stack overflow with deeply nested expression trees. +// +// Key Features: +// - Supports all expression types (30+ pooled types) +// - Iterative cleanup algorithm (no recursion limits) +// - Prevents stack overflow for deeply nested expressions +// - Work queue size limits (MaxWorkQueueSize = 1000) +// - Nil-safe (ignores nil expressions) +// +// Supported Expression Types: +// - Identifier, LiteralValue, AliasedExpression +// - BinaryExpression, UnaryExpression +// - FunctionCall, CaseExpression +// - BetweenExpression, InExpression +// - SubqueryExpression, ExistsExpression, AnyExpression, AllExpression +// - CastExpression, ExtractExpression, PositionExpression, SubstringExpression +// - ListExpression +// +// Iterative Cleanup Algorithm: +// 1. Use work queue instead of recursion +// 2. Process expressions breadth-first +// 3. Collect child expressions and add to queue +// 4. Clean and return to pool +// 5. Limit queue size to prevent memory exhaustion +// +// Parameters: +// - expr: Expression to return to pool (nil-safe) +// +// Usage Pattern: +// +// expr := ast.GetBinaryExpression() +// defer ast.PutExpression(expr) +// +// // Build expression tree... +// +// Example - Cleaning up complex expression: +// +// // Build: (age > 18 AND status = 'active') OR (role = 'admin') +// expr := &ast.BinaryExpression{ +// Left: &ast.BinaryExpression{ +// Left: &ast.BinaryExpression{...}, +// Operator: "AND", +// Right: &ast.BinaryExpression{...}, +// }, +// Operator: "OR", +// Right: &ast.BinaryExpression{...}, +// } +// +// // Cleanup all nested expressions +// ast.PutExpression(expr) // Handles entire tree iteratively +// +// Performance Characteristics: +// - O(n) time complexity where n = number of nodes +// - O(min(n, MaxWorkQueueSize)) space complexity +// - No stack overflow risk regardless of nesting depth +// - Efficient for both shallow and deeply nested expressions +// +// Safety Guarantees: +// - Thread-safe (uses sync.Pool internally) +// - Nil-safe (gracefully handles nil expressions) +// - Stack-safe (iterative, not recursive) +// - Memory-safe (work queue size limits) +// +// IMPORTANT: This function should be used for all expression cleanup. +// Direct pool returns (e.g., binaryExprPool.Put()) bypass the iterative +// cleanup and may leave child expressions unreleased. +// +// See also: GetBinaryExpression(), GetFunctionCall(), GetIdentifier() +func PutExpression(expr Expression) { + if expr == nil { + return + } + putExpressionImpl(expr, 0) +} + +// putExpressionImpl is the internal driver for PutExpression. The depth +// parameter tracks recursive re-entries from the work-queue overflow path +// to prevent stack overflow on pathologically deep ASTs. +// +// The iterative work queue is drawn from putExpressionWorkQueuePool so that +// hot-path PutExpression calls (10-100× per parse) do not repeatedly allocate +// a fresh 32-cap []Expression. The slice is reset to zero-length and its +// element slots nil'd before being returned to the pool (preventing the +// pool from pinning Expression pointers we've already returned to their +// own pools). +func putExpressionImpl(expr Expression, depth int) { + if expr == nil { + return + } + + // Acquire a pooled work queue. We must write the (possibly grown) + // slice header back to the pointer before Put so that subsequent + // Get calls see the grown capacity. + qp := putExpressionWorkQueuePool.Get().(*[]Expression) + workQueue := (*qp)[:0] + defer func() { + // Nil out slice elements up to the underlying capacity we used so + // the pool cannot pin arbitrarily-aged Expression pointers. Using + // the full capacity is safe because we only wrote through + // append — anything beyond len was never assigned here, but prior + // use of this pooled slice may have written to those slots. Clear + // them all by reslicing to capacity and zeroing. + cleared := workQueue[:cap(workQueue)] + for i := range cleared { + cleared[i] = nil + } + *qp = workQueue[:0] + putExpressionWorkQueuePool.Put(qp) + }() + workQueue = append(workQueue, expr) + + processed := 0 + for len(workQueue) > 0 && processed < MaxWorkQueueSize { + // Pop from queue + current := workQueue[len(workQueue)-1] + workQueue = workQueue[:len(workQueue)-1] + processed++ + + if current == nil { + continue + } + + // Process and collect child expressions + switch e := current.(type) { + case *Identifier: + e.Name = "" + identifierPool.Put(e) + + case *BinaryExpression: + if e.Left != nil { + workQueue = append(workQueue, e.Left) + } + if e.Right != nil { + workQueue = append(workQueue, e.Right) + } + e.Left = nil + e.Right = nil + e.Operator = "" + binaryExprPool.Put(e) + + case *LiteralValue: + e.Value = nil + e.Type = "" + literalValuePool.Put(e) + + case *FunctionCall: + for i := range e.Arguments { + if e.Arguments[i] != nil { + workQueue = append(workQueue, e.Arguments[i]) + } + e.Arguments[i] = nil + } + e.Arguments = e.Arguments[:0] + e.Name = "" + e.Over = nil + e.Distinct = false + e.Filter = nil + functionCallPool.Put(e) + + case *CaseExpression: + if e.Value != nil { + workQueue = append(workQueue, e.Value) + } + for i := range e.WhenClauses { + if e.WhenClauses[i].Condition != nil { + workQueue = append(workQueue, e.WhenClauses[i].Condition) + } + if e.WhenClauses[i].Result != nil { + workQueue = append(workQueue, e.WhenClauses[i].Result) + } + } + if e.ElseClause != nil { + workQueue = append(workQueue, e.ElseClause) + } + e.Value = nil + e.WhenClauses = e.WhenClauses[:0] + e.ElseClause = nil + caseExprPool.Put(e) + + case *BetweenExpression: + if e.Expr != nil { + workQueue = append(workQueue, e.Expr) + } + if e.Lower != nil { + workQueue = append(workQueue, e.Lower) + } + if e.Upper != nil { + workQueue = append(workQueue, e.Upper) + } + e.Expr = nil + e.Lower = nil + e.Upper = nil + e.Not = false + betweenExprPool.Put(e) + + case *InExpression: + if e.Expr != nil { + workQueue = append(workQueue, e.Expr) + } + for i := range e.List { + if e.List[i] != nil { + workQueue = append(workQueue, e.List[i]) + } + e.List[i] = nil + } + e.Expr = nil + e.List = e.List[:0] + // Subquery is a Statement (typically *SelectStatement); release + // it through the statement dispatcher so every nested pooled + // node is returned. Silently setting to nil was a leak. + if e.Subquery != nil { + releaseStatement(e.Subquery) + e.Subquery = nil + } + e.Not = false + inExprPool.Put(e) + + case *SubqueryExpression: + if e.Subquery != nil { + releaseStatement(e.Subquery) + e.Subquery = nil + } + subqueryExprPool.Put(e) + + case *CastExpression: + if e.Expr != nil { + workQueue = append(workQueue, e.Expr) + } + e.Expr = nil + e.Type = "" + castExprPool.Put(e) + + case *IntervalExpression: + e.Value = "" + intervalExprPool.Put(e) + + case *ArraySubscriptExpression: + if e.Array != nil { + workQueue = append(workQueue, e.Array) + } + for i := range e.Indices { + if e.Indices[i] != nil { + workQueue = append(workQueue, e.Indices[i]) + } + } + e.Array = nil + e.Indices = e.Indices[:0] + arraySubscriptExprPool.Put(e) + + case *ArraySliceExpression: + if e.Array != nil { + workQueue = append(workQueue, e.Array) + } + if e.Start != nil { + workQueue = append(workQueue, e.Start) + } + if e.End != nil { + workQueue = append(workQueue, e.End) + } + e.Array = nil + e.Start = nil + e.End = nil + arraySliceExprPool.Put(e) + + case *ExistsExpression: + if e.Subquery != nil { + releaseStatement(e.Subquery) + e.Subquery = nil + } + existsExprPool.Put(e) + + case *AnyExpression: + if e.Expr != nil { + workQueue = append(workQueue, e.Expr) + } + e.Expr = nil + if e.Subquery != nil { + releaseStatement(e.Subquery) + e.Subquery = nil + } + e.Operator = "" + anyExprPool.Put(e) + + case *AllExpression: + if e.Expr != nil { + workQueue = append(workQueue, e.Expr) + } + e.Expr = nil + if e.Subquery != nil { + releaseStatement(e.Subquery) + e.Subquery = nil + } + e.Operator = "" + allExprPool.Put(e) + + case *ListExpression: + for i := range e.Values { + if e.Values[i] != nil { + workQueue = append(workQueue, e.Values[i]) + } + e.Values[i] = nil + } + e.Values = e.Values[:0] + listExprPool.Put(e) + + case *TupleExpression: + for i := range e.Expressions { + if e.Expressions[i] != nil { + workQueue = append(workQueue, e.Expressions[i]) + } + e.Expressions[i] = nil + } + e.Expressions = e.Expressions[:0] + tupleExprPool.Put(e) + + case *ArrayConstructorExpression: + for i := range e.Elements { + if e.Elements[i] != nil { + workQueue = append(workQueue, e.Elements[i]) + } + e.Elements[i] = nil + } + e.Elements = e.Elements[:0] + // Subquery is *SelectStatement — release through the + // statement pool, not a bare nil-assign (leak before fix). + if e.Subquery != nil { + PutSelectStatement(e.Subquery) + e.Subquery = nil + } + arrayConstructorPool.Put(e) + + case *UnaryExpression: + if e.Expr != nil { + workQueue = append(workQueue, e.Expr) + } + e.Expr = nil + e.Operator = 0 // UnaryOperator is int type + unaryExprPool.Put(e) + + case *ExtractExpression: + if e.Source != nil { + workQueue = append(workQueue, e.Source) + } + e.Field = "" + e.Source = nil + extractExprPool.Put(e) + + case *PositionExpression: + if e.Substr != nil { + workQueue = append(workQueue, e.Substr) + } + if e.Str != nil { + workQueue = append(workQueue, e.Str) + } + e.Substr = nil + e.Str = nil + positionExprPool.Put(e) + + case *SubstringExpression: + if e.Str != nil { + workQueue = append(workQueue, e.Str) + } + if e.Start != nil { + workQueue = append(workQueue, e.Start) + } + if e.Length != nil { + workQueue = append(workQueue, e.Length) + } + e.Str = nil + e.Start = nil + e.Length = nil + substringExprPool.Put(e) + + case *AliasedExpression: + if e.Expr != nil { + workQueue = append(workQueue, e.Expr) + } + e.Expr = nil + e.Alias = "" + aliasedExprPool.Put(e) + + // Default case - expression type not pooled, just ignore + default: + // Unknown expression type - no pool available + } + } + + // OVERFLOW DRAIN: if we hit the work-queue cap, there are still pooled + // nodes in workQueue that would otherwise leak. Fall back to a recursive + // drain, depth-limited to prevent stack overflow on deeply nested trees. + // Each recursive call starts its own fresh work queue of up to + // MaxWorkQueueSize, so the recursion depth is effectively + // ceil(total_nodes / MaxWorkQueueSize). MaxCleanupDepth = 100 bounds this + // at ~10_000_000 total nodes in an AST — far beyond any real SQL query. + if len(workQueue) > 0 { + atomic.AddUint64(&poolLeakCount, uint64(len(workQueue))) + if depth < MaxCleanupDepth { + for _, remaining := range workQueue { + putExpressionImpl(remaining, depth+1) + } + } + // If depth exceeded MaxCleanupDepth we accept the leak rather than + // blow the stack; poolLeakCount records the truncation for diagnostics. + } +} + +// GetFunctionCall gets a FunctionCall from the pool +func GetFunctionCall() *FunctionCall { + fc := functionCallPool.Get().(*FunctionCall) + fc.Arguments = fc.Arguments[:0] + return fc +} + +// PutFunctionCall returns a FunctionCall to the pool +func PutFunctionCall(fc *FunctionCall) { + if fc == nil { + return + } + for i := range fc.Arguments { + PutExpression(fc.Arguments[i]) + fc.Arguments[i] = nil + } + fc.Arguments = fc.Arguments[:0] + fc.Name = "" + fc.Over = nil + fc.Distinct = false + fc.Filter = nil + functionCallPool.Put(fc) +} + +// GetCaseExpression gets a CaseExpression from the pool +func GetCaseExpression() *CaseExpression { + ce := caseExprPool.Get().(*CaseExpression) + ce.WhenClauses = ce.WhenClauses[:0] + return ce +} + +// PutCaseExpression returns a CaseExpression to the pool +func PutCaseExpression(ce *CaseExpression) { + if ce == nil { + return + } + PutExpression(ce.Value) + ce.Value = nil + for i := range ce.WhenClauses { + PutExpression(ce.WhenClauses[i].Condition) + PutExpression(ce.WhenClauses[i].Result) + } + ce.WhenClauses = ce.WhenClauses[:0] + PutExpression(ce.ElseClause) + ce.ElseClause = nil + caseExprPool.Put(ce) +} + +// GetBetweenExpression gets a BetweenExpression from the pool +func GetBetweenExpression() *BetweenExpression { + return betweenExprPool.Get().(*BetweenExpression) +} + +// PutBetweenExpression returns a BetweenExpression to the pool +func PutBetweenExpression(be *BetweenExpression) { + if be == nil { + return + } + PutExpression(be.Expr) + PutExpression(be.Lower) + PutExpression(be.Upper) + be.Expr = nil + be.Lower = nil + be.Upper = nil + be.Not = false + betweenExprPool.Put(be) +} + +// GetInExpression gets an InExpression from the pool +func GetInExpression() *InExpression { + ie := inExprPool.Get().(*InExpression) + ie.List = ie.List[:0] + return ie +} + +// PutInExpression returns an InExpression to the pool +func PutInExpression(ie *InExpression) { + if ie == nil { + return + } + PutExpression(ie.Expr) + ie.Expr = nil + for i := range ie.List { + PutExpression(ie.List[i]) + ie.List[i] = nil + } + ie.List = ie.List[:0] + // Subquery (IN (SELECT ...)) is a Statement — release through the + // statement dispatcher, not a bare nil-assign. + if ie.Subquery != nil { + releaseStatement(ie.Subquery) + ie.Subquery = nil + } + ie.Not = false + inExprPool.Put(ie) +} + +// GetTupleExpression gets a TupleExpression from the pool +func GetTupleExpression() *TupleExpression { + te := tupleExprPool.Get().(*TupleExpression) + te.Expressions = te.Expressions[:0] + return te +} + +// PutTupleExpression returns a TupleExpression to the pool +func PutTupleExpression(te *TupleExpression) { + if te == nil { + return + } + for i := range te.Expressions { + PutExpression(te.Expressions[i]) + te.Expressions[i] = nil + } + te.Expressions = te.Expressions[:0] + tupleExprPool.Put(te) +} + +// GetArrayConstructor gets an ArrayConstructorExpression from the pool +func GetArrayConstructor() *ArrayConstructorExpression { + ac := arrayConstructorPool.Get().(*ArrayConstructorExpression) + ac.Elements = ac.Elements[:0] + ac.Subquery = nil + return ac +} + +// PutArrayConstructor returns an ArrayConstructorExpression to the pool +func PutArrayConstructor(ac *ArrayConstructorExpression) { + if ac == nil { + return + } + for i := range ac.Elements { + PutExpression(ac.Elements[i]) + ac.Elements[i] = nil + } + ac.Elements = ac.Elements[:0] + // Subquery is *SelectStatement — release through the statement pool. + if ac.Subquery != nil { + PutSelectStatement(ac.Subquery) + ac.Subquery = nil + } + arrayConstructorPool.Put(ac) +} + +// GetSubqueryExpression gets a SubqueryExpression from the pool +func GetSubqueryExpression() *SubqueryExpression { + return subqueryExprPool.Get().(*SubqueryExpression) +} + +// PutSubqueryExpression returns a SubqueryExpression to the pool +func PutSubqueryExpression(se *SubqueryExpression) { + if se == nil { + return + } + // Subquery is a Statement — release it through the statement dispatcher. + if se.Subquery != nil { + releaseStatement(se.Subquery) + se.Subquery = nil + } + subqueryExprPool.Put(se) +} + +// GetCastExpression gets a CastExpression from the pool +func GetCastExpression() *CastExpression { + return castExprPool.Get().(*CastExpression) +} + +// PutCastExpression returns a CastExpression to the pool +func PutCastExpression(ce *CastExpression) { + if ce == nil { + return + } + PutExpression(ce.Expr) + ce.Expr = nil + ce.Type = "" + castExprPool.Put(ce) +} + +// GetIntervalExpression gets an IntervalExpression from the pool +func GetIntervalExpression() *IntervalExpression { + return intervalExprPool.Get().(*IntervalExpression) +} + +// PutIntervalExpression returns an IntervalExpression to the pool +func PutIntervalExpression(ie *IntervalExpression) { + if ie == nil { + return + } + ie.Value = "" + intervalExprPool.Put(ie) +} + +// GetAliasedExpression retrieves an AliasedExpression from the pool +func GetAliasedExpression() *AliasedExpression { + return aliasedExprPool.Get().(*AliasedExpression) +} + +// PutAliasedExpression returns an AliasedExpression to the pool +func PutAliasedExpression(ae *AliasedExpression) { + if ae == nil { + return + } + PutExpression(ae.Expr) + ae.Expr = nil + ae.Alias = "" + aliasedExprPool.Put(ae) +} + +// GetArraySubscriptExpression gets an ArraySubscriptExpression from the pool +func GetArraySubscriptExpression() *ArraySubscriptExpression { + return arraySubscriptExprPool.Get().(*ArraySubscriptExpression) +} + +// PutArraySubscriptExpression returns an ArraySubscriptExpression to the pool +func PutArraySubscriptExpression(ase *ArraySubscriptExpression) { + if ase == nil { + return + } + // Clean up array expression + if ase.Array != nil { + PutExpression(ase.Array) + ase.Array = nil + } + // Clean up indices + for i := range ase.Indices { + if ase.Indices[i] != nil { + PutExpression(ase.Indices[i]) + } + } + ase.Indices = ase.Indices[:0] // Clear slice but keep capacity + arraySubscriptExprPool.Put(ase) +} + +// GetArraySliceExpression gets an ArraySliceExpression from the pool +func GetArraySliceExpression() *ArraySliceExpression { + return arraySliceExprPool.Get().(*ArraySliceExpression) +} + +// PutArraySliceExpression returns an ArraySliceExpression to the pool +func PutArraySliceExpression(ase *ArraySliceExpression) { + if ase == nil { + return + } + // Clean up array expression + if ase.Array != nil { + PutExpression(ase.Array) + ase.Array = nil + } + // Clean up start/end expressions + if ase.Start != nil { + PutExpression(ase.Start) + ase.Start = nil + } + if ase.End != nil { + PutExpression(ase.End) + ase.End = nil + } + arraySliceExprPool.Put(ase) +} diff --git a/pkg/sql/ast/pool_statement_release.go b/pkg/sql/ast/pool_statement_release.go new file mode 100644 index 00000000..a41b7c92 --- /dev/null +++ b/pkg/sql/ast/pool_statement_release.go @@ -0,0 +1,766 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +// GetInsertStatement gets an InsertStatement from the pool +func GetInsertStatement() *InsertStatement { + return insertStmtPool.Get().(*InsertStatement) +} + +// PutInsertStatement returns an InsertStatement to the pool. +// +// Releases every pooled Expression/Statement reachable from the InsertStatement: +// - With (CTEs + nested statements + scalar CTE expressions) +// - Columns +// - Output (SQL Server OUTPUT clause) +// - Values (all rows, all cells) +// - Query (INSERT ... SELECT — the nested QueryExpression) +// - Returning +// - OnConflict.Target, OnConflict.Action.DoUpdate (Column, Value), OnConflict.Action.Where +// - OnDuplicateKey.Updates (Column, Value) +func PutInsertStatement(stmt *InsertStatement) { + if stmt == nil { + return + } + + // ── WITH clause / CTEs ──────────────────────────────────────────── + if stmt.With != nil { + for _, cte := range stmt.With.CTEs { + if cte == nil { + continue + } + releaseStatement(cte.Statement) + cte.Statement = nil + PutExpression(cte.ScalarExpr) + cte.ScalarExpr = nil + } + stmt.With.CTEs = nil + stmt.With = nil + } + + // ── Column list ─────────────────────────────────────────────────── + for i := range stmt.Columns { + PutExpression(stmt.Columns[i]) + stmt.Columns[i] = nil + } + stmt.Columns = stmt.Columns[:0] + + // ── OUTPUT clause (SQL Server) ──────────────────────────────────── + for i := range stmt.Output { + PutExpression(stmt.Output[i]) + stmt.Output[i] = nil + } + stmt.Output = stmt.Output[:0] + + // ── VALUES (multi-row) ──────────────────────────────────────────── + for i := range stmt.Values { + for j := range stmt.Values[i] { + PutExpression(stmt.Values[i][j]) + stmt.Values[i][j] = nil + } + stmt.Values[i] = stmt.Values[i][:0] + } + stmt.Values = stmt.Values[:0] + + // ── Query (INSERT ... SELECT) ───────────────────────────────────── + if stmt.Query != nil { + // Query is a QueryExpression (Statement); dispatch via releaseStatement. + releaseStatement(stmt.Query) + stmt.Query = nil + } + + // ── RETURNING ────────────────────────────────────────────────────── + for i := range stmt.Returning { + PutExpression(stmt.Returning[i]) + stmt.Returning[i] = nil + } + stmt.Returning = stmt.Returning[:0] + + // ── ON CONFLICT (PostgreSQL) ────────────────────────────────────── + if stmt.OnConflict != nil { + for i := range stmt.OnConflict.Target { + PutExpression(stmt.OnConflict.Target[i]) + stmt.OnConflict.Target[i] = nil + } + stmt.OnConflict.Target = nil + for i := range stmt.OnConflict.Action.DoUpdate { + PutExpression(stmt.OnConflict.Action.DoUpdate[i].Column) + PutExpression(stmt.OnConflict.Action.DoUpdate[i].Value) + stmt.OnConflict.Action.DoUpdate[i].Column = nil + stmt.OnConflict.Action.DoUpdate[i].Value = nil + } + stmt.OnConflict.Action.DoUpdate = nil + PutExpression(stmt.OnConflict.Action.Where) + stmt.OnConflict.Action.Where = nil + stmt.OnConflict = nil + } + + // ── ON DUPLICATE KEY UPDATE (MySQL) ─────────────────────────────── + if stmt.OnDuplicateKey != nil { + for i := range stmt.OnDuplicateKey.Updates { + PutExpression(stmt.OnDuplicateKey.Updates[i].Column) + PutExpression(stmt.OnDuplicateKey.Updates[i].Value) + stmt.OnDuplicateKey.Updates[i].Column = nil + stmt.OnDuplicateKey.Updates[i].Value = nil + } + stmt.OnDuplicateKey.Updates = nil + stmt.OnDuplicateKey = nil + } + + stmt.TableName = "" + + // Return to pool + insertStmtPool.Put(stmt) +} + +// GetUpdateStatement gets an UpdateStatement from the pool +func GetUpdateStatement() *UpdateStatement { + return updateStmtPool.Get().(*UpdateStatement) +} + +// PutUpdateStatement returns an UpdateStatement to the pool. +// +// Releases every pooled Expression/Statement reachable from the UpdateStatement: +// - With (CTEs + nested statements + scalar CTE expressions) +// - Assignments (Column, Value) +// - From (TableReference.Subquery, TableFunc, Pivot, MatchRecognize, TimeTravel, ForSystemTime) +// - Where +// - Returning +func PutUpdateStatement(stmt *UpdateStatement) { + if stmt == nil { + return + } + + // ── WITH clause / CTEs ──────────────────────────────────────────── + if stmt.With != nil { + for _, cte := range stmt.With.CTEs { + if cte == nil { + continue + } + releaseStatement(cte.Statement) + cte.Statement = nil + PutExpression(cte.ScalarExpr) + cte.ScalarExpr = nil + } + stmt.With.CTEs = nil + stmt.With = nil + } + + // ── SET assignments ─────────────────────────────────────────────── + for i := range stmt.Assignments { + PutExpression(stmt.Assignments[i].Column) + PutExpression(stmt.Assignments[i].Value) + stmt.Assignments[i].Column = nil + stmt.Assignments[i].Value = nil + } + stmt.Assignments = stmt.Assignments[:0] + + // ── FROM table references ───────────────────────────────────────── + for i := range stmt.From { + releaseTableReference(&stmt.From[i]) + } + stmt.From = stmt.From[:0] + + // ── WHERE ────────────────────────────────────────────────────────── + PutExpression(stmt.Where) + stmt.Where = nil + + // ── RETURNING ────────────────────────────────────────────────────── + for i := range stmt.Returning { + PutExpression(stmt.Returning[i]) + stmt.Returning[i] = nil + } + stmt.Returning = stmt.Returning[:0] + + // ── Scalars ──────────────────────────────────────────────────────── + stmt.TableName = "" + stmt.Alias = "" + + // Return to pool + updateStmtPool.Put(stmt) +} + +// GetDeleteStatement gets a DeleteStatement from the pool +func GetDeleteStatement() *DeleteStatement { + return deleteStmtPool.Get().(*DeleteStatement) +} + +// PutDeleteStatement returns a DeleteStatement to the pool. +// +// Releases every pooled Expression/Statement reachable from the DeleteStatement: +// - With (CTEs + nested statements + scalar CTE expressions) +// - Using (TableReference subqueries, TableFunc, Pivot, MatchRecognize, TimeTravel, ForSystemTime) +// - Where +// - Returning +func PutDeleteStatement(stmt *DeleteStatement) { + if stmt == nil { + return + } + + // ── WITH clause / CTEs ──────────────────────────────────────────── + if stmt.With != nil { + for _, cte := range stmt.With.CTEs { + if cte == nil { + continue + } + releaseStatement(cte.Statement) + cte.Statement = nil + PutExpression(cte.ScalarExpr) + cte.ScalarExpr = nil + } + stmt.With.CTEs = nil + stmt.With = nil + } + + // ── USING table references (PostgreSQL) ─────────────────────────── + for i := range stmt.Using { + releaseTableReference(&stmt.Using[i]) + } + stmt.Using = stmt.Using[:0] + + // ── WHERE ────────────────────────────────────────────────────────── + PutExpression(stmt.Where) + stmt.Where = nil + + // ── RETURNING ────────────────────────────────────────────────────── + for i := range stmt.Returning { + PutExpression(stmt.Returning[i]) + stmt.Returning[i] = nil + } + stmt.Returning = stmt.Returning[:0] + + // ── Scalars ──────────────────────────────────────────────────────── + stmt.TableName = "" + stmt.Alias = "" + + // Return to pool + deleteStmtPool.Put(stmt) +} + +// GetSelectStatement gets a SelectStatement from the pool +func GetSelectStatement() *SelectStatement { + stmt := selectStmtPool.Get().(*SelectStatement) + stmt.Columns = stmt.Columns[:0] + stmt.OrderBy = stmt.OrderBy[:0] + return stmt +} + +// PutSelectStatement returns a SelectStatement to the pool. +// +// Uses iterative cleanup via PutExpression to handle deeply nested expressions. +// This function MUST release every pooled Expression/Node reachable from the +// SelectStatement; missing fields cause silent pool leaks that defeat the +// 60-80% memory reduction target and degrade hit-rate below 95%. +// +// Coverage (v1.14.0+ — comprehensive audit): +// - With (CTEs + their nested statements + scalar CTE expressions) +// - Top.Count +// - DistinctOnColumns +// - Columns +// - From (TableReference.Subquery, TableFunc, Pivot.AggregateFunction, MatchRecognize) +// - Joins (Left/Right TableRefs, Condition) +// - ArrayJoin (element Exprs) +// - PrewhereClause +// - Sample (no Expressions, but zeroed for hygiene) +// - Where +// - GroupBy +// - Having +// - Qualify +// - StartWith / ConnectBy.Condition +// - Windows (PartitionBy + OrderBy expressions + FrameClause bounds) +// - OrderBy +// - Fetch / For (no Expression children, just zero) +// - Limit / Offset (*int — no release needed) +func PutSelectStatement(stmt *SelectStatement) { + if stmt == nil { + return + } + + // ── WITH clause / CTEs ──────────────────────────────────────────── + if stmt.With != nil { + for _, cte := range stmt.With.CTEs { + if cte == nil { + continue + } + releaseStatement(cte.Statement) + cte.Statement = nil + PutExpression(cte.ScalarExpr) + cte.ScalarExpr = nil + } + stmt.With.CTEs = nil + stmt.With = nil + } + + // ── TOP clause ───────────────────────────────────────────────────── + if stmt.Top != nil { + PutExpression(stmt.Top.Count) + stmt.Top.Count = nil + stmt.Top = nil + } + + // ── DISTINCT ON columns ──────────────────────────────────────────── + for i := range stmt.DistinctOnColumns { + PutExpression(stmt.DistinctOnColumns[i]) + stmt.DistinctOnColumns[i] = nil + } + stmt.DistinctOnColumns = stmt.DistinctOnColumns[:0] + + // ── SELECT list columns ──────────────────────────────────────────── + for i := range stmt.Columns { + PutExpression(stmt.Columns[i]) + stmt.Columns[i] = nil + } + stmt.Columns = stmt.Columns[:0] + + // ── FROM table references (Subquery, TableFunc, Pivot, MatchRecognize) ─ + for i := range stmt.From { + releaseTableReference(&stmt.From[i]) + } + stmt.From = stmt.From[:0] + + // ── JOINs ────────────────────────────────────────────────────────── + for i := range stmt.Joins { + releaseTableReference(&stmt.Joins[i].Left) + releaseTableReference(&stmt.Joins[i].Right) + PutExpression(stmt.Joins[i].Condition) + stmt.Joins[i].Condition = nil + stmt.Joins[i].Type = "" + } + stmt.Joins = stmt.Joins[:0] + + // ── ARRAY JOIN (ClickHouse) ──────────────────────────────────────── + if stmt.ArrayJoin != nil { + for i := range stmt.ArrayJoin.Elements { + PutExpression(stmt.ArrayJoin.Elements[i].Expr) + stmt.ArrayJoin.Elements[i].Expr = nil + stmt.ArrayJoin.Elements[i].Alias = "" + } + stmt.ArrayJoin.Elements = nil + stmt.ArrayJoin = nil + } + + // ── PREWHERE / WHERE / HAVING / QUALIFY / START WITH ─────────────── + PutExpression(stmt.PrewhereClause) + stmt.PrewhereClause = nil + PutExpression(stmt.Where) + stmt.Where = nil + PutExpression(stmt.Having) + stmt.Having = nil + PutExpression(stmt.Qualify) + stmt.Qualify = nil + PutExpression(stmt.StartWith) + stmt.StartWith = nil + + // ── CONNECT BY ───────────────────────────────────────────────────── + if stmt.ConnectBy != nil { + PutExpression(stmt.ConnectBy.Condition) + stmt.ConnectBy.Condition = nil + stmt.ConnectBy = nil + } + + // ── SAMPLE (no expression children, just drop) ───────────────────── + stmt.Sample = nil + + // ── GROUP BY ─────────────────────────────────────────────────────── + for i := range stmt.GroupBy { + PutExpression(stmt.GroupBy[i]) + stmt.GroupBy[i] = nil + } + stmt.GroupBy = stmt.GroupBy[:0] + + // ── WINDOWS (PartitionBy, OrderBy, FrameClause bounds) ───────────── + for i := range stmt.Windows { + w := &stmt.Windows[i] + for j := range w.PartitionBy { + PutExpression(w.PartitionBy[j]) + w.PartitionBy[j] = nil + } + w.PartitionBy = w.PartitionBy[:0] + for j := range w.OrderBy { + PutExpression(w.OrderBy[j].Expression) + w.OrderBy[j].Expression = nil + } + w.OrderBy = w.OrderBy[:0] + if w.FrameClause != nil { + PutExpression(w.FrameClause.Start.Value) + w.FrameClause.Start.Value = nil + if w.FrameClause.End != nil { + PutExpression(w.FrameClause.End.Value) + w.FrameClause.End.Value = nil + w.FrameClause.End = nil + } + w.FrameClause = nil + } + w.Name = "" + } + stmt.Windows = stmt.Windows[:0] + + // ── ORDER BY ─────────────────────────────────────────────────────── + for i := range stmt.OrderBy { + PutExpression(stmt.OrderBy[i].Expression) + stmt.OrderBy[i].Expression = nil + } + stmt.OrderBy = stmt.OrderBy[:0] + + // ── LIMIT / OFFSET (*int - no Expression) ────────────────────────── + stmt.Limit = nil + stmt.Offset = nil + + // ── FETCH / FOR (no Expression children) ─────────────────────────── + stmt.Fetch = nil + stmt.For = nil + + // ── Scalars ──────────────────────────────────────────────────────── + stmt.TableName = "" + stmt.Distinct = false + + // Return to pool + selectStmtPool.Put(stmt) +} + +// releaseTableReference releases all pooled Expression/Statement references +// reachable from a TableReference. Zero-copies the TableReference back to a +// clean state suitable for pool reuse. +func releaseTableReference(tr *TableReference) { + if tr == nil { + return + } + // Subquery is itself a *SelectStatement — recurse through the statement + // dispatcher to release every nested pool reference. + if tr.Subquery != nil { + PutSelectStatement(tr.Subquery) + tr.Subquery = nil + } + // TableFunc is a *FunctionCall — release as expression. + if tr.TableFunc != nil { + PutExpression(tr.TableFunc) + tr.TableFunc = nil + } + // Pivot.AggregateFunction is an Expression. + if tr.Pivot != nil { + PutExpression(tr.Pivot.AggregateFunction) + tr.Pivot.AggregateFunction = nil + tr.Pivot = nil + } + // Unpivot holds only strings — drop the struct. + tr.Unpivot = nil + // MatchRecognize carries PartitionBy / OrderBy / Measures / Definitions. + if tr.MatchRecognize != nil { + mr := tr.MatchRecognize + for i := range mr.PartitionBy { + PutExpression(mr.PartitionBy[i]) + mr.PartitionBy[i] = nil + } + mr.PartitionBy = mr.PartitionBy[:0] + for i := range mr.OrderBy { + PutExpression(mr.OrderBy[i].Expression) + mr.OrderBy[i].Expression = nil + } + mr.OrderBy = mr.OrderBy[:0] + for i := range mr.Measures { + PutExpression(mr.Measures[i].Expr) + mr.Measures[i].Expr = nil + mr.Measures[i].Alias = "" + } + mr.Measures = mr.Measures[:0] + for i := range mr.Definitions { + PutExpression(mr.Definitions[i].Condition) + mr.Definitions[i].Condition = nil + mr.Definitions[i].Name = "" + } + mr.Definitions = mr.Definitions[:0] + tr.MatchRecognize = nil + } + // TimeTravel carries Named map of Expressions + Chained clauses. + if tr.TimeTravel != nil { + releaseTimeTravelClause(tr.TimeTravel) + tr.TimeTravel = nil + } + // ForSystemTime carries Point/Start/End expressions. + if tr.ForSystemTime != nil { + PutExpression(tr.ForSystemTime.Point) + PutExpression(tr.ForSystemTime.Start) + PutExpression(tr.ForSystemTime.End) + tr.ForSystemTime.Point = nil + tr.ForSystemTime.Start = nil + tr.ForSystemTime.End = nil + tr.ForSystemTime = nil + } + tr.Name = "" + tr.Alias = "" + tr.Lateral = false + tr.Final = false + tr.TableHints = nil +} + +// ============================================================ +// DDL Statement Pool Functions +// ============================================================ + +// GetCreateTableStatement gets a CreateTableStatement from the pool. +func GetCreateTableStatement() *CreateTableStatement { + stmt := createTableStmtPool.Get().(*CreateTableStatement) + stmt.Columns = stmt.Columns[:0] + stmt.Constraints = stmt.Constraints[:0] + stmt.Inherits = stmt.Inherits[:0] + stmt.Options = stmt.Options[:0] + return stmt +} + +// PutCreateTableStatement returns a CreateTableStatement to the pool. +// It recursively releases any nested expressions (column defaults, check constraints, etc.). +func PutCreateTableStatement(stmt *CreateTableStatement) { + if stmt == nil { + return + } + + // Release expressions embedded in column definitions + for i := range stmt.Columns { + for j := range stmt.Columns[i].Constraints { + PutExpression(stmt.Columns[i].Constraints[j].Default) + PutExpression(stmt.Columns[i].Constraints[j].Check) + stmt.Columns[i].Constraints[j].Default = nil + stmt.Columns[i].Constraints[j].Check = nil + stmt.Columns[i].Constraints[j].References = nil + } + stmt.Columns[i].Constraints = stmt.Columns[i].Constraints[:0] + stmt.Columns[i].Name = "" + stmt.Columns[i].Type = "" + } + stmt.Columns = stmt.Columns[:0] + + // Release expressions in table constraints + for i := range stmt.Constraints { + PutExpression(stmt.Constraints[i].Check) + stmt.Constraints[i].Check = nil + stmt.Constraints[i].References = nil + stmt.Constraints[i].Name = "" + stmt.Constraints[i].Type = "" + stmt.Constraints[i].Columns = stmt.Constraints[i].Columns[:0] + } + stmt.Constraints = stmt.Constraints[:0] + + // Release expressions in PartitionBy + if stmt.PartitionBy != nil { + for i, expr := range stmt.PartitionBy.Boundary { + PutExpression(expr) + stmt.PartitionBy.Boundary[i] = nil + } + stmt.PartitionBy.Boundary = stmt.PartitionBy.Boundary[:0] + stmt.PartitionBy.Columns = stmt.PartitionBy.Columns[:0] + stmt.PartitionBy.Type = "" + stmt.PartitionBy = nil + } + + // Release expressions in PartitionDefinitions + for i := range stmt.Partitions { + for j, expr := range stmt.Partitions[i].Values { + PutExpression(expr) + stmt.Partitions[i].Values[j] = nil + } + PutExpression(stmt.Partitions[i].LessThan) + PutExpression(stmt.Partitions[i].From) + PutExpression(stmt.Partitions[i].To) + for j, expr := range stmt.Partitions[i].InValues { + PutExpression(expr) + stmt.Partitions[i].InValues[j] = nil + } + stmt.Partitions[i].Values = stmt.Partitions[i].Values[:0] + stmt.Partitions[i].InValues = stmt.Partitions[i].InValues[:0] + stmt.Partitions[i].LessThan = nil + stmt.Partitions[i].From = nil + stmt.Partitions[i].To = nil + stmt.Partitions[i].Name = "" + stmt.Partitions[i].Type = "" + stmt.Partitions[i].Tablespace = "" + } + stmt.Partitions = stmt.Partitions[:0] + + stmt.Inherits = stmt.Inherits[:0] + + for i := range stmt.Options { + stmt.Options[i].Name = "" + stmt.Options[i].Value = "" + } + stmt.Options = stmt.Options[:0] + + // Reset scalar fields + stmt.IfNotExists = false + stmt.Temporary = false + stmt.Name = "" + + createTableStmtPool.Put(stmt) +} + +// GetAlterTableStatement gets an AlterTableStatement from the pool. +func GetAlterTableStatement() *AlterTableStatement { + stmt := alterTableStmtPool.Get().(*AlterTableStatement) + stmt.Actions = stmt.Actions[:0] + return stmt +} + +// PutAlterTableStatement returns an AlterTableStatement to the pool. +// It recursively releases nested expressions in column definitions and constraints. +func PutAlterTableStatement(stmt *AlterTableStatement) { + if stmt == nil { + return + } + + for i := range stmt.Actions { + // Release nested ColumnDef expressions + if stmt.Actions[i].ColumnDef != nil { + for j := range stmt.Actions[i].ColumnDef.Constraints { + PutExpression(stmt.Actions[i].ColumnDef.Constraints[j].Default) + PutExpression(stmt.Actions[i].ColumnDef.Constraints[j].Check) + stmt.Actions[i].ColumnDef.Constraints[j].Default = nil + stmt.Actions[i].ColumnDef.Constraints[j].Check = nil + stmt.Actions[i].ColumnDef.Constraints[j].References = nil + } + stmt.Actions[i].ColumnDef.Constraints = stmt.Actions[i].ColumnDef.Constraints[:0] + stmt.Actions[i].ColumnDef = nil + } + // Release nested TableConstraint expressions + if stmt.Actions[i].Constraint != nil { + PutExpression(stmt.Actions[i].Constraint.Check) + stmt.Actions[i].Constraint.Check = nil + stmt.Actions[i].Constraint = nil + } + stmt.Actions[i].Type = "" + stmt.Actions[i].ColumnName = "" + } + stmt.Actions = stmt.Actions[:0] + stmt.Table = "" + + alterTableStmtPool.Put(stmt) +} + +// GetMergeStatement gets a MergeStatement from the pool. +func GetMergeStatement() *MergeStatement { + stmt := mergeStmtPool.Get().(*MergeStatement) + stmt.WhenClauses = stmt.WhenClauses[:0] + stmt.Output = stmt.Output[:0] + return stmt +} + +// PutMergeStatement returns a MergeStatement to the pool. +// It recursively releases nested expressions in WHEN clauses and OUTPUT. +func PutMergeStatement(stmt *MergeStatement) { + if stmt == nil { + return + } + + // Release OnCondition + PutExpression(stmt.OnCondition) + stmt.OnCondition = nil + + // Release WHEN clause expressions + for i := range stmt.WhenClauses { + if stmt.WhenClauses[i] == nil { + continue + } + PutExpression(stmt.WhenClauses[i].Condition) + stmt.WhenClauses[i].Condition = nil + if stmt.WhenClauses[i].Action != nil { + for j := range stmt.WhenClauses[i].Action.SetClauses { + PutExpression(stmt.WhenClauses[i].Action.SetClauses[j].Value) + stmt.WhenClauses[i].Action.SetClauses[j].Value = nil + stmt.WhenClauses[i].Action.SetClauses[j].Column = "" + } + stmt.WhenClauses[i].Action.SetClauses = stmt.WhenClauses[i].Action.SetClauses[:0] + for j, expr := range stmt.WhenClauses[i].Action.Values { + PutExpression(expr) + stmt.WhenClauses[i].Action.Values[j] = nil + } + stmt.WhenClauses[i].Action.Values = stmt.WhenClauses[i].Action.Values[:0] + stmt.WhenClauses[i].Action.Columns = stmt.WhenClauses[i].Action.Columns[:0] + stmt.WhenClauses[i].Action.ActionType = "" + stmt.WhenClauses[i].Action.DefaultValues = false + stmt.WhenClauses[i].Action = nil + } + stmt.WhenClauses[i].Type = "" + stmt.WhenClauses[i] = nil + } + stmt.WhenClauses = stmt.WhenClauses[:0] + + // Release OUTPUT expressions + for i, expr := range stmt.Output { + PutExpression(expr) + stmt.Output[i] = nil + } + stmt.Output = stmt.Output[:0] + + // Reset TargetTable / SourceTable (value types - zero them out) + stmt.TargetTable = TableReference{} + stmt.SourceTable = TableReference{} + stmt.TargetAlias = "" + stmt.SourceAlias = "" + + mergeStmtPool.Put(stmt) +} + +// GetReplaceStatement gets a ReplaceStatement from the pool. +func GetReplaceStatement() *ReplaceStatement { + stmt := replaceStmtPool.Get().(*ReplaceStatement) + stmt.Columns = stmt.Columns[:0] + stmt.Values = stmt.Values[:0] + return stmt +} + +// PutReplaceStatement returns a ReplaceStatement to the pool. +// It recursively releases nested column and value expressions. +func PutReplaceStatement(stmt *ReplaceStatement) { + if stmt == nil { + return + } + + for i := range stmt.Columns { + PutExpression(stmt.Columns[i]) + stmt.Columns[i] = nil + } + stmt.Columns = stmt.Columns[:0] + + for i := range stmt.Values { + for j := range stmt.Values[i] { + PutExpression(stmt.Values[i][j]) + stmt.Values[i][j] = nil + } + stmt.Values[i] = stmt.Values[i][:0] + } + stmt.Values = stmt.Values[:0] + + stmt.TableName = "" + + replaceStmtPool.Put(stmt) +} + +// releaseTimeTravelClause walks a TimeTravelClause graph, releasing every +// Expression stored in Named maps and every chained sub-clause. Chained +// cycles are not possible because the parser builds a tree, but we still +// guard against nil to be defensive. +func releaseTimeTravelClause(c *TimeTravelClause) { + if c == nil { + return + } + for k, v := range c.Named { + PutExpression(v) + delete(c.Named, k) + } + for _, ch := range c.Chained { + releaseTimeTravelClause(ch) + } + c.Chained = nil + c.Named = nil + c.Kind = "" +}