Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions candle-binding/regex_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package candle_binding

import (
"context"
"fmt"
"regexp"
"strings"
"time"
)

// RegexProviderConfig holds the configuration for the regex provider.
type RegexProviderConfig struct {
MaxPatterns int `yaml:"max_patterns"`
MaxPatternLength int `yaml:"max_pattern_length"`
MaxInputLength int `yaml:"max_input_length"`
DefaultTimeoutMs int `yaml:"default_timeout_ms"`
Patterns []RegexPattern `yaml:"patterns"`
}

// RegexPattern defines a single regex pattern.
type RegexPattern struct {
ID string `yaml:"id"`
Pattern string `yaml:"pattern"`
Flags string `yaml:"flags"`
Category string `yaml:"category"`
}

// RegexProvider is a ReDoS-safe regex scanner.
// It uses Go's built-in regexp package, which is based on RE2 and is not
// vulnerable to regular expression denial of service attacks.
type RegexProvider struct {
compiled []*regexp.Regexp
patterns []RegexPattern
timeout time.Duration
maxInputLength int
testDelay time.Duration // For testing purposes
}

// MatchResult represents a single regex match.
type MatchResult struct {
PatternID string
Category string
Match string
StartIndex int
EndIndex int
}

// NewRegexProvider creates a new RegexProvider.
func NewRegexProvider(cfg RegexProviderConfig, options ...func(*RegexProvider)) (*RegexProvider, error) {
if len(cfg.Patterns) > cfg.MaxPatterns {
return nil, fmt.Errorf("number of patterns (%d) exceeds max_patterns (%d)", len(cfg.Patterns), cfg.MaxPatterns)
}

compiled := make([]*regexp.Regexp, 0, len(cfg.Patterns))
for _, p := range cfg.Patterns {
if len(p.Pattern) > cfg.MaxPatternLength {
return nil, fmt.Errorf("pattern length for ID '%s' (%d) exceeds max_pattern_length (%d)", p.ID, len(p.Pattern), cfg.MaxPatternLength)
}

pattern := p.Pattern
if strings.Contains(p.Flags, "i") {
pattern = "(?i)" + pattern
}

re, err := regexp.Compile(pattern)
if err != nil {
return nil, fmt.Errorf("failed to compile pattern ID '%s': %w", p.ID, err)
}
compiled = append(compiled, re)
}

rp := &RegexProvider{
compiled: compiled,
patterns: cfg.Patterns,
timeout: time.Duration(cfg.DefaultTimeoutMs) * time.Millisecond,
maxInputLength: cfg.MaxInputLength,
}

for _, option := range options {
option(rp)
}

return rp, nil
}

// WithTestDelay is a functional option to add a delay for testing timeouts.
func WithTestDelay(d time.Duration) func(*RegexProvider) {
return func(rp *RegexProvider) {
rp.testDelay = d
}
}

// Scan scans the input string for matches.
// The scan is performed in a separate goroutine and is subject to a timeout.
// The timeout check is performed between each pattern, so a single very slow
// pattern can still block for longer than the timeout. However, Go's regex
// engine is very fast and not vulnerable to ReDoS, so this is not a major
// concern in practice.
func (rp *RegexProvider) Scan(input string) ([]MatchResult, error) {
if len(input) > rp.maxInputLength {
return nil, fmt.Errorf("input length (%d) exceeds max_input_length (%d)", len(input), rp.maxInputLength)
}

ctx, cancel := context.WithTimeout(context.Background(), rp.timeout)
defer cancel()

resultChan := make(chan struct {
matches []MatchResult
err error
}, 1)

go func() {
var matches []MatchResult
for i, re := range rp.compiled {
select {
case <-ctx.Done():
// The context was cancelled, so we don't need to continue.
return
default:
// Introduce a delay for testing purposes
if rp.testDelay > 0 {
time.Sleep(rp.testDelay)
}

locs := re.FindAllStringIndex(input, -1)
for _, loc := range locs {
matches = append(matches, MatchResult{
PatternID: rp.patterns[i].ID,
Category: rp.patterns[i].Category,
Match: input[loc[0]:loc[1]],
StartIndex: loc[0],
EndIndex: loc[1],
})
}
}
}
resultChan <- struct {
matches []MatchResult
err error
}{matches, nil}
}()

select {
case res := <-resultChan:
return res.matches, res.err
case <-ctx.Done():
return nil, fmt.Errorf("regex scan timed out after %v", rp.timeout)
}
}
58 changes: 58 additions & 0 deletions candle-binding/regex_provider_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package candle_binding

