diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 7ff132229..20ba2711b 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -101,6 +101,8 @@ var ( Version: version, Host: viper.GetString("host"), Port: viper.GetInt("port"), + BaseURL: viper.GetString("base-url"), + ResourcePath: viper.GetString("base-path"), ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), @@ -134,7 +136,11 @@ func init() { rootCmd.PersistentFlags().Bool("lockdown-mode", false, "Enable lockdown mode") rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") - rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port") + + // HTTP-specific flags + httpCmd.Flags().Int("port", 8082, "HTTP server port") + httpCmd.Flags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") + httpCmd.Flags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -150,7 +156,9 @@ func init() { _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) - _ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")) + _ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port")) + _ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url")) + _ = viper.BindPFlag("base-path", httpCmd.Flags().Lookup("base-path")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 7fa38a73d..9bb98b86b 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -8,6 +8,7 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/middleware" + "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/go-chi/chi/v5" @@ -25,11 +26,13 @@ type Handler struct { t translations.TranslationHelperFunc githubMcpServerFactory GitHubMCPServerFactoryFunc inventoryFactoryFunc InventoryFactoryFunc + oauthCfg *oauth.Config } type HandlerOptions struct { GitHubMcpServerFactory GitHubMCPServerFactoryFunc InventoryFactory InventoryFactoryFunc + OAuthConfig *oauth.Config FeatureChecker inventory.FeatureFlagChecker } @@ -47,6 +50,12 @@ func WithInventoryFactory(f InventoryFactoryFunc) HandlerOption { } } +func WithOAuthConfig(cfg *oauth.Config) HandlerOption { + return func(o *HandlerOptions) { + o.OAuthConfig = cfg + } +} + func WithFeatureChecker(checker inventory.FeatureFlagChecker) HandlerOption { return func(o *HandlerOptions) { o.FeatureChecker = checker @@ -83,14 +92,20 @@ func NewHTTPMcpHandler( t: t, githubMcpServerFactory: githubMcpServerFactory, inventoryFactoryFunc: inventoryFactory, + oauthCfg: opts.OAuthConfig, } } +func (h *Handler) RegisterMiddleware(r chi.Router) { + r.Use( + middleware.ExtractUserToken(h.oauthCfg), + middleware.WithRequestConfig, + ) +} + // RegisterRoutes registers the routes for the MCP server // URL-based values take precedence over header-based values func (h *Handler) RegisterRoutes(r chi.Router) { - r.Use(middleware.WithRequestConfig) - // Base routes r.Mount("/", h) r.With(withReadonly).Mount("/readonly", h) @@ -154,7 +169,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Stateless: true, }) - middleware.ExtractUserToken()(mcpHandler).ServeHTTP(w, r) + mcpHandler.ServeHTTP(w, r) } func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) { diff --git a/pkg/http/handler_test.go b/pkg/http/handler_test.go index d02797330..70258436c 100644 --- a/pkg/http/handler_test.go +++ b/pkg/http/handler_test.go @@ -11,6 +11,7 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/http/middleware" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/go-chi/chi/v5" @@ -294,6 +295,7 @@ func TestHTTPHandlerRoutes(t *testing.T) { // Create router and register routes r := chi.NewRouter() + r.Use(middleware.WithRequestConfig) handler.RegisterRoutes(r) // Create request diff --git a/pkg/http/headers/headers.go b/pkg/http/headers/headers.go index 20d436c7c..5ffe30806 100644 --- a/pkg/http/headers/headers.go +++ b/pkg/http/headers/headers.go @@ -21,6 +21,11 @@ const ( // RealIPHeader is a standard HTTP Header used to indicate the real IP address of the client. RealIPHeader = "X-Real-IP" + // ForwardedHostHeader is a standard HTTP Header for preserving the original Host header when proxying. + ForwardedHostHeader = "X-Forwarded-Host" + // ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying. + ForwardedProtoHeader = "X-Forwarded-Proto" + // RequestHmacHeader is used to authenticate requests to the Raw API. RequestHmacHeader = "Request-Hmac" diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index 6369abf14..26973a548 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -10,6 +10,7 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" httpheaders "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/mark" + "github.com/github/github-mcp-server/pkg/http/oauth" ) type authType int @@ -39,14 +40,14 @@ var supportedThirdPartyTokenPrefixes = []string{ // were 40 characters long and only contained the characters a-f and 0-9. var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`) -func ExtractUserToken() func(next http.Handler) http.Handler { +func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, token, err := parseAuthorizationHeader(r) if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec if errors.Is(err, errMissingAuthorizationHeader) { - // sendAuthChallenge(w, r, cfg, obsv) + sendAuthChallenge(w, r, oauthCfg) return } // For other auth errors (bad format, unsupported), return 400 @@ -62,6 +63,16 @@ func ExtractUserToken() func(next http.Handler) http.Handler { }) } } + +// sendAuthChallenge sends a 401 Unauthorized response with WWW-Authenticate header +// containing the OAuth protected resource metadata URL as per RFC 6750 and MCP spec. +func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) { + resourcePath := oauth.ResolveResourcePath(r, oauthCfg) + resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath) + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} + func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) { authHeader := req.Header.Get(httpheaders.AuthorizationHeader) if authHeader == "" { diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go new file mode 100644 index 000000000..ecdcf95ab --- /dev/null +++ b/pkg/http/oauth/oauth.go @@ -0,0 +1,243 @@ +// Package oauth provides OAuth 2.0 Protected Resource Metadata (RFC 9728) support +// for the GitHub MCP Server HTTP mode. +package oauth + +import ( + "fmt" + "net/http" + "strings" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +const ( + // OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata. + OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" + + // DefaultAuthorizationServer is GitHub's OAuth authorization server. + DefaultAuthorizationServer = "https://github.com/login/oauth" +) + +// SupportedScopes lists all OAuth scopes that may be required by MCP tools. +var SupportedScopes = []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", +} + +// Config holds the OAuth configuration for the MCP server. +type Config struct { + // BaseURL is the publicly accessible URL where this server is hosted. + // This is used to construct the OAuth resource URL. + BaseURL string + + // AuthorizationServer is the OAuth authorization server URL. + // Defaults to GitHub's OAuth server if not specified. + AuthorizationServer string + + // ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp"). + // This is used to restore the original path when a proxy strips a base path before forwarding. + // If empty, requests are treated as already using the external path. + ResourcePath string +} + +// AuthHandler handles OAuth-related HTTP endpoints. +type AuthHandler struct { + cfg *Config +} + +// NewAuthHandler creates a new OAuth auth handler. +func NewAuthHandler(cfg *Config) (*AuthHandler, error) { + if cfg == nil { + cfg = &Config{} + } + + // Default authorization server to GitHub + if cfg.AuthorizationServer == "" { + cfg.AuthorizationServer = DefaultAuthorizationServer + } + + return &AuthHandler{ + cfg: cfg, + }, nil +} + +// routePatterns defines the route patterns for OAuth protected resource metadata. +var routePatterns = []string{ + "", // Root: /.well-known/oauth-protected-resource + "/readonly", // Read-only mode + "/insiders", // Insiders mode + "/x/{toolset}", + "/x/{toolset}/readonly", +} + +// RegisterRoutes registers the OAuth protected resource metadata routes. +func (h *AuthHandler) RegisterRoutes(r chi.Router) { + for _, pattern := range routePatterns { + for _, route := range h.routesForPattern(pattern) { + path := OAuthProtectedResourcePrefix + route + r.Handle(path, h.metadataHandler()) + } + } +} + +func (h *AuthHandler) metadataHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resourcePath := resolveResourcePath( + strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix), + h.cfg.ResourcePath, + ) + resourceURL := h.buildResourceURL(r, resourcePath) + + metadata := &oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{h.cfg.AuthorizationServer}, + ResourceName: "GitHub MCP Server", + ScopesSupported: SupportedScopes, + BearerMethodsSupported: []string{"header"}, + } + + auth.ProtectedResourceMetadataHandler(metadata).ServeHTTP(w, r) + }) +} + +// routesForPattern generates route variants for a given pattern. +// GitHub strips the /mcp prefix before forwarding, so we register both variants: +// - With /mcp prefix: for direct access or when GitHub doesn't strip +// - Without /mcp prefix: for when GitHub has stripped the prefix +func (h *AuthHandler) routesForPattern(pattern string) []string { + basePaths := []string{""} + if basePath := normalizeBasePath(h.cfg.ResourcePath); basePath != "" { + basePaths = append(basePaths, basePath) + } else { + basePaths = append(basePaths, "/mcp") + } + + routes := make([]string, 0, len(basePaths)*2) + for _, basePath := range basePaths { + routes = append(routes, joinRoute(basePath, pattern)) + routes = append(routes, joinRoute(basePath, pattern)+"/") + } + + return routes +} + +// resolveResourcePath returns the externally visible resource path, +// restoring the configured base path when proxies strip it before forwarding. +func resolveResourcePath(path, basePath string) string { + if path == "" { + path = "/" + } + base := normalizeBasePath(basePath) + if base == "" { + return path + } + if path == "/" { + return base + } + if path == base || strings.HasPrefix(path, base+"/") { + return path + } + return base + path +} + +// ResolveResourcePath returns the externally visible resource path for a request. +// Exported for use by middleware. +func ResolveResourcePath(r *http.Request, cfg *Config) string { + basePath := "" + if cfg != nil { + basePath = cfg.ResourcePath + } + return resolveResourcePath(r.URL.Path, basePath) +} + +// buildResourceURL constructs the full resource URL for OAuth metadata. +func (h *AuthHandler) buildResourceURL(r *http.Request, resourcePath string) string { + host, scheme := GetEffectiveHostAndScheme(r, h.cfg) + baseURL := fmt.Sprintf("%s://%s", scheme, host) + if h.cfg.BaseURL != "" { + baseURL = strings.TrimSuffix(h.cfg.BaseURL, "/") + } + if resourcePath == "" { + resourcePath = "/" + } + if !strings.HasPrefix(resourcePath, "/") { + resourcePath = "/" + resourcePath + } + return baseURL + resourcePath +} + +// GetEffectiveHostAndScheme returns the effective host and scheme for a request. +func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive + if fh := r.Header.Get(headers.ForwardedHostHeader); fh != "" { + host = fh + } else { + host = r.Host + } + if host == "" { + host = "localhost" + } + if fp := r.Header.Get(headers.ForwardedProtoHeader); fp != "" { + scheme = strings.ToLower(fp) + } else { + if r.TLS != nil { + scheme = "https" + } else { + scheme = "http" + } + } + return +} + +// BuildResourceMetadataURL constructs the full URL to the OAuth protected resource metadata endpoint. +func BuildResourceMetadataURL(r *http.Request, cfg *Config, resourcePath string) string { + host, scheme := GetEffectiveHostAndScheme(r, cfg) + suffix := "" + if resourcePath != "" && resourcePath != "/" { + if !strings.HasPrefix(resourcePath, "/") { + suffix = "/" + resourcePath + } else { + suffix = resourcePath + } + } + if cfg != nil && cfg.BaseURL != "" { + return strings.TrimSuffix(cfg.BaseURL, "/") + OAuthProtectedResourcePrefix + suffix + } + return fmt.Sprintf("%s://%s%s%s", scheme, host, OAuthProtectedResourcePrefix, suffix) +} + +func normalizeBasePath(path string) string { + trimmed := strings.TrimSpace(path) + if trimmed == "" || trimmed == "/" { + return "" + } + if !strings.HasPrefix(trimmed, "/") { + trimmed = "/" + trimmed + } + return strings.TrimSuffix(trimmed, "/") +} + +func joinRoute(basePath, pattern string) string { + if basePath == "" { + return pattern + } + if pattern == "" { + return basePath + } + if strings.HasSuffix(basePath, "/") { + return strings.TrimSuffix(basePath, "/") + pattern + } + return basePath + pattern +} diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go new file mode 100644 index 000000000..9133e8331 --- /dev/null +++ b/pkg/http/oauth/oauth_test.go @@ -0,0 +1,615 @@ +package oauth + +import ( + "crypto/tls" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAuthHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + expectedAuthServer string + expectedResourcePath string + }{ + { + name: "nil config uses defaults", + cfg: nil, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "empty config uses defaults", + cfg: &Config{}, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "", + }, + { + name: "custom authorization server", + cfg: &Config{ + AuthorizationServer: "https://custom.example.com/oauth", + }, + expectedAuthServer: "https://custom.example.com/oauth", + expectedResourcePath: "", + }, + { + name: "custom base URL and resource path", + cfg: &Config{ + BaseURL: "https://example.com", + ResourcePath: "/mcp", + }, + expectedAuthServer: DefaultAuthorizationServer, + expectedResourcePath: "/mcp", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + require.NotNil(t, handler) + + assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer) + }) + } +} + +func TestGetEffectiveHostAndScheme(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupRequest func() *http.Request + cfg *Config + expectedHost string + expectedScheme string + }{ + { + name: "basic request without forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", // defaults to http + }, + { + name: "request with X-Forwarded-Host header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "http", + }, + { + name: "request with X-Forwarded-Proto header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "request with both forwarding headers", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + cfg: &Config{}, + expectedHost: "public.example.com", + expectedScheme: "https", + }, + { + name: "request with TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + { + name: "X-Forwarded-Proto takes precedence over TLS", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{} + req.Header.Set(headers.ForwardedProtoHeader, "http") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "http", + }, + { + name: "scheme is lowercased", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + req.Header.Set(headers.ForwardedProtoHeader, "HTTPS") + return req + }, + cfg: &Config{}, + expectedHost: "example.com", + expectedScheme: "https", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + host, scheme := GetEffectiveHostAndScheme(req, tc.cfg) + + assert.Equal(t, tc.expectedHost, host) + assert.Equal(t, tc.expectedScheme, scheme) + }) + } +} + +func TestResolveResourcePath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + expectedPath string + }{ + { + name: "no base path uses request path", + cfg: &Config{}, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/x/repos", nil) + }, + expectedPath: "/x/repos", + }, + { + name: "base path restored for root", + cfg: &Config{ + ResourcePath: "/mcp", + }, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/", nil) + }, + expectedPath: "/mcp", + }, + { + name: "base path restored for nested", + cfg: &Config{ + ResourcePath: "/mcp", + }, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/readonly", nil) + }, + expectedPath: "/mcp/readonly", + }, + { + name: "base path preserved when already present", + cfg: &Config{ + ResourcePath: "/mcp", + }, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/mcp/readonly/", nil) + }, + expectedPath: "/mcp/readonly/", + }, + { + name: "custom base path restored", + cfg: &Config{ + ResourcePath: "/api", + }, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/x/repos", nil) + }, + expectedPath: "/api/x/repos", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + path := ResolveResourcePath(req, tc.cfg) + + assert.Equal(t, tc.expectedPath, path) + }) + } +} + +func TestBuildResourceMetadataURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + setupRequest func() *http.Request + resourcePath string + expectedURL string + }{ + { + name: "root path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/", + expectedURL: "http://api.example.com/.well-known/oauth-protected-resource", + }, + { + name: "resource path preserves trailing slash", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp/", + expectedURL: "http://api.example.com/.well-known/oauth-protected-resource/mcp/", + }, + { + name: "with custom resource path", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "http://api.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with base URL config", + cfg: &Config{ + BaseURL: "https://custom.example.com", + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "/mcp", + expectedURL: "https://custom.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "with forwarded headers", + cfg: &Config{}, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Host = "internal.example.com" + req.Header.Set(headers.ForwardedHostHeader, "public.example.com") + req.Header.Set(headers.ForwardedProtoHeader, "https") + return req + }, + resourcePath: "/mcp", + expectedURL: "https://public.example.com/.well-known/oauth-protected-resource/mcp", + }, + { + name: "nil config uses request host", + cfg: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = "api.example.com" + return req + }, + resourcePath: "", + expectedURL: "http://api.example.com/.well-known/oauth-protected-resource", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := tc.setupRequest() + url := BuildResourceMetadataURL(req, tc.cfg, tc.resourcePath) + + assert.Equal(t, tc.expectedURL, url) + }) + } +} + +func TestHandleProtectedResource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + path string + host string + method string + expectedStatusCode int + expectedScopes []string + validateResponse func(t *testing.T, body map[string]any) + }{ + { + name: "GET request returns protected resource metadata", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + expectedScopes: SupportedScopes, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "GitHub MCP Server", body["resource_name"]) + assert.Equal(t, "https://api.example.com/", body["resource"]) + + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) + }, + }, + { + name: "OPTIONS request for CORS preflight", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodOptions, + expectedStatusCode: http.StatusNoContent, + }, + { + name: "path with /mcp suffix", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix + "/mcp", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "https://api.example.com/mcp", body["resource"]) + }, + }, + { + name: "path with /readonly suffix", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix + "/readonly", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "https://api.example.com/readonly", body["resource"]) + }, + }, + { + name: "path with trailing slash", + cfg: &Config{ + BaseURL: "https://api.example.com", + }, + path: OAuthProtectedResourcePrefix + "/mcp/", + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + assert.Equal(t, "https://api.example.com/mcp/", body["resource"]) + }, + }, + { + name: "custom authorization server in response", + cfg: &Config{ + BaseURL: "https://api.example.com", + AuthorizationServer: "https://custom.auth.example.com/oauth", + }, + path: OAuthProtectedResourcePrefix, + host: "api.example.com", + method: http.MethodGet, + expectedStatusCode: http.StatusOK, + validateResponse: func(t *testing.T, body map[string]any) { + t.Helper() + authServers, ok := body["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, authServers, 1) + assert.Equal(t, "https://custom.auth.example.com/oauth", authServers[0]) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(tc.cfg) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(tc.method, tc.path, nil) + req.Host = tc.host + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatusCode, rec.Code) + + // Check CORS headers + assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "GET") + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "OPTIONS") + + if tc.method == http.MethodGet && tc.validateResponse != nil { + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var body map[string]any + err := json.Unmarshal(rec.Body.Bytes(), &body) + require.NoError(t, err) + + tc.validateResponse(t, body) + + // Verify scopes if expected + if tc.expectedScopes != nil { + scopes, ok := body["scopes_supported"].([]any) + require.True(t, ok) + assert.Len(t, scopes, len(tc.expectedScopes)) + } + } + }) + } +} + +func TestRegisterRoutes(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{ + BaseURL: "https://api.example.com", + }) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + // List of expected routes that should be registered + expectedRoutes := []string{ + OAuthProtectedResourcePrefix, + OAuthProtectedResourcePrefix + "/", + OAuthProtectedResourcePrefix + "/mcp", + OAuthProtectedResourcePrefix + "/mcp/", + OAuthProtectedResourcePrefix + "/readonly", + OAuthProtectedResourcePrefix + "/readonly/", + OAuthProtectedResourcePrefix + "/mcp/readonly", + OAuthProtectedResourcePrefix + "/mcp/readonly/", + OAuthProtectedResourcePrefix + "/x/repos", + OAuthProtectedResourcePrefix + "/mcp/x/repos", + } + + for _, route := range expectedRoutes { + t.Run("route:"+route, func(t *testing.T) { + // Test GET + req := httptest.NewRequest(http.MethodGet, route, nil) + req.Host = "api.example.com" + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "GET %s should return 200", route) + + // Test OPTIONS (CORS preflight) + req = httptest.NewRequest(http.MethodOptions, route, nil) + req.Host = "api.example.com" + rec = httptest.NewRecorder() + router.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNoContent, rec.Code, "OPTIONS %s should return 204", route) + }) + } +} + +func TestSupportedScopes(t *testing.T) { + t.Parallel() + + // Verify all expected scopes are present + expectedScopes := []string{ + "repo", + "read:org", + "read:user", + "user:email", + "read:packages", + "write:packages", + "read:project", + "project", + "gist", + "notifications", + "workflow", + "codespace", + } + + assert.Equal(t, expectedScopes, SupportedScopes) +} + +func TestProtectedResourceResponseFormat(t *testing.T) { + t.Parallel() + + handler, err := NewAuthHandler(&Config{ + BaseURL: "https://api.example.com", + }) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil) + req.Host = "api.example.com" + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response map[string]any + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + // Verify all required RFC 9728 fields are present + assert.Contains(t, response, "resource") + assert.Contains(t, response, "authorization_servers") + assert.Contains(t, response, "bearer_methods_supported") + assert.Contains(t, response, "scopes_supported") + + // Verify resource name (optional but we include it) + assert.Contains(t, response, "resource_name") + assert.Equal(t, "GitHub MCP Server", response["resource_name"]) + + // Verify bearer_methods_supported contains "header" + bearerMethods, ok := response["bearer_methods_supported"].([]any) + require.True(t, ok) + assert.Contains(t, bearerMethods, "header") + + // Verify authorization_servers is an array with GitHub OAuth + authServers, ok := response["authorization_servers"].([]any) + require.True(t, ok) + assert.Len(t, authServers, 1) + assert.Equal(t, DefaultAuthorizationServer, authServers[0]) +} + +func TestOAuthProtectedResourcePrefix(t *testing.T) { + t.Parallel() + + // RFC 9728 specifies this well-known path + assert.Equal(t, "/.well-known/oauth-protected-resource", OAuthProtectedResourcePrefix) +} + +func TestDefaultAuthorizationServer(t *testing.T) { + t.Parallel() + + assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer) +} diff --git a/pkg/http/server.go b/pkg/http/server.go index 8ea8c641c..c2aad4c61 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -14,6 +14,7 @@ import ( ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/translations" @@ -38,6 +39,14 @@ type ServerConfig struct { // Port to listen on (default: 8082) Port int + // BaseURL is the publicly accessible URL of this server for OAuth resource metadata. + // If not set, the server will derive the URL from incoming request headers. + BaseURL string + + // ResourcePath is the externally visible base path for this server (e.g., "/mcp"). + // This is used to restore the original path when a proxy strips a base path before forwarding. + ResourcePath string + // ExportTranslations indicates if we should export translations // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions ExportTranslations bool @@ -107,8 +116,30 @@ func RunHTTPServer(cfg ServerConfig) error { r := chi.NewRouter() - handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithFeatureChecker(featureChecker)) - handler.RegisterRoutes(r) + // Register OAuth protected resource metadata endpoints + oauthCfg := &oauth.Config{ + BaseURL: cfg.BaseURL, + ResourcePath: cfg.ResourcePath, + } + oauthHandler, err := oauth.NewAuthHandler(oauthCfg) + if err != nil { + return fmt.Errorf("failed to create OAuth handler: %w", err) + } + + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg)) + + // MCP routes with middleware + r.Group(func(r chi.Router) { + handler.RegisterMiddleware(r) + handler.RegisterRoutes(r) + }) + logger.Info("MCP endpoints registered", "baseURL", cfg.BaseURL) + + // OAuth routes without MCP middleware + r.Group(func(r chi.Router) { + oauthHandler.RegisterRoutes(r) + }) + logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) addr := fmt.Sprintf(":%d", cfg.Port) httpSvr := http.Server{