diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 8f2ae5852..0420ff75a 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -90,6 +90,7 @@ var ( ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), + LogLevel: viper.GetString("log-level"), ContentWindowSize: viper.GetInt("content-window-size"), LockdownMode: viper.GetBool("lockdown-mode"), InsidersMode: viper.GetBool("insiders"), @@ -137,6 +138,7 @@ var ( ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), + LogLevel: viper.GetString("log-level"), ContentWindowSize: viper.GetInt("content-window-size"), LockdownMode: viper.GetBool("lockdown-mode"), RepoAccessCacheTTL: &ttl, @@ -168,6 +170,7 @@ func init() { rootCmd.PersistentFlags().Bool("dynamic-toolsets", false, "Enable dynamic toolsets") rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations") rootCmd.PersistentFlags().String("log-file", "", "Path to log file") + rootCmd.PersistentFlags().String("log-level", "", "Log level (debug, info, warn, error). Defaults to debug when --log-file is set, info otherwise.") rootCmd.PersistentFlags().Bool("enable-command-logging", false, "When enabled, the server will log all command requests and responses to the log file") rootCmd.PersistentFlags().Bool("export-translations", false, "Save translations to a JSON file") rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)") @@ -190,6 +193,7 @@ func init() { _ = viper.BindPFlag("dynamic_toolsets", rootCmd.PersistentFlags().Lookup("dynamic-toolsets")) _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) _ = viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file")) + _ = viper.BindPFlag("log-level", rootCmd.PersistentFlags().Lookup("log-level")) _ = viper.BindPFlag("enable-command-logging", rootCmd.PersistentFlags().Lookup("enable-command-logging")) _ = viper.BindPFlag("export-translations", rootCmd.PersistentFlags().Lookup("export-translations")) _ = viper.BindPFlag("host", rootCmd.PersistentFlags().Lookup("gh-host")) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 3f81ac3f7..873ef5a52 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -214,6 +214,11 @@ type StdioServerConfig struct { // Path to the log file if not stderr LogFilePath string + // LogLevel overrides the default log level ("debug", "info", "warn", "error"). + // When empty, the default depends on LogFilePath: file-backed logs default + // to Debug (local troubleshooting), stderr logs default to Info. + LogLevel string + // Content window size ContentWindowSize int @@ -248,11 +253,14 @@ func RunStdioServer(cfg StdioServerConfig) error { return fmt.Errorf("failed to open log file: %w", err) } logOutput = file - slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelDebug}) } else { logOutput = os.Stderr - slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo}) } + level, err := observability.ParseLogLevel(cfg.LogLevel, defaultLogLevel(cfg.LogFilePath)) + if err != nil { + return err + } + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: level}) logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) @@ -335,6 +343,16 @@ func RunStdioServer(cfg StdioServerConfig) error { return nil } +// defaultLogLevel returns the log level used when the user hasn't specified +// one. Debug when a log file is configured (intended for local debugging), +// Info when logs go to stderr (keeps terminal output quiet). +func defaultLogLevel(logFilePath string) slog.Level { + if logFilePath != "" { + return slog.LevelDebug + } + return slog.LevelInfo +} + // createFeatureChecker returns a FeatureFlagChecker that resolves features // using the centralized ResolveFeatureFlags function. For the local server, // features are resolved once at startup from --features CLI flag + insiders mode. diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 57c6133a8..fbb9980f6 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -6,7 +6,6 @@ import ( "fmt" "log/slog" "net/http" - "os" ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/http/transport" @@ -187,7 +186,15 @@ func (d BaseDeps) GetFlags(_ context.Context) FeatureFlags { return d.Flags } func (d BaseDeps) GetContentWindowSize() int { return d.ContentWindowSize } // Logger implements ToolDependencies. -func (d BaseDeps) Logger(_ context.Context) *slog.Logger { +// If an enriched logger has been attached to ctx by ToolLoggingMiddleware +// (via observability.ContextWithLogger), that logger is returned so tool +// handlers inherit request-scoped attributes such as tool name and +// mcp.method. Otherwise the base logger from the observability exporters +// is returned. +func (d BaseDeps) Logger(ctx context.Context) *slog.Logger { + if l := observability.LoggerFromContext(ctx); l != nil { + return l + } return d.Obsv.Logger() } @@ -206,8 +213,12 @@ func (d BaseDeps) IsFeatureEnabled(ctx context.Context, flagName string) bool { enabled, err := d.featureChecker(ctx, flagName) if err != nil { - // Log error but don't fail the tool - treat as disabled - fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) + // Treat errors as disabled, but surface them via the logger so + // operators can diagnose feature-flag backend issues in production. + d.Logger(ctx).Warn("feature flag check failed", + slog.String("flag", flagName), + slog.String("error", err.Error()), + ) return false } @@ -406,7 +417,11 @@ func (d *RequestDeps) GetFlags(ctx context.Context) FeatureFlags { func (d *RequestDeps) GetContentWindowSize() int { return d.ContentWindowSize } // Logger implements ToolDependencies. -func (d *RequestDeps) Logger(_ context.Context) *slog.Logger { +// See BaseDeps.Logger for the context-scoped logger fallback behaviour. +func (d *RequestDeps) Logger(ctx context.Context) *slog.Logger { + if l := observability.LoggerFromContext(ctx); l != nil { + return l + } return d.obsv.Logger() } @@ -423,8 +438,10 @@ func (d *RequestDeps) IsFeatureEnabled(ctx context.Context, flagName string) boo enabled, err := d.featureChecker(ctx, flagName) if err != nil { - // Log error but don't fail the tool - treat as disabled - fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) + d.Logger(ctx).Warn("feature flag check failed", + slog.String("flag", flagName), + slog.String("error", err.Error()), + ) return false } diff --git a/pkg/github/logging_middleware.go b/pkg/github/logging_middleware.go new file mode 100644 index 000000000..f33f1bd35 --- /dev/null +++ b/pkg/github/logging_middleware.go @@ -0,0 +1,116 @@ +package github + +import ( + "context" + "log/slog" + "time" + + "github.com/github/github-mcp-server/pkg/observability" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MCPMethodCallTool is the JSON-RPC method name for MCP tool invocations. +// The SDK keeps its equivalent constant unexported, so we mirror it here. +const MCPMethodCallTool = "tools/call" + +// ToolLoggingMiddleware returns an MCP middleware that uniformly logs every +// tool invocation with its name, duration, and outcome, and exposes a +// request-scoped *slog.Logger via observability.ContextWithLogger so tool +// handlers can retrieve an enriched logger from deps.Logger(ctx). +// +// Logging policy: +// - tools/call success: logged at Debug with tool name and duration. +// - tools/call failure (error return or IsError result): logged at Error +// with tool name, duration, and the error when present. +// - Non-tool methods pass through without emitting any log line; the +// middleware only attaches an enriched logger so downstream code can +// still benefit from the method tag if it chooses to log. +// +// The base logger comes from ToolDependencies on the context (populated by +// InjectDepsMiddleware), so this middleware must be registered AFTER +// InjectDepsMiddleware in the receiving middleware chain. +func ToolLoggingMiddleware() mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + deps, ok := DepsFromContext(ctx) + if !ok { + // Deps not injected yet; nothing we can do but pass through. + return next(ctx, method, req) + } + + base := deps.Logger(ctx) + if base == nil { + return next(ctx, method, req) + } + + logger := base.With(slog.String("mcp.method", method)) + toolName := toolNameFromRequest(method, req) + if toolName != "" { + logger = logger.With(slog.String("mcp.tool", toolName)) + } + + ctx = observability.ContextWithLogger(ctx, logger) + + // Only time+log for tool calls. Other methods (initialize, + // resources/list, etc.) are infrastructure chatter we leave + // to the SDK unless a handler chooses to log explicitly. + if method != MCPMethodCallTool { + return next(ctx, method, req) + } + + start := time.Now() + result, err := next(ctx, method, req) + duration := time.Since(start) + + switch { + case err != nil: + logger.LogAttrs(ctx, slog.LevelError, "tool call failed", + slog.Duration("duration", duration), + slog.String("error", err.Error()), + ) + case isErrorResult(result): + logger.LogAttrs(ctx, slog.LevelError, "tool call returned error result", + slog.Duration("duration", duration), + ) + default: + logger.LogAttrs(ctx, slog.LevelDebug, "tool call succeeded", + slog.Duration("duration", duration), + ) + } + + return result, err + } + } +} + +// toolNameFromRequest extracts the tool name from a tools/call request. +// Returns "" for other methods or when the name cannot be determined. +func toolNameFromRequest(method string, req mcp.Request) string { + if method != MCPMethodCallTool || req == nil { + return "" + } + switch p := req.GetParams().(type) { + case *mcp.CallToolParams: + if p != nil { + return p.Name + } + case *mcp.CallToolParamsRaw: + if p != nil { + return p.Name + } + } + return "" +} + +// isErrorResult reports whether the MCP result represents a tool-reported +// error (CallToolResult.IsError == true). A returned Go error is handled +// separately by the caller. +func isErrorResult(r mcp.Result) bool { + if r == nil { + return false + } + if ctr, ok := r.(*mcp.CallToolResult); ok && ctr != nil { + return ctr.IsError + } + return false +} diff --git a/pkg/github/logging_middleware_test.go b/pkg/github/logging_middleware_test.go new file mode 100644 index 000000000..e024fc3b7 --- /dev/null +++ b/pkg/github/logging_middleware_test.go @@ -0,0 +1,145 @@ +package github + +import ( + "bytes" + "context" + "errors" + "log/slog" + "strings" + "testing" + + "github.com/github/github-mcp-server/pkg/observability" + "github.com/github/github-mcp-server/pkg/observability/metrics" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeToolDeps implements just enough of ToolDependencies to drive the +// logging middleware. Unused methods panic so we notice if callers grow a +// dependency on them. +type fakeToolDeps struct { + ToolDependencies + logger *slog.Logger +} + +func (f fakeToolDeps) Logger(_ context.Context) *slog.Logger { return f.logger } + +func newTestLogger(level slog.Level) (*slog.Logger, *bytes.Buffer) { + buf := &bytes.Buffer{} + h := slog.NewTextHandler(buf, &slog.HandlerOptions{Level: level}) + return slog.New(h), buf +} + +func callToolRequest(name string) mcp.Request { + return &mcp.CallToolRequest{Params: &mcp.CallToolParamsRaw{Name: name}} +} + +func TestToolLoggingMiddleware_LogsToolSuccessAtDebug(t *testing.T) { + logger, buf := newTestLogger(slog.LevelDebug) + deps := fakeToolDeps{logger: logger} + + ctx := ContextWithDeps(context.Background(), deps) + + handler := ToolLoggingMiddleware()(func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + // Tool handlers should see the enriched logger via the context. + assert.NotNil(t, observability.LoggerFromContext(ctx)) + return &mcp.CallToolResult{}, nil + }) + + _, err := handler(ctx, MCPMethodCallTool, callToolRequest("create_issue")) + require.NoError(t, err) + + out := buf.String() + assert.Contains(t, out, "level=DEBUG") + assert.Contains(t, out, `msg="tool call succeeded"`) + assert.Contains(t, out, "mcp.method=tools/call") + assert.Contains(t, out, "mcp.tool=create_issue") + assert.Contains(t, out, "duration=") +} + +func TestToolLoggingMiddleware_LogsToolErrorAtError(t *testing.T) { + logger, buf := newTestLogger(slog.LevelDebug) + deps := fakeToolDeps{logger: logger} + ctx := ContextWithDeps(context.Background(), deps) + + wantErr := errors.New("boom") + handler := ToolLoggingMiddleware()(func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return nil, wantErr + }) + + _, err := handler(ctx, MCPMethodCallTool, callToolRequest("create_issue")) + require.ErrorIs(t, err, wantErr) + + out := buf.String() + assert.Contains(t, out, "level=ERROR") + assert.Contains(t, out, `msg="tool call failed"`) + assert.Contains(t, out, "mcp.tool=create_issue") + assert.Contains(t, out, "error=boom") +} + +func TestToolLoggingMiddleware_LogsIsErrorResult(t *testing.T) { + logger, buf := newTestLogger(slog.LevelDebug) + deps := fakeToolDeps{logger: logger} + ctx := ContextWithDeps(context.Background(), deps) + + handler := ToolLoggingMiddleware()(func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return &mcp.CallToolResult{IsError: true}, nil + }) + + _, err := handler(ctx, MCPMethodCallTool, callToolRequest("get_repo")) + require.NoError(t, err) + + out := buf.String() + assert.Contains(t, out, "level=ERROR") + assert.Contains(t, out, `msg="tool call returned error result"`) +} + +func TestToolLoggingMiddleware_NonToolMethodSilent(t *testing.T) { + logger, buf := newTestLogger(slog.LevelDebug) + deps := fakeToolDeps{logger: logger} + ctx := ContextWithDeps(context.Background(), deps) + + var sawLogger *slog.Logger + handler := ToolLoggingMiddleware()(func(innerCtx context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + sawLogger = observability.LoggerFromContext(innerCtx) + return nil, nil + }) + + _, err := handler(ctx, "tools/list", &mcp.ListToolsRequest{Params: &mcp.ListToolsParams{}}) + require.NoError(t, err) + + assert.NotNil(t, sawLogger, "non-tool methods should still get the enriched logger") + // No success/failure log lines for non-tool methods. + assert.False(t, strings.Contains(buf.String(), "tool call"), + "middleware should not log tool outcomes for non-tool methods; got: %s", buf.String()) +} + +func TestToolLoggingMiddleware_MissingDepsPassesThrough(t *testing.T) { + called := false + handler := ToolLoggingMiddleware()(func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + called = true + return nil, nil + }) + + // No deps injected — middleware must not panic and must still call next. + _, err := handler(context.Background(), MCPMethodCallTool, callToolRequest("x")) + require.NoError(t, err) + assert.True(t, called) +} + +// Exercise the Logger(ctx) fallback in BaseDeps: when the context carries +// an enriched logger (as set by ToolLoggingMiddleware), deps.Logger(ctx) +// should return it rather than the base logger. +func TestBaseDeps_Logger_UsesContextLogger(t *testing.T) { + base, _ := newTestLogger(slog.LevelInfo) + obsv, err := observability.NewExporters(base, metrics.NewNoopMetrics()) + require.NoError(t, err) + d := BaseDeps{Obsv: obsv} + + enriched := base.With("tool", "x") + ctx := observability.ContextWithLogger(context.Background(), enriched) + + assert.Equal(t, enriched, d.Logger(ctx)) + assert.Equal(t, base, d.Logger(context.Background())) +} diff --git a/pkg/github/server.go b/pkg/github/server.go index ee41e90e9..a085328e9 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -107,6 +107,9 @@ func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependenci // and any middleware that needs to read or modify the context should be before it. ghServer.AddReceivingMiddleware(middleware...) ghServer.AddReceivingMiddleware(InjectDepsMiddleware(deps)) + // ToolLoggingMiddleware needs deps in context so it runs after InjectDepsMiddleware. + // It enriches the logger with tool/method attributes and times tool calls. + ghServer.AddReceivingMiddleware(ToolLoggingMiddleware()) ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) if unrecognized := inv.UnrecognizedToolsets(); len(unrecognized) > 0 { diff --git a/pkg/http/server.go b/pkg/http/server.go index d1e8192ba..38d115634 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -52,6 +52,11 @@ type ServerConfig struct { // Path to the log file if not stderr LogFilePath string + // LogLevel overrides the default log level ("debug", "info", "warn", "error"). + // When empty, defaults depend on LogFilePath: Debug when writing to a file, + // Info when writing to stderr. + LogLevel string + // Content window size ContentWindowSize int @@ -103,11 +108,14 @@ func RunHTTPServer(cfg ServerConfig) error { return fmt.Errorf("failed to open log file: %w", err) } logOutput = file - slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelDebug}) } else { logOutput = os.Stderr - slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo}) } + level, err := observability.ParseLogLevel(cfg.LogLevel, defaultHTTPLogLevel(cfg.LogFilePath)) + if err != nil { + return err + } + slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: level}) logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "lockdownEnabled", cfg.LockdownMode, "readOnly", cfg.ReadOnly, "insidersMode", cfg.InsidersMode) @@ -212,6 +220,15 @@ func RunHTTPServer(cfg ServerConfig) error { return nil } +// defaultHTTPLogLevel mirrors the ghmcp stdio default: Debug when a log +// file is configured, Info when writing to stderr. +func defaultHTTPLogLevel(logFilePath string) slog.Level { + if logFilePath != "" { + return slog.LevelDebug + } + return slog.LevelInfo +} + func initGlobalToolScopeMap(t translations.TranslationHelperFunc) error { // Build inventory with all tools to extract scope information inv, err := inventory.NewBuilder(). diff --git a/pkg/observability/log_level.go b/pkg/observability/log_level.go new file mode 100644 index 000000000..7c6b45d5c --- /dev/null +++ b/pkg/observability/log_level.go @@ -0,0 +1,28 @@ +package observability + +import ( + "fmt" + "log/slog" + "strings" +) + +// ParseLogLevel parses a textual log level (case-insensitive) into a slog.Level. +// Accepts "debug", "info", "warn"/"warning", "error". An empty string returns +// the provided default. Unknown values produce an error. +func ParseLogLevel(s string, def slog.Level) (slog.Level, error) { + if strings.TrimSpace(s) == "" { + return def, nil + } + switch strings.ToLower(strings.TrimSpace(s)) { + case "debug": + return slog.LevelDebug, nil + case "info": + return slog.LevelInfo, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + default: + return def, fmt.Errorf("unknown log level %q (want one of: debug, info, warn, error)", s) + } +} diff --git a/pkg/observability/log_level_test.go b/pkg/observability/log_level_test.go new file mode 100644 index 000000000..b8fd15b4e --- /dev/null +++ b/pkg/observability/log_level_test.go @@ -0,0 +1,40 @@ +package observability + +import ( + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseLogLevel(t *testing.T) { + cases := []struct { + in string + want slog.Level + }{ + {"debug", slog.LevelDebug}, + {"DEBUG", slog.LevelDebug}, + {" info ", slog.LevelInfo}, + {"warn", slog.LevelWarn}, + {"warning", slog.LevelWarn}, + {"error", slog.LevelError}, + } + for _, tc := range cases { + got, err := ParseLogLevel(tc.in, slog.LevelInfo) + require.NoError(t, err, tc.in) + assert.Equal(t, tc.want, got, tc.in) + } +} + +func TestParseLogLevel_EmptyReturnsDefault(t *testing.T) { + got, err := ParseLogLevel("", slog.LevelWarn) + require.NoError(t, err) + assert.Equal(t, slog.LevelWarn, got) +} + +func TestParseLogLevel_Unknown(t *testing.T) { + got, err := ParseLogLevel("verbose", slog.LevelInfo) + require.Error(t, err) + assert.Equal(t, slog.LevelInfo, got) +} diff --git a/pkg/observability/logger_context.go b/pkg/observability/logger_context.go new file mode 100644 index 000000000..3eb4bbb8b --- /dev/null +++ b/pkg/observability/logger_context.go @@ -0,0 +1,32 @@ +package observability + +import ( + "context" + "log/slog" +) + +// loggerContextKey is the context key for request-scoped *slog.Logger. +// Using a private type prevents collisions with other packages. +type loggerContextKey struct{} + +// ContextWithLogger returns a new context that carries the supplied logger. +// Use this to attach a logger enriched with request-scoped attributes +// (e.g. tool name, request id) so downstream code can pick it up via +// LoggerFromContext. +func ContextWithLogger(ctx context.Context, logger *slog.Logger) context.Context { + if logger == nil { + return ctx + } + return context.WithValue(ctx, loggerContextKey{}, logger) +} + +// LoggerFromContext returns the request-scoped logger stored in ctx by +// ContextWithLogger, or nil if none was set. Callers should fall back to +// a base logger in that case. +func LoggerFromContext(ctx context.Context) *slog.Logger { + if ctx == nil { + return nil + } + logger, _ := ctx.Value(loggerContextKey{}).(*slog.Logger) + return logger +} diff --git a/pkg/observability/logger_context_test.go b/pkg/observability/logger_context_test.go new file mode 100644 index 000000000..52db39c7f --- /dev/null +++ b/pkg/observability/logger_context_test.go @@ -0,0 +1,29 @@ +package observability + +import ( + "context" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoggerContext_RoundTrip(t *testing.T) { + logger := slog.New(slog.DiscardHandler) + ctx := ContextWithLogger(context.Background(), logger) + assert.Equal(t, logger, LoggerFromContext(ctx)) +} + +func TestLoggerFromContext_Empty(t *testing.T) { + assert.Nil(t, LoggerFromContext(context.Background())) + // Defensive: nil context should not panic. Use a typed nil so staticcheck's + // SA1012 (which flags untyped nil Context literals) stays quiet. + var nilCtx context.Context + assert.Nil(t, LoggerFromContext(nilCtx)) +} + +func TestContextWithLogger_Nil(t *testing.T) { + // Storing a nil logger should not mask later reads. + ctx := ContextWithLogger(context.Background(), nil) + assert.Nil(t, LoggerFromContext(ctx)) +}