diff --git a/pkg/api/v1/registry.go b/pkg/api/v1/registry.go index c96d969aff..0911fc1e39 100644 --- a/pkg/api/v1/registry.go +++ b/pkg/api/v1/registry.go @@ -26,9 +26,10 @@ import ( // Desktop clients (Studio) match on this value to display the correct UI. const RegistryAuthRequiredCode = "registry_auth_required" -// registryAuthErrorResponse is the JSON body for HTTP 503 auth-required errors. -// Studio uses the "code" field to detect this specific condition and prompt the user. -type registryAuthErrorResponse struct { +// registryErrorResponse is the JSON body for structured HTTP 503 error responses. +// The "code" field allows clients (e.g. Studio) to distinguish between +// "registry_auth_required" and "registry_unavailable" conditions. +type registryErrorResponse struct { Code string `json:"code"` Message string `json:"message"` } @@ -38,7 +39,7 @@ type registryAuthErrorResponse struct { // but thv serve itself lacks a valid registry credential. This is a server-side dependency // issue, not a client auth failure (which would be 401). func writeRegistryAuthRequiredError(w http.ResponseWriter) { - body := registryAuthErrorResponse{ + body := registryErrorResponse{ Code: RegistryAuthRequiredCode, Message: "Registry authentication required. Run 'thv registry login' to authenticate.", } @@ -47,6 +48,22 @@ func writeRegistryAuthRequiredError(w http.ResponseWriter) { _ = json.NewEncoder(w).Encode(body) } +// RegistryUnavailableCode is the machine-readable error code returned in the +// structured JSON 503 response when the upstream registry is unreachable. +const RegistryUnavailableCode = "registry_unavailable" + +// writeRegistryUnavailableError writes a structured JSON 503 response when the +// upstream registry cannot be reached or returns an unexpected error (e.g. 404). +func writeRegistryUnavailableError(w http.ResponseWriter, unavailableErr *regpkg.UnavailableError) { + body := registryErrorResponse{ + Code: RegistryUnavailableCode, + Message: unavailableErr.Error(), + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + _ = json.NewEncoder(w).Encode(body) +} + // resolveAuthStatus returns the auth_status and auth_type strings for API responses // by delegating to the AuthManager. func (rr *RegistryRoutes) resolveAuthStatus() (authStatus, authType string) { @@ -244,6 +261,12 @@ func (rr *RegistryRoutes) getCurrentProvider(w http.ResponseWriter) (regpkg.Prov writeRegistryAuthRequiredError(w) return nil, false } + var unavailableErr *regpkg.UnavailableError + if errors.As(err, &unavailableErr) { + slog.Error("upstream registry unavailable", "error", err) + writeRegistryUnavailableError(w, unavailableErr) + return nil, false + } http.Error(w, "Failed to get registry provider", http.StatusInternalServerError) slog.Error("failed to get registry provider", "error", err) return nil, false @@ -342,6 +365,12 @@ func (rr *RegistryRoutes) listRegistries(w http.ResponseWriter, _ *http.Request) writeRegistryAuthRequiredError(w) return } + var unavailableErr *regpkg.UnavailableError + if errors.As(err, &unavailableErr) { + slog.Error("upstream registry unavailable", "error", err) + writeRegistryUnavailableError(w, unavailableErr) + return + } http.Error(w, "Failed to get registry", http.StatusInternalServerError) return } @@ -417,6 +446,12 @@ func (rr *RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { writeRegistryAuthRequiredError(w) return } + var unavailableErr *regpkg.UnavailableError + if errors.As(err, &unavailableErr) { + slog.Error("upstream registry unavailable", "error", err) + writeRegistryUnavailableError(w, unavailableErr) + return + } http.Error(w, "Failed to get registry", http.StatusInternalServerError) return } @@ -660,6 +695,12 @@ func (rr *RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { writeRegistryAuthRequiredError(w) return } + var unavailableErr *regpkg.UnavailableError + if errors.As(err, &unavailableErr) { + slog.Error("upstream registry unavailable", "error", err) + writeRegistryUnavailableError(w, unavailableErr) + return + } slog.Error("failed to get registry", "error", err) http.Error(w, "Failed to get registry", http.StatusInternalServerError) return diff --git a/pkg/api/v1/registry_test.go b/pkg/api/v1/registry_test.go index 9461bfc83b..c28c8a226c 100644 --- a/pkg/api/v1/registry_test.go +++ b/pkg/api/v1/registry_test.go @@ -49,6 +49,96 @@ func CreateTestConfigProvider(t *testing.T, cfg *config.Config) (config.Provider } } +// TestRegistryAPI_GetEndpoint_UnavailableUpstream tests that GET endpoints return +// 503 with a structured JSON response when the upstream registry API is unreachable +// or returns an unexpected error (e.g. 404 because the URL path is wrong). +// +//nolint:paralleltest // Uses global registry provider singleton +func TestRegistryAPI_GetEndpoint_UnavailableUpstream(t *testing.T) { + // Mock server that returns 404 (simulates a wrong registry API URL) + notFoundServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "404 page not found", http.StatusNotFound) + })) + defer notFoundServer.Close() + + // Configure registry to point at the mock 404 server + cfg := &config.Config{ + RegistryApiUrl: notFoundServer.URL, + AllowPrivateRegistryIp: true, + } + configProvider, cleanup := CreateTestConfigProvider(t, cfg) + defer cleanup() + + registry.ResetDefaultProvider() + t.Cleanup(registry.ResetDefaultProvider) + + routes := &RegistryRoutes{ + configProvider: configProvider, + configService: registry.NewConfiguratorWithProvider(configProvider), + serveMode: true, + } + + endpoints := []struct { + name string + method string + path string + handler http.HandlerFunc + urlParams map[string]string + }{ + { + name: "listRegistries", + method: http.MethodGet, + path: "/", + handler: routes.listRegistries, + }, + { + name: "getRegistry", + method: http.MethodGet, + path: "/default", + handler: routes.getRegistry, + urlParams: map[string]string{"name": "default"}, + }, + { + name: "listServers", + method: http.MethodGet, + path: "/default/servers", + handler: routes.listServers, + urlParams: map[string]string{"name": "default"}, + }, + } + + for _, ep := range endpoints { + t.Run(ep.name, func(t *testing.T) { + registry.ResetDefaultProvider() + + req := httptest.NewRequest(ep.method, ep.path, nil) + if ep.urlParams != nil { + rctx := chi.NewRouteContext() + for k, v := range ep.urlParams { + rctx.URLParams.Add(k, v) + } + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + } + + w := httptest.NewRecorder() + ep.handler(w, req) + + assert.Equal(t, http.StatusServiceUnavailable, w.Code, + "Expected 503 Service Unavailable for unreachable upstream registry") + + var body registryErrorResponse + err := json.NewDecoder(w.Body).Decode(&body) + require.NoError(t, err, "Response should be valid JSON") + assert.Equal(t, RegistryUnavailableCode, body.Code, + "Response code should be registry_unavailable") + assert.Contains(t, body.Message, "unavailable", + "Response message should indicate unavailability") + assert.Contains(t, w.Header().Get("Content-Type"), "application/json", + "Response Content-Type should be application/json") + }) + } +} + func TestRegistryRouter(t *testing.T) { t.Parallel() diff --git a/pkg/registry/errors.go b/pkg/registry/errors.go new file mode 100644 index 0000000000..256b27f2d4 --- /dev/null +++ b/pkg/registry/errors.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import "fmt" + +// UnavailableError indicates the upstream registry is unreachable +// or returned an unexpected (non-auth) error such as 404, timeout, or +// connection refused. API handlers translate this into HTTP 503. +type UnavailableError struct { + URL string + Err error +} + +func (e *UnavailableError) Error() string { + if e.URL != "" { + return fmt.Sprintf("upstream registry at %s is unavailable: %s", e.URL, e.Err) + } + return fmt.Sprintf("upstream registry is unavailable: %s", e.Err) +} + +func (e *UnavailableError) Unwrap() error { + return e.Err +} diff --git a/pkg/registry/errors_test.go b/pkg/registry/errors_test.go new file mode 100644 index 0000000000..71e4c77794 --- /dev/null +++ b/pkg/registry/errors_test.go @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnavailableError_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err *UnavailableError + expected string + }{ + { + name: "with URL", + err: &UnavailableError{ + URL: "https://example.com/registry", + Err: fmt.Errorf("connection refused"), + }, + expected: "upstream registry at https://example.com/registry is unavailable: connection refused", + }, + { + name: "without URL", + err: &UnavailableError{ + Err: fmt.Errorf("timeout"), + }, + expected: "upstream registry is unavailable: timeout", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.expected, tt.err.Error()) + }) + } +} + +func TestUnavailableError_Unwrap(t *testing.T) { + t.Parallel() + + inner := fmt.Errorf("registry API returned status 404") + err := &UnavailableError{URL: "https://example.com", Err: inner} + + assert.Equal(t, inner, errors.Unwrap(err)) +} + +func TestUnavailableError_ErrorsAs(t *testing.T) { + t.Parallel() + + inner := fmt.Errorf("registry API returned status 404") + original := &UnavailableError{URL: "https://example.com", Err: inner} + wrapped := fmt.Errorf("failed to create provider: %w", original) + + var target *UnavailableError + require.True(t, errors.As(wrapped, &target)) + assert.Equal(t, "https://example.com", target.URL) + assert.Equal(t, inner, target.Err) +} diff --git a/pkg/registry/provider_api.go b/pkg/registry/provider_api.go index 2b063c0570..c8bbb9f167 100644 --- a/pkg/registry/provider_api.go +++ b/pkg/registry/provider_api.go @@ -69,7 +69,7 @@ func NewAPIRegistryProvider(apiURL string, allowPrivateIp bool, tokenSource auth apiURL, auth.ErrRegistryAuthRequired, ) } - return nil, fmt.Errorf("API endpoint not functional: %w", err) + return nil, &UnavailableError{URL: apiURL, Err: err} } } @@ -96,7 +96,7 @@ func (p *APIRegistryProvider) GetRegistry() (*types.Registry, error) { if errors.Is(err, api.ErrRegistryUnauthorized) { return nil, fmt.Errorf("registry rejected credentials: %w", auth.ErrRegistryAuthRequired) } - return nil, fmt.Errorf("failed to list servers from API: %w", err) + return nil, &UnavailableError{URL: p.apiURL, Err: err} } // Convert servers to ToolHive format