Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions internal/gate/decide.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"os"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -98,6 +99,26 @@ func (e Engine) Decide(ctx context.Context, p parsers.ParseResult) Result {
}

func (e Engine) failOpenOrClosed(err error) Result {
// Rate-limit errors carry structured quota info on the hosted backend.
// Render those specially in both fail-open and fail-closed modes — the
// raw "server: rate limited" line is technically correct but useless
// at the terminal.
var rate *server.RateLimitedError
if errors.As(err, &rate) {
if e.Policy.FailClosed {
return Result{
Decision: DecisionBlock,
Message: renderRateLimited(rate, true),
ServerError: err,
}
}
return Result{
Decision: DecisionAllow,
Message: renderRateLimited(rate, false),
ServerError: err,
}
}

if e.Policy.FailClosed {
return Result{
Decision: DecisionBlock,
Expand All @@ -109,6 +130,8 @@ func (e Engine) failOpenOrClosed(err error) Result {
if errors.Is(err, server.ErrServerUnreachable) {
hint = " (set REFUSE_FAIL_CLOSED=1 to require gate)"
} else if errors.Is(err, server.ErrRateLimited) {
// Hit only when the server returned a 429 with no parseable body
// (older self-hosted versions, upstream 429 from a proxy).
hint = " — rate limited, allowing install (upgrade plan to raise the limit)"
} else if errors.Is(err, server.ErrUnauthorized) {
hint = " — set REFUSE_API_KEY or run `refuse init`"
Expand All @@ -120,6 +143,74 @@ func (e Engine) failOpenOrClosed(err error) Result {
}
}

// renderRateLimited formats a rich, multi-line message for a 429 from the
// hosted backend. Includes used/limit, reset window, optional upgrade URL,
// and the action knob (REFUSE_FAIL_CLOSED) so the user can change behavior
// without reading docs.
func renderRateLimited(rate *server.RateLimitedError, blocked bool) string {
var b strings.Builder
b.WriteString("refuse: rate limited — account quota ")
b.WriteString(formatThousands(rate.Used))
b.WriteString("/")
b.WriteString(formatThousands(rate.Limit))
b.WriteString(" used")
if !rate.PeriodEnd.IsZero() {
days := int(time.Until(rate.PeriodEnd).Hours() / 24)
if days < 0 {
days = 0
}
fmt.Fprintf(&b, " (resets %s, %d %s)",
rate.PeriodEnd.Format("Jan 2"),
days,
pluralize(days, "day", "days"))
}
if rate.UpgradeURL != "" {
b.WriteString("\n upgrade: ")
b.WriteString(rate.UpgradeURL)
}
if blocked {
b.WriteString("\n install blocked (REFUSE_FAIL_CLOSED=1)")
} else {
b.WriteString("\n install allowed — set REFUSE_FAIL_CLOSED=1 to block on rate limit")
}
return b.String()
}

// formatThousands turns 100000 into "100,000". Avoids pulling in
// golang.org/x/text just for this.
func formatThousands(n int64) string {
if n < 0 {
return "-" + formatThousands(-n)
}
s := strconv.FormatInt(n, 10)
if len(s) <= 3 {
return s
}
// Walk back-to-front inserting commas every 3 digits.
out := make([]byte, 0, len(s)+(len(s)-1)/3)
first := len(s) % 3
if first > 0 {
out = append(out, s[:first]...)
if len(s) > first {
out = append(out, ',')
}
}
for i := first; i < len(s); i += 3 {
out = append(out, s[i:i+3]...)
if i+3 < len(s) {
out = append(out, ',')
}
}
return string(out)
}

func pluralize(n int, singular, plural string) string {
if n == 1 {
return singular
}
return plural
}

func renderBlock(resp server.BatchCheckResponse) string {
var b strings.Builder
b.WriteString("refuse: blocked — vulnerable package(s)\n")
Expand Down
91 changes: 91 additions & 0 deletions internal/gate/decide_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package gate
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/RefuseHQ/refuse-cli/internal/config"
"github.com/RefuseHQ/refuse-cli/internal/parsers"
Expand Down Expand Up @@ -122,6 +124,95 @@ func TestGateFailClosedOnServerError(t *testing.T) {
}
}

// quotaServer returns a server whose POSTs all 429 with a quota body shaped
// like mcp.refuse.dev's. periodEnd is rendered fresh on each call so the
// "resets in N days" math always reads sensibly relative to the test's
// wall clock.
func quotaServer(t *testing.T, used, limit int64, plan string, periodEnd time.Time, upgrade string) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
fmt.Fprintf(w, `{
"error":"Quota exceeded",
"quota":{"used":%d,"limit":%d,"plan":%q,"period":"2026-05-28","period_end":%q},
"upgrade":%q
}`, used, limit, plan, periodEnd.Format(time.RFC3339), upgrade)
}))
}

