From 04ef425f35a80bc2ffb9588ff1a54c3a9788b42d Mon Sep 17 00:00:00 2001 From: Jake Bailey <5341706+jakebailey@users.noreply.github.com> Date: Fri, 14 Nov 2025 16:36:56 -0800 Subject: [PATCH] Make client requests type safe, unmarshal --- internal/fourslash/fourslash.go | 6 +- internal/lsp/lsproto/_generate/generate.mts | 30 ++++ internal/lsp/lsproto/jsonrpc.go | 36 ++--- internal/lsp/lsproto/lsp.go | 40 ++++++ internal/lsp/lsproto/lsp_generated.go | 145 ++++++++++++++++++++ internal/lsp/server.go | 21 ++- 6 files changed, 237 insertions(+), 41 deletions(-) diff --git a/internal/fourslash/fourslash.go b/internal/fourslash/fourslash.go index b60bca8ae9..1b94b9e2e3 100644 --- a/internal/fourslash/fourslash.go +++ b/internal/fourslash/fourslash.go @@ -350,8 +350,7 @@ func getCapabilitiesWithDefaults(capabilities *lsproto.ClientCapabilities) *lspr func sendRequest[Params, Resp any](t *testing.T, f *FourslashTest, info lsproto.RequestInfo[Params, Resp], params Params) (*lsproto.Message, Resp, bool) { id := f.nextID() - req := lsproto.NewRequestMessage( - info.Method, + req := info.NewRequestMessage( lsproto.NewID(lsproto.IntegerOrString{Integer: &id}), params, ) @@ -393,8 +392,7 @@ func sendRequest[Params, Resp any](t *testing.T, f *FourslashTest, info lsproto. } func sendNotification[Params any](t *testing.T, f *FourslashTest, info lsproto.NotificationInfo[Params], params Params) { - notification := lsproto.NewNotificationMessage( - info.Method, + notification := info.NewNotificationMessage( params, ) f.writeMsg(t, notification.Message()) diff --git a/internal/lsp/lsproto/_generate/generate.mts b/internal/lsp/lsproto/_generate/generate.mts index 68bfb5778e..ed981a8c5b 100644 --- a/internal/lsp/lsproto/_generate/generate.mts +++ b/internal/lsp/lsproto/_generate/generate.mts @@ -784,6 +784,36 @@ function generateCode() { writeLine("}"); writeLine(""); + // Generate unmarshalResult function + writeLine("func unmarshalResult(method Method, data []byte) (any, error) {"); + writeLine("\tswitch method {"); + + // Only requests have results, not notifications + for (const request of model.requests) { + const methodName = methodNameIdentifier(request.method); + + if (!("result" in request)) { + continue; + } + + let responseTypeName: string; + if (request.typeName && request.typeName.endsWith("Request")) { + responseTypeName = request.typeName.replace(/Request$/, "Response"); + } + else { + responseTypeName = `${methodName}Response`; + } + + writeLine(`\tcase Method${methodName}:`); + writeLine(`\t\treturn unmarshalValue[${responseTypeName}](data)`); + } + + writeLine("\tdefault:"); + writeLine(`\t\treturn unmarshalAny(data)`); + writeLine("\t}"); + writeLine("}"); + writeLine(""); + writeLine("// Methods"); writeLine("const ("); for (const request of requestsAndNotifications) { diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index d78894e824..19d75ac8ad 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -105,8 +105,10 @@ func (m *Message) UnmarshalJSON(data []byte) error { Method Method `json:"method"` ID *ID `json:"id,omitzero"` Params jsontext.Value `json:"params"` - Result any `json:"result,omitzero"` - Error *ResponseError `json:"error,omitzero"` + // We don't have a method in the response, so we have no idea what to decode. + // Store the raw text and let the caller decode it. + Result jsontext.Value `json:"result,omitzero"` + Error *ResponseError `json:"error,omitzero"` } if err := json.Unmarshal(data, &raw); err != nil { return fmt.Errorf("%w: %w", ErrInvalidRequest, err) @@ -114,10 +116,9 @@ func (m *Message) UnmarshalJSON(data []byte) error { if raw.ID != nil && raw.Method == "" { m.Kind = MessageKindResponse m.msg = &ResponseMessage{ - JSONRPC: raw.JSONRPC, - ID: raw.ID, - Result: raw.Result, - Error: raw.Error, + ID: raw.ID, + Result: raw.Result, + Error: raw.Error, } return nil } @@ -138,10 +139,9 @@ func (m *Message) UnmarshalJSON(data []byte) error { } m.msg = &RequestMessage{ - JSONRPC: raw.JSONRPC, - ID: raw.ID, - Method: raw.Method, - Params: params, + ID: raw.ID, + Method: raw.Method, + Params: params, } return nil @@ -151,14 +151,6 @@ func (m *Message) MarshalJSON() ([]byte, error) { return json.Marshal(m.msg) } -func NewNotificationMessage(method Method, params any) *RequestMessage { - return &RequestMessage{ - JSONRPC: JSONRPCVersion{}, - Method: method, - Params: params, - } -} - type RequestMessage struct { JSONRPC JSONRPCVersion `json:"jsonrpc"` ID *ID `json:"id,omitzero"` @@ -166,14 +158,6 @@ type RequestMessage struct { Params any `json:"params,omitzero"` } -func NewRequestMessage(method Method, id *ID, params any) *RequestMessage { - return &RequestMessage{ - ID: id, - Method: method, - Params: params, - } -} - func (r *RequestMessage) Message() *Message { return &Message{ Kind: MessageKindRequest, diff --git a/internal/lsp/lsproto/lsp.go b/internal/lsp/lsproto/lsp.go index cc27437593..e251b444f1 100644 --- a/internal/lsp/lsproto/lsp.go +++ b/internal/lsp/lsproto/lsp.go @@ -71,6 +71,14 @@ func unmarshalPtrTo[T any](data []byte) (*T, error) { return &v, nil } +func unmarshalValue[T any](data []byte) (T, error) { + var v T + if err := json.Unmarshal(data, &v); err != nil { + return *new(T), fmt.Errorf("failed to unmarshal %T: %w", (*T)(nil), err) + } + return v, nil +} + func unmarshalAny(data []byte) (any, error) { var v any if err := json.Unmarshal(data, &v); err != nil { @@ -129,11 +137,43 @@ type RequestInfo[Params, Resp any] struct { Method Method } +func (info RequestInfo[Params, Resp]) UnmarshalResult(result any) (Resp, error) { + if r, ok := result.(Resp); ok { + return r, nil + } + + raw, ok := result.(jsontext.Value) + if !ok { + return *new(Resp), fmt.Errorf("expected jsontext.Value, got %T", result) + } + + r, err := unmarshalResult(info.Method, raw) + if err != nil { + return *new(Resp), err + } + return r.(Resp), nil +} + +func (info RequestInfo[Params, Resp]) NewRequestMessage(id *ID, params Params) *RequestMessage { + return &RequestMessage{ + ID: id, + Method: info.Method, + Params: params, + } +} + type NotificationInfo[Params any] struct { _ [0]Params Method Method } +func (info NotificationInfo[Params]) NewNotificationMessage(params Params) *RequestMessage { + return &RequestMessage{ + Method: info.Method, + Params: params, + } +} + type Null struct{} func (Null) UnmarshalJSONFrom(dec *jsontext.Decoder) error { diff --git a/internal/lsp/lsproto/lsp_generated.go b/internal/lsp/lsproto/lsp_generated.go index 075d574c00..7c71cb1a2c 100644 --- a/internal/lsp/lsproto/lsp_generated.go +++ b/internal/lsp/lsproto/lsp_generated.go @@ -20709,6 +20709,151 @@ func unmarshalParams(method Method, data []byte) (any, error) { } } +func unmarshalResult(method Method, data []byte) (any, error) { + switch method { + case MethodTextDocumentImplementation: + return unmarshalValue[ImplementationResponse](data) + case MethodTextDocumentTypeDefinition: + return unmarshalValue[TypeDefinitionResponse](data) + case MethodWorkspaceWorkspaceFolders: + return unmarshalValue[WorkspaceFoldersResponse](data) + case MethodWorkspaceConfiguration: + return unmarshalValue[ConfigurationResponse](data) + case MethodTextDocumentDocumentColor: + return unmarshalValue[DocumentColorResponse](data) + case MethodTextDocumentColorPresentation: + return unmarshalValue[ColorPresentationResponse](data) + case MethodTextDocumentFoldingRange: + return unmarshalValue[FoldingRangeResponse](data) + case MethodWorkspaceFoldingRangeRefresh: + return unmarshalValue[FoldingRangeRefreshResponse](data) + case MethodTextDocumentDeclaration: + return unmarshalValue[DeclarationResponse](data) + case MethodTextDocumentSelectionRange: + return unmarshalValue[SelectionRangeResponse](data) + case MethodWindowWorkDoneProgressCreate: + return unmarshalValue[WorkDoneProgressCreateResponse](data) + case MethodTextDocumentPrepareCallHierarchy: + return unmarshalValue[CallHierarchyPrepareResponse](data) + case MethodCallHierarchyIncomingCalls: + return unmarshalValue[CallHierarchyIncomingCallsResponse](data) + case MethodCallHierarchyOutgoingCalls: + return unmarshalValue[CallHierarchyOutgoingCallsResponse](data) + case MethodTextDocumentSemanticTokensFull: + return unmarshalValue[SemanticTokensResponse](data) + case MethodTextDocumentSemanticTokensFullDelta: + return unmarshalValue[SemanticTokensDeltaResponse](data) + case MethodTextDocumentSemanticTokensRange: + return unmarshalValue[SemanticTokensRangeResponse](data) + case MethodWorkspaceSemanticTokensRefresh: + return unmarshalValue[SemanticTokensRefreshResponse](data) + case MethodWindowShowDocument: + return unmarshalValue[ShowDocumentResponse](data) + case MethodTextDocumentLinkedEditingRange: + return unmarshalValue[LinkedEditingRangeResponse](data) + case MethodWorkspaceWillCreateFiles: + return unmarshalValue[WillCreateFilesResponse](data) + case MethodWorkspaceWillRenameFiles: + return unmarshalValue[WillRenameFilesResponse](data) + case MethodWorkspaceWillDeleteFiles: + return unmarshalValue[WillDeleteFilesResponse](data) + case MethodTextDocumentMoniker: + return unmarshalValue[MonikerResponse](data) + case MethodTextDocumentPrepareTypeHierarchy: + return unmarshalValue[TypeHierarchyPrepareResponse](data) + case MethodTypeHierarchySupertypes: + return unmarshalValue[TypeHierarchySupertypesResponse](data) + case MethodTypeHierarchySubtypes: + return unmarshalValue[TypeHierarchySubtypesResponse](data) + case MethodTextDocumentInlineValue: + return unmarshalValue[InlineValueResponse](data) + case MethodWorkspaceInlineValueRefresh: + return unmarshalValue[InlineValueRefreshResponse](data) + case MethodTextDocumentInlayHint: + return unmarshalValue[InlayHintResponse](data) + case MethodInlayHintResolve: + return unmarshalValue[InlayHintResolveResponse](data) + case MethodWorkspaceInlayHintRefresh: + return unmarshalValue[InlayHintRefreshResponse](data) + case MethodTextDocumentDiagnostic: + return unmarshalValue[DocumentDiagnosticResponse](data) + case MethodWorkspaceDiagnostic: + return unmarshalValue[WorkspaceDiagnosticResponse](data) + case MethodWorkspaceDiagnosticRefresh: + return unmarshalValue[DiagnosticRefreshResponse](data) + case MethodTextDocumentInlineCompletion: + return unmarshalValue[InlineCompletionResponse](data) + case MethodWorkspaceTextDocumentContent: + return unmarshalValue[TextDocumentContentResponse](data) + case MethodWorkspaceTextDocumentContentRefresh: + return unmarshalValue[TextDocumentContentRefreshResponse](data) + case MethodClientRegisterCapability: + return unmarshalValue[RegistrationResponse](data) + case MethodClientUnregisterCapability: + return unmarshalValue[UnregistrationResponse](data) + case MethodInitialize: + return unmarshalValue[InitializeResponse](data) + case MethodShutdown: + return unmarshalValue[ShutdownResponse](data) + case MethodWindowShowMessageRequest: + return unmarshalValue[ShowMessageResponse](data) + case MethodTextDocumentWillSaveWaitUntil: + return unmarshalValue[WillSaveTextDocumentWaitUntilResponse](data) + case MethodTextDocumentCompletion: + return unmarshalValue[CompletionResponse](data) + case MethodCompletionItemResolve: + return unmarshalValue[CompletionResolveResponse](data) + case MethodTextDocumentHover: + return unmarshalValue[HoverResponse](data) + case MethodTextDocumentSignatureHelp: + return unmarshalValue[SignatureHelpResponse](data) + case MethodTextDocumentDefinition: + return unmarshalValue[DefinitionResponse](data) + case MethodTextDocumentReferences: + return unmarshalValue[ReferencesResponse](data) + case MethodTextDocumentDocumentHighlight: + return unmarshalValue[DocumentHighlightResponse](data) + case MethodTextDocumentDocumentSymbol: + return unmarshalValue[DocumentSymbolResponse](data) + case MethodTextDocumentCodeAction: + return unmarshalValue[CodeActionResponse](data) + case MethodCodeActionResolve: + return unmarshalValue[CodeActionResolveResponse](data) + case MethodWorkspaceSymbol: + return unmarshalValue[WorkspaceSymbolResponse](data) + case MethodWorkspaceSymbolResolve: + return unmarshalValue[WorkspaceSymbolResolveResponse](data) + case MethodTextDocumentCodeLens: + return unmarshalValue[CodeLensResponse](data) + case MethodCodeLensResolve: + return unmarshalValue[CodeLensResolveResponse](data) + case MethodWorkspaceCodeLensRefresh: + return unmarshalValue[CodeLensRefreshResponse](data) + case MethodTextDocumentDocumentLink: + return unmarshalValue[DocumentLinkResponse](data) + case MethodDocumentLinkResolve: + return unmarshalValue[DocumentLinkResolveResponse](data) + case MethodTextDocumentFormatting: + return unmarshalValue[DocumentFormattingResponse](data) + case MethodTextDocumentRangeFormatting: + return unmarshalValue[DocumentRangeFormattingResponse](data) + case MethodTextDocumentRangesFormatting: + return unmarshalValue[DocumentRangesFormattingResponse](data) + case MethodTextDocumentOnTypeFormatting: + return unmarshalValue[DocumentOnTypeFormattingResponse](data) + case MethodTextDocumentRename: + return unmarshalValue[RenameResponse](data) + case MethodTextDocumentPrepareRename: + return unmarshalValue[PrepareRenameResponse](data) + case MethodWorkspaceExecuteCommand: + return unmarshalValue[ExecuteCommandResponse](data) + case MethodWorkspaceApplyEdit: + return unmarshalValue[ApplyWorkspaceEditResponse](data) + default: + return unmarshalAny(data) + } +} + // Methods const ( // A request to resolve the implementation locations of a symbol at a given text diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 85e5b42447..de880048c1 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -165,7 +165,7 @@ type Server struct { // WatchFiles implements project.Client. func (s *Server) WatchFiles(ctx context.Context, id project.WatcherID, watchers []*lsproto.FileSystemWatcher) error { - _, err := s.sendRequest(ctx, lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ + _, err := sendClientRequest(ctx, s, lsproto.ClientRegisterCapabilityInfo, &lsproto.RegistrationParams{ Registrations: []*lsproto.Registration{ { Id: string(id), @@ -187,7 +187,7 @@ func (s *Server) WatchFiles(ctx context.Context, id project.WatcherID, watchers // UnwatchFiles implements project.Client. func (s *Server) UnwatchFiles(ctx context.Context, id project.WatcherID) error { if s.watchers.Has(id) { - _, err := s.sendRequest(ctx, lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ + _, err := sendClientRequest(ctx, s, lsproto.ClientUnregisterCapabilityInfo, &lsproto.UnregistrationParams{ Unregisterations: []*lsproto.Unregistration{ { Id: string(id), @@ -212,7 +212,7 @@ func (s *Server) RefreshDiagnostics(ctx context.Context) error { return nil } - if _, err := s.sendRequest(ctx, lsproto.MethodWorkspaceDiagnosticRefresh, nil); err != nil { + if _, err := sendClientRequest(ctx, s, lsproto.WorkspaceDiagnosticRefreshInfo, nil); err != nil { return fmt.Errorf("failed to refresh diagnostics: %w", err) } @@ -225,7 +225,7 @@ func (s *Server) RequestConfiguration(ctx context.Context) (*lsutil.UserPreferen // if no configuration request capapbility, return default preferences return s.session.NewUserPreferences(), nil } - result, err := s.sendRequest(ctx, lsproto.MethodWorkspaceConfiguration, &lsproto.ConfigurationParams{ + configs, err := sendClientRequest(ctx, s, lsproto.WorkspaceConfigurationInfo, &lsproto.ConfigurationParams{ Items: []*lsproto.ConfigurationItem{ { Section: ptrTo("typescript"), @@ -235,7 +235,6 @@ func (s *Server) RequestConfiguration(ctx context.Context) (*lsutil.UserPreferen if err != nil { return nil, fmt.Errorf("configure request failed: %w", err) } - configs := result.([]any) s.Log(fmt.Sprintf("\n\nconfiguration: %+v, %T\n\n", configs, configs)) userPreferences := s.session.NewUserPreferences() for _, item := range configs { @@ -391,9 +390,9 @@ func (s *Server) writeLoop(ctx context.Context) error { } } -func (s *Server) sendRequest(ctx context.Context, method lsproto.Method, params any) (any, error) { +func sendClientRequest[Req, Resp any](ctx context.Context, s *Server, info lsproto.RequestInfo[Req, Resp], params Req) (Resp, error) { id := lsproto.NewIDString(fmt.Sprintf("ts%d", s.clientSeq.Add(1))) - req := lsproto.NewRequestMessage(method, id, params) + req := info.NewRequestMessage(id, params) responseChan := make(chan *lsproto.ResponseMessage, 1) s.pendingServerRequestsMu.Lock() @@ -410,12 +409,12 @@ func (s *Server) sendRequest(ctx context.Context, method lsproto.Method, params close(respChan) delete(s.pendingServerRequests, *id) } - return nil, ctx.Err() + return *new(Resp), ctx.Err() case resp := <-responseChan: if resp.Error != nil { - return nil, fmt.Errorf("request failed: %s", resp.Error.String()) + return *new(Resp), fmt.Errorf("request failed: %s", resp.Error.String()) } - return resp.Result, nil + return info.UnmarshalResult(resp.Result) } } @@ -747,7 +746,7 @@ func (s *Server) handleInitialized(ctx context.Context, params *lsproto.Initiali } s.session.InitializeWithConfig(userPreferences) - _, err = s.sendRequest(ctx, lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ + _, err = sendClientRequest(ctx, s, lsproto.ClientRegisterCapabilityInfo, &lsproto.RegistrationParams{ Registrations: []*lsproto.Registration{ { Id: "typescript-config-watch-id",