Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/github-mcp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ var (
ExportTranslations: viper.GetBool("export-translations"),
EnableCommandLogging: viper.GetBool("enable-command-logging"),
LogFilePath: viper.GetString("log-file"),
LogLevel: viper.GetString("log-level"),
ContentWindowSize: viper.GetInt("content-window-size"),
LockdownMode: viper.GetBool("lockdown-mode"),
InsidersMode: viper.GetBool("insiders"),
Expand Down Expand Up @@ -137,6 +138,7 @@ var (
ExportTranslations: viper.GetBool("export-translations"),
EnableCommandLogging: viper.GetBool("enable-command-logging"),
LogFilePath: viper.GetString("log-file"),
LogLevel: viper.GetString("log-level"),
ContentWindowSize: viper.GetInt("content-window-size"),
LockdownMode: viper.GetBool("lockdown-mode"),
RepoAccessCacheTTL: &ttl,
Expand Down Expand Up @@ -168,6 +170,7 @@ func init() {
rootCmd.PersistentFlags().Bool("dynamic-toolsets", false, "Enable dynamic toolsets")
rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations")
rootCmd.PersistentFlags().String("log-file", "", "Path to log file")
rootCmd.PersistentFlags().String("log-level", "", "Log level (debug, info, warn, error). Defaults to debug when --log-file is set, info otherwise.")
rootCmd.PersistentFlags().Bool("enable-command-logging", false, "When enabled, the server will log all command requests and responses to the log file")
rootCmd.PersistentFlags().Bool("export-translations", false, "Save translations to a JSON file")
rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)")
Expand All @@ -190,6 +193,7 @@ func init() {
_ = viper.BindPFlag("dynamic_toolsets", rootCmd.PersistentFlags().Lookup("dynamic-toolsets"))
_ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only"))
_ = viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file"))
_ = viper.BindPFlag("log-level", rootCmd.PersistentFlags().Lookup("log-level"))
_ = viper.BindPFlag("enable-command-logging", rootCmd.PersistentFlags().Lookup("enable-command-logging"))
_ = viper.BindPFlag("export-translations", rootCmd.PersistentFlags().Lookup("export-translations"))
_ = viper.BindPFlag("host", rootCmd.PersistentFlags().Lookup("gh-host"))
Expand Down
22 changes: 20 additions & 2 deletions internal/ghmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ type StdioServerConfig struct {
// Path to the log file if not stderr
LogFilePath string

// LogLevel overrides the default log level ("debug", "info", "warn", "error").
// When empty, the default depends on LogFilePath: file-backed logs default
// to Debug (local troubleshooting), stderr logs default to Info.
LogLevel string

// Content window size
ContentWindowSize int

Expand Down Expand Up @@ -248,11 +253,14 @@ func RunStdioServer(cfg StdioServerConfig) error {
return fmt.Errorf("failed to open log file: %w", err)
}
logOutput = file
slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelDebug})
} else {
logOutput = os.Stderr
slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo})
}
level, err := observability.ParseLogLevel(cfg.LogLevel, defaultLogLevel(cfg.LogFilePath))
if err != nil {
return err
}
slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: level})
logger := slog.New(slogHandler)
logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode)

Expand Down Expand Up @@ -335,6 +343,16 @@ func RunStdioServer(cfg StdioServerConfig) error {
return nil
}

// defaultLogLevel returns the log level used when the user hasn't specified
// one. Debug when a log file is configured (intended for local debugging),
// Info when logs go to stderr (keeps terminal output quiet).
func defaultLogLevel(logFilePath string) slog.Level {
if logFilePath != "" {
return slog.LevelDebug
}
return slog.LevelInfo
}

// createFeatureChecker returns a FeatureFlagChecker that resolves features
// using the centralized ResolveFeatureFlags function. For the local server,
// features are resolved once at startup from --features CLI flag + insiders mode.
Expand Down
31 changes: 24 additions & 7 deletions pkg/github/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"log/slog"
"net/http"
"os"

ghcontext "github.com/github/github-mcp-server/pkg/context"
"github.com/github/github-mcp-server/pkg/http/transport"
Expand Down Expand Up @@ -187,7 +186,15 @@ func (d BaseDeps) GetFlags(_ context.Context) FeatureFlags { return d.Flags }
func (d BaseDeps) GetContentWindowSize() int { return d.ContentWindowSize }