import (
"fmt"
"testing"
)

func BenchmarkRegexProvider_Scan(b *testing.B) {
cfg := RegexProviderConfig{
MaxPatterns: 100,
MaxPatternLength: 1000,
MaxInputLength: 10000,
DefaultTimeoutMs: 1000,
Patterns: []RegexPattern{
{ID: "email", Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`},
{ID: "word", Pattern: "hello"},
{ID: "case", Pattern: "World", Flags: "i"},
},
}
rp, err := NewRegexProvider(cfg)
if err != nil {
b.Fatalf("failed to create regex provider: %v", err)
}

input := "my email is test@example.com, say hello to the beautiful World"

b.Run("SinglePattern", func(b *testing.B) {
singlePatternCfg := RegexProviderConfig{
MaxPatterns: 1,
MaxPatternLength: 100,
MaxInputLength: 1000,
DefaultTimeoutMs: 100,
Patterns: []RegexPattern{
{ID: "email", Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`},
},
}
singleRp, _ := NewRegexProvider(singlePatternCfg)
for i := 0; i < b.N; i++ {
_, _ = singleRp.Scan(input)
}
})

b.Run("MultiPattern", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = rp.Scan(input)
}
})

b.Run("LargeInput", func(b *testing.B) {
largeInput := ""
for i := 0; i < 100; i++ {
largeInput += fmt.Sprintf("email%d@example.com ", i)
}
for i := 0; i < b.N; i++ {
_, _ = rp.Scan(largeInput)
}
})
}
190 changes: 190 additions & 0 deletions candle-binding/regex_provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package candle_binding

import (
"strings"
"testing"
"time"
)

