diff --git a/src/git/src/mcp_server_git/server.py b/src/git/src/mcp_server_git/server.py index 1d0298b465..b9202293db 100644 --- a/src/git/src/mcp_server_git/server.py +++ b/src/git/src/mcp_server_git/server.py @@ -1,4 +1,5 @@ import logging +import re from pathlib import Path from typing import Sequence, Optional from mcp.server import Server @@ -92,6 +93,19 @@ class GitBranch(BaseModel): ) +class GitCurrentBranch(BaseModel): + repo_path: str + +class GitDefaultBranch(BaseModel): + repo_path: str + remote: str = Field( + "origin", + description="The remote to get the default branch for (defaults to 'origin')" + ) + +class GitRemote(BaseModel): + repo_path: str + class GitTools(str, Enum): STATUS = "git_status" DIFF_UNSTAGED = "git_diff_unstaged" @@ -106,6 +120,9 @@ class GitTools(str, Enum): SHOW = "git_show" BRANCH = "git_branch" + CURRENT_BRANCH = "git_current_branch" + DEFAULT_BRANCH = "git_default_branch" + REMOTE = "git_remote" def git_status(repo: git.Repo) -> str: return repo.git.status() @@ -268,6 +285,42 @@ def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, no return branch_info +def git_current_branch(repo: git.Repo) -> str: + if repo.head.is_detached: + return f"HEAD detached at {repo.head.commit.hexsha[:7]}" + return repo.active_branch.name + +def git_default_branch(repo: git.Repo, remote: str = "origin") -> str: + # Try git ls-remote --symref to detect remote HEAD + try: + output = repo.git.ls_remote("--symref", remote, "HEAD") + # Output format: "ref: refs/heads/main\tHEAD\n\tHEAD" + match = re.search(r"^ref: refs/heads/(\S+)\t", output, re.MULTILINE) + if match: + return f"{remote}/{match.group(1)}" + except git.GitCommandError: + pass + + # Try local ref resolution via rev-parse (returns "origin/main" directly) + try: + return repo.git.rev_parse("--abbrev-ref", f"{remote}/HEAD") + except git.GitCommandError: + pass + + # Fallback: check for common local branch names + local_branches = [ref.name for ref in repo.branches] + if "main" in local_branches: + return f"{remote}/main" + if "master" in local_branches: + return f"{remote}/master" + + raise ValueError( + f"Could not determine the default branch for remote '{remote}'" + ) + +def git_remote(repo: git.Repo) -> str: + return repo.git.remote("-v") + async def serve(repository: Path | None) -> None: logger = logging.getLogger(__name__) @@ -345,8 +398,22 @@ async def list_tools() -> list[Tool]: name=GitTools.BRANCH, description="List Git branches", inputSchema=GitBranch.model_json_schema(), - - ) + ), + Tool( + name=GitTools.CURRENT_BRANCH, + description="Returns the name of the currently checked out branch, or the commit SHA if HEAD is detached", + inputSchema=GitCurrentBranch.model_json_schema(), + ), + Tool( + name=GitTools.DEFAULT_BRANCH, + description="Returns the default branch for a remote in '/' format (e.g., 'origin/main'). Queries the remote directly, with fallback to local ref resolution and branch detection.", + inputSchema=GitDefaultBranch.model_json_schema(), + ), + Tool( + name=GitTools.REMOTE, + description="Lists all configured remotes with their fetch and push URLs", + inputSchema=GitRemote.model_json_schema(), + ), ] async def list_repos() -> Sequence[str]: @@ -488,6 +555,30 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: text=result )] + case GitTools.CURRENT_BRANCH: + result = git_current_branch(repo) + return [TextContent( + type="text", + text=result + )] + + case GitTools.DEFAULT_BRANCH: + result = git_default_branch( + repo, + arguments.get("remote", "origin") + ) + return [TextContent( + type="text", + text=result + )] + + case GitTools.REMOTE: + result = git_remote(repo) + return [TextContent( + type="text", + text=result + )] + case _: raise ValueError(f"Unknown tool: {name}") diff --git a/src/git/tests/test_server.py b/src/git/tests/test_server.py index 054bf8c756..32b0095960 100644 --- a/src/git/tests/test_server.py +++ b/src/git/tests/test_server.py @@ -5,6 +5,9 @@ from mcp_server_git.server import ( git_checkout, git_branch, + git_current_branch, + git_default_branch, + git_remote, git_add, git_status, git_diff_unstaged, @@ -423,3 +426,145 @@ def test_git_checkout_rejects_malicious_refs(test_repository): # Cleanup malicious_ref_path.unlink() + + +# Tests for git_current_branch + +def test_git_current_branch(test_repository): + result = git_current_branch(test_repository) + assert result == test_repository.active_branch.name + +def test_git_current_branch_detached_head(test_repository): + commit_sha = test_repository.head.commit.hexsha + test_repository.git.checkout(commit_sha) + result = git_current_branch(test_repository) + assert "detached" in result.lower() + assert commit_sha[:7] in result + + +# Tests for git_default_branch + +def test_git_default_branch_fallback_local(test_repository): + """Repo with no remote; falls back to detecting the local default branch name.""" + default_branch = test_repository.active_branch.name + result = git_default_branch(test_repository) + assert result == f"origin/{default_branch}" + +def test_git_default_branch_with_remote(tmp_path): + """Create a bare remote repo, add it as origin, verify ls-remote detection works.""" + # Create a bare repo to act as the remote + bare_path = tmp_path / "bare_remote.git" + bare_repo = git.Repo.init(bare_path, bare=True) + + # Create a local repo and push to the bare remote + local_path = tmp_path / "local_repo" + local_repo = git.Repo.init(local_path) + + Path(local_path / "test.txt").write_text("test") + local_repo.index.add(["test.txt"]) + local_repo.index.commit("initial commit") + + local_repo.create_remote("origin", str(bare_path)) + local_repo.git.push("--set-upstream", "origin", local_repo.active_branch.name) + + result = git_default_branch(local_repo) + assert result == f"origin/{local_repo.active_branch.name}" + + shutil.rmtree(local_path) + shutil.rmtree(bare_path) + +def test_git_default_branch_custom_remote(tmp_path): + """Add a remote with a non-'origin' name, verify the remote parameter selects it.""" + bare_path = tmp_path / "custom_remote.git" + bare_repo = git.Repo.init(bare_path, bare=True) + + local_path = tmp_path / "local_repo" + local_repo = git.Repo.init(local_path) + + Path(local_path / "test.txt").write_text("test") + local_repo.index.add(["test.txt"]) + local_repo.index.commit("initial commit") + + local_repo.create_remote("upstream", str(bare_path)) + local_repo.git.push("--set-upstream", "upstream", local_repo.active_branch.name) + + result = git_default_branch(local_repo, remote="upstream") + assert result == f"upstream/{local_repo.active_branch.name}" + + shutil.rmtree(local_path) + shutil.rmtree(bare_path) + +def test_git_default_branch_undetectable(tmp_path): + """Repo with no remotes and no main/master branch; should raise ValueError.""" + repo_path = tmp_path / "no_default_repo" + repo = git.Repo.init(repo_path) + + # Create a commit on a non-standard branch name + repo.git.checkout("-b", "develop") + Path(repo_path / "test.txt").write_text("test") + repo.index.add(["test.txt"]) + repo.index.commit("initial commit") + + with pytest.raises(ValueError, match="Could not determine the default branch"): + git_default_branch(repo) + + shutil.rmtree(repo_path) + +def test_git_default_branch_revparse_fallback(tmp_path): + """When ls-remote fails but local ref cache exists, rev-parse fallback should work.""" + # Create a bare repo to act as the remote + bare_path = tmp_path / "bare_remote.git" + git.Repo.init(bare_path, bare=True) + + # Create a local repo and push to the bare remote + local_path = tmp_path / "local_repo" + local_repo = git.Repo.init(local_path) + + Path(local_path / "test.txt").write_text("test") + local_repo.index.add(["test.txt"]) + local_repo.index.commit("initial commit") + + active_branch = local_repo.active_branch.name + local_repo.create_remote("origin", str(bare_path)) + local_repo.git.push("--set-upstream", "origin", active_branch) + + # Populate local ref cache for origin/HEAD + local_repo.git.remote("set-head", "origin", "--auto") + + # Replace remote URL with an invalid path so ls-remote will fail + local_repo.git.remote("set-url", "origin", "/nonexistent/path") + + result = git_default_branch(local_repo) + assert result == f"origin/{active_branch}" + + shutil.rmtree(local_path) + shutil.rmtree(bare_path) + + +# Tests for git_remote + +def test_git_remote_no_remotes(test_repository): + """Repo with no remotes; verify empty output.""" + result = git_remote(test_repository) + assert result == "" + +def test_git_remote_with_remote(tmp_path): + """Repo with a remote configured; verify remote name and URL appear in output.""" + bare_path = tmp_path / "bare_remote.git" + git.Repo.init(bare_path, bare=True) + + local_path = tmp_path / "local_repo" + local_repo = git.Repo.init(local_path) + + Path(local_path / "test.txt").write_text("test") + local_repo.index.add(["test.txt"]) + local_repo.index.commit("initial commit") + + local_repo.create_remote("origin", str(bare_path)) + + result = git_remote(local_repo) + assert "origin" in result + assert str(bare_path) in result + + shutil.rmtree(local_path) + shutil.rmtree(bare_path)