diff --git a/internal/gate/decide.go b/internal/gate/decide.go index c73ca2d..4357a14 100644 --- a/internal/gate/decide.go +++ b/internal/gate/decide.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "os" + "strconv" "strings" "time" @@ -98,6 +99,26 @@ func (e Engine) Decide(ctx context.Context, p parsers.ParseResult) Result { } func (e Engine) failOpenOrClosed(err error) Result { + // Rate-limit errors carry structured quota info on the hosted backend. + // Render those specially in both fail-open and fail-closed modes — the + // raw "server: rate limited" line is technically correct but useless + // at the terminal. + var rate *server.RateLimitedError + if errors.As(err, &rate) { + if e.Policy.FailClosed { + return Result{ + Decision: DecisionBlock, + Message: renderRateLimited(rate, true), + ServerError: err, + } + } + return Result{ + Decision: DecisionAllow, + Message: renderRateLimited(rate, false), + ServerError: err, + } + } + if e.Policy.FailClosed { return Result{ Decision: DecisionBlock, @@ -109,6 +130,8 @@ func (e Engine) failOpenOrClosed(err error) Result { if errors.Is(err, server.ErrServerUnreachable) { hint = " (set REFUSE_FAIL_CLOSED=1 to require gate)" } else if errors.Is(err, server.ErrRateLimited) { + // Hit only when the server returned a 429 with no parseable body + // (older self-hosted versions, upstream 429 from a proxy). hint = " — rate limited, allowing install (upgrade plan to raise the limit)" } else if errors.Is(err, server.ErrUnauthorized) { hint = " — set REFUSE_API_KEY or run `refuse init`" @@ -120,6 +143,74 @@ func (e Engine) failOpenOrClosed(err error) Result { } } +// renderRateLimited formats a rich, multi-line message for a 429 from the +// hosted backend. Includes used/limit, reset window, optional upgrade URL, +// and the action knob (REFUSE_FAIL_CLOSED) so the user can change behavior +// without reading docs. +func renderRateLimited(rate *server.RateLimitedError, blocked bool) string { + var b strings.Builder + b.WriteString("refuse: rate limited — account quota ") + b.WriteString(formatThousands(rate.Used)) + b.WriteString("/") + b.WriteString(formatThousands(rate.Limit)) + b.WriteString(" used") + if !rate.PeriodEnd.IsZero() { + days := int(time.Until(rate.PeriodEnd).Hours() / 24) + if days < 0 { + days = 0 + } + fmt.Fprintf(&b, " (resets %s, %d %s)", + rate.PeriodEnd.Format("Jan 2"), + days, + pluralize(days, "day", "days")) + } + if rate.UpgradeURL != "" { + b.WriteString("\n upgrade: ") + b.WriteString(rate.UpgradeURL) + } + if blocked { + b.WriteString("\n install blocked (REFUSE_FAIL_CLOSED=1)") + } else { + b.WriteString("\n install allowed — set REFUSE_FAIL_CLOSED=1 to block on rate limit") + } + return b.String() +} + +// formatThousands turns 100000 into "100,000". Avoids pulling in +// golang.org/x/text just for this. +func formatThousands(n int64) string { + if n < 0 { + return "-" + formatThousands(-n) + } + s := strconv.FormatInt(n, 10) + if len(s) <= 3 { + return s + } + // Walk back-to-front inserting commas every 3 digits. + out := make([]byte, 0, len(s)+(len(s)-1)/3) + first := len(s) % 3 + if first > 0 { + out = append(out, s[:first]...) + if len(s) > first { + out = append(out, ',') + } + } + for i := first; i < len(s); i += 3 { + out = append(out, s[i:i+3]...) + if i+3 < len(s) { + out = append(out, ',') + } + } + return string(out) +} + +func pluralize(n int, singular, plural string) string { + if n == 1 { + return singular + } + return plural +} + func renderBlock(resp server.BatchCheckResponse) string { var b strings.Builder b.WriteString("refuse: blocked — vulnerable package(s)\n") diff --git a/internal/gate/decide_test.go b/internal/gate/decide_test.go index 5824411..6329b55 100644 --- a/internal/gate/decide_test.go +++ b/internal/gate/decide_test.go @@ -3,10 +3,12 @@ package gate import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" "testing" + "time" "github.com/RefuseHQ/refuse-cli/internal/config" "github.com/RefuseHQ/refuse-cli/internal/parsers" @@ -122,6 +124,95 @@ func TestGateFailClosedOnServerError(t *testing.T) { } } +// quotaServer returns a server whose POSTs all 429 with a quota body shaped +// like mcp.refuse.dev's. periodEnd is rendered fresh on each call so the +// "resets in N days" math always reads sensibly relative to the test's +// wall clock. +func quotaServer(t *testing.T, used, limit int64, plan string, periodEnd time.Time, upgrade string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprintf(w, `{ + "error":"Quota exceeded", + "quota":{"used":%d,"limit":%d,"plan":%q,"period":"2026-05-28","period_end":%q}, + "upgrade":%q + }`, used, limit, plan, periodEnd.Format(time.RFC3339), upgrade) + })) +} + +func TestGateRateLimitMessage_FailOpen(t *testing.T) { + resetAt := time.Now().Add(16 * 24 * time.Hour) + srv := quotaServer(t, 100_000, 100_000, "free", resetAt, "https://refuse.dev/pricing") + defer srv.Close() + + eng := Engine{Client: server.New(srv.URL, ""), Policy: config.Policy{SeverityThreshold: "high"}} + res := eng.Decide(context.Background(), parsers.ParseResult{ + IsInstall: true, Mode: parsers.ModeDirect, + Packages: []parsers.PkgRef{{Ecosystem: "npm", Name: "x", Version: "1.0.0"}}, + }) + if res.Decision != DecisionAllow { + t.Fatalf("expected DecisionAllow, got %v", res.Decision) + } + must := func(s string) { + t.Helper() + if !strings.Contains(res.Message, s) { + t.Errorf("message missing %q\n--- got ---\n%s", s, res.Message) + } + } + must("refuse: rate limited") + must("account quota 100,000/100,000 used") + must("https://refuse.dev/pricing") + must("REFUSE_FAIL_CLOSED=1 to block") + if strings.Contains(res.Message, "fail-open") || strings.Contains(res.Message, "fail-closed") { + t.Errorf("avoid jargon in the rendered message\n--- got ---\n%s", res.Message) + } +} + +func TestGateRateLimitMessage_FailClosed(t *testing.T) { + resetAt := time.Now().Add(2 * 24 * time.Hour) + srv := quotaServer(t, 5_000, 5_000, "free", resetAt, "") + defer srv.Close() + + eng := Engine{Client: server.New(srv.URL, ""), Policy: config.Policy{SeverityThreshold: "high", FailClosed: true}} + res := eng.Decide(context.Background(), parsers.ParseResult{ + IsInstall: true, Mode: parsers.ModeDirect, + Packages: []parsers.PkgRef{{Ecosystem: "npm", Name: "x", Version: "1.0.0"}}, + }) + if res.Decision != DecisionBlock { + t.Fatalf("expected DecisionBlock under fail-closed, got %v", res.Decision) + } + if !strings.Contains(res.Message, "install blocked") { + t.Errorf("expected 'install blocked' in message:\n%s", res.Message) + } + // No upgrade URL was supplied — message shouldn't make one up. + if strings.Contains(res.Message, "https://") { + t.Errorf("message synthesized an upgrade URL the server didn't return:\n%s", res.Message) + } +} + +func TestGateRateLimit_FallbackWhenNoBody(t *testing.T) { + // Some 429s come from upstream proxies with no quota body — we should + // still fall back to the plain "rate limited" message and fail-open + // without crashing. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + })) + defer srv.Close() + + eng := Engine{Client: server.New(srv.URL, ""), Policy: config.Policy{SeverityThreshold: "high"}} + res := eng.Decide(context.Background(), parsers.ParseResult{ + IsInstall: true, Mode: parsers.ModeDirect, + Packages: []parsers.PkgRef{{Ecosystem: "npm", Name: "x", Version: "1.0.0"}}, + }) + if res.Decision != DecisionAllow { + t.Fatalf("expected DecisionAllow on bodyless 429, got %v", res.Decision) + } + if !strings.Contains(res.Message, "rate limited") { + t.Errorf("expected 'rate limited' in fallback message:\n%s", res.Message) + } +} + func TestGateAllowsLockfileMode(t *testing.T) { srv := fakeServer(t, server.BatchCheckResponse{}, 200) defer srv.Close() diff --git a/internal/server/client.go b/internal/server/client.go index a79ce2d..d92b6e3 100644 --- a/internal/server/client.go +++ b/internal/server/client.go @@ -48,9 +48,39 @@ func New(baseURL, apiKey string) *Client { // ErrUnauthorized is returned on a 401 from the server (missing or invalid key). var ErrUnauthorized = errors.New("server: unauthorized") -// ErrRateLimited is returned on a 429. +// ErrRateLimited is returned on a 429. When the server provided a parseable +// quota body, a *RateLimitedError is returned instead — it unwraps to this +// sentinel so existing `errors.Is(err, ErrRateLimited)` checks keep working +// while richer callers can `errors.As` to extract used/limit/reset details. var ErrRateLimited = errors.New("server: rate limited") +// RateLimitedError wraps ErrRateLimited with the per-account quota +// information mcp.refuse.dev returns alongside its 429 response. The body +// looks like: +// +// { "error": "Quota exceeded", +// "quota": { "used": …, "limit": …, "period": "YYYY-MM-DD", +// "period_end": "ISO-8601", "plan": "free" }, +// "upgrade": "https://refuse.dev/pricing" } +// +// When the response is empty or malformed we fall back to the plain +// ErrRateLimited so callers always know they hit a rate limit. +type RateLimitedError struct { + Used int64 + Limit int64 + Plan string + Period string // YYYY-MM-DD cycle start + PeriodEnd time.Time // zero if the server didn't send a usable date + UpgradeURL string // empty when none was offered (e.g. Pro users) +} + +// Error implements error. The message intentionally omits "server: " so it +// renders naturally when callers prepend their own prefix ("refuse: %v"). +func (e *RateLimitedError) Error() string { return "rate limited" } + +// Unwrap lets `errors.Is(err, ErrRateLimited)` see through this wrapper. +func (e *RateLimitedError) Unwrap() error { return ErrRateLimited } + // ErrServerUnreachable is returned on connection / timeout errors. var ErrServerUnreachable = errors.New("server: unreachable") @@ -115,8 +145,52 @@ func (c *Client) post(ctx context.Context, path string, in any, out any) error { case http.StatusUnauthorized: return ErrUnauthorized case http.StatusTooManyRequests: + body, _ := io.ReadAll(resp.Body) + if re := parseRateLimitBody(body); re != nil { + return re + } return ErrRateLimited } b, _ := io.ReadAll(resp.Body) return fmt.Errorf("server: %d %s: %s", resp.StatusCode, resp.Status, string(b)) } + +// parseRateLimitBody pulls the structured quota info out of mcp.refuse.dev's +// 429 body. Returns nil when the body is missing, unparseable, or doesn't +// look like a quota response — callers should fall back to ErrRateLimited. +func parseRateLimitBody(body []byte) *RateLimitedError { + if len(body) == 0 { + return nil + } + var payload struct { + Quota struct { + Used int64 `json:"used"` + Limit int64 `json:"limit"` + Plan string `json:"plan"` + Period string `json:"period"` + PeriodEnd string `json:"period_end"` + } `json:"quota"` + Upgrade string `json:"upgrade"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil + } + // A response with no `limit` isn't a quota body — could be a generic + // 429 from a different middleware (Cloudflare, an upstream proxy, etc). + if payload.Quota.Limit == 0 { + return nil + } + out := &RateLimitedError{ + Used: payload.Quota.Used, + Limit: payload.Quota.Limit, + Plan: payload.Quota.Plan, + Period: payload.Quota.Period, + UpgradeURL: payload.Upgrade, + } + if payload.Quota.PeriodEnd != "" { + if t, err := time.Parse(time.RFC3339, payload.Quota.PeriodEnd); err == nil { + out.PeriodEnd = t + } + } + return out +} diff --git a/internal/server/client_test.go b/internal/server/client_test.go new file mode 100644 index 0000000..b3935f3 --- /dev/null +++ b/internal/server/client_test.go @@ -0,0 +1,80 @@ +package server + +import ( + "errors" + "testing" + "time" +) + +func TestParseRateLimitBody_FullPayload(t *testing.T) { + body := []byte(`{ + "error": "Quota exceeded", + "quota": { + "used": 100000, + "limit": 100000, + "plan": "free", + "period": "2026-05-28", + "period_end": "2026-06-27T08:11:17.540Z" + }, + "upgrade": "https://refuse.dev/pricing" + }`) + + re := parseRateLimitBody(body) + if re == nil { + t.Fatal("expected RateLimitedError, got nil") + } + if got, want := re.Used, int64(100000); got != want { + t.Errorf("Used = %d, want %d", got, want) + } + if got, want := re.Limit, int64(100000); got != want { + t.Errorf("Limit = %d, want %d", got, want) + } + if got, want := re.Plan, "free"; got != want { + t.Errorf("Plan = %q, want %q", got, want) + } + if got, want := re.Period, "2026-05-28"; got != want { + t.Errorf("Period = %q, want %q", got, want) + } + wantEnd, _ := time.Parse(time.RFC3339, "2026-06-27T08:11:17.540Z") + if !re.PeriodEnd.Equal(wantEnd) { + t.Errorf("PeriodEnd = %s, want %s", re.PeriodEnd, wantEnd) + } + if got, want := re.UpgradeURL, "https://refuse.dev/pricing"; got != want { + t.Errorf("UpgradeURL = %q, want %q", got, want) + } +} + +func TestParseRateLimitBody_UnwrapsToErrRateLimited(t *testing.T) { + body := []byte(`{"quota":{"used":1,"limit":2}}`) + re := parseRateLimitBody(body) + if re == nil { + t.Fatal("expected RateLimitedError") + } + if !errors.Is(re, ErrRateLimited) { + t.Error("errors.Is(re, ErrRateLimited) should be true") + } +} + +func TestParseRateLimitBody_RejectsNonQuotaBody(t *testing.T) { + cases := map[string][]byte{ + "empty": nil, + "invalid json": []byte("not json at all"), + "no quota": []byte(`{"error":"rate limit exceeded by upstream proxy"}`), + "zero limit": []byte(`{"quota":{"used":5,"limit":0}}`), + } + for name, body := range cases { + t.Run(name, func(t *testing.T) { + if got := parseRateLimitBody(body); got != nil { + t.Errorf("expected nil for %q body, got %+v", name, got) + } + }) + } +} + +func TestRateLimitedError_ErrorString(t *testing.T) { + re := &RateLimitedError{Used: 1, Limit: 2} + // Intentionally short — callers prepend "refuse: " themselves. + if got, want := re.Error(), "rate limited"; got != want { + t.Errorf("Error() = %q, want %q", got, want) + } +}