Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions pkg/api/v1/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ import (
// Desktop clients (Studio) match on this value to display the correct UI.
const RegistryAuthRequiredCode = "registry_auth_required"

// registryAuthErrorResponse is the JSON body for HTTP 503 auth-required errors.
// Studio uses the "code" field to detect this specific condition and prompt the user.
type registryAuthErrorResponse struct {
// registryErrorResponse is the JSON body for structured HTTP 503 error responses.
// The "code" field allows clients (e.g. Studio) to distinguish between
// "registry_auth_required" and "registry_unavailable" conditions.
type registryErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
Expand All @@ -38,7 +39,7 @@ type registryAuthErrorResponse struct {
// but thv serve itself lacks a valid registry credential. This is a server-side dependency
// issue, not a client auth failure (which would be 401).
func writeRegistryAuthRequiredError(w http.ResponseWriter) {
body := registryAuthErrorResponse{
body := registryErrorResponse{
Code: RegistryAuthRequiredCode,
Message: "Registry authentication required. Run 'thv registry login' to authenticate.",
}
Expand All @@ -47,6 +48,22 @@ func writeRegistryAuthRequiredError(w http.ResponseWriter) {
_ = json.NewEncoder(w).Encode(body)
}

// RegistryUnavailableCode is the machine-readable error code returned in the
// structured JSON 503 response when the upstream registry is unreachable.
const RegistryUnavailableCode = "registry_unavailable"

// writeRegistryUnavailableError writes a structured JSON 503 response when the
// upstream registry cannot be reached or returns an unexpected error (e.g. 404).
func writeRegistryUnavailableError(w http.ResponseWriter, unavailableErr *regpkg.UnavailableError) {
body := registryErrorResponse{
Code: RegistryUnavailableCode,
Message: unavailableErr.Error(),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
_ = json.NewEncoder(w).Encode(body)
}

// resolveAuthStatus returns the auth_status and auth_type strings for API responses
// by delegating to the AuthManager.
func (rr *RegistryRoutes) resolveAuthStatus() (authStatus, authType string) {
Expand Down Expand Up @@ -244,6 +261,12 @@ func (rr *RegistryRoutes) getCurrentProvider(w http.ResponseWriter) (regpkg.Prov
writeRegistryAuthRequiredError(w)
return nil, false
}
var unavailableErr *regpkg.UnavailableError
if errors.As(err, &unavailableErr) {
slog.Error("upstream registry unavailable", "error", err)
writeRegistryUnavailableError(w, unavailableErr)
return nil, false
}
http.Error(w, "Failed to get registry provider", http.StatusInternalServerError)
slog.Error("failed to get registry provider", "error", err)
return nil, false
Expand Down Expand Up @@ -342,6 +365,12 @@ func (rr *RegistryRoutes) listRegistries(w http.ResponseWriter, _ *http.Request)
writeRegistryAuthRequiredError(w)
return
}
var unavailableErr *regpkg.UnavailableError
if errors.As(err, &unavailableErr) {
slog.Error("upstream registry unavailable", "error", err)
writeRegistryUnavailableError(w, unavailableErr)
return
}
http.Error(w, "Failed to get registry", http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -417,6 +446,12 @@ func (rr *RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) {
writeRegistryAuthRequiredError(w)
return
}
var unavailableErr *regpkg.UnavailableError
if errors.As(err, &unavailableErr) {
slog.Error("upstream registry unavailable", "error", err)
writeRegistryUnavailableError(w, unavailableErr)
return
}
http.Error(w, "Failed to get registry", http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -660,6 +695,12 @@ func (rr *RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) {
writeRegistryAuthRequiredError(w)
return
}
var unavailableErr *regpkg.UnavailableError
if errors.As(err, &unavailableErr) {
slog.Error("upstream registry unavailable", "error", err)
writeRegistryUnavailableError(w, unavailableErr)
return
}
slog.Error("failed to get registry", "error", err)
http.Error(w, "Failed to get registry", http.StatusInternalServerError)
return
Expand Down
90 changes: 90 additions & 0 deletions pkg/api/v1/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,96 @@ func CreateTestConfigProvider(t *testing.T, cfg *config.Config) (config.Provider
}
}

// TestRegistryAPI_GetEndpoint_UnavailableUpstream tests that GET endpoints return
// 503 with a structured JSON response when the upstream registry API is unreachable
// or returns an unexpected error (e.g. 404 because the URL path is wrong).
//
//nolint:paralleltest // Uses global registry provider singleton
func TestRegistryAPI_GetEndpoint_UnavailableUpstream(t *testing.T) {
// Mock server that returns 404 (simulates a wrong registry API URL)
notFoundServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "404 page not found", http.StatusNotFound)
}))
defer notFoundServer.Close()

// Configure registry to point at the mock 404 server
cfg := &config.Config{
RegistryApiUrl: notFoundServer.URL,
AllowPrivateRegistryIp: true,
}
configProvider, cleanup := CreateTestConfigProvider(t, cfg)
defer cleanup()

registry.ResetDefaultProvider()
t.Cleanup(registry.ResetDefaultProvider)

routes := &RegistryRoutes{
configProvider: configProvider,
configService: registry.NewConfiguratorWithProvider(configProvider),
serveMode: true,
}

endpoints := []struct {
name string
method string
path string
handler http.HandlerFunc
urlParams map[string]string
}{
{
name: "listRegistries",
method: http.MethodGet,
path: "/",
handler: routes.listRegistries,
},
{
name: "getRegistry",
method: http.MethodGet,
path: "/default",
handler: routes.getRegistry,
urlParams: map[string]string{"name": "default"},
},
{
name: "listServers",
method: http.MethodGet,
path: "/default/servers",
handler: routes.listServers,
urlParams: map[string]string{"name": "default"},
},
}

for _, ep := range endpoints {
t.Run(ep.name, func(t *testing.T) {
registry.ResetDefaultProvider()

req := httptest.NewRequest(ep.method, ep.path, nil)
if ep.urlParams != nil {
rctx := chi.NewRouteContext()
for k, v := range ep.urlParams {
rctx.URLParams.Add(k, v)
}
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
}

w := httptest.NewRecorder()
ep.handler(w, req)

assert.Equal(t, http.StatusServiceUnavailable, w.Code,
"Expected 503 Service Unavailable for unreachable upstream registry")

var body registryErrorResponse
err := json.NewDecoder(w.Body).Decode(&body)
require.NoError(t, err, "Response should be valid JSON")
assert.Equal(t, RegistryUnavailableCode, body.Code,
"Response code should be registry_unavailable")
assert.Contains(t, body.Message, "unavailable",
"Response message should indicate unavailability")
assert.Contains(t, w.Header().Get("Content-Type"), "application/json",
"Response Content-Type should be application/json")
})
}
}

func TestRegistryRouter(t *testing.T) {
t.Parallel()

Expand Down
25 changes: 25 additions & 0 deletions pkg/registry/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package registry

import "fmt"

// UnavailableError indicates the upstream registry is unreachable
// or returned an unexpected (non-auth) error such as 404, timeout, or
// connection refused. API handlers translate this into HTTP 503.
type UnavailableError struct {
URL string
Err error
}

func (e *UnavailableError) Error() string {
if e.URL != "" {
return fmt.Sprintf("upstream registry at %s is unavailable: %s", e.URL, e.Err)
}
return fmt.Sprintf("upstream registry is unavailable: %s", e.Err)
}

func (e *UnavailableError) Unwrap() error {
return e.Err
}
68 changes: 68 additions & 0 deletions pkg/registry/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package registry

import (
"errors"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestUnavailableError_Error(t *testing.T) {
t.Parallel()

tests := []struct {
name string
err *UnavailableError
expected string
}{
{
name: "with URL",
err: &UnavailableError{
URL: "https://example.com/registry",
Err: fmt.Errorf("connection refused"),
},
expected: "upstream registry at https://example.com/registry is unavailable: connection refused",
},
{
name: "without URL",
err: &UnavailableError{
Err: fmt.Errorf("timeout"),
},
expected: "upstream registry is unavailable: timeout",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.expected, tt.err.Error())
})
}
}

func TestUnavailableError_Unwrap(t *testing.T) {
t.Parallel()

inner := fmt.Errorf("registry API returned status 404")
err := &UnavailableError{URL: "https://example.com", Err: inner}

assert.Equal(t, inner, errors.Unwrap(err))
}

func TestUnavailableError_ErrorsAs(t *testing.T) {
t.Parallel()

inner := fmt.Errorf("registry API returned status 404")
original := &UnavailableError{URL: "https://example.com", Err: inner}
wrapped := fmt.Errorf("failed to create provider: %w", original)

var target *UnavailableError
require.True(t, errors.As(wrapped, &target))
assert.Equal(t, "https://example.com", target.URL)
assert.Equal(t, inner, target.Err)
}
4 changes: 2 additions & 2 deletions pkg/registry/provider_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func NewAPIRegistryProvider(apiURL string, allowPrivateIp bool, tokenSource auth
apiURL, auth.ErrRegistryAuthRequired,
)
}
return nil, fmt.Errorf("API endpoint not functional: %w", err)
return nil, &UnavailableError{URL: apiURL, Err: err}
}
}

Expand All @@ -96,7 +96,7 @@ func (p *APIRegistryProvider) GetRegistry() (*types.Registry, error) {
if errors.Is(err, api.ErrRegistryUnauthorized) {
return nil, fmt.Errorf("registry rejected credentials: %w", auth.ErrRegistryAuthRequired)
}
return nil, fmt.Errorf("failed to list servers from API: %w", err)
return nil, &UnavailableError{URL: p.apiURL, Err: err}
}

// Convert servers to ToolHive format
Expand Down
Loading