func TestGateRateLimitMessage_FailOpen(t *testing.T) {
resetAt := time.Now().Add(16 * 24 * time.Hour)
srv := quotaServer(t, 100_000, 100_000, "free", resetAt, "https://refuse.dev/pricing")
defer srv.Close()

eng := Engine{Client: server.New(srv.URL, ""), Policy: config.Policy{SeverityThreshold: "high"}}
res := eng.Decide(context.Background(), parsers.ParseResult{
IsInstall: true, Mode: parsers.ModeDirect,
Packages: []parsers.PkgRef{{Ecosystem: "npm", Name: "x", Version: "1.0.0"}},
})
if res.Decision != DecisionAllow {
t.Fatalf("expected DecisionAllow, got %v", res.Decision)
}
must := func(s string) {
t.Helper()
if !strings.Contains(res.Message, s) {
t.Errorf("message missing %q\n--- got ---\n%s", s, res.Message)
}
}
must("refuse: rate limited")
must("account quota 100,000/100,000 used")
must("https://refuse.dev/pricing")
must("REFUSE_FAIL_CLOSED=1 to block")
if strings.Contains(res.Message, "fail-open") || strings.Contains(res.Message, "fail-closed") {
t.Errorf("avoid jargon in the rendered message\n--- got ---\n%s", res.Message)
}
}

func TestGateRateLimitMessage_FailClosed(t *testing.T) {
resetAt := time.Now().Add(2 * 24 * time.Hour)
srv := quotaServer(t, 5_000, 5_000, "free", resetAt, "")
defer srv.Close()

eng := Engine{Client: server.New(srv.URL, ""), Policy: config.Policy{SeverityThreshold: "high", FailClosed: true}}
res := eng.Decide(context.Background(), parsers.ParseResult{
IsInstall: true, Mode: parsers.ModeDirect,
Packages: []parsers.PkgRef{{Ecosystem: "npm", Name: "x", Version: "1.0.0"}},
})
if res.Decision != DecisionBlock {
t.Fatalf("expected DecisionBlock under fail-closed, got %v", res.Decision)
}
if !strings.Contains(res.Message, "install blocked") {
t.Errorf("expected 'install blocked' in message:\n%s", res.Message)
}
// No upgrade URL was supplied — message shouldn't make one up.
if strings.Contains(res.Message, "https://") {
t.Errorf("message synthesized an upgrade URL the server didn't return:\n%s", res.Message)
}
}

func TestGateRateLimit_FallbackWhenNoBody(t *testing.T) {
// Some 429s come from upstream proxies with no quota body — we should
// still fall back to the plain "rate limited" message and fail-open
// without crashing.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
}))
defer srv.Close()

eng := Engine{Client: server.New(srv.URL, ""), Policy: config.Policy{SeverityThreshold: "high"}}
res := eng.Decide(context.Background(), parsers.ParseResult{
IsInstall: true, Mode: parsers.ModeDirect,
Packages: []parsers.PkgRef{{Ecosystem: "npm", Name: "x", Version: "1.0.0"}},
})
if res.Decision != DecisionAllow {
t.Fatalf("expected DecisionAllow on bodyless 429, got %v", res.Decision)
}
if !strings.Contains(res.Message, "rate limited") {
t.Errorf("expected 'rate limited' in fallback message:\n%s", res.Message)
}
}

