diff --git a/pkg/sql/parser/performance_regression_norace.go b/pkg/sql/parser/performance_regression_norace.go index 7eb3ee7a..8ec4a064 100644 --- a/pkg/sql/parser/performance_regression_norace.go +++ b/pkg/sql/parser/performance_regression_norace.go @@ -1,5 +1,4 @@ //go:build !race -// +build !race package parser diff --git a/pkg/sql/parser/performance_regression_race.go b/pkg/sql/parser/performance_regression_race.go index f53c8545..8de2c614 100644 --- a/pkg/sql/parser/performance_regression_race.go +++ b/pkg/sql/parser/performance_regression_race.go @@ -1,5 +1,4 @@ //go:build race -// +build race package parser diff --git a/pkg/sql/parser/recovery.go b/pkg/sql/parser/recovery.go new file mode 100644 index 00000000..7eb76202 --- /dev/null +++ b/pkg/sql/parser/recovery.go @@ -0,0 +1,122 @@ +package parser + +import ( + "fmt" + + "github.com/ajitpratap0/GoSQLX/pkg/models" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/token" +) + +// ParseError represents a parse error with position information. +type ParseError struct { + Msg string + TokenIdx int + Line int + Column int + TokenType string + Literal string +} + +func (e *ParseError) Error() string { + if e.Line > 0 { + return fmt.Sprintf("parse error at line %d, column %d (token %d): %s", e.Line, e.Column, e.TokenIdx, e.Msg) + } + return fmt.Sprintf("parse error at token %d: %s", e.TokenIdx, e.Msg) +} + +// isStatementStartingKeyword checks if the current token is a statement-starting keyword. +func (p *Parser) isStatementStartingKeyword() bool { + if p.currentToken.ModelType != modelTypeUnset { + switch p.currentToken.ModelType { + case models.TokenTypeSelect, models.TokenTypeInsert, models.TokenTypeUpdate, + models.TokenTypeDelete, models.TokenTypeCreate, models.TokenTypeAlter, + models.TokenTypeDrop, models.TokenTypeWith, models.TokenTypeMerge, + models.TokenTypeRefresh, models.TokenTypeTruncate: + return true + } + } + // Fallback: string comparison for tokens without ModelType (e.g., tests) + switch string(p.currentToken.Type) { + case "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", + "WITH", "MERGE", "REFRESH", "TRUNCATE": + return true + } + return false +} + +// synchronize advances the parser past the current error to a synchronization point: +// either past a semicolon or to a statement-starting keyword. +func (p *Parser) synchronize() { + for p.currentPos < len(p.tokens) && !p.isType(models.TokenTypeEOF) { + // If we hit a semicolon, consume it and stop + if p.isType(models.TokenTypeSemicolon) { + p.advance() + return + } + // If we hit a statement-starting keyword, stop (don't consume it) + if p.isStatementStartingKeyword() { + return + } + p.advance() + } +} + +// ParseWithRecovery parses a token stream, recovering from errors to collect multiple +// errors and return a partial AST with successfully parsed statements. +// +// Unlike Parse(), which stops at the first error, this method uses synchronization +// tokens (semicolons and statement-starting keywords) to skip past errors and +// continue parsing subsequent statements. +// +// Parameters: +// - tokens: Slice of parser tokens to parse +// +// Returns: +// - []ast.Statement: Successfully parsed statements (may be empty) +// - []error: All parse errors encountered (each includes position information) +func (p *Parser) ParseWithRecovery(tokens []token.Token) ([]ast.Statement, []error) { + p.tokens = tokens + p.currentPos = 0 + if len(tokens) > 0 { + p.currentToken = tokens[0] + } + + var statements []ast.Statement + var errors []error + + for p.currentPos < len(tokens) && !p.isType(models.TokenTypeEOF) { + // Skip semicolons between statements + if p.isType(models.TokenTypeSemicolon) { + p.advance() + continue + } + + savedPos := p.currentPos + stmt, err := p.parseStatement() + if err != nil { + // Create a ParseError with position info + loc := p.currentLocation() + pe := &ParseError{ + Msg: err.Error(), + TokenIdx: savedPos, + Line: loc.Line, + Column: loc.Column, + } + if savedPos < len(tokens) { + pe.TokenType = string(tokens[savedPos].Type) + pe.Literal = tokens[savedPos].Literal + } + errors = append(errors, pe) + p.synchronize() + } else { + statements = append(statements, stmt) + // Optionally consume semicolon after statement + if p.isType(models.TokenTypeSemicolon) { + p.advance() + } + } + } + + return statements, errors +} diff --git a/pkg/sql/parser/recovery_multi_error_test.go b/pkg/sql/parser/recovery_multi_error_test.go new file mode 100644 index 00000000..085b3946 --- /dev/null +++ b/pkg/sql/parser/recovery_multi_error_test.go @@ -0,0 +1,177 @@ +package parser + +import ( + "testing" + + "github.com/ajitpratap0/GoSQLX/pkg/sql/token" +) + +func eof() token.Token { + return token.Token{Type: "EOF", Literal: ""} +} + +func semi() token.Token { + return token.Token{Type: ";", Literal: ";"} +} + +func tok(typ, lit string) token.Token { + return token.Token{Type: token.Type(typ), Literal: lit} +} + +// TestParseWithRecovery_MultipleErrors tests that multiple syntax errors are all reported. +func TestParseWithRecovery_MultipleErrors(t *testing.T) { + // "INVALID1 foo; INVALID2 bar;" + tokens := []token.Token{ + tok("IDENT", "INVALID1"), tok("IDENT", "foo"), semi(), + tok("IDENT", "INVALID2"), tok("IDENT", "bar"), semi(), + eof(), + } + p := NewParser() + stmts, errs := p.ParseWithRecovery(tokens) + if len(stmts) != 0 { + t.Errorf("expected 0 statements, got %d", len(stmts)) + } + if len(errs) < 2 { + t.Errorf("expected at least 2 errors, got %d", len(errs)) + } + // Each error should be a *ParseError with position info + for i, err := range errs { + if _, ok := err.(*ParseError); !ok { + t.Errorf("error %d is not a *ParseError: %T", i, err) + } + } +} + +// TestParseWithRecovery_FirstValidSecondInvalid tests partial AST with valid+invalid mix. +func TestParseWithRecovery_FirstValidSecondInvalid(t *testing.T) { + // "SELECT * FROM users; INVALID foo;" + tokens := []token.Token{ + tok("SELECT", "SELECT"), tok("*", "*"), tok("FROM", "FROM"), tok("IDENT", "users"), semi(), + tok("IDENT", "INVALID"), tok("IDENT", "foo"), semi(), + eof(), + } + p := NewParser() + stmts, errs := p.ParseWithRecovery(tokens) + if len(stmts) != 1 { + t.Errorf("expected 1 statement, got %d", len(stmts)) + } + if len(errs) != 1 { + t.Errorf("expected 1 error, got %d", len(errs)) + } +} + +// TestParseWithRecovery_AllInvalid tests that all-invalid input returns empty AST + multiple errors. +func TestParseWithRecovery_AllInvalid(t *testing.T) { + tokens := []token.Token{ + tok("IDENT", "BAD1"), semi(), + tok("IDENT", "BAD2"), semi(), + tok("IDENT", "BAD3"), semi(), + eof(), + } + p := NewParser() + stmts, errs := p.ParseWithRecovery(tokens) + if len(stmts) != 0 { + t.Errorf("expected 0 statements, got %d", len(stmts)) + } + if len(errs) != 3 { + t.Errorf("expected 3 errors, got %d", len(errs)) + } +} + +// TestParseWithRecovery_UnclosedParen tests recovery after unclosed parenthesis. +func TestParseWithRecovery_UnclosedParen(t *testing.T) { + // "SELECT (1 + ; SELECT * FROM users;" + tokens := []token.Token{ + tok("SELECT", "SELECT"), tok("(", "("), tok("INT", "1"), tok("+", "+"), semi(), + tok("SELECT", "SELECT"), tok("*", "*"), tok("FROM", "FROM"), tok("IDENT", "users"), semi(), + eof(), + } + p := NewParser() + stmts, errs := p.ParseWithRecovery(tokens) + if len(errs) < 1 { + t.Errorf("expected at least 1 error, got %d", len(errs)) + } + if len(stmts) < 1 { + t.Errorf("expected at least 1 successfully parsed statement, got %d", len(stmts)) + } +} + +// TestParseWithRecovery_InvalidExpression tests recovery after invalid expression. +func TestParseWithRecovery_InvalidExpression(t *testing.T) { + // "SELECT FROM; SELECT 1;" + // First SELECT has no columns (invalid), second is valid + tokens := []token.Token{ + tok("SELECT", "SELECT"), tok("FROM", "FROM"), semi(), + tok("SELECT", "SELECT"), tok("INT", "1"), semi(), + eof(), + } + p := NewParser() + stmts, errs := p.ParseWithRecovery(tokens) + // The first SELECT FROM might parse differently depending on parser internals, + // but we should get at least one statement or error + if len(stmts)+len(errs) < 2 { + t.Errorf("expected at least 2 total statements+errors, got stmts=%d errs=%d", len(stmts), len(errs)) + } +} + +// TestParseWithRecovery_RecoveryToKeyword tests recovery skipping to next statement keyword. +func TestParseWithRecovery_RecoveryToKeyword(t *testing.T) { + // "INVALID foo bar SELECT * FROM users;" + // No semicolon after invalid part, should recover at SELECT keyword + tokens := []token.Token{ + tok("IDENT", "INVALID"), tok("IDENT", "foo"), tok("IDENT", "bar"), + tok("SELECT", "SELECT"), tok("*", "*"), tok("FROM", "FROM"), tok("IDENT", "users"), semi(), + eof(), + } + p := NewParser() + stmts, errs := p.ParseWithRecovery(tokens) + if len(errs) != 1 { + t.Errorf("expected 1 error, got %d", len(errs)) + } + if len(stmts) != 1 { + t.Errorf("expected 1 statement, got %d", len(stmts)) + } +} + +// TestParseWithRecovery_EmptyInput tests empty token stream. +func TestParseWithRecovery_EmptyInput(t *testing.T) { + tokens := []token.Token{eof()} + p := NewParser() + stmts, errs := p.ParseWithRecovery(tokens) + if len(stmts) != 0 { + t.Errorf("expected 0 statements, got %d", len(stmts)) + } + if len(errs) != 0 { + t.Errorf("expected 0 errors, got %d", len(errs)) + } +} + +// TestParseWithRecovery_AllValid tests that all-valid input returns all statements. +func TestParseWithRecovery_AllValid(t *testing.T) { + tokens := []token.Token{ + tok("SELECT", "SELECT"), tok("INT", "1"), semi(), + tok("SELECT", "SELECT"), tok("INT", "2"), semi(), + eof(), + } + p := NewParser() + stmts, errs := p.ParseWithRecovery(tokens) + if len(stmts) != 2 { + t.Errorf("expected 2 statements, got %d", len(stmts)) + } + if len(errs) != 0 { + t.Errorf("expected 0 errors, got %d", len(errs)) + } +} + +// TestParseError_ErrorMessage tests ParseError formatting. +func TestParseError_ErrorMessage(t *testing.T) { + e := &ParseError{Msg: "unexpected token", TokenIdx: 5} + if e.Error() != "parse error at token 5: unexpected token" { + t.Errorf("unexpected error message: %s", e.Error()) + } + + e2 := &ParseError{Msg: "bad syntax", TokenIdx: 3, Line: 2, Column: 10} + if e2.Error() != "parse error at line 2, column 10 (token 3): bad syntax" { + t.Errorf("unexpected error message: %s", e2.Error()) + } +} diff --git a/pkg/sql/tokenizer/norace.go b/pkg/sql/tokenizer/norace.go index 03f8f581..94e3829a 100644 --- a/pkg/sql/tokenizer/norace.go +++ b/pkg/sql/tokenizer/norace.go @@ -1,5 +1,4 @@ //go:build !race -// +build !race package tokenizer diff --git a/pkg/sql/tokenizer/race.go b/pkg/sql/tokenizer/race.go index 99b1ff7b..f1aae592 100644 --- a/pkg/sql/tokenizer/race.go +++ b/pkg/sql/tokenizer/race.go @@ -1,5 +1,4 @@ //go:build race -// +build race package tokenizer