diff --git a/candle-binding/regex_provider.go b/candle-binding/regex_provider.go new file mode 100644 index 000000000..14baa2040 --- /dev/null +++ b/candle-binding/regex_provider.go @@ -0,0 +1,149 @@ +package candle_binding + +import ( + "context" + "fmt" + "regexp" + "strings" + "time" +) + +// RegexProviderConfig holds the configuration for the regex provider. +type RegexProviderConfig struct { + MaxPatterns int `yaml:"max_patterns"` + MaxPatternLength int `yaml:"max_pattern_length"` + MaxInputLength int `yaml:"max_input_length"` + DefaultTimeoutMs int `yaml:"default_timeout_ms"` + Patterns []RegexPattern `yaml:"patterns"` +} + +// RegexPattern defines a single regex pattern. +type RegexPattern struct { + ID string `yaml:"id"` + Pattern string `yaml:"pattern"` + Flags string `yaml:"flags"` + Category string `yaml:"category"` +} + +// RegexProvider is a ReDoS-safe regex scanner. +// It uses Go's built-in regexp package, which is based on RE2 and is not +// vulnerable to regular expression denial of service attacks. +type RegexProvider struct { + compiled []*regexp.Regexp + patterns []RegexPattern + timeout time.Duration + maxInputLength int + testDelay time.Duration // For testing purposes +} + +// MatchResult represents a single regex match. +type MatchResult struct { + PatternID string + Category string + Match string + StartIndex int + EndIndex int +} + +// NewRegexProvider creates a new RegexProvider. +func NewRegexProvider(cfg RegexProviderConfig, options ...func(*RegexProvider)) (*RegexProvider, error) { + if len(cfg.Patterns) > cfg.MaxPatterns { + return nil, fmt.Errorf("number of patterns (%d) exceeds max_patterns (%d)", len(cfg.Patterns), cfg.MaxPatterns) + } + + compiled := make([]*regexp.Regexp, 0, len(cfg.Patterns)) + for _, p := range cfg.Patterns { + if len(p.Pattern) > cfg.MaxPatternLength { + return nil, fmt.Errorf("pattern length for ID '%s' (%d) exceeds max_pattern_length (%d)", p.ID, len(p.Pattern), cfg.MaxPatternLength) + } + + pattern := p.Pattern + if strings.Contains(p.Flags, "i") { + pattern = "(?i)" + pattern + } + + re, err := regexp.Compile(pattern) + if err != nil { + return nil, fmt.Errorf("failed to compile pattern ID '%s': %w", p.ID, err) + } + compiled = append(compiled, re) + } + + rp := &RegexProvider{ + compiled: compiled, + patterns: cfg.Patterns, + timeout: time.Duration(cfg.DefaultTimeoutMs) * time.Millisecond, + maxInputLength: cfg.MaxInputLength, + } + + for _, option := range options { + option(rp) + } + + return rp, nil +} + +// WithTestDelay is a functional option to add a delay for testing timeouts. +func WithTestDelay(d time.Duration) func(*RegexProvider) { + return func(rp *RegexProvider) { + rp.testDelay = d + } +} + +// Scan scans the input string for matches. +// The scan is performed in a separate goroutine and is subject to a timeout. +// The timeout check is performed between each pattern, so a single very slow +// pattern can still block for longer than the timeout. However, Go's regex +// engine is very fast and not vulnerable to ReDoS, so this is not a major +// concern in practice. +func (rp *RegexProvider) Scan(input string) ([]MatchResult, error) { + if len(input) > rp.maxInputLength { + return nil, fmt.Errorf("input length (%d) exceeds max_input_length (%d)", len(input), rp.maxInputLength) + } + + ctx, cancel := context.WithTimeout(context.Background(), rp.timeout) + defer cancel() + + resultChan := make(chan struct { + matches []MatchResult + err error + }, 1) + + go func() { + var matches []MatchResult + for i, re := range rp.compiled { + select { + case <-ctx.Done(): + // The context was cancelled, so we don't need to continue. + return + default: + // Introduce a delay for testing purposes + if rp.testDelay > 0 { + time.Sleep(rp.testDelay) + } + + locs := re.FindAllStringIndex(input, -1) + for _, loc := range locs { + matches = append(matches, MatchResult{ + PatternID: rp.patterns[i].ID, + Category: rp.patterns[i].Category, + Match: input[loc[0]:loc[1]], + StartIndex: loc[0], + EndIndex: loc[1], + }) + } + } + } + resultChan <- struct { + matches []MatchResult + err error + }{matches, nil} + }() + + select { + case res := <-resultChan: + return res.matches, res.err + case <-ctx.Done(): + return nil, fmt.Errorf("regex scan timed out after %v", rp.timeout) + } +} diff --git a/candle-binding/regex_provider_bench_test.go b/candle-binding/regex_provider_bench_test.go new file mode 100644 index 000000000..800534a08 --- /dev/null +++ b/candle-binding/regex_provider_bench_test.go @@ -0,0 +1,58 @@ +package candle_binding + +import ( + "fmt" + "testing" +) + +func BenchmarkRegexProvider_Scan(b *testing.B) { + cfg := RegexProviderConfig{ + MaxPatterns: 100, + MaxPatternLength: 1000, + MaxInputLength: 10000, + DefaultTimeoutMs: 1000, + Patterns: []RegexPattern{ + {ID: "email", Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`}, + {ID: "word", Pattern: "hello"}, + {ID: "case", Pattern: "World", Flags: "i"}, + }, + } + rp, err := NewRegexProvider(cfg) + if err != nil { + b.Fatalf("failed to create regex provider: %v", err) + } + + input := "my email is test@example.com, say hello to the beautiful World" + + b.Run("SinglePattern", func(b *testing.B) { + singlePatternCfg := RegexProviderConfig{ + MaxPatterns: 1, + MaxPatternLength: 100, + MaxInputLength: 1000, + DefaultTimeoutMs: 100, + Patterns: []RegexPattern{ + {ID: "email", Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`}, + }, + } + singleRp, _ := NewRegexProvider(singlePatternCfg) + for i := 0; i < b.N; i++ { + _, _ = singleRp.Scan(input) + } + }) + + b.Run("MultiPattern", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = rp.Scan(input) + } + }) + + b.Run("LargeInput", func(b *testing.B) { + largeInput := "" + for i := 0; i < 100; i++ { + largeInput += fmt.Sprintf("email%d@example.com ", i) + } + for i := 0; i < b.N; i++ { + _, _ = rp.Scan(largeInput) + } + }) +} diff --git a/candle-binding/regex_provider_test.go b/candle-binding/regex_provider_test.go new file mode 100644 index 000000000..6d2473896 --- /dev/null +++ b/candle-binding/regex_provider_test.go @@ -0,0 +1,190 @@ +package candle_binding + +import ( + "strings" + "testing" + "time" +) + +func TestNewRegexProvider(t *testing.T) { + t.Run("ValidConfig", func(t *testing.T) { + cfg := RegexProviderConfig{ + MaxPatterns: 10, + MaxPatternLength: 100, + MaxInputLength: 1000, + DefaultTimeoutMs: 50, + Patterns: []RegexPattern{ + {ID: "email", Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`}, + }, + } + _, err := NewRegexProvider(cfg) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + }) + + t.Run("TooManyPatterns", func(t *testing.T) { + cfg := RegexProviderConfig{ + MaxPatterns: 1, + Patterns: []RegexPattern{ + {ID: "p1", Pattern: "a"}, + {ID: "p2", Pattern: "b"}, + }, + } + _, err := NewRegexProvider(cfg) + if err == nil { + t.Fatal("expected an error for too many patterns, got nil") + } + }) + + t.Run("PatternTooLong", func(t *testing.T) { + cfg := RegexProviderConfig{ + MaxPatterns: 10, + MaxPatternLength: 5, + Patterns: []RegexPattern{ + {ID: "long", Pattern: "abcdef"}, + }, + } + _, err := NewRegexProvider(cfg) + if err == nil { + t.Fatal("expected an error for pattern too long, got nil") + } + }) + + t.Run("InvalidRegex", func(t *testing.T) { + cfg := RegexProviderConfig{ + MaxPatterns: 10, + MaxPatternLength: 100, + Patterns: []RegexPattern{ + {ID: "invalid", Pattern: `[`}, + }, + } + _, err := NewRegexProvider(cfg) + if err == nil { + t.Fatal("expected an error for invalid regex, got nil") + } + }) +} + +func TestRegexProvider_Scan(t *testing.T) { + cfg := RegexProviderConfig{ + MaxPatterns: 10, + MaxPatternLength: 100, + MaxInputLength: 1000, + DefaultTimeoutMs: 100, + Patterns: []RegexPattern{ + {ID: "email", Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`, Category: "pii"}, + {ID: "word", Pattern: "hello", Category: "greeting"}, + {ID: "case", Pattern: "World", Flags: "i", Category: "case-test"}, + }, + } + rp, err := NewRegexProvider(cfg) + if err != nil { + t.Fatalf("failed to create regex provider: %v", err) + } + + t.Run("SimpleMatch", func(t *testing.T) { + input := "say hello to the world" + matches, err := rp.Scan(input) + if err != nil { + t.Fatalf("scan failed: %v", err) + } + if len(matches) != 2 { // "hello" and "world" (case-insensitive) + t.Fatalf("expected 2 matches, got %d", len(matches)) + } + }) + + t.Run("CaseInsensitiveMatch", func(t *testing.T) { + input := "hello WORLD" + matches, err := rp.Scan(input) + if err != nil { + t.Fatalf("scan failed: %v", err) + } + if len(matches) != 2 { + t.Fatalf("expected 2 matches, got %d", len(matches)) + } + }) + + t.Run("MultipleMatches", func(t *testing.T) { + input := "my email is test@example.com, say hello" + matches, err := rp.Scan(input) + if err != nil { + t.Fatalf("scan failed: %v", err) + } + if len(matches) != 2 { + t.Fatalf("expected 2 matches, got %d", len(matches)) + } + }) + + t.Run("NoMatch", func(t *testing.T) { + input := "nothing to see here" + matches, err := rp.Scan(input) + if err != nil { + t.Fatalf("scan failed: %v", err) + } + if len(matches) != 0 { + t.Fatalf("expected 0 matches, got %d", len(matches)) + } + }) + + t.Run("InputTooLong", func(t *testing.T) { + rp.maxInputLength = 5 + _, err := rp.Scan("abcdef") + if err == nil { + t.Fatal("expected an error for input too long, got nil") + } + rp.maxInputLength = 1000 // reset + }) + + t.Run("Timeout", func(t *testing.T) { + cfg := RegexProviderConfig{ + MaxPatterns: 1, + MaxPatternLength: 100, + MaxInputLength: 1000, + DefaultTimeoutMs: 10, // 10ms + Patterns: []RegexPattern{ + {ID: "any", Pattern: `.`}, + }, + } + // Create a provider with a 20ms delay, which is longer than the timeout + rp, err := NewRegexProvider(cfg, WithTestDelay(20*time.Millisecond)) + if err != nil { + t.Fatalf("failed to create regex provider: %v", err) + } + + _, err = rp.Scan("a") + if err == nil { + t.Fatal("expected a timeout error, got nil") + } + if !strings.Contains(err.Error(), "timed out") { + t.Errorf("expected timeout error, got: %v", err) + } + }) + + t.Run("ReDoSAttackVector", func(t *testing.T) { + // This pattern is a known ReDoS vector for backtracking regex engines. + // Go's engine is not vulnerable, so this should execute quickly. + cfg := RegexProviderConfig{ + MaxPatterns: 1, + MaxPatternLength: 100, + MaxInputLength: 1000, + DefaultTimeoutMs: 500, // 500ms timeout + Patterns: []RegexPattern{ + {ID: "redos", Pattern: `(a+)+$`}, + }, + } + rp, err := NewRegexProvider(cfg) + if err != nil { + t.Fatalf("failed to create regex provider: %v", err) + } + + // A long string of 'a's followed by a non-matching character. + // In a vulnerable engine, this would cause catastrophic backtracking. + input := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" + + _, err = rp.Scan(input) + if err != nil { + t.Fatalf("scan failed for ReDoS pattern: %v", err) + } + }) +}