func TestGateAllowsLockfileMode(t *testing.T) {
srv := fakeServer(t, server.BatchCheckResponse{}, 200)
defer srv.Close()
Expand Down
76 changes: 75 additions & 1 deletion internal/server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,39 @@ func New(baseURL, apiKey string) *Client {
// ErrUnauthorized is returned on a 401 from the server (missing or invalid key).
var ErrUnauthorized = errors.New("server: unauthorized")

// ErrRateLimited is returned on a 429.
// ErrRateLimited is returned on a 429. When the server provided a parseable
// quota body, a *RateLimitedError is returned instead — it unwraps to this
// sentinel so existing `errors.Is(err, ErrRateLimited)` checks keep working
// while richer callers can `errors.As` to extract used/limit/reset details.
var ErrRateLimited = errors.New("server: rate limited")

// RateLimitedError wraps ErrRateLimited with the per-account quota
// information mcp.refuse.dev returns alongside its 429 response. The body
// looks like:
//
// { "error": "Quota exceeded",
// "quota": { "used": …, "limit": …, "period": "YYYY-MM-DD",
// "period_end": "ISO-8601", "plan": "free" },
// "upgrade": "https://refuse.dev/pricing" }
//
// When the response is empty or malformed we fall back to the plain
// ErrRateLimited so callers always know they hit a rate limit.
type RateLimitedError struct {
Used int64
Limit int64
Plan string
Period string // YYYY-MM-DD cycle start
PeriodEnd time.Time // zero if the server didn't send a usable date
UpgradeURL string // empty when none was offered (e.g. Pro users)
}

// Error implements error. The message intentionally omits "server: " so it
// renders naturally when callers prepend their own prefix ("refuse: %v").
func (e *RateLimitedError) Error() string { return "rate limited" }

// Unwrap lets `errors.Is(err, ErrRateLimited)` see through this wrapper.
func (e *RateLimitedError) Unwrap() error { return ErrRateLimited }

// ErrServerUnreachable is returned on connection / timeout errors.
var ErrServerUnreachable = errors.New("server: unreachable")

Expand Down Expand Up @@ -115,8 +145,52 @@ func (c *Client) post(ctx context.Context, path string, in any, out any) error {
case http.StatusUnauthorized:
return ErrUnauthorized
case http.StatusTooManyRequests:
body, _ := io.ReadAll(resp.Body)
if re := parseRateLimitBody(body); re != nil {
return re
}
return ErrRateLimited
}
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("server: %d %s: %s", resp.StatusCode, resp.Status, string(b))
}

// parseRateLimitBody pulls the structured quota info out of mcp.refuse.dev's
// 429 body. Returns nil when the body is missing, unparseable, or doesn't
// look like a quota response — callers should fall back to ErrRateLimited.
func parseRateLimitBody(body []byte) *RateLimitedError {
if len(body) == 0 {
return nil
}
var payload struct {
Quota struct {
Used int64 `json:"used"`
Limit int64 `json:"limit"`
Plan string `json:"plan"`
Period string `json:"period"`
PeriodEnd string `json:"period_end"`
} `json:"quota"`
Upgrade string `json:"upgrade"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return nil
}
// A response with no `limit` isn't a quota body — could be a generic
// 429 from a different middleware (Cloudflare, an upstream proxy, etc).
if payload.Quota.Limit == 0 {
return nil
}
out := &RateLimitedError{
Used: payload.Quota.Used,
Limit: payload.Quota.Limit,
Plan: payload.Quota.Plan,
Period: payload.Quota.Period,
UpgradeURL: payload.Upgrade,
}
if payload.Quota.PeriodEnd != "" {
if t, err := time.Parse(time.RFC3339, payload.Quota.PeriodEnd); err == nil {
out.PeriodEnd = t
}
}
return out
}
Loading
Loading