diff --git a/go/core/error.go b/go/core/error.go index ba70d7f76e..4482fc21fd 100644 --- a/go/core/error.go +++ b/go/core/error.go @@ -37,11 +37,12 @@ type ReflectionError struct { // GenkitError is the base error type for Genkit errors. type GenkitError struct { - Message string `json:"message"` // Exclude from default JSON if embedded elsewhere - Status StatusName `json:"status"` - HTTPCode int `json:"-"` // Exclude from default JSON - Details map[string]any `json:"details"` // Use map for arbitrary details - Source *string `json:"source,omitempty"` // Pointer for optional + Message string `json:"message"` // Exclude from default JSON if embedded elsewhere + Status StatusName `json:"status"` + HTTPCode int `json:"-"` // Exclude from default JSON + Details map[string]any `json:"details"` // Use map for arbitrary details + Source *string `json:"source,omitempty"` // Pointer for optional + originalError error // The wrapped error, if any } // UserFacingError is the base error type for user facing errors. @@ -70,7 +71,6 @@ func (e *UserFacingError) Error() string { // NewError creates a new GenkitError with a stack trace. func NewError(status StatusName, message string, args ...any) *GenkitError { - // Prevents a compile-time warning about non-constant message. msg := message ge := &GenkitError{ @@ -78,6 +78,14 @@ func NewError(status StatusName, message string, args ...any) *GenkitError { Message: fmt.Sprintf(msg, args...), } + // scan args for the last error to wrap it (Iterate backwards) + for i := len(args) - 1; i >= 0; i-- { + if err, ok := args[i].(error); ok { + ge.originalError = err + break + } + } + errStack := string(debug.Stack()) if errStack != "" { ge.Details = make(map[string]any) @@ -91,14 +99,28 @@ func (e *GenkitError) Error() string { return e.Message } +// Unwrap implements the standard error unwrapping interface. +// This allows errors.Is and errors.As to work with GenkitError. +func (e *GenkitError) Unwrap() error { + return e.originalError +} + // ToReflectionError returns a JSON-serializable representation for reflection API responses. func (e *GenkitError) ToReflectionError() ReflectionError { - errDetails := &ReflectionErrorDetails{} - if stackVal, ok := e.Details["stack"].(string); ok { - errDetails.Stack = &stackVal - } - if traceVal, ok := e.Details["traceId"].(string); ok { - errDetails.TraceID = &traceVal + var errDetails *ReflectionErrorDetails + if e.Details != nil { + stackVal, stackOk := e.Details["stack"].(string) + traceVal, traceOk := e.Details["traceId"].(string) + + if stackOk || traceOk { + errDetails = &ReflectionErrorDetails{} + if stackOk { + errDetails.Stack = &stackVal + } + if traceOk { + errDetails.TraceID = &traceVal + } + } } return ReflectionError{ Details: errDetails, diff --git a/go/core/error_test.go b/go/core/error_test.go index 60ff503bdc..a2e26a25a8 100644 --- a/go/core/error_test.go +++ b/go/core/error_test.go @@ -157,8 +157,73 @@ func TestGenkitErrorToReflectionError(t *testing.T) { if re.Message != "success" { t.Errorf("Message = %q, want %q", re.Message, "success") } - if re.Details.Stack != nil { - t.Error("expected nil stack") + if re.Details != nil { + t.Error("expected nil details") + } + }) +} + +// testCustomError is a helper type for the errors.As subtest. +type testCustomError struct { + code int +} + +func (e *testCustomError) Error() string { + return fmt.Sprintf("custom error %d", e.code) +} + +func TestGenkitErrorUnwrap(t *testing.T) { + t.Run("errors.Is matches original cause", func(t *testing.T) { + original := errors.New("original failure") + gErr := NewError(INTERNAL, "something happened: %v", original) + + if !errors.Is(gErr, original) { + t.Errorf("expected errors.Is to return true, but got false") + } + if gErr.Unwrap() != original { + t.Errorf("Unwrap() returned wrong error") + } + }) + + t.Run("errors.As extracts typed cause", func(t *testing.T) { + cause := &testCustomError{code: 42} + ge := NewError(INTERNAL, "failed: %v", cause) + + var target *testCustomError + if !errors.As(ge, &target) { + t.Fatal("errors.As failed to find *testCustomError") + } + if target.code != 42 { + t.Errorf("target.code = %d, want 42", target.code) + } + }) + + t.Run("no args returns nil", func(t *testing.T) { + ge := NewError(INTERNAL, "no args error") + + if ge.Unwrap() != nil { + t.Errorf("Unwrap() = %v, want nil", ge.Unwrap()) + } + }) + + t.Run("multiple errors preserves the last one", func(t *testing.T) { + first := errors.New("first") + second := errors.New("second") + ge := NewError(INTERNAL, "two errors: %v %v", first, second) + + if ge.Unwrap() != second { + t.Errorf("Unwrap() = %v, want %v (last error)", ge.Unwrap(), second) + } + if !errors.Is(ge, second) { + t.Error("errors.Is(ge, second) = false, want true") + } + }) + + t.Run("non-error args returns nil", func(t *testing.T) { + ge := NewError(INTERNAL, "value is %d and %s", 42, "hello") + + if ge.Unwrap() != nil { + t.Errorf("Unwrap() = %v, want nil", ge.Unwrap()) } }) }