Skip to content

Commit c58e6db

Browse files
committed
fix(git-hook): Pass branch as bytes
Dulwich expects repo branch to be bytes, not str. This seems to be a gap in the type hinting of dulwich, as other arguments to client.clone can be passed as str. Regardless, as a workaround, we can encode branch as bytes.
1 parent c318bbc commit c58e6db

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

airflow_dbt_python/hooks/fs/git.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,11 @@ def _download(
157157
client, path, branch = self.get_git_client_path(source)
158158

159159
client.clone(
160-
path, str(destination), mkdir=not destination.exists(), branch=branch
160+
path,
161+
str(destination),
162+
mkdir=not destination.exists(),
163+
# NOTE: Dulwich expects branch to be bytes if defined.
164+
branch=branch.encode("utf-8") if isinstance(branch, str) else branch,
161165
)
162166

163167
def get_git_client_path(self, url: URL) -> Tuple[GitClients, str, Optional[str]]:

tests/hooks/test_git_hook.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,16 @@ def repo_name():
294294
return "test/test_shop"
295295

296296

297+
@pytest.fixture
298+
def repo_branch(request) -> bytes | None:
299+
"""A configurable local git repo branch."""
300+
try:
301+
return request.param
302+
except AttributeError:
303+
# Default to dulwich's
304+
return None
305+
306+
297307
@pytest.fixture
298308
def repo_dir(tmp_path):
299309
"""A testing local git repo directory."""
@@ -303,9 +313,9 @@ def repo_dir(tmp_path):
303313

304314

305315
@pytest.fixture
306-
def repo(repo_dir, dbt_project_file, test_files, profiles_file):
316+
def repo(repo_dir, dbt_project_file, test_files, profiles_file, repo_branch):
307317
"""Initialize a git repo with some dbt test files."""
308-
repo = Repo.init(repo_dir)
318+
repo = Repo.init(repo_dir, default_branch=repo_branch)
309319
shutil.copyfile(dbt_project_file, repo_dir / "dbt_project.yml")
310320
repo.stage("dbt_project.yml")
311321

@@ -364,6 +374,31 @@ def test_download_dbt_project_with_local_server(
364374
assert_dir_contents(local_repo_path, expected, exact=False)
365375

366376

377+
@no_git_local_server
378+
@pytest.mark.parametrize("repo_branch", ["test-branch".encode("utf-8")], indirect=True)
379+
def test_download_dbt_project_with_custom_branch_from_local_server(
380+
tmp_path, git_server, repo_name, assert_dir_contents, repo_branch
381+
):
382+
"""Test downloading a dbt project from a local git server."""
383+
local_path = tmp_path / "local"
384+
fs_hook = DbtGitFSHook()
385+
server_address, server_port = git_server
386+
source = URL(
387+
f"git://{server_address}:{server_port}/{repo_name}@{repo_branch.decode('utf-8')}"
388+
)
389+
local_repo_path = fs_hook.download_dbt_project(source, local_path)
390+
391+
expected = [
392+
URL(local_repo_path / "dbt_project.yml"),
393+
URL(local_repo_path / "models" / "a_model.sql"),
394+
URL(local_repo_path / "models" / "another_model.sql"),
395+
URL(local_repo_path / "seeds" / "a_seed.csv"),
396+
]
397+
398+
assert local_repo_path.exists()
399+
assert_dir_contents(local_repo_path, expected, exact=False)
400+
401+
367402
@pytest.fixture
368403
def pre_run(hook, repo_dir):
369404
"""Fixture to run a dbt run task."""

0 commit comments

Comments
 (0)