// Logger implements ToolDependencies.
func (d BaseDeps) Logger(_ context.Context) *slog.Logger {
// If an enriched logger has been attached to ctx by ToolLoggingMiddleware
// (via observability.ContextWithLogger), that logger is returned so tool
// handlers inherit request-scoped attributes such as tool name and
// mcp.method. Otherwise the base logger from the observability exporters
// is returned.
func (d BaseDeps) Logger(ctx context.Context) *slog.Logger {
if l := observability.LoggerFromContext(ctx); l != nil {
return l
}
return d.Obsv.Logger()
}

Expand All @@ -206,8 +213,12 @@ func (d BaseDeps) IsFeatureEnabled(ctx context.Context, flagName string) bool {

enabled, err := d.featureChecker(ctx, flagName)
if err != nil {
// Log error but don't fail the tool - treat as disabled
fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err)
// Treat errors as disabled, but surface them via the logger so
// operators can diagnose feature-flag backend issues in production.
d.Logger(ctx).Warn("feature flag check failed",
slog.String("flag", flagName),
slog.String("error", err.Error()),
)
return false
}

Expand Down Expand Up @@ -406,7 +417,11 @@ func (d *RequestDeps) GetFlags(ctx context.Context) FeatureFlags {
func (d *RequestDeps) GetContentWindowSize() int { return d.ContentWindowSize }

// Logger implements ToolDependencies.
func (d *RequestDeps) Logger(_ context.Context) *slog.Logger {
// See BaseDeps.Logger for the context-scoped logger fallback behaviour.
func (d *RequestDeps) Logger(ctx context.Context) *slog.Logger {
if l := observability.LoggerFromContext(ctx); l != nil {
return l
}
return d.obsv.Logger()
}

Expand All @@ -423,8 +438,10 @@ func (d *RequestDeps) IsFeatureEnabled(ctx context.Context, flagName string) boo

enabled, err := d.featureChecker(ctx, flagName)
if err != nil {
// Log error but don't fail the tool - treat as disabled
fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err)
d.Logger(ctx).Warn("feature flag check failed",
slog.String("flag", flagName),
slog.String("error", err.Error()),
)
return false
}

Expand Down
116 changes: 116 additions & 0 deletions pkg/github/logging_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package github

import (
"context"
"log/slog"
"time"

"github.com/github/github-mcp-server/pkg/observability"
"github.com/modelcontextprotocol/go-sdk/mcp"
)

// MCPMethodCallTool is the JSON-RPC method name for MCP tool invocations.
// The SDK keeps its equivalent constant unexported, so we mirror it here.
const MCPMethodCallTool = "tools/call"

// ToolLoggingMiddleware returns an MCP middleware that uniformly logs every
// tool invocation with its name, duration, and outcome, and exposes a
// request-scoped *slog.Logger via observability.ContextWithLogger so tool
// handlers can retrieve an enriched logger from deps.Logger(ctx).
//
// Logging policy:
// - tools/call success: logged at Debug with tool name and duration.
// - tools/call failure (error return or IsError result): logged at Error
// with tool name, duration, and the error when present.
// - Non-tool methods pass through without emitting any log line; the
// middleware only attaches an enriched logger so downstream code can
// still benefit from the method tag if it chooses to log.
//
// The base logger comes from ToolDependencies on the context (populated by
// InjectDepsMiddleware), so this middleware must be registered AFTER
// InjectDepsMiddleware in the receiving middleware chain.
func ToolLoggingMiddleware() mcp.Middleware {
return func(next mcp.MethodHandler) mcp.MethodHandler {
return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) {
deps, ok := DepsFromContext(ctx)
if !ok {
// Deps not injected yet; nothing we can do but pass through.
return next(ctx, method, req)
}

base := deps.Logger(ctx)
if base == nil {
return next(ctx, method, req)
}

logger := base.With(slog.String("mcp.method", method))
toolName := toolNameFromRequest(method, req)
if toolName != "" {
logger = logger.With(slog.String("mcp.tool", toolName))
}

ctx = observability.ContextWithLogger(ctx, logger)

// Only time+log for tool calls. Other methods (initialize,
// resources/list, etc.) are infrastructure chatter we leave
// to the SDK unless a handler chooses to log explicitly.
if method != MCPMethodCallTool {
return next(ctx, method, req)
}

start := time.Now()
result, err := next(ctx, method, req)
duration := time.Since(start)

switch {
case err != nil:
logger.LogAttrs(ctx, slog.LevelError, "tool call failed",
slog.Duration("duration", duration),
slog.String("error", err.Error()),
)
case isErrorResult(result):
logger.LogAttrs(ctx, slog.LevelError, "tool call returned error result",
slog.Duration("duration", duration),
)
default:
logger.LogAttrs(ctx, slog.LevelDebug, "tool call succeeded",
slog.Duration("duration", duration),
)
}

return result, err
}
}
}

// toolNameFromRequest extracts the tool name from a tools/call request.
// Returns "" for other methods or when the name cannot be determined.
func toolNameFromRequest(method string, req mcp.Request) string {
if method != MCPMethodCallTool || req == nil {
return ""
}
switch p := req.GetParams().(type) {
case *mcp.CallToolParams:
if p != nil {
return p.Name
}
case *mcp.CallToolParamsRaw:
if p != nil {
return p.Name
}
}
return ""
}

// isErrorResult reports whether the MCP result represents a tool-reported
// error (CallToolResult.IsError == true). A returned Go error is handled
// separately by the caller.
func isErrorResult(r mcp.Result) bool {
if r == nil {
return false
}
if ctr, ok := r.(*mcp.CallToolResult); ok && ctr != nil {
return ctr.IsError
}
return false
}
Loading
Loading