func TestNewRegexProvider(t *testing.T) {
t.Run("ValidConfig", func(t *testing.T) {
cfg := RegexProviderConfig{
MaxPatterns: 10,
MaxPatternLength: 100,
MaxInputLength: 1000,
DefaultTimeoutMs: 50,
Patterns: []RegexPattern{
{ID: "email", Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`},
},
}
_, err := NewRegexProvider(cfg)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
})

t.Run("TooManyPatterns", func(t *testing.T) {
cfg := RegexProviderConfig{
MaxPatterns: 1,
Patterns: []RegexPattern{
{ID: "p1", Pattern: "a"},
{ID: "p2", Pattern: "b"},
},
}
_, err := NewRegexProvider(cfg)
if err == nil {
t.Fatal("expected an error for too many patterns, got nil")
}
})

t.Run("PatternTooLong", func(t *testing.T) {
cfg := RegexProviderConfig{
MaxPatterns: 10,
MaxPatternLength: 5,
Patterns: []RegexPattern{
{ID: "long", Pattern: "abcdef"},
},
}
_, err := NewRegexProvider(cfg)
if err == nil {
t.Fatal("expected an error for pattern too long, got nil")
}
})

t.Run("InvalidRegex", func(t *testing.T) {
cfg := RegexProviderConfig{
MaxPatterns: 10,
MaxPatternLength: 100,
Patterns: []RegexPattern{
{ID: "invalid", Pattern: `[`},
},
}
_, err := NewRegexProvider(cfg)
if err == nil {
t.Fatal("expected an error for invalid regex, got nil")
}
})
}

func TestRegexProvider_Scan(t *testing.T) {
cfg := RegexProviderConfig{
MaxPatterns: 10,
MaxPatternLength: 100,
MaxInputLength: 1000,
DefaultTimeoutMs: 100,
Patterns: []RegexPattern{
{ID: "email", Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`, Category: "pii"},
{ID: "word", Pattern: "hello", Category: "greeting"},
{ID: "case", Pattern: "World", Flags: "i", Category: "case-test"},
},
}
rp, err := NewRegexProvider(cfg)
if err != nil {
t.Fatalf("failed to create regex provider: %v", err)
}

t.Run("SimpleMatch", func(t *testing.T) {
input := "say hello to the world"
matches, err := rp.Scan(input)
if err != nil {
t.Fatalf("scan failed: %v", err)
}
if len(matches) != 2 { // "hello" and "world" (case-insensitive)
t.Fatalf("expected 2 matches, got %d", len(matches))
}
})

t.Run("CaseInsensitiveMatch", func(t *testing.T) {
input := "hello WORLD"
matches, err := rp.Scan(input)
if err != nil {
t.Fatalf("scan failed: %v", err)
}
if len(matches) != 2 {
t.Fatalf("expected 2 matches, got %d", len(matches))
}
})

t.Run("MultipleMatches", func(t *testing.T) {
input := "my email is test@example.com, say hello"
matches, err := rp.Scan(input)
if err != nil {
t.Fatalf("scan failed: %v", err)
}
if len(matches) != 2 {
t.Fatalf("expected 2 matches, got %d", len(matches))
}
})

t.Run("NoMatch", func(t *testing.T) {
input := "nothing to see here"
matches, err := rp.Scan(input)
if err != nil {
t.Fatalf("scan failed: %v", err)
}
if len(matches) != 0 {
t.Fatalf("expected 0 matches, got %d", len(matches))
}
})

t.Run("InputTooLong", func(t *testing.T) {
rp.maxInputLength = 5
_, err := rp.Scan("abcdef")
if err == nil {
t.Fatal("expected an error for input too long, got nil")
}
rp.maxInputLength = 1000 // reset
})

t.Run("Timeout", func(t *testing.T) {
cfg := RegexProviderConfig{
MaxPatterns: 1,
MaxPatternLength: 100,
MaxInputLength: 1000,
DefaultTimeoutMs: 10, // 10ms
Patterns: []RegexPattern{
{ID: "any", Pattern: `.`},
},
}
// Create a provider with a 20ms delay, which is longer than the timeout
rp, err := NewRegexProvider(cfg, WithTestDelay(20*time.Millisecond))
if err != nil {
t.Fatalf("failed to create regex provider: %v", err)
}

_, err = rp.Scan("a")
if err == nil {
t.Fatal("expected a timeout error, got nil")
}
if !strings.Contains(err.Error(), "timed out") {
t.Errorf("expected timeout error, got: %v", err)
}
})

t.Run("ReDoSAttackVector", func(t *testing.T) {
// This pattern is a known ReDoS vector for backtracking regex engines.
// Go's engine is not vulnerable, so this should execute quickly.
cfg := RegexProviderConfig{
MaxPatterns: 1,
MaxPatternLength: 100,
MaxInputLength: 1000,
DefaultTimeoutMs: 500, // 500ms timeout
Patterns: []RegexPattern{
{ID: "redos", Pattern: `(a+)+$`},
},
}
rp, err := NewRegexProvider(cfg)
if err != nil {
t.Fatalf("failed to create regex provider: %v", err)
}

// A long string of 'a's followed by a non-matching character.
// In a vulnerable engine, this would cause catastrophic backtracking.
input := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"

_, err = rp.Scan(input)
if err != nil {
t.Fatalf("scan failed for ReDoS pattern: %v", err)
}
})
}
Loading