From 7700a8fe223d80ca3b942da8aecbe56188a7d7a3 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 11 Dec 2025 13:50:43 +0000 Subject: [PATCH 1/5] feat: add circuit breaker for upstream provider overload protection Implement per-provider circuit breakers that detect upstream rate limiting (429/503/529 status codes) and temporarily stop sending requests when providers are overloaded. Key features: - Per-provider circuit breakers (Anthropic, OpenAI) - Configurable failure threshold, time window, and cooldown period - Half-open state allows gradual recovery testing - Prometheus metrics for monitoring (state gauge, trips counter, rejects counter) - Thread-safe implementation with proper state machine transitions - Disabled by default for backward compatibility Circuit breaker states: - Closed: normal operation, tracking failures within sliding window - Open: all requests rejected with 503, waiting for cooldown - Half-Open: limited requests allowed to test if upstream recovered Status codes that trigger circuit breaker: - 429 Too Many Requests - 503 Service Unavailable - 529 Anthropic Overloaded Relates to: https://github.com/coder/internal/issues/1153 --- bridge.go | 45 ++++- circuit_breaker.go | 349 ++++++++++++++++++++++++++++++++++++ circuit_breaker_test.go | 382 ++++++++++++++++++++++++++++++++++++++++ interception.go | 52 +++++- metrics.go | 26 +++ 5 files changed, 847 insertions(+), 7 deletions(-) create mode 100644 circuit_breaker.go create mode 100644 circuit_breaker_test.go diff --git a/bridge.go b/bridge.go index 9f2c424..16872a0 100644 --- a/bridge.go +++ b/bridge.go @@ -30,6 +30,10 @@ type RequestBridge struct { mcpProxy mcp.ServerProxier + // circuitBreakers manages circuit breakers for upstream providers. + // When enabled, it protects against cascading failures from upstream rate limits. + circuitBreakers *CircuitBreakerManager + inflightReqs atomic.Int32 inflightWG sync.WaitGroup // For graceful shutdown. @@ -49,12 +53,34 @@ var _ http.Handler = &RequestBridge{} // // mcpProxy will be closed when the [RequestBridge] is closed. func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { + return NewRequestBridgeWithCircuitBreaker(ctx, providers, recorder, mcpProxy, logger, metrics, tracer, DefaultCircuitBreakerConfig()) +} + +// NewRequestBridgeWithCircuitBreaker creates a new *[RequestBridge] with custom circuit breaker configuration. +// See [NewRequestBridge] for more details. +func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbConfig CircuitBreakerConfig) (*RequestBridge, error) { mux := http.NewServeMux() + // Create circuit breaker manager + cbManager := NewCircuitBreakerManager(cbConfig) + + // Set up metrics callback if metrics are provided + if metrics != nil { + cbManager.SetStateChangeCallback(func(provider string, from, to CircuitState) { + metrics.CircuitBreakerState.WithLabelValues(provider).Set(float64(to)) + if to == CircuitOpen { + metrics.CircuitBreakerTrips.WithLabelValues(provider).Inc() + } + }) + } + for _, provider := range providers { + // Pre-create circuit breaker for this provider + cbManager.GetOrCreate(provider.Name()) + // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer)) + mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer, cbManager)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. @@ -77,11 +103,12 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record inflightCtx, cancel := context.WithCancel(context.Background()) return &RequestBridge{ - mux: mux, - logger: logger, - mcpProxy: mcpProxy, - inflightCtx: inflightCtx, - inflightCancel: cancel, + mux: mux, + logger: logger, + mcpProxy: mcpProxy, + circuitBreakers: cbManager, + inflightCtx: inflightCtx, + inflightCancel: cancel, closed: make(chan struct{}, 1), }, nil @@ -153,6 +180,12 @@ func (b *RequestBridge) InflightRequests() int32 { return b.inflightReqs.Load() } +// CircuitBreakers returns the circuit breaker manager for this bridge. +// This can be used to query circuit breaker states or configure callbacks. +func (b *RequestBridge) CircuitBreakers() *CircuitBreakerManager { + return b.circuitBreakers +} + // mergeContexts merges two contexts together, so that if either is cancelled // the returned context is cancelled. The context values will only be used from // the first context. diff --git a/circuit_breaker.go b/circuit_breaker.go new file mode 100644 index 0000000..48800ee --- /dev/null +++ b/circuit_breaker.go @@ -0,0 +1,349 @@ +package aibridge + +import ( + "net/http" + "sync" + "time" +) + +// CircuitState represents the current state of a circuit breaker. +type CircuitState int + +const ( + // CircuitClosed is the normal state - all requests pass through. + CircuitClosed CircuitState = iota + // CircuitOpen is the tripped state - requests are rejected immediately. + CircuitOpen + // CircuitHalfOpen is the testing state - limited requests pass through. + CircuitHalfOpen +) + +func (s CircuitState) String() string { + switch s { + case CircuitClosed: + return "closed" + case CircuitOpen: + return "open" + case CircuitHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// CircuitBreakerConfig holds configuration for a circuit breaker. +type CircuitBreakerConfig struct { + // Enabled controls whether the circuit breaker is active. + // If false, all requests pass through regardless of failures. + Enabled bool + // FailureThreshold is the number of failures within the window that triggers the circuit to open. + FailureThreshold int64 + // Window is the time window for counting failures. + Window time.Duration + // Cooldown is how long the circuit stays open before transitioning to half-open. + Cooldown time.Duration + // HalfOpenMaxRequests is the maximum number of requests allowed in half-open state + // before deciding whether to close or re-open the circuit. + HalfOpenMaxRequests int64 +} + +// DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. +func DefaultCircuitBreakerConfig() CircuitBreakerConfig { + return CircuitBreakerConfig{ + Enabled: false, // Disabled by default for backward compatibility + FailureThreshold: 5, + Window: 10 * time.Second, + Cooldown: 30 * time.Second, + HalfOpenMaxRequests: 3, + } +} + +// CircuitBreaker implements the circuit breaker pattern to protect against +// upstream service failures. It tracks failures from upstream providers +// (like rate limit errors) and temporarily blocks requests when the +// failure threshold is exceeded. +type CircuitBreaker struct { + mu sync.RWMutex + + // Current state + state CircuitState + failures int64 // Failure count in current window + windowStart time.Time // Start of current failure counting window + openedAt time.Time // When circuit transitioned to open + + // Half-open state tracking + halfOpenSuccesses int64 + halfOpenFailures int64 + + // Configuration + config CircuitBreakerConfig + + // Provider name for logging/metrics + provider string + + // Optional metrics callback + onStateChange func(provider string, from, to CircuitState) +} + +// NewCircuitBreaker creates a new circuit breaker for the given provider. +func NewCircuitBreaker(provider string, config CircuitBreakerConfig) *CircuitBreaker { + return &CircuitBreaker{ + state: CircuitClosed, + windowStart: time.Now(), + config: config, + provider: provider, + } +} + +// SetStateChangeCallback sets a callback that is invoked when the circuit state changes. +// This is useful for metrics and logging. +func (cb *CircuitBreaker) SetStateChangeCallback(fn func(provider string, from, to CircuitState)) { + cb.mu.Lock() + defer cb.mu.Unlock() + cb.onStateChange = fn +} + +// Allow checks if a request should be allowed through. +// Returns true if the request can proceed, false if it should be rejected. +func (cb *CircuitBreaker) Allow() bool { + if !cb.config.Enabled { + return true + } + + cb.mu.Lock() + defer cb.mu.Unlock() + + now := time.Now() + + switch cb.state { + case CircuitClosed: + return true + + case CircuitOpen: + // Check if cooldown period has elapsed + if now.Sub(cb.openedAt) >= cb.config.Cooldown { + cb.transitionTo(CircuitHalfOpen) + return true + } + return false + + case CircuitHalfOpen: + // Allow limited requests in half-open state + totalHalfOpenRequests := cb.halfOpenSuccesses + cb.halfOpenFailures + return totalHalfOpenRequests < cb.config.HalfOpenMaxRequests + } + + return true +} + +// RecordSuccess records a successful request. +// This is called after a request completes successfully. +func (cb *CircuitBreaker) RecordSuccess() { + if !cb.config.Enabled { + return + } + + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case CircuitHalfOpen: + cb.halfOpenSuccesses++ + // If we've had enough successes in half-open, close the circuit + if cb.halfOpenSuccesses >= cb.config.HalfOpenMaxRequests { + cb.transitionTo(CircuitClosed) + } + case CircuitClosed: + // Reset failure count on success (sliding window behavior) + // This helps prevent false positives from old failures + cb.maybeResetWindow() + } +} + +// RecordFailure records a failed request. +// statusCode is the HTTP status code from the upstream response. +// Returns true if this failure caused the circuit to trip open. +func (cb *CircuitBreaker) RecordFailure(statusCode int) bool { + if !cb.config.Enabled { + return false + } + + // Only count specific error codes as circuit-breaker failures + if !isCircuitBreakerFailure(statusCode) { + return false + } + + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case CircuitClosed: + cb.maybeResetWindow() + cb.failures++ + if cb.failures >= cb.config.FailureThreshold { + cb.transitionTo(CircuitOpen) + return true + } + + case CircuitHalfOpen: + cb.halfOpenFailures++ + // Any failure in half-open state re-opens the circuit + cb.transitionTo(CircuitOpen) + return true + } + + return false +} + +// State returns the current state of the circuit breaker. +func (cb *CircuitBreaker) State() CircuitState { + cb.mu.RLock() + defer cb.mu.RUnlock() + return cb.state +} + +// Provider returns the provider name this circuit breaker is for. +func (cb *CircuitBreaker) Provider() string { + return cb.provider +} + +// Failures returns the current failure count. +func (cb *CircuitBreaker) Failures() int64 { + cb.mu.RLock() + defer cb.mu.RUnlock() + return cb.failures +} + +// transitionTo changes the circuit state. Must be called with lock held. +func (cb *CircuitBreaker) transitionTo(newState CircuitState) { + oldState := cb.state + if oldState == newState { + return + } + + cb.state = newState + now := time.Now() + + switch newState { + case CircuitOpen: + cb.openedAt = now + case CircuitHalfOpen: + cb.halfOpenSuccesses = 0 + cb.halfOpenFailures = 0 + case CircuitClosed: + cb.failures = 0 + cb.windowStart = now + } + + if cb.onStateChange != nil { + // Call callback without holding lock to avoid deadlocks + callback := cb.onStateChange + go callback(cb.provider, oldState, newState) + } +} + +// maybeResetWindow resets the failure count if the window has elapsed. +// Must be called with lock held. +func (cb *CircuitBreaker) maybeResetWindow() { + now := time.Now() + if now.Sub(cb.windowStart) >= cb.config.Window { + cb.failures = 0 + cb.windowStart = now + } +} + +// isCircuitBreakerFailure returns true if the given HTTP status code +// should count as a failure for circuit breaker purposes. +// We specifically track rate limiting and overload errors from upstream. +func isCircuitBreakerFailure(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests: // 429 - Rate limited + return true + case http.StatusServiceUnavailable: // 503 - Service unavailable + return true + case 529: // Anthropic-specific "Overloaded" error + return true + default: + return false + } +} + +// CircuitBreakerManager manages circuit breakers for multiple providers. +type CircuitBreakerManager struct { + mu sync.RWMutex + breakers map[string]*CircuitBreaker + config CircuitBreakerConfig + + // Metrics callbacks + onStateChange func(provider string, from, to CircuitState) +} + +// NewCircuitBreakerManager creates a new manager with the given configuration. +func NewCircuitBreakerManager(config CircuitBreakerConfig) *CircuitBreakerManager { + return &CircuitBreakerManager{ + breakers: make(map[string]*CircuitBreaker), + config: config, + } +} + +// SetStateChangeCallback sets the callback for state changes on all circuit breakers. +func (m *CircuitBreakerManager) SetStateChangeCallback(fn func(provider string, from, to CircuitState)) { + m.mu.Lock() + defer m.mu.Unlock() + m.onStateChange = fn + + // Update existing breakers + for _, cb := range m.breakers { + cb.SetStateChangeCallback(fn) + } +} + +// GetOrCreate returns the circuit breaker for the given provider, +// creating one if it doesn't exist. +func (m *CircuitBreakerManager) GetOrCreate(provider string) *CircuitBreaker { + m.mu.RLock() + if cb, ok := m.breakers[provider]; ok { + m.mu.RUnlock() + return cb + } + m.mu.RUnlock() + + m.mu.Lock() + defer m.mu.Unlock() + + // Double-check after acquiring write lock + if cb, ok := m.breakers[provider]; ok { + return cb + } + + cb := NewCircuitBreaker(provider, m.config) + if m.onStateChange != nil { + cb.SetStateChangeCallback(m.onStateChange) + } + m.breakers[provider] = cb + return cb +} + +// Get returns the circuit breaker for the given provider, or nil if not found. +func (m *CircuitBreakerManager) Get(provider string) *CircuitBreaker { + m.mu.RLock() + defer m.mu.RUnlock() + return m.breakers[provider] +} + +// AllStates returns the current state of all circuit breakers. +func (m *CircuitBreakerManager) AllStates() map[string]CircuitState { + m.mu.RLock() + defer m.mu.RUnlock() + + states := make(map[string]CircuitState, len(m.breakers)) + for provider, cb := range m.breakers { + states[provider] = cb.State() + } + return states +} + +// Config returns the configuration used by this manager. +func (m *CircuitBreakerManager) Config() CircuitBreakerConfig { + return m.config +} diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go new file mode 100644 index 0000000..4d52a15 --- /dev/null +++ b/circuit_breaker_test.go @@ -0,0 +1,382 @@ +package aibridge + +import ( + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCircuitBreaker_DefaultConfig(t *testing.T) { + t.Parallel() + + cfg := DefaultCircuitBreakerConfig() + assert.False(t, cfg.Enabled, "should be disabled by default") + assert.Equal(t, int64(5), cfg.FailureThreshold) + assert.Equal(t, 10*time.Second, cfg.Window) + assert.Equal(t, 30*time.Second, cfg.Cooldown) + assert.Equal(t, int64(3), cfg.HalfOpenMaxRequests) +} + +func TestCircuitBreaker_DisabledByDefault(t *testing.T) { + t.Parallel() + + cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig()) + + // Should always allow when disabled + assert.True(t, cb.Allow()) + + // Recording failures should not affect state when disabled + for i := 0; i < 100; i++ { + cb.RecordFailure(http.StatusTooManyRequests) + } + assert.True(t, cb.Allow()) + assert.Equal(t, CircuitClosed, cb.State()) +} + +func TestCircuitBreaker_StateTransitions(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 3, + Window: time.Minute, // Long window so it doesn't reset during test + Cooldown: 50 * time.Millisecond, + HalfOpenMaxRequests: 2, + } + cb := NewCircuitBreaker("test", cfg) + + // Start in closed state + assert.Equal(t, CircuitClosed, cb.State()) + assert.True(t, cb.Allow()) + + // Record failures below threshold + cb.RecordFailure(http.StatusTooManyRequests) + cb.RecordFailure(http.StatusTooManyRequests) + assert.Equal(t, CircuitClosed, cb.State()) + assert.True(t, cb.Allow()) + + // Third failure should trip the circuit + tripped := cb.RecordFailure(http.StatusTooManyRequests) + assert.True(t, tripped) + assert.Equal(t, CircuitOpen, cb.State()) + assert.False(t, cb.Allow()) + + // Wait for cooldown + time.Sleep(60 * time.Millisecond) + + // Should transition to half-open and allow request + assert.True(t, cb.Allow()) + assert.Equal(t, CircuitHalfOpen, cb.State()) + + // Success in half-open should eventually close + cb.RecordSuccess() + cb.RecordSuccess() + assert.Equal(t, CircuitClosed, cb.State()) + assert.True(t, cb.Allow()) +} + +func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 2, + Window: time.Minute, + Cooldown: 50 * time.Millisecond, + HalfOpenMaxRequests: 3, + } + cb := NewCircuitBreaker("test", cfg) + + // Trip the circuit + cb.RecordFailure(http.StatusTooManyRequests) + cb.RecordFailure(http.StatusTooManyRequests) + assert.Equal(t, CircuitOpen, cb.State()) + + // Wait for cooldown + time.Sleep(60 * time.Millisecond) + + // Transition to half-open + assert.True(t, cb.Allow()) + assert.Equal(t, CircuitHalfOpen, cb.State()) + + // Failure in half-open should re-open circuit + tripped := cb.RecordFailure(http.StatusServiceUnavailable) + assert.True(t, tripped) + assert.Equal(t, CircuitOpen, cb.State()) + assert.False(t, cb.Allow()) +} + +func TestCircuitBreaker_OnlyCountsRelevantStatusCodes(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 2, + Window: time.Minute, + Cooldown: time.Minute, + HalfOpenMaxRequests: 2, + } + cb := NewCircuitBreaker("test", cfg) + + // Non-circuit-breaker status codes should not count + cb.RecordFailure(http.StatusBadRequest) // 400 + cb.RecordFailure(http.StatusUnauthorized) // 401 + cb.RecordFailure(http.StatusInternalServerError) // 500 + cb.RecordFailure(http.StatusBadGateway) // 502 + assert.Equal(t, CircuitClosed, cb.State()) + assert.Equal(t, int64(0), cb.Failures()) + + // These should count + cb.RecordFailure(http.StatusTooManyRequests) // 429 + assert.Equal(t, int64(1), cb.Failures()) + + cb.RecordFailure(http.StatusServiceUnavailable) // 503 + assert.Equal(t, CircuitOpen, cb.State()) +} + +func TestCircuitBreaker_Anthropic529(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 1, + Window: time.Minute, + Cooldown: time.Minute, + HalfOpenMaxRequests: 1, + } + cb := NewCircuitBreaker("anthropic", cfg) + + // Anthropic-specific 529 "Overloaded" should trip the circuit + tripped := cb.RecordFailure(529) + assert.True(t, tripped) + assert.Equal(t, CircuitOpen, cb.State()) +} + +func TestCircuitBreaker_WindowReset(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 3, + Window: 50 * time.Millisecond, // Short window + Cooldown: time.Minute, + HalfOpenMaxRequests: 2, + } + cb := NewCircuitBreaker("test", cfg) + + // Record failures + cb.RecordFailure(http.StatusTooManyRequests) + cb.RecordFailure(http.StatusTooManyRequests) + assert.Equal(t, int64(2), cb.Failures()) + + // Wait for window to expire + time.Sleep(60 * time.Millisecond) + + // Next failure should reset counter (due to window expiry) + cb.RecordFailure(http.StatusTooManyRequests) + assert.Equal(t, int64(1), cb.Failures()) + assert.Equal(t, CircuitClosed, cb.State()) +} + +func TestCircuitBreaker_ConcurrentAccess(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 100, + Window: time.Minute, + Cooldown: time.Minute, + HalfOpenMaxRequests: 10, + } + cb := NewCircuitBreaker("test", cfg) + + var wg sync.WaitGroup + numGoroutines := 50 + opsPerGoroutine := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + cb.Allow() + cb.RecordSuccess() + cb.RecordFailure(http.StatusTooManyRequests) + cb.State() + cb.Failures() + } + }() + } + + wg.Wait() + // Should not panic or deadlock +} + +func TestCircuitBreaker_StateChangeCallback(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 2, + Window: time.Minute, + Cooldown: 50 * time.Millisecond, + HalfOpenMaxRequests: 1, + } + cb := NewCircuitBreaker("test", cfg) + + var mu sync.Mutex + var transitions []struct { + from, to CircuitState + } + + cb.SetStateChangeCallback(func(provider string, from, to CircuitState) { + mu.Lock() + defer mu.Unlock() + transitions = append(transitions, struct{ from, to CircuitState }{from, to}) + }) + + // Trip the circuit + cb.RecordFailure(http.StatusTooManyRequests) + cb.RecordFailure(http.StatusTooManyRequests) + + // Wait for callback + time.Sleep(10 * time.Millisecond) + + // Wait for cooldown and trigger half-open + time.Sleep(60 * time.Millisecond) + cb.Allow() + + // Wait for callback + time.Sleep(10 * time.Millisecond) + + // Success to close + cb.RecordSuccess() + + // Wait for callback + time.Sleep(10 * time.Millisecond) + + mu.Lock() + defer mu.Unlock() + require.Len(t, transitions, 3) + assert.Equal(t, CircuitClosed, transitions[0].from) + assert.Equal(t, CircuitOpen, transitions[0].to) + assert.Equal(t, CircuitOpen, transitions[1].from) + assert.Equal(t, CircuitHalfOpen, transitions[1].to) + assert.Equal(t, CircuitHalfOpen, transitions[2].from) + assert.Equal(t, CircuitClosed, transitions[2].to) +} + +func TestCircuitBreakerManager_GetOrCreate(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 5, + Window: time.Minute, + Cooldown: time.Minute, + } + manager := NewCircuitBreakerManager(cfg) + + // First call should create + cb1 := manager.GetOrCreate("anthropic") + require.NotNil(t, cb1) + assert.Equal(t, "anthropic", cb1.Provider()) + + // Second call should return same instance + cb2 := manager.GetOrCreate("anthropic") + assert.Same(t, cb1, cb2) + + // Different provider gets different instance + cb3 := manager.GetOrCreate("openai") + require.NotNil(t, cb3) + assert.NotSame(t, cb1, cb3) + assert.Equal(t, "openai", cb3.Provider()) +} + +func TestCircuitBreakerManager_AllStates(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 1, + Window: time.Minute, + Cooldown: time.Minute, + } + manager := NewCircuitBreakerManager(cfg) + + manager.GetOrCreate("anthropic") + manager.GetOrCreate("openai") + + // Trip one circuit + manager.Get("anthropic").RecordFailure(http.StatusTooManyRequests) + + states := manager.AllStates() + assert.Equal(t, CircuitOpen, states["anthropic"]) + assert.Equal(t, CircuitClosed, states["openai"]) +} + +func TestCircuitBreakerManager_ConcurrentGetOrCreate(t *testing.T) { + t.Parallel() + + cfg := DefaultCircuitBreakerConfig() + cfg.Enabled = true + manager := NewCircuitBreakerManager(cfg) + + var wg sync.WaitGroup + var results [100]*CircuitBreaker + + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx] = manager.GetOrCreate("test-provider") + }(i) + } + + wg.Wait() + + // All should be the same instance + first := results[0] + for i := 1; i < 100; i++ { + assert.Same(t, first, results[i]) + } +} + +func TestIsCircuitBreakerFailure(t *testing.T) { + t.Parallel() + + tests := []struct { + statusCode int + isFailure bool + }{ + {http.StatusOK, false}, + {http.StatusBadRequest, false}, + {http.StatusUnauthorized, false}, + {http.StatusForbidden, false}, + {http.StatusNotFound, false}, + {http.StatusTooManyRequests, true}, // 429 + {http.StatusInternalServerError, false}, + {http.StatusBadGateway, false}, + {http.StatusServiceUnavailable, true}, // 503 + {529, true}, // Anthropic Overloaded + } + + for _, tt := range tests { + t.Run(http.StatusText(tt.statusCode), func(t *testing.T) { + assert.Equal(t, tt.isFailure, isCircuitBreakerFailure(tt.statusCode)) + }) + } +} + +func TestCircuitState_String(t *testing.T) { + t.Parallel() + + assert.Equal(t, "closed", CircuitClosed.String()) + assert.Equal(t, "open", CircuitOpen.String()) + assert.Equal(t, "half-open", CircuitHalfOpen.String()) + assert.Equal(t, "unknown", CircuitState(99).String()) +} diff --git a/interception.go b/interception.go index 46ec7bd..d1d8444 100644 --- a/interception.go +++ b/interception.go @@ -40,11 +40,26 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { +func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbManager *CircuitBreakerManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() + // Check circuit breaker before proceeding + cb := cbManager.GetOrCreate(p.Name()) + if !cb.Allow() { + span.SetStatus(codes.Error, "circuit breaker open") + logger.Debug(ctx, "request rejected by circuit breaker", + slog.F("provider", p.Name()), + slog.F("circuit_state", cb.State().String()), + ) + if metrics != nil { + metrics.CircuitBreakerRejects.WithLabelValues(p.Name()).Inc() + } + http.Error(w, fmt.Sprintf("%s is currently unavailable due to upstream rate limiting. Please try again later.", p.Name()), http.StatusServiceUnavailable) + return + } + interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) @@ -116,11 +131,24 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server } span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) log.Warn(ctx, "interception failed", slog.Error(err)) + + // Record failure for circuit breaker - extract status code if available + if statusCode := extractStatusCodeFromError(err); statusCode > 0 { + if cb.RecordFailure(statusCode) { + log.Warn(ctx, "circuit breaker tripped", + slog.F("provider", p.Name()), + slog.F("status_code", statusCode), + ) + } + } } else { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusCompleted, route, r.Method, actor.id).Add(1) } log.Debug(ctx, "interception ended") + + // Record success for circuit breaker + cb.RecordSuccess() } asyncRecorder.RecordInterceptionEnded(ctx, &InterceptionRecordEnded{ID: interceptor.ID().String()}) @@ -128,3 +156,25 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server asyncRecorder.Wait() } } + +// extractStatusCodeFromError attempts to extract an HTTP status code from an error. +// This is used for circuit breaker failure tracking. +func extractStatusCodeFromError(err error) int { + if err == nil { + return 0 + } + + // Check for Anthropic error response + var antErr *AnthropicErrorResponse + if errors.As(err, &antErr) && antErr != nil { + return antErr.StatusCode + } + + // Check for OpenAI error response + var oaiErr *OpenAIErrorResponse + if errors.As(err, &oaiErr) && oaiErr != nil { + return oaiErr.StatusCode + } + + return 0 +} diff --git a/metrics.go b/metrics.go index 32d5a78..565029a 100644 --- a/metrics.go +++ b/metrics.go @@ -28,6 +28,11 @@ type Metrics struct { // Tool-related metrics. InjectedToolUseCount *prometheus.CounterVec NonInjectedToolUseCount *prometheus.CounterVec + + // Circuit breaker metrics. + CircuitBreakerState *prometheus.GaugeVec // Current state (0=closed, 1=open, 2=half-open) + CircuitBreakerTrips *prometheus.CounterVec // Total times circuit opened + CircuitBreakerRejects *prometheus.CounterVec // Requests rejected due to open circuit } // NewMetrics creates AND registers metrics. It will panic if a collector has already been registered. @@ -102,5 +107,26 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Name: "total", Help: "The number of times an AI model selected a tool to be invoked by the client.", }, append(baseLabels, "name")), + + // Circuit breaker metrics. + + // Pessimistic cardinality: 2 providers = up to 2. + CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Subsystem: "circuit_breaker", + Name: "state", + Help: "Current state of the circuit breaker (0=closed, 1=open, 2=half-open).", + }, []string{"provider"}), + // Pessimistic cardinality: 2 providers = up to 2. + CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "trips_total", + Help: "Total number of times the circuit breaker has tripped open.", + }, []string{"provider"}), + // Pessimistic cardinality: 2 providers = up to 2. + CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "rejects_total", + Help: "Total number of requests rejected due to open circuit breaker.", + }, []string{"provider"}), } } From aad288c4cacf7c854dd757fc315153c5485347c6 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Fri, 12 Dec 2025 13:25:59 +0000 Subject: [PATCH 2/5] chore: apply make fmt --- circuit_breaker.go | 10 +++++----- circuit_breaker_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index 48800ee..d863cad 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -66,11 +66,11 @@ type CircuitBreaker struct { mu sync.RWMutex // Current state - state CircuitState - failures int64 // Failure count in current window + state CircuitState + failures int64 // Failure count in current window windowStart time.Time // Start of current failure counting window - openedAt time.Time // When circuit transitioned to open - + openedAt time.Time // When circuit transitioned to open + // Half-open state tracking halfOpenSuccesses int64 halfOpenFailures int64 @@ -291,7 +291,7 @@ func (m *CircuitBreakerManager) SetStateChangeCallback(fn func(provider string, m.mu.Lock() defer m.mu.Unlock() m.onStateChange = fn - + // Update existing breakers for _, cb := range m.breakers { cb.SetStateChangeCallback(fn) diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 4d52a15..3f84ba6 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -131,9 +131,9 @@ func TestCircuitBreaker_OnlyCountsRelevantStatusCodes(t *testing.T) { assert.Equal(t, int64(0), cb.Failures()) // These should count - cb.RecordFailure(http.StatusTooManyRequests) // 429 + cb.RecordFailure(http.StatusTooManyRequests) // 429 assert.Equal(t, int64(1), cb.Failures()) - + cb.RecordFailure(http.StatusServiceUnavailable) // 503 assert.Equal(t, CircuitOpen, cb.State()) } From 47253f193976b840bb6d2c85fdcfa4c1951a58d1 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Tue, 16 Dec 2025 10:27:35 +0000 Subject: [PATCH 3/5] refactor: use sony/gobreaker for circuit breakers with per-endpoint isolation - Replace custom circuit breaker implementation with sony/gobreaker - Change from per-provider to per-endpoint circuit breakers (e.g., OpenAI chat completions failing won't block responses API) - Simplify API: CircuitBreakers manages all breakers internally - Update metrics to include endpoint label - Simplify tests to focus on key behaviors Based on PR review feedback suggesting use of established library and per-endpoint granularity for better fault isolation. --- bridge.go | 30 ++-- circuit_breaker.go | 348 ++++++++++------------------------------ circuit_breaker_test.go | 266 ++++++++---------------------- go.mod | 2 + go.sum | 2 + interception.go | 20 +-- metrics.go | 12 +- 7 files changed, 186 insertions(+), 494 deletions(-) diff --git a/bridge.go b/bridge.go index 16872a0..3147486 100644 --- a/bridge.go +++ b/bridge.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "sync" "sync/atomic" @@ -32,7 +33,7 @@ type RequestBridge struct { // circuitBreakers manages circuit breakers for upstream providers. // When enabled, it protects against cascading failures from upstream rate limits. - circuitBreakers *CircuitBreakerManager + circuitBreakers *CircuitBreakers inflightReqs atomic.Int32 inflightWG sync.WaitGroup // For graceful shutdown. @@ -61,26 +62,24 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbConfig CircuitBreakerConfig) (*RequestBridge, error) { mux := http.NewServeMux() - // Create circuit breaker manager - cbManager := NewCircuitBreakerManager(cbConfig) - - // Set up metrics callback if metrics are provided + // Create circuit breakers with metrics callback + var onChange func(name string, from, to CircuitState) if metrics != nil { - cbManager.SetStateChangeCallback(func(provider string, from, to CircuitState) { - metrics.CircuitBreakerState.WithLabelValues(provider).Set(float64(to)) + onChange = func(name string, from, to CircuitState) { + provider, endpoint, _ := strings.Cut(name, ":") + metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(float64(to)) if to == CircuitOpen { - metrics.CircuitBreakerTrips.WithLabelValues(provider).Inc() + metrics.CircuitBreakerTrips.WithLabelValues(provider, endpoint).Inc() } - }) + } } + cbs := NewCircuitBreakers(cbConfig, onChange) for _, provider := range providers { - // Pre-create circuit breaker for this provider - cbManager.GetOrCreate(provider.Name()) // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer, cbManager)) + mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer, cbs)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. @@ -106,7 +105,7 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide mux: mux, logger: logger, mcpProxy: mcpProxy, - circuitBreakers: cbManager, + circuitBreakers: cbs, inflightCtx: inflightCtx, inflightCancel: cancel, @@ -180,9 +179,8 @@ func (b *RequestBridge) InflightRequests() int32 { return b.inflightReqs.Load() } -// CircuitBreakers returns the circuit breaker manager for this bridge. -// This can be used to query circuit breaker states or configure callbacks. -func (b *RequestBridge) CircuitBreakers() *CircuitBreakerManager { +// CircuitBreakers returns the circuit breakers for this bridge. +func (b *RequestBridge) CircuitBreakers() *CircuitBreakers { return b.circuitBreakers } diff --git a/circuit_breaker.go b/circuit_breaker.go index d863cad..d27f0cd 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -1,9 +1,12 @@ package aibridge import ( + "fmt" "net/http" "sync" "time" + + "github.com/sony/gobreaker/v2" ) // CircuitState represents the current state of a circuit breaker. @@ -31,19 +34,31 @@ func (s CircuitState) String() string { } } -// CircuitBreakerConfig holds configuration for a circuit breaker. +// toCircuitState converts gobreaker.State to our CircuitState. +func toCircuitState(s gobreaker.State) CircuitState { + switch s { + case gobreaker.StateClosed: + return CircuitClosed + case gobreaker.StateOpen: + return CircuitOpen + case gobreaker.StateHalfOpen: + return CircuitHalfOpen + default: + return CircuitClosed + } +} + +// CircuitBreakerConfig holds configuration for circuit breakers. type CircuitBreakerConfig struct { - // Enabled controls whether the circuit breaker is active. - // If false, all requests pass through regardless of failures. + // Enabled controls whether circuit breakers are active. Enabled bool - // FailureThreshold is the number of failures within the window that triggers the circuit to open. + // FailureThreshold is the number of consecutive failures that triggers the circuit to open. FailureThreshold int64 // Window is the time window for counting failures. Window time.Duration // Cooldown is how long the circuit stays open before transitioning to half-open. Cooldown time.Duration - // HalfOpenMaxRequests is the maximum number of requests allowed in half-open state - // before deciding whether to close or re-open the circuit. + // HalfOpenMaxRequests is the maximum number of requests allowed in half-open state. HalfOpenMaxRequests int64 } @@ -58,292 +73,97 @@ func DefaultCircuitBreakerConfig() CircuitBreakerConfig { } } -// CircuitBreaker implements the circuit breaker pattern to protect against -// upstream service failures. It tracks failures from upstream providers -// (like rate limit errors) and temporarily blocks requests when the -// failure threshold is exceeded. -type CircuitBreaker struct { - mu sync.RWMutex - - // Current state - state CircuitState - failures int64 // Failure count in current window - windowStart time.Time // Start of current failure counting window - openedAt time.Time // When circuit transitioned to open - - // Half-open state tracking - halfOpenSuccesses int64 - halfOpenFailures int64 - - // Configuration - config CircuitBreakerConfig - - // Provider name for logging/metrics - provider string - - // Optional metrics callback - onStateChange func(provider string, from, to CircuitState) -} - -// NewCircuitBreaker creates a new circuit breaker for the given provider. -func NewCircuitBreaker(provider string, config CircuitBreakerConfig) *CircuitBreaker { - return &CircuitBreaker{ - state: CircuitClosed, - windowStart: time.Now(), - config: config, - provider: provider, - } -} - -// SetStateChangeCallback sets a callback that is invoked when the circuit state changes. -// This is useful for metrics and logging. -func (cb *CircuitBreaker) SetStateChangeCallback(fn func(provider string, from, to CircuitState)) { - cb.mu.Lock() - defer cb.mu.Unlock() - cb.onStateChange = fn -} - -// Allow checks if a request should be allowed through. -// Returns true if the request can proceed, false if it should be rejected. -func (cb *CircuitBreaker) Allow() bool { - if !cb.config.Enabled { - return true - } - - cb.mu.Lock() - defer cb.mu.Unlock() - - now := time.Now() - - switch cb.state { - case CircuitClosed: - return true - - case CircuitOpen: - // Check if cooldown period has elapsed - if now.Sub(cb.openedAt) >= cb.config.Cooldown { - cb.transitionTo(CircuitHalfOpen) - return true - } - return false - - case CircuitHalfOpen: - // Allow limited requests in half-open state - totalHalfOpenRequests := cb.halfOpenSuccesses + cb.halfOpenFailures - return totalHalfOpenRequests < cb.config.HalfOpenMaxRequests - } - - return true -} - -// RecordSuccess records a successful request. -// This is called after a request completes successfully. -func (cb *CircuitBreaker) RecordSuccess() { - if !cb.config.Enabled { - return - } - - cb.mu.Lock() - defer cb.mu.Unlock() - - switch cb.state { - case CircuitHalfOpen: - cb.halfOpenSuccesses++ - // If we've had enough successes in half-open, close the circuit - if cb.halfOpenSuccesses >= cb.config.HalfOpenMaxRequests { - cb.transitionTo(CircuitClosed) - } - case CircuitClosed: - // Reset failure count on success (sliding window behavior) - // This helps prevent false positives from old failures - cb.maybeResetWindow() - } -} - -// RecordFailure records a failed request. -// statusCode is the HTTP status code from the upstream response. -// Returns true if this failure caused the circuit to trip open. -func (cb *CircuitBreaker) RecordFailure(statusCode int) bool { - if !cb.config.Enabled { - return false - } - - // Only count specific error codes as circuit-breaker failures - if !isCircuitBreakerFailure(statusCode) { - return false - } - - cb.mu.Lock() - defer cb.mu.Unlock() - - switch cb.state { - case CircuitClosed: - cb.maybeResetWindow() - cb.failures++ - if cb.failures >= cb.config.FailureThreshold { - cb.transitionTo(CircuitOpen) - return true - } - - case CircuitHalfOpen: - cb.halfOpenFailures++ - // Any failure in half-open state re-opens the circuit - cb.transitionTo(CircuitOpen) - return true - } - - return false -} - -// State returns the current state of the circuit breaker. -func (cb *CircuitBreaker) State() CircuitState { - cb.mu.RLock() - defer cb.mu.RUnlock() - return cb.state -} - -// Provider returns the provider name this circuit breaker is for. -func (cb *CircuitBreaker) Provider() string { - return cb.provider -} - -// Failures returns the current failure count. -func (cb *CircuitBreaker) Failures() int64 { - cb.mu.RLock() - defer cb.mu.RUnlock() - return cb.failures -} - -// transitionTo changes the circuit state. Must be called with lock held. -func (cb *CircuitBreaker) transitionTo(newState CircuitState) { - oldState := cb.state - if oldState == newState { - return - } - - cb.state = newState - now := time.Now() - - switch newState { - case CircuitOpen: - cb.openedAt = now - case CircuitHalfOpen: - cb.halfOpenSuccesses = 0 - cb.halfOpenFailures = 0 - case CircuitClosed: - cb.failures = 0 - cb.windowStart = now - } - - if cb.onStateChange != nil { - // Call callback without holding lock to avoid deadlocks - callback := cb.onStateChange - go callback(cb.provider, oldState, newState) - } -} - -// maybeResetWindow resets the failure count if the window has elapsed. -// Must be called with lock held. -func (cb *CircuitBreaker) maybeResetWindow() { - now := time.Now() - if now.Sub(cb.windowStart) >= cb.config.Window { - cb.failures = 0 - cb.windowStart = now - } -} - // isCircuitBreakerFailure returns true if the given HTTP status code // should count as a failure for circuit breaker purposes. -// We specifically track rate limiting and overload errors from upstream. func isCircuitBreakerFailure(statusCode int) bool { switch statusCode { - case http.StatusTooManyRequests: // 429 - Rate limited - return true - case http.StatusServiceUnavailable: // 503 - Service unavailable - return true - case 529: // Anthropic-specific "Overloaded" error + case http.StatusTooManyRequests, // 429 + http.StatusServiceUnavailable, // 503 + 529: // Anthropic "Overloaded" return true default: return false } } -// CircuitBreakerManager manages circuit breakers for multiple providers. -type CircuitBreakerManager struct { - mu sync.RWMutex - breakers map[string]*CircuitBreaker +// CircuitBreakers manages per-endpoint circuit breakers using sony/gobreaker. +// Circuit breakers are keyed by "provider:endpoint" for per-endpoint isolation. +type CircuitBreakers struct { + breakers sync.Map // map[string]*gobreaker.CircuitBreaker[any] config CircuitBreakerConfig - - // Metrics callbacks - onStateChange func(provider string, from, to CircuitState) + onChange func(name string, from, to CircuitState) } -// NewCircuitBreakerManager creates a new manager with the given configuration. -func NewCircuitBreakerManager(config CircuitBreakerConfig) *CircuitBreakerManager { - return &CircuitBreakerManager{ - breakers: make(map[string]*CircuitBreaker), +// NewCircuitBreakers creates a new circuit breaker manager. +func NewCircuitBreakers(config CircuitBreakerConfig, onChange func(name string, from, to CircuitState)) *CircuitBreakers { + return &CircuitBreakers{ config: config, + onChange: onChange, } } -// SetStateChangeCallback sets the callback for state changes on all circuit breakers. -func (m *CircuitBreakerManager) SetStateChangeCallback(fn func(provider string, from, to CircuitState)) { - m.mu.Lock() - defer m.mu.Unlock() - m.onStateChange = fn - - // Update existing breakers - for _, cb := range m.breakers { - cb.SetStateChangeCallback(fn) +// Allow checks if a request to provider/endpoint should be allowed. +func (c *CircuitBreakers) Allow(provider, endpoint string) bool { + if !c.config.Enabled { + return true } + cb := c.getOrCreate(provider, endpoint) + return cb.State() != gobreaker.StateOpen } -// GetOrCreate returns the circuit breaker for the given provider, -// creating one if it doesn't exist. -func (m *CircuitBreakerManager) GetOrCreate(provider string) *CircuitBreaker { - m.mu.RLock() - if cb, ok := m.breakers[provider]; ok { - m.mu.RUnlock() - return cb - } - m.mu.RUnlock() - - m.mu.Lock() - defer m.mu.Unlock() - - // Double-check after acquiring write lock - if cb, ok := m.breakers[provider]; ok { - return cb +// RecordSuccess records a successful request. +func (c *CircuitBreakers) RecordSuccess(provider, endpoint string) { + if !c.config.Enabled { + return } + cb := c.getOrCreate(provider, endpoint) + _, _ = cb.Execute(func() (any, error) { return nil, nil }) +} - cb := NewCircuitBreaker(provider, m.config) - if m.onStateChange != nil { - cb.SetStateChangeCallback(m.onStateChange) +// RecordFailure records a failed request. Returns true if this caused the circuit to open. +func (c *CircuitBreakers) RecordFailure(provider, endpoint string, statusCode int) bool { + if !c.config.Enabled || !isCircuitBreakerFailure(statusCode) { + return false } - m.breakers[provider] = cb - return cb + cb := c.getOrCreate(provider, endpoint) + before := cb.State() + _, _ = cb.Execute(func() (any, error) { + return nil, fmt.Errorf("upstream error: %d", statusCode) + }) + return before != gobreaker.StateOpen && cb.State() == gobreaker.StateOpen } -// Get returns the circuit breaker for the given provider, or nil if not found. -func (m *CircuitBreakerManager) Get(provider string) *CircuitBreaker { - m.mu.RLock() - defer m.mu.RUnlock() - return m.breakers[provider] +// State returns the current state for a provider/endpoint. +func (c *CircuitBreakers) State(provider, endpoint string) CircuitState { + if !c.config.Enabled { + return CircuitClosed + } + cb := c.getOrCreate(provider, endpoint) + return toCircuitState(cb.State()) } -// AllStates returns the current state of all circuit breakers. -func (m *CircuitBreakerManager) AllStates() map[string]CircuitState { - m.mu.RLock() - defer m.mu.RUnlock() +func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.CircuitBreaker[any] { + key := provider + ":" + endpoint + if v, ok := c.breakers.Load(key); ok { + return v.(*gobreaker.CircuitBreaker[any]) + } - states := make(map[string]CircuitState, len(m.breakers)) - for provider, cb := range m.breakers { - states[provider] = cb.State() + settings := gobreaker.Settings{ + Name: key, + MaxRequests: uint32(c.config.HalfOpenMaxRequests), + Interval: c.config.Window, + Timeout: c.config.Cooldown, + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= uint32(c.config.FailureThreshold) + }, + OnStateChange: func(name string, from, to gobreaker.State) { + if c.onChange != nil { + c.onChange(name, toCircuitState(from), toCircuitState(to)) + } + }, } - return states -} -// Config returns the configuration used by this manager. -func (m *CircuitBreakerManager) Config() CircuitBreakerConfig { - return m.config + cb := gobreaker.NewCircuitBreaker[any](settings) + actual, _ := c.breakers.LoadOrStore(key, cb) + return actual.(*gobreaker.CircuitBreaker[any]) } diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 3f84ba6..0d8deb9 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -21,96 +21,86 @@ func TestCircuitBreaker_DefaultConfig(t *testing.T) { assert.Equal(t, int64(3), cfg.HalfOpenMaxRequests) } -func TestCircuitBreaker_DisabledByDefault(t *testing.T) { +func TestCircuitBreakers_DisabledByDefault(t *testing.T) { t.Parallel() - cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig()) + cbs := NewCircuitBreakers(DefaultCircuitBreakerConfig(), nil) // Should always allow when disabled - assert.True(t, cb.Allow()) + assert.True(t, cbs.Allow("anthropic", "/v1/messages")) // Recording failures should not affect state when disabled for i := 0; i < 100; i++ { - cb.RecordFailure(http.StatusTooManyRequests) + cbs.RecordFailure("anthropic", "/v1/messages", http.StatusTooManyRequests) } - assert.True(t, cb.Allow()) - assert.Equal(t, CircuitClosed, cb.State()) + assert.True(t, cbs.Allow("anthropic", "/v1/messages")) + assert.Equal(t, CircuitClosed, cbs.State("anthropic", "/v1/messages")) } -func TestCircuitBreaker_StateTransitions(t *testing.T) { +func TestCircuitBreakers_StateTransitions(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ Enabled: true, FailureThreshold: 3, - Window: time.Minute, // Long window so it doesn't reset during test + Window: time.Minute, Cooldown: 50 * time.Millisecond, HalfOpenMaxRequests: 2, } - cb := NewCircuitBreaker("test", cfg) + cbs := NewCircuitBreakers(cfg, nil) // Start in closed state - assert.Equal(t, CircuitClosed, cb.State()) - assert.True(t, cb.Allow()) + assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.True(t, cbs.Allow("test", "/api")) // Record failures below threshold - cb.RecordFailure(http.StatusTooManyRequests) - cb.RecordFailure(http.StatusTooManyRequests) - assert.Equal(t, CircuitClosed, cb.State()) - assert.True(t, cb.Allow()) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) + assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) // Third failure should trip the circuit - tripped := cb.RecordFailure(http.StatusTooManyRequests) + tripped := cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cb.State()) - assert.False(t, cb.Allow()) + assert.Equal(t, CircuitOpen, cbs.State("test", "/api")) + assert.False(t, cbs.Allow("test", "/api")) // Wait for cooldown time.Sleep(60 * time.Millisecond) // Should transition to half-open and allow request - assert.True(t, cb.Allow()) - assert.Equal(t, CircuitHalfOpen, cb.State()) + assert.True(t, cbs.Allow("test", "/api")) + assert.Equal(t, CircuitHalfOpen, cbs.State("test", "/api")) // Success in half-open should eventually close - cb.RecordSuccess() - cb.RecordSuccess() - assert.Equal(t, CircuitClosed, cb.State()) - assert.True(t, cb.Allow()) + cbs.RecordSuccess("test", "/api") + cbs.RecordSuccess("test", "/api") + assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) } -func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { +func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ Enabled: true, - FailureThreshold: 2, + FailureThreshold: 1, Window: time.Minute, - Cooldown: 50 * time.Millisecond, - HalfOpenMaxRequests: 3, + Cooldown: time.Minute, + HalfOpenMaxRequests: 1, } - cb := NewCircuitBreaker("test", cfg) - - // Trip the circuit - cb.RecordFailure(http.StatusTooManyRequests) - cb.RecordFailure(http.StatusTooManyRequests) - assert.Equal(t, CircuitOpen, cb.State()) + cbs := NewCircuitBreakers(cfg, nil) - // Wait for cooldown - time.Sleep(60 * time.Millisecond) + // Trip circuit for one endpoint + cbs.RecordFailure("openai", "/v1/chat/completions", http.StatusTooManyRequests) + assert.Equal(t, CircuitOpen, cbs.State("openai", "/v1/chat/completions")) - // Transition to half-open - assert.True(t, cb.Allow()) - assert.Equal(t, CircuitHalfOpen, cb.State()) - - // Failure in half-open should re-open circuit - tripped := cb.RecordFailure(http.StatusServiceUnavailable) - assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cb.State()) - assert.False(t, cb.Allow()) + // Other endpoints should still be closed + assert.Equal(t, CircuitClosed, cbs.State("openai", "/v1/responses")) + assert.Equal(t, CircuitClosed, cbs.State("anthropic", "/v1/messages")) + assert.True(t, cbs.Allow("openai", "/v1/responses")) + assert.True(t, cbs.Allow("anthropic", "/v1/messages")) } -func TestCircuitBreaker_OnlyCountsRelevantStatusCodes(t *testing.T) { +func TestCircuitBreakers_OnlyCountsRelevantStatusCodes(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ @@ -120,25 +110,22 @@ func TestCircuitBreaker_OnlyCountsRelevantStatusCodes(t *testing.T) { Cooldown: time.Minute, HalfOpenMaxRequests: 2, } - cb := NewCircuitBreaker("test", cfg) + cbs := NewCircuitBreakers(cfg, nil) // Non-circuit-breaker status codes should not count - cb.RecordFailure(http.StatusBadRequest) // 400 - cb.RecordFailure(http.StatusUnauthorized) // 401 - cb.RecordFailure(http.StatusInternalServerError) // 500 - cb.RecordFailure(http.StatusBadGateway) // 502 - assert.Equal(t, CircuitClosed, cb.State()) - assert.Equal(t, int64(0), cb.Failures()) + cbs.RecordFailure("test", "/api", http.StatusBadRequest) // 400 + cbs.RecordFailure("test", "/api", http.StatusUnauthorized) // 401 + cbs.RecordFailure("test", "/api", http.StatusInternalServerError) // 500 + cbs.RecordFailure("test", "/api", http.StatusBadGateway) // 502 + assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) // These should count - cb.RecordFailure(http.StatusTooManyRequests) // 429 - assert.Equal(t, int64(1), cb.Failures()) - - cb.RecordFailure(http.StatusServiceUnavailable) // 503 - assert.Equal(t, CircuitOpen, cb.State()) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) // 429 + cbs.RecordFailure("test", "/api", http.StatusServiceUnavailable) // 503 + assert.Equal(t, CircuitOpen, cbs.State("test", "/api")) } -func TestCircuitBreaker_Anthropic529(t *testing.T) { +func TestCircuitBreakers_Anthropic529(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ @@ -148,75 +135,44 @@ func TestCircuitBreaker_Anthropic529(t *testing.T) { Cooldown: time.Minute, HalfOpenMaxRequests: 1, } - cb := NewCircuitBreaker("anthropic", cfg) + cbs := NewCircuitBreakers(cfg, nil) // Anthropic-specific 529 "Overloaded" should trip the circuit - tripped := cb.RecordFailure(529) + tripped := cbs.RecordFailure("anthropic", "/v1/messages", 529) assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cb.State()) -} - -func TestCircuitBreaker_WindowReset(t *testing.T) { - t.Parallel() - - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 3, - Window: 50 * time.Millisecond, // Short window - Cooldown: time.Minute, - HalfOpenMaxRequests: 2, - } - cb := NewCircuitBreaker("test", cfg) - - // Record failures - cb.RecordFailure(http.StatusTooManyRequests) - cb.RecordFailure(http.StatusTooManyRequests) - assert.Equal(t, int64(2), cb.Failures()) - - // Wait for window to expire - time.Sleep(60 * time.Millisecond) - - // Next failure should reset counter (due to window expiry) - cb.RecordFailure(http.StatusTooManyRequests) - assert.Equal(t, int64(1), cb.Failures()) - assert.Equal(t, CircuitClosed, cb.State()) + assert.Equal(t, CircuitOpen, cbs.State("anthropic", "/v1/messages")) } -func TestCircuitBreaker_ConcurrentAccess(t *testing.T) { +func TestCircuitBreakers_ConcurrentAccess(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ Enabled: true, - FailureThreshold: 100, + FailureThreshold: 1000, Window: time.Minute, Cooldown: time.Minute, HalfOpenMaxRequests: 10, } - cb := NewCircuitBreaker("test", cfg) + cbs := NewCircuitBreakers(cfg, nil) var wg sync.WaitGroup - numGoroutines := 50 - opsPerGoroutine := 100 - - for i := 0; i < numGoroutines; i++ { + for i := 0; i < 50; i++ { wg.Add(1) go func() { defer wg.Done() - for j := 0; j < opsPerGoroutine; j++ { - cb.Allow() - cb.RecordSuccess() - cb.RecordFailure(http.StatusTooManyRequests) - cb.State() - cb.Failures() + for j := 0; j < 100; j++ { + cbs.Allow("test", "/api") + cbs.RecordSuccess("test", "/api") + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) + cbs.State("test", "/api") } }() } - wg.Wait() // Should not panic or deadlock } -func TestCircuitBreaker_StateChangeCallback(t *testing.T) { +func TestCircuitBreakers_StateChangeCallback(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ @@ -226,38 +182,29 @@ func TestCircuitBreaker_StateChangeCallback(t *testing.T) { Cooldown: 50 * time.Millisecond, HalfOpenMaxRequests: 1, } - cb := NewCircuitBreaker("test", cfg) var mu sync.Mutex - var transitions []struct { - from, to CircuitState - } + var transitions []struct{ from, to CircuitState } - cb.SetStateChangeCallback(func(provider string, from, to CircuitState) { + cbs := NewCircuitBreakers(cfg, func(name string, from, to CircuitState) { mu.Lock() defer mu.Unlock() transitions = append(transitions, struct{ from, to CircuitState }{from, to}) }) // Trip the circuit - cb.RecordFailure(http.StatusTooManyRequests) - cb.RecordFailure(http.StatusTooManyRequests) - - // Wait for callback - time.Sleep(10 * time.Millisecond) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) // Wait for cooldown and trigger half-open time.Sleep(60 * time.Millisecond) - cb.Allow() - - // Wait for callback - time.Sleep(10 * time.Millisecond) + cbs.Allow("test", "/api") // Success to close - cb.RecordSuccess() + cbs.RecordSuccess("test", "/api") - // Wait for callback - time.Sleep(10 * time.Millisecond) + // Wait for callbacks + time.Sleep(20 * time.Millisecond) mu.Lock() defer mu.Unlock() @@ -270,82 +217,6 @@ func TestCircuitBreaker_StateChangeCallback(t *testing.T) { assert.Equal(t, CircuitClosed, transitions[2].to) } -func TestCircuitBreakerManager_GetOrCreate(t *testing.T) { - t.Parallel() - - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 5, - Window: time.Minute, - Cooldown: time.Minute, - } - manager := NewCircuitBreakerManager(cfg) - - // First call should create - cb1 := manager.GetOrCreate("anthropic") - require.NotNil(t, cb1) - assert.Equal(t, "anthropic", cb1.Provider()) - - // Second call should return same instance - cb2 := manager.GetOrCreate("anthropic") - assert.Same(t, cb1, cb2) - - // Different provider gets different instance - cb3 := manager.GetOrCreate("openai") - require.NotNil(t, cb3) - assert.NotSame(t, cb1, cb3) - assert.Equal(t, "openai", cb3.Provider()) -} - -func TestCircuitBreakerManager_AllStates(t *testing.T) { - t.Parallel() - - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1, - Window: time.Minute, - Cooldown: time.Minute, - } - manager := NewCircuitBreakerManager(cfg) - - manager.GetOrCreate("anthropic") - manager.GetOrCreate("openai") - - // Trip one circuit - manager.Get("anthropic").RecordFailure(http.StatusTooManyRequests) - - states := manager.AllStates() - assert.Equal(t, CircuitOpen, states["anthropic"]) - assert.Equal(t, CircuitClosed, states["openai"]) -} - -func TestCircuitBreakerManager_ConcurrentGetOrCreate(t *testing.T) { - t.Parallel() - - cfg := DefaultCircuitBreakerConfig() - cfg.Enabled = true - manager := NewCircuitBreakerManager(cfg) - - var wg sync.WaitGroup - var results [100]*CircuitBreaker - - for i := 0; i < 100; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - results[idx] = manager.GetOrCreate("test-provider") - }(i) - } - - wg.Wait() - - // All should be the same instance - first := results[0] - for i := 1; i < 100; i++ { - assert.Same(t, first, results[i]) - } -} - func TestIsCircuitBreakerFailure(t *testing.T) { t.Parallel() @@ -356,11 +227,8 @@ func TestIsCircuitBreakerFailure(t *testing.T) { {http.StatusOK, false}, {http.StatusBadRequest, false}, {http.StatusUnauthorized, false}, - {http.StatusForbidden, false}, - {http.StatusNotFound, false}, {http.StatusTooManyRequests, true}, // 429 {http.StatusInternalServerError, false}, - {http.StatusBadGateway, false}, {http.StatusServiceUnavailable, true}, // 503 {529, true}, // Anthropic Overloaded } diff --git a/go.mod b/go.mod index 9a62089..4715f99 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,8 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 ) +require github.com/sony/gobreaker/v2 v2.3.0 + require ( github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect diff --git a/go.sum b/go.sum index 385345d..fff1ee3 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sony/gobreaker/v2 v2.3.0 h1:7VYxZ69QXRQ2Q4eEawHn6eU4FiuwovzJwsUMA03Lu4I= +github.com/sony/gobreaker/v2 v2.3.0/go.mod h1:pTyFJgcZ3h2tdQVLZZruK2C0eoFL1fb/G83wK1ZQl+s= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/interception.go b/interception.go index d1d8444..f1d7472 100644 --- a/interception.go +++ b/interception.go @@ -40,23 +40,26 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbManager *CircuitBreakerManager) http.HandlerFunc { +func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbs *CircuitBreakers) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() + // Extract endpoint (route) for per-endpoint circuit breaker + route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) + // Check circuit breaker before proceeding - cb := cbManager.GetOrCreate(p.Name()) - if !cb.Allow() { + if !cbs.Allow(p.Name(), route) { span.SetStatus(codes.Error, "circuit breaker open") logger.Debug(ctx, "request rejected by circuit breaker", slog.F("provider", p.Name()), - slog.F("circuit_state", cb.State().String()), + slog.F("endpoint", route), + slog.F("circuit_state", cbs.State(p.Name(), route).String()), ) if metrics != nil { - metrics.CircuitBreakerRejects.WithLabelValues(p.Name()).Inc() + metrics.CircuitBreakerRejects.WithLabelValues(p.Name(), route).Inc() } - http.Error(w, fmt.Sprintf("%s is currently unavailable due to upstream rate limiting. Please try again later.", p.Name()), http.StatusServiceUnavailable) + http.Error(w, fmt.Sprintf("%s %s is currently unavailable due to upstream rate limiting. Please try again later.", p.Name(), route), http.StatusServiceUnavailable) return } @@ -108,7 +111,6 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server return } - route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) log := logger.With( slog.F("route", route), slog.F("provider", p.Name()), @@ -134,7 +136,7 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server // Record failure for circuit breaker - extract status code if available if statusCode := extractStatusCodeFromError(err); statusCode > 0 { - if cb.RecordFailure(statusCode) { + if cbs.RecordFailure(p.Name(), route, statusCode) { log.Warn(ctx, "circuit breaker tripped", slog.F("provider", p.Name()), slog.F("status_code", statusCode), @@ -148,7 +150,7 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server log.Debug(ctx, "interception ended") // Record success for circuit breaker - cb.RecordSuccess() + cbs.RecordSuccess(p.Name(), route) } asyncRecorder.RecordInterceptionEnded(ctx, &InterceptionRecordEnded{ID: interceptor.ID().String()}) diff --git a/metrics.go b/metrics.go index 565029a..f744d10 100644 --- a/metrics.go +++ b/metrics.go @@ -110,23 +110,23 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { // Circuit breaker metrics. - // Pessimistic cardinality: 2 providers = up to 2. + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ Subsystem: "circuit_breaker", Name: "state", Help: "Current state of the circuit breaker (0=closed, 1=open, 2=half-open).", - }, []string{"provider"}), - // Pessimistic cardinality: 2 providers = up to 2. + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "trips_total", Help: "Total number of times the circuit breaker has tripped open.", - }, []string{"provider"}), - // Pessimistic cardinality: 2 providers = up to 2. + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "rejects_total", Help: "Total number of requests rejected due to open circuit breaker.", - }, []string{"provider"}), + }, []string{"provider", "endpoint"}), } } From 8cf2d18baddab1342b417179f4440c2a2f8268c7 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Tue, 16 Dec 2025 10:30:03 +0000 Subject: [PATCH 4/5] refactor: align CircuitBreakerConfig fields with gobreaker.Settings Rename fields to match gobreaker naming convention: - Window -> Interval - Cooldown -> Timeout - HalfOpenMaxRequests -> MaxRequests - FailureThreshold type int64 -> uint32 --- circuit_breaker.go | 33 ++++++++++---------- circuit_breaker_test.go | 68 ++++++++++++++++++++--------------------- 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index d27f0cd..5ecf674 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -49,27 +49,28 @@ func toCircuitState(s gobreaker.State) CircuitState { } // CircuitBreakerConfig holds configuration for circuit breakers. +// Fields match gobreaker.Settings for clarity. type CircuitBreakerConfig struct { // Enabled controls whether circuit breakers are active. Enabled bool + // MaxRequests is the maximum number of requests allowed in half-open state. + MaxRequests uint32 + // Interval is the cyclic period of the closed state for clearing internal counts. + Interval time.Duration + // Timeout is how long the circuit stays open before transitioning to half-open. + Timeout time.Duration // FailureThreshold is the number of consecutive failures that triggers the circuit to open. - FailureThreshold int64 - // Window is the time window for counting failures. - Window time.Duration - // Cooldown is how long the circuit stays open before transitioning to half-open. - Cooldown time.Duration - // HalfOpenMaxRequests is the maximum number of requests allowed in half-open state. - HalfOpenMaxRequests int64 + FailureThreshold uint32 } // DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. func DefaultCircuitBreakerConfig() CircuitBreakerConfig { return CircuitBreakerConfig{ - Enabled: false, // Disabled by default for backward compatibility - FailureThreshold: 5, - Window: 10 * time.Second, - Cooldown: 30 * time.Second, - HalfOpenMaxRequests: 3, + Enabled: false, // Disabled by default for backward compatibility + FailureThreshold: 5, + Interval: 10 * time.Second, + Timeout: 30 * time.Second, + MaxRequests: 3, } } @@ -150,11 +151,11 @@ func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.Circ settings := gobreaker.Settings{ Name: key, - MaxRequests: uint32(c.config.HalfOpenMaxRequests), - Interval: c.config.Window, - Timeout: c.config.Cooldown, + MaxRequests: c.config.MaxRequests, + Interval: c.config.Interval, + Timeout: c.config.Timeout, ReadyToTrip: func(counts gobreaker.Counts) bool { - return counts.ConsecutiveFailures >= uint32(c.config.FailureThreshold) + return counts.ConsecutiveFailures >= c.config.FailureThreshold }, OnStateChange: func(name string, from, to gobreaker.State) { if c.onChange != nil { diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 0d8deb9..c7e2c82 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -15,10 +15,10 @@ func TestCircuitBreaker_DefaultConfig(t *testing.T) { cfg := DefaultCircuitBreakerConfig() assert.False(t, cfg.Enabled, "should be disabled by default") - assert.Equal(t, int64(5), cfg.FailureThreshold) - assert.Equal(t, 10*time.Second, cfg.Window) - assert.Equal(t, 30*time.Second, cfg.Cooldown) - assert.Equal(t, int64(3), cfg.HalfOpenMaxRequests) + assert.Equal(t, uint32(5), cfg.FailureThreshold) + assert.Equal(t, 10*time.Second, cfg.Interval) + assert.Equal(t, 30*time.Second, cfg.Timeout) + assert.Equal(t, uint32(3), cfg.MaxRequests) } func TestCircuitBreakers_DisabledByDefault(t *testing.T) { @@ -41,11 +41,11 @@ func TestCircuitBreakers_StateTransitions(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 3, - Window: time.Minute, - Cooldown: 50 * time.Millisecond, - HalfOpenMaxRequests: 2, + Enabled: true, + FailureThreshold: 3, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 2, } cbs := NewCircuitBreakers(cfg, nil) @@ -81,11 +81,11 @@ func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1, - Window: time.Minute, - Cooldown: time.Minute, - HalfOpenMaxRequests: 1, + Enabled: true, + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, } cbs := NewCircuitBreakers(cfg, nil) @@ -104,11 +104,11 @@ func TestCircuitBreakers_OnlyCountsRelevantStatusCodes(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 2, - Window: time.Minute, - Cooldown: time.Minute, - HalfOpenMaxRequests: 2, + Enabled: true, + FailureThreshold: 2, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 2, } cbs := NewCircuitBreakers(cfg, nil) @@ -129,11 +129,11 @@ func TestCircuitBreakers_Anthropic529(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1, - Window: time.Minute, - Cooldown: time.Minute, - HalfOpenMaxRequests: 1, + Enabled: true, + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, } cbs := NewCircuitBreakers(cfg, nil) @@ -147,11 +147,11 @@ func TestCircuitBreakers_ConcurrentAccess(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1000, - Window: time.Minute, - Cooldown: time.Minute, - HalfOpenMaxRequests: 10, + Enabled: true, + FailureThreshold: 1000, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 10, } cbs := NewCircuitBreakers(cfg, nil) @@ -176,11 +176,11 @@ func TestCircuitBreakers_StateChangeCallback(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 2, - Window: time.Minute, - Cooldown: 50 * time.Millisecond, - HalfOpenMaxRequests: 1, + Enabled: true, + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, } var mu sync.Mutex From 8e44145e9bb28a2fb0ceefcbcb7cbe85380e780b Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Tue, 16 Dec 2025 10:32:01 +0000 Subject: [PATCH 5/5] refactor: remove CircuitState, use gobreaker.State directly --- bridge.go | 10 ++++---- circuit_breaker.go | 51 +++++----------------------------------- circuit_breaker_test.go | 52 +++++++++++++++++------------------------ interception.go | 2 +- 4 files changed, 34 insertions(+), 81 deletions(-) diff --git a/bridge.go b/bridge.go index 3147486..4cafa6b 100644 --- a/bridge.go +++ b/bridge.go @@ -10,9 +10,9 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge/mcp" - "go.opentelemetry.io/otel/trace" - "github.com/hashicorp/go-multierror" + "github.com/sony/gobreaker/v2" + "go.opentelemetry.io/otel/trace" ) // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; @@ -63,12 +63,12 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide mux := http.NewServeMux() // Create circuit breakers with metrics callback - var onChange func(name string, from, to CircuitState) + var onChange func(name string, from, to gobreaker.State) if metrics != nil { - onChange = func(name string, from, to CircuitState) { + onChange = func(name string, from, to gobreaker.State) { provider, endpoint, _ := strings.Cut(name, ":") metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(float64(to)) - if to == CircuitOpen { + if to == gobreaker.StateOpen { metrics.CircuitBreakerTrips.WithLabelValues(provider, endpoint).Inc() } } diff --git a/circuit_breaker.go b/circuit_breaker.go index 5ecf674..06754bb 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -9,45 +9,6 @@ import ( "github.com/sony/gobreaker/v2" ) -// CircuitState represents the current state of a circuit breaker. -type CircuitState int - -const ( - // CircuitClosed is the normal state - all requests pass through. - CircuitClosed CircuitState = iota - // CircuitOpen is the tripped state - requests are rejected immediately. - CircuitOpen - // CircuitHalfOpen is the testing state - limited requests pass through. - CircuitHalfOpen -) - -func (s CircuitState) String() string { - switch s { - case CircuitClosed: - return "closed" - case CircuitOpen: - return "open" - case CircuitHalfOpen: - return "half-open" - default: - return "unknown" - } -} - -// toCircuitState converts gobreaker.State to our CircuitState. -func toCircuitState(s gobreaker.State) CircuitState { - switch s { - case gobreaker.StateClosed: - return CircuitClosed - case gobreaker.StateOpen: - return CircuitOpen - case gobreaker.StateHalfOpen: - return CircuitHalfOpen - default: - return CircuitClosed - } -} - // CircuitBreakerConfig holds configuration for circuit breakers. // Fields match gobreaker.Settings for clarity. type CircuitBreakerConfig struct { @@ -92,11 +53,11 @@ func isCircuitBreakerFailure(statusCode int) bool { type CircuitBreakers struct { breakers sync.Map // map[string]*gobreaker.CircuitBreaker[any] config CircuitBreakerConfig - onChange func(name string, from, to CircuitState) + onChange func(name string, from, to gobreaker.State) } // NewCircuitBreakers creates a new circuit breaker manager. -func NewCircuitBreakers(config CircuitBreakerConfig, onChange func(name string, from, to CircuitState)) *CircuitBreakers { +func NewCircuitBreakers(config CircuitBreakerConfig, onChange func(name string, from, to gobreaker.State)) *CircuitBreakers { return &CircuitBreakers{ config: config, onChange: onChange, @@ -135,12 +96,12 @@ func (c *CircuitBreakers) RecordFailure(provider, endpoint string, statusCode in } // State returns the current state for a provider/endpoint. -func (c *CircuitBreakers) State(provider, endpoint string) CircuitState { +func (c *CircuitBreakers) State(provider, endpoint string) gobreaker.State { if !c.config.Enabled { - return CircuitClosed + return gobreaker.StateClosed } cb := c.getOrCreate(provider, endpoint) - return toCircuitState(cb.State()) + return cb.State() } func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.CircuitBreaker[any] { @@ -159,7 +120,7 @@ func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.Circ }, OnStateChange: func(name string, from, to gobreaker.State) { if c.onChange != nil { - c.onChange(name, toCircuitState(from), toCircuitState(to)) + c.onChange(name, from, to) } }, } diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index c7e2c82..ca45620 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/sony/gobreaker/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -34,7 +35,7 @@ func TestCircuitBreakers_DisabledByDefault(t *testing.T) { cbs.RecordFailure("anthropic", "/v1/messages", http.StatusTooManyRequests) } assert.True(t, cbs.Allow("anthropic", "/v1/messages")) - assert.Equal(t, CircuitClosed, cbs.State("anthropic", "/v1/messages")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("anthropic", "/v1/messages")) } func TestCircuitBreakers_StateTransitions(t *testing.T) { @@ -50,18 +51,18 @@ func TestCircuitBreakers_StateTransitions(t *testing.T) { cbs := NewCircuitBreakers(cfg, nil) // Start in closed state - assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) assert.True(t, cbs.Allow("test", "/api")) // Record failures below threshold cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) - assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) // Third failure should trip the circuit tripped := cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateOpen, cbs.State("test", "/api")) assert.False(t, cbs.Allow("test", "/api")) // Wait for cooldown @@ -69,12 +70,12 @@ func TestCircuitBreakers_StateTransitions(t *testing.T) { // Should transition to half-open and allow request assert.True(t, cbs.Allow("test", "/api")) - assert.Equal(t, CircuitHalfOpen, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateHalfOpen, cbs.State("test", "/api")) // Success in half-open should eventually close cbs.RecordSuccess("test", "/api") cbs.RecordSuccess("test", "/api") - assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) } func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { @@ -91,11 +92,11 @@ func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { // Trip circuit for one endpoint cbs.RecordFailure("openai", "/v1/chat/completions", http.StatusTooManyRequests) - assert.Equal(t, CircuitOpen, cbs.State("openai", "/v1/chat/completions")) + assert.Equal(t, gobreaker.StateOpen, cbs.State("openai", "/v1/chat/completions")) // Other endpoints should still be closed - assert.Equal(t, CircuitClosed, cbs.State("openai", "/v1/responses")) - assert.Equal(t, CircuitClosed, cbs.State("anthropic", "/v1/messages")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("openai", "/v1/responses")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("anthropic", "/v1/messages")) assert.True(t, cbs.Allow("openai", "/v1/responses")) assert.True(t, cbs.Allow("anthropic", "/v1/messages")) } @@ -117,12 +118,12 @@ func TestCircuitBreakers_OnlyCountsRelevantStatusCodes(t *testing.T) { cbs.RecordFailure("test", "/api", http.StatusUnauthorized) // 401 cbs.RecordFailure("test", "/api", http.StatusInternalServerError) // 500 cbs.RecordFailure("test", "/api", http.StatusBadGateway) // 502 - assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) // These should count cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) // 429 cbs.RecordFailure("test", "/api", http.StatusServiceUnavailable) // 503 - assert.Equal(t, CircuitOpen, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateOpen, cbs.State("test", "/api")) } func TestCircuitBreakers_Anthropic529(t *testing.T) { @@ -140,7 +141,7 @@ func TestCircuitBreakers_Anthropic529(t *testing.T) { // Anthropic-specific 529 "Overloaded" should trip the circuit tripped := cbs.RecordFailure("anthropic", "/v1/messages", 529) assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cbs.State("anthropic", "/v1/messages")) + assert.Equal(t, gobreaker.StateOpen, cbs.State("anthropic", "/v1/messages")) } func TestCircuitBreakers_ConcurrentAccess(t *testing.T) { @@ -184,12 +185,12 @@ func TestCircuitBreakers_StateChangeCallback(t *testing.T) { } var mu sync.Mutex - var transitions []struct{ from, to CircuitState } + var transitions []struct{ from, to gobreaker.State } - cbs := NewCircuitBreakers(cfg, func(name string, from, to CircuitState) { + cbs := NewCircuitBreakers(cfg, func(name string, from, to gobreaker.State) { mu.Lock() defer mu.Unlock() - transitions = append(transitions, struct{ from, to CircuitState }{from, to}) + transitions = append(transitions, struct{ from, to gobreaker.State }{from, to}) }) // Trip the circuit @@ -209,12 +210,12 @@ func TestCircuitBreakers_StateChangeCallback(t *testing.T) { mu.Lock() defer mu.Unlock() require.Len(t, transitions, 3) - assert.Equal(t, CircuitClosed, transitions[0].from) - assert.Equal(t, CircuitOpen, transitions[0].to) - assert.Equal(t, CircuitOpen, transitions[1].from) - assert.Equal(t, CircuitHalfOpen, transitions[1].to) - assert.Equal(t, CircuitHalfOpen, transitions[2].from) - assert.Equal(t, CircuitClosed, transitions[2].to) + assert.Equal(t, gobreaker.StateClosed, transitions[0].from) + assert.Equal(t, gobreaker.StateOpen, transitions[0].to) + assert.Equal(t, gobreaker.StateOpen, transitions[1].from) + assert.Equal(t, gobreaker.StateHalfOpen, transitions[1].to) + assert.Equal(t, gobreaker.StateHalfOpen, transitions[2].from) + assert.Equal(t, gobreaker.StateClosed, transitions[2].to) } func TestIsCircuitBreakerFailure(t *testing.T) { @@ -239,12 +240,3 @@ func TestIsCircuitBreakerFailure(t *testing.T) { }) } } - -func TestCircuitState_String(t *testing.T) { - t.Parallel() - - assert.Equal(t, "closed", CircuitClosed.String()) - assert.Equal(t, "open", CircuitOpen.String()) - assert.Equal(t, "half-open", CircuitHalfOpen.String()) - assert.Equal(t, "unknown", CircuitState(99).String()) -} diff --git a/interception.go b/interception.go index f1d7472..da589be 100644 --- a/interception.go +++ b/interception.go @@ -54,7 +54,7 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server logger.Debug(ctx, "request rejected by circuit breaker", slog.F("provider", p.Name()), slog.F("endpoint", route), - slog.F("circuit_state", cbs.State(p.Name(), route).String()), + slog.F("circuit_state", cbs.State(p.Name(), route)), ) if metrics != nil { metrics.CircuitBreakerRejects.WithLabelValues(p.Name(), route).Inc()