diff --git a/bridge.go b/bridge.go index 9f2c424..bef401c 100644 --- a/bridge.go +++ b/bridge.go @@ -4,14 +4,15 @@ import ( "context" "fmt" "net/http" + "strings" "sync" "sync/atomic" "cdr.dev/slog" "github.com/coder/aibridge/mcp" - "go.opentelemetry.io/otel/trace" - "github.com/hashicorp/go-multierror" + "github.com/sony/gobreaker/v2" + "go.opentelemetry.io/otel/trace" ) // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; @@ -30,6 +31,10 @@ type RequestBridge struct { mcpProxy mcp.ServerProxier + // circuitBreakers manages circuit breakers for upstream providers. + // When enabled, it protects against cascading failures from upstream rate limits. + circuitBreakers *CircuitBreakers + inflightReqs atomic.Int32 inflightWG sync.WaitGroup // For graceful shutdown. @@ -49,12 +54,35 @@ var _ http.Handler = &RequestBridge{} // // mcpProxy will be closed when the [RequestBridge] is closed. func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { + return NewRequestBridgeWithCircuitBreaker(ctx, providers, recorder, mcpProxy, logger, metrics, tracer, nil) +} + +// NewRequestBridgeWithCircuitBreaker creates a new *[RequestBridge] with per-provider circuit breaker configuration. +// The cbConfigs map is keyed by provider name. Providers not in the map will not have circuit breaker protection. +// Pass nil to disable circuit breakers entirely. +func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbConfigs map[string]CircuitBreakerConfig) (*RequestBridge, error) { mux := http.NewServeMux() + // Create circuit breakers with metrics callback + var onChange func(name string, from, to gobreaker.State) + if metrics != nil { + onChange = func(name string, from, to gobreaker.State) { + provider, endpoint, _ := strings.Cut(name, ":") + metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(stateToGaugeValue(to)) + if to == gobreaker.StateOpen { + metrics.CircuitBreakerTrips.WithLabelValues(provider, endpoint).Inc() + } + } + } + cbs := NewCircuitBreakers(cbConfigs, onChange) + for _, provider := range providers { // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer)) + handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) + // Wrap with circuit breaker middleware if configured for this provider + handler = CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler).ServeHTTP + mux.HandleFunc(path, handler) } // Any requests which passthrough to this will be reverse-proxied to the upstream. @@ -77,11 +105,12 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record inflightCtx, cancel := context.WithCancel(context.Background()) return &RequestBridge{ - mux: mux, - logger: logger, - mcpProxy: mcpProxy, - inflightCtx: inflightCtx, - inflightCancel: cancel, + mux: mux, + logger: logger, + mcpProxy: mcpProxy, + circuitBreakers: cbs, + inflightCtx: inflightCtx, + inflightCancel: cancel, closed: make(chan struct{}, 1), }, nil @@ -153,6 +182,11 @@ func (b *RequestBridge) InflightRequests() int32 { return b.inflightReqs.Load() } +// CircuitBreakers returns the circuit breakers for this bridge. +func (b *RequestBridge) CircuitBreakers() *CircuitBreakers { + return b.circuitBreakers +} + // mergeContexts merges two contexts together, so that if either is cancelled // the returned context is cancelled. The context values will only be used from // the first context. diff --git a/circuit_breaker.go b/circuit_breaker.go new file mode 100644 index 0000000..9ffd0cc --- /dev/null +++ b/circuit_breaker.go @@ -0,0 +1,199 @@ +package aibridge + +import ( + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/sony/gobreaker/v2" +) + +// CircuitBreakerConfig holds configuration for circuit breakers. +// Fields match gobreaker.Settings for clarity. +type CircuitBreakerConfig struct { + // MaxRequests is the maximum number of requests allowed in half-open state. + MaxRequests uint32 + // Interval is the cyclic period of the closed state for clearing internal counts. + Interval time.Duration + // Timeout is how long the circuit stays open before transitioning to half-open. + Timeout time.Duration + // FailureThreshold is the number of consecutive failures that triggers the circuit to open. + FailureThreshold uint32 + // IsFailure determines if a status code should count as a failure. + // If nil, defaults to 429, 503, and 529 (Anthropic overloaded). + IsFailure func(statusCode int) bool +} + +// DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. +func DefaultCircuitBreakerConfig() CircuitBreakerConfig { + return CircuitBreakerConfig{ + FailureThreshold: 5, + Interval: 10 * time.Second, + Timeout: 30 * time.Second, + MaxRequests: 3, + IsFailure: DefaultIsFailure, + } +} + +// DefaultIsFailure returns true for status codes that typically indicate +// upstream overload: 429 (Too Many Requests), 503 (Service Unavailable), +// and 529 (Anthropic Overloaded). +func DefaultIsFailure(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, // 429 + http.StatusServiceUnavailable, // 503 + 529: // Anthropic "Overloaded" + return true + default: + return false + } +} + +// CircuitBreakers manages per-endpoint circuit breakers using sony/gobreaker. +// Circuit breakers are keyed by "provider:endpoint" for per-endpoint isolation. +type CircuitBreakers struct { + breakers sync.Map // map[string]*gobreaker.CircuitBreaker[any] + configs map[string]CircuitBreakerConfig + onChange func(name string, from, to gobreaker.State) +} + +// NewCircuitBreakers creates a new circuit breaker manager with per-provider configs. +// The configs map is keyed by provider name. Providers not in the map will not have +// circuit breaker protection. +func NewCircuitBreakers(configs map[string]CircuitBreakerConfig, onChange func(name string, from, to gobreaker.State)) *CircuitBreakers { + return &CircuitBreakers{ + configs: configs, + onChange: onChange, + } +} + +// getConfig returns the config for a provider, or nil if not configured. +func (c *CircuitBreakers) getConfig(provider string) *CircuitBreakerConfig { + if c.configs == nil { + return nil + } + cfg, ok := c.configs[provider] + if !ok { + return nil + } + return &cfg +} + +// getOrCreate returns the circuit breaker for a provider/endpoint, creating if needed. +// Returns nil if the provider is not configured. +func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.CircuitBreaker[any] { + cfg := c.getConfig(provider) + if cfg == nil { + return nil + } + + key := provider + ":" + endpoint + if v, ok := c.breakers.Load(key); ok { + return v.(*gobreaker.CircuitBreaker[any]) + } + + settings := gobreaker.Settings{ + Name: key, + MaxRequests: cfg.MaxRequests, + Interval: cfg.Interval, + Timeout: cfg.Timeout, + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= cfg.FailureThreshold + }, + OnStateChange: func(name string, from, to gobreaker.State) { + if c.onChange != nil { + c.onChange(name, from, to) + } + }, + } + + cb := gobreaker.NewCircuitBreaker[any](settings) + actual, _ := c.breakers.LoadOrStore(key, cb) + return actual.(*gobreaker.CircuitBreaker[any]) +} + +// statusCapturingWriter wraps http.ResponseWriter to capture the status code. +type statusCapturingWriter struct { + http.ResponseWriter + statusCode int + headerWritten bool +} + +func (w *statusCapturingWriter) WriteHeader(code int) { + if !w.headerWritten { + w.statusCode = code + w.headerWritten = true + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *statusCapturingWriter) Write(b []byte) (int, error) { + if !w.headerWritten { + w.statusCode = http.StatusOK + w.headerWritten = true + } + return w.ResponseWriter.Write(b) +} + +// CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. +// It captures the response status code to determine success/failure without provider-specific logic. +func CircuitBreakerMiddleware(cbs *CircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + cfg := cbs.getConfig(provider) + if cfg == nil { + // No config for this provider, pass through + return next + } + + isFailure := cfg.IsFailure + if isFailure == nil { + isFailure = DefaultIsFailure + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + endpoint := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", provider)) + + // Check if circuit is open + cb := cbs.getOrCreate(provider, endpoint) + if cb != nil && cb.State() == gobreaker.StateOpen { + if metrics != nil { + metrics.CircuitBreakerRejects.WithLabelValues(provider, endpoint).Inc() + } + http.Error(w, "circuit breaker is open", http.StatusServiceUnavailable) + return + } + + // Wrap response writer to capture status code + sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + next.ServeHTTP(sw, r) + + // Record result + if cb != nil { + if isFailure(sw.statusCode) { + _, _ = cb.Execute(func() (any, error) { + return nil, fmt.Errorf("upstream error: %d", sw.statusCode) + }) + } else { + _, _ = cb.Execute(func() (any, error) { return nil, nil }) + } + } + }) + } +} + +// stateToGaugeValue converts gobreaker.State to a gauge value. +// closed=0, half-open=0.5, open=1 +func stateToGaugeValue(s gobreaker.State) float64 { + switch s { + case gobreaker.StateClosed: + return 0 + case gobreaker.StateHalfOpen: + return 0.5 + case gobreaker.StateOpen: + return 1 + default: + return 0 + } +} diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go new file mode 100644 index 0000000..dca23e2 --- /dev/null +++ b/circuit_breaker_test.go @@ -0,0 +1,262 @@ +package aibridge + +import ( + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/sony/gobreaker/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + // Mock upstream that returns 429 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + }) + + // Create circuit breaker with low threshold + cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ + "test": { + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, + }, nil) + + // Wrap upstream with circuit breaker middleware + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // First 2 requests hit upstream, get 429 + for i := 0; i < 2; i++ { + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(2), upstreamCalls.Load()) + + // Third request should get 503 "circuit breaker is open" without hitting upstream + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Contains(t, string(body), "circuit breaker is open") + assert.Equal(t, int32(2), upstreamCalls.Load()) // No new upstream call + + // Wait for timeout, verify recovery + time.Sleep(60 * time.Millisecond) + + // Next request should hit upstream again (half-open state) + resp, err = http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, int32(3), upstreamCalls.Load()) +} + +func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { + t.Parallel() + + chatCalls := atomic.Int32{} + responsesCalls := atomic.Int32{} + + // Mock upstream - /chat returns 429, /responses returns 200 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test/v1/chat/completions" { + chatCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + } else { + responsesCalls.Add(1) + w.WriteHeader(http.StatusOK) + } + }) + + cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ + "test": { + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, + }, nil) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // Trip circuit on /chat/completions + resp, err := http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + + // /chat/completions should now be blocked + resp, err = http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, int32(1), chatCalls.Load()) // Only 1 call, second was blocked + + // /responses should still work + resp, err = http.Get(server.URL + "/test/v1/responses") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), responsesCalls.Load()) +} + +func TestCircuitBreakerMiddleware_NotConfigured(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + }) + + // No config for "test" provider + cbs := NewCircuitBreakers(nil, nil) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // All requests should pass through even with 429s + for i := 0; i < 10; i++ { + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(10), upstreamCalls.Load()) +} + +func TestCircuitBreakerMiddleware_RecoveryAfterSuccess(t *testing.T) { + t.Parallel() + + var returnError atomic.Bool + returnError.Store(true) + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if returnError.Load() { + w.WriteHeader(http.StatusTooManyRequests) + } else { + w.WriteHeader(http.StatusOK) + } + }) + + cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ + "test": { + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, + }, nil) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // Trip the circuit + for i := 0; i < 2; i++ { + resp, _ := http.Get(server.URL + "/test/v1/messages") + resp.Body.Close() + } + + // Circuit should be open + resp, _ := http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + resp.Body.Close() + + // Wait for timeout, switch upstream to success + time.Sleep(60 * time.Millisecond) + returnError.Store(false) + + // Half-open: one request allowed + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() + + // Circuit should be closed now, more requests allowed + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { + t.Parallel() + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) // 502 + }) + + // Custom IsFailure that treats 502 as failure + cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ + "test": { + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + IsFailure: func(statusCode int) bool { + return statusCode == http.StatusBadGateway + }, + }, + }, nil) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // First request returns 502, trips circuit + resp, _ := http.Get(server.URL + "/test/v1/messages") + resp.Body.Close() + + // Second request should be blocked + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + resp.Body.Close() +} + +func TestDefaultIsFailure(t *testing.T) { + t.Parallel() + + tests := []struct { + statusCode int + isFailure bool + }{ + {http.StatusOK, false}, + {http.StatusBadRequest, false}, + {http.StatusUnauthorized, false}, + {http.StatusTooManyRequests, true}, // 429 + {http.StatusInternalServerError, false}, + {http.StatusBadGateway, false}, + {http.StatusServiceUnavailable, true}, // 503 + {529, true}, // Anthropic Overloaded + } + + for _, tt := range tests { + assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) + } +} + +func TestStateToGaugeValue(t *testing.T) { + t.Parallel() + + assert.Equal(t, float64(0), stateToGaugeValue(gobreaker.StateClosed)) + assert.Equal(t, float64(0.5), stateToGaugeValue(gobreaker.StateHalfOpen)) + assert.Equal(t, float64(1), stateToGaugeValue(gobreaker.StateOpen)) +} diff --git a/go.mod b/go.mod index 9a62089..4715f99 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,8 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 ) +require github.com/sony/gobreaker/v2 v2.3.0 + require ( github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect diff --git a/go.sum b/go.sum index 385345d..fff1ee3 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sony/gobreaker/v2 v2.3.0 h1:7VYxZ69QXRQ2Q4eEawHn6eU4FiuwovzJwsUMA03Lu4I= +github.com/sony/gobreaker/v2 v2.3.0/go.mod h1:pTyFJgcZ3h2tdQVLZZruK2C0eoFL1fb/G83wK1ZQl+s= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/interception.go b/interception.go index 46ec7bd..62201aa 100644 --- a/interception.go +++ b/interception.go @@ -45,6 +45,8 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() + route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) + interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) @@ -93,7 +95,6 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server return } - route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) log := logger.With( slog.F("route", route), slog.F("provider", p.Name()), diff --git a/metrics.go b/metrics.go index 32d5a78..f744d10 100644 --- a/metrics.go +++ b/metrics.go @@ -28,6 +28,11 @@ type Metrics struct { // Tool-related metrics. InjectedToolUseCount *prometheus.CounterVec NonInjectedToolUseCount *prometheus.CounterVec + + // Circuit breaker metrics. + CircuitBreakerState *prometheus.GaugeVec // Current state (0=closed, 1=open, 2=half-open) + CircuitBreakerTrips *prometheus.CounterVec // Total times circuit opened + CircuitBreakerRejects *prometheus.CounterVec // Requests rejected due to open circuit } // NewMetrics creates AND registers metrics. It will panic if a collector has already been registered. @@ -102,5 +107,26 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Name: "total", Help: "The number of times an AI model selected a tool to be invoked by the client.", }, append(baseLabels, "name")), + + // Circuit breaker metrics. + + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Subsystem: "circuit_breaker", + Name: "state", + Help: "Current state of the circuit breaker (0=closed, 1=open, 2=half-open).", + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "trips_total", + Help: "Total number of times the circuit breaker has tripped open.", + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "rejects_total", + Help: "Total number of requests rejected due to open circuit breaker.", + }, []string{"provider", "endpoint"}), } }