Skip to content

Commit b705b42

Browse files
committed
feat(operator): Add new operator for dbt docs generate
Support for dbt docs generate itself is pretty trivial as we only needed a new configuration and dbt docs generate does not take many configuration arguments. The biggest challenge of supporting dbt docs generate is solving the issue of running dbt docs generate in a temporary directory: there isn't much sense in executing dbt docs generate if we cannot use the results of the execution. The alternative of pushing the dbt artifacts to XCom is, of course, still available, but I wanted to use this an excuse to develop something else. So, we also added a push_dbt_project method to our S3 hook. This method is called right before exiting the temporary directory, and can be controlled with a push_dbt_project argument to the operator. The argument is by default True for the new DbtDocsGenerateOperator so that we can push the documentation files to S3. It is also True by default when running DbtDepsOperator since there was no way for us install dependencies. In the future, I'd like us to support more "dbt backends" besides S3: maybe support for pulling projects from Github? SSH and SFTP connections? Lots of possibilities!
1 parent e9f48ff commit b705b42

File tree

7 files changed

+419
-28
lines changed

7 files changed

+419
-28
lines changed

airflow_dbt_python/hooks/dbt.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from dbt.task.debug import DebugTask
2626
from dbt.task.deps import DepsTask
2727
from dbt.task.freshness import FreshnessTask
28+
from dbt.task.generate import GenerateTask
2829
from dbt.task.list import ListTask
2930
from dbt.task.parse import ParseTask
3031
from dbt.task.run import RunTask
@@ -321,6 +322,15 @@ class DepsTaskConfig(BaseConfig):
321322
which: str = dataclasses.field(default="deps", init=False)
322323

323324

325+
@dataclass
326+
class GenerateTaskConfig(SelectionConfig):
327+
"""Generate task arguments."""
328+
329+
cls: BaseTask = dataclasses.field(default=GenerateTask, init=False)
330+
compile: bool = True
331+
which: str = dataclasses.field(default="generate", init=False)
332+
333+
324334
@dataclass
325335
class ListTaskConfig(SelectionConfig):
326336
"""Dbt list task arguments."""
@@ -428,6 +438,7 @@ class ConfigFactory(FromStrMixin, Enum):
428438
CLEAN = CleanTaskConfig
429439
DEBUG = DebugTaskConfig
430440
DEPS = DepsTaskConfig
441+
GENERATE = GenerateTaskConfig
431442
LIST = ListTaskConfig
432443
PARSE = ParseTaskConfig
433444
RUN = RunTaskConfig

airflow_dbt_python/hooks/s3.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ class DbtS3Hook(S3Hook):
1515
all the files corresponding to a project.
1616
"""
1717

18-
def get_dbt_profiles(
18+
def pull_dbt_profiles(
1919
self, s3_profiles_url: str, profiles_dir: Optional[str] = None
2020
) -> Path:
21-
"""Fetch a dbt profiles file from S3.
21+
"""Pull a dbt profiles file from S3.
2222
23-
Fetches dbt profiles.yml file from the directory given by s3_profiles_url
24-
and pulls it to profiles_dir/profiles.yml.
23+
Pulls dbt profiles.yml file from the directory given by s3_profiles_url
24+
and saves it to profiles_dir/profiles.yml.
2525
2626
Args:
2727
s3_profiles_url: An S3 URL to a directory containing the dbt profiles file.
@@ -63,13 +63,13 @@ def download_one_s3_object(self, target: Path, s3_object):
6363
# exists.
6464
self.log.warning("A file with no name was found in S3 at %s", s3_object)
6565

66-
def get_dbt_project(
66+
def pull_dbt_project(
6767
self, s3_project_url: str, project_dir: Optional[str] = None
6868
) -> Path:
69-
"""Fetch all dbt project files from S3.
69+
"""Pull all dbt project files from S3.
7070
71-
Fetches the dbt project files from the directory given by s3_project_url
72-
and pulls them to project_dir. However, if the URL points to a zip file,
71+
Pulls the dbt project files from the directory given by s3_project_url
72+
and saves them to project_dir. However, if the URL points to a zip file,
7373
we assume it contains all the project files, and only download and unzip that
7474
instead.
7575
@@ -116,7 +116,7 @@ def get_dbt_project(
116116
def download_many_s3_keys(
117117
self, bucket_name: str, s3_keys: list[str], target_dir: Path, prefix: str
118118
):
119-
"""Download multiple s3 keys."""
119+
"""Download multiple S3 keys."""
120120
for s3_object_key in s3_keys:
121121
s3_object = self.get_key(key=s3_object_key, bucket_name=bucket_name)
122122
path_file = Path(s3_object_key).relative_to(prefix)
@@ -131,3 +131,35 @@ def download_many_s3_keys(
131131
local_project_file.parent.mkdir(parents=True, exist_ok=True)
132132

133133
self.download_one_s3_object(local_project_file, s3_object)
134+
135+
def push_dbt_project(self, s3_project_url: str, project_dir: str):
136+
"""Push a dbt project to S3."""
137+
bucket_name, key = self.parse_s3_url(s3_project_url)
138+
dbt_project_files = Path(project_dir).glob("**/*")
139+
140+
if key.endswith(".zip"):
141+
zip_file_path = Path(project_dir) / "dbt_project.zip"
142+
with ZipFile(zip_file_path, "w") as zf:
143+
for _file in dbt_project_files:
144+
zf.write(_file, arcname=_file.relative_to(project_dir))
145+
146+
self.load_file(
147+
zip_file_path, key=s3_project_url, bucket_name=bucket_name, replace=True
148+
)
149+
zip_file_path.unlink()
150+
151+
else:
152+
for _file in dbt_project_files:
153+
if _file.is_dir():
154+
continue
155+
156+
s3_key = f"s3://{bucket_name}/{key}{ _file.relative_to(project_dir)}"
157+
158+
self.load_file(
159+
filename=_file,
160+
key=s3_key,
161+
bucket_name=bucket_name,
162+
replace=True,
163+
)
164+
165+
self.log.info("Pushed dbt project to: %s", s3_project_url)

airflow_dbt_python/operators/dbt.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
# Extra features configuration
8989
s3_conn_id: str = "aws_default",
9090
do_xcom_push_artifacts: Optional[list[str]] = None,
91+
push_dbt_project: bool = False,
9192
**kwargs,
9293
) -> None:
9394
super().__init__(**kwargs)
@@ -132,6 +133,7 @@ def __init__(
132133

133134
self.s3_conn_id = s3_conn_id
134135
self.do_xcom_push_artifacts = do_xcom_push_artifacts
136+
self.push_dbt_project = push_dbt_project
135137
self._s3_hook = None
136138
self._dbt_hook = None
137139

@@ -222,7 +224,10 @@ def dbt_directory(self) -> Iterator[str]:
222224
"""Provides a temporary directory to execute dbt.
223225
224226
Creates a temporary directory for dbt to run in and prepares the dbt files
225-
if they need to be pulled from S3.
227+
if they need to be pulled from S3. If a S3 backend is being used, and
228+
self.push_dbt_project is True, before leaving the temporary directory, we push
229+
back the project to S3. Pushing back a project enables commands like deps or
230+
docs generate.
226231
227232
Yields:
228233
The temporary directory's name.
@@ -242,22 +247,29 @@ def dbt_directory(self) -> Iterator[str]:
242247

243248
yield tmp_dir
244249

250+
if (
251+
self.push_dbt_project is True
252+
and urlparse(str(store_project_dir)).scheme == "s3"
253+
):
254+
self.log.info("Pushing dbt project back to S3: %s", store_project_dir)
255+
self.s3_hook.push_dbt_project(store_project_dir, tmp_dir)
256+
245257
self.profiles_dir = store_profiles_dir
246258
self.project_dir = store_project_dir
247259

248260
def prepare_directory(self, tmp_dir: str):
249261
"""Prepares a dbt directory by pulling files from S3."""
250262
if urlparse(str(self.profiles_dir)).scheme == "s3":
251263
self.log.info("Fetching profiles.yml from S3: %s", self.profiles_dir)
252-
profiles_file_path = self.s3_hook.get_dbt_profiles(
264+
profiles_file_path = self.s3_hook.pull_dbt_profiles(
253265
self.profiles_dir,
254266
tmp_dir,
255267
)
256268
self.profiles_dir = str(profiles_file_path.parent) + "/"
257269

258270
if urlparse(str(self.project_dir)).scheme == "s3":
259271
self.log.info("Fetching dbt project from S3: %s", self.project_dir)
260-
project_dir_path = self.s3_hook.get_dbt_project(
272+
project_dir_path = self.s3_hook.pull_dbt_project(
261273
self.project_dir,
262274
tmp_dir,
263275
)
@@ -463,15 +475,34 @@ class DbtDepsOperator(DbtBaseOperator):
463475
https://docs.getdbt.com/reference/commands/deps.
464476
"""
465477

466-
def __init__(self, **kwargs) -> None:
478+
def __init__(self, push_dbt_project=True, **kwargs) -> None:
467479
super().__init__(**kwargs)
480+
self.push_dbt_project = push_dbt_project
468481

469482
@property
470483
def command(self) -> str:
471484
"""Return the deps command."""
472485
return "deps"
473486

474487

488+
class DbtDocsGenerateOperator(DbtBaseOperator):
489+
"""Executes a dbt docs generate command.
490+
491+
The documentation for the dbt command can be found here:
492+
https://docs.getdbt.com/reference/commands/cmd-docs.
493+
"""
494+
495+
def __init__(self, compile=True, push_dbt_project=True, **kwargs) -> None:
496+
super().__init__(**kwargs)
497+
self.compile = compile
498+
self.push_dbt_project = push_dbt_project
499+
500+
@property
501+
def command(self) -> str:
502+
"""Return the generate command."""
503+
return "generate"
504+
505+
475506
class DbtCleanOperator(DbtBaseOperator):
476507
"""Executes a dbt clean command.
477508
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Unit test module for running dbt docs generate with the DbtHook."""
2+
3+
4+
def test_dbt_docs_generate_task(hook, profiles_file, dbt_project_file, model_files):
5+
"""Test a dbt docs generate task."""
6+
import shutil
7+
8+
target_dir = dbt_project_file.parent / "target"
9+
if target_dir.exists() is True:
10+
shutil.rmtree(target_dir)
11+
assert target_dir.exists() is False
12+
13+
factory = hook.get_config_factory("generate")
14+
config = factory.create_config(
15+
project_dir=dbt_project_file.parent,
16+
profiles_dir=profiles_file.parent,
17+
)
18+
success, results = hook.run_dbt_task(config)
19+
20+
assert success is True
21+
assert results is not None
22+
assert target_dir.exists() is True
23+
24+
index = target_dir / "index.html"
25+
assert index.exists() is True
26+
27+
manifest = target_dir / "manifest.json"
28+
assert manifest.exists() is True
29+
30+
catalog = target_dir / "catalog.json"
31+
assert catalog.exists() is True
32+
33+
34+
def test_dbt_docs_generate_task_no_compile(
35+
hook, profiles_file, dbt_project_file, model_files
36+
):
37+
"""Test a dbt docs generate task without compiling."""
38+
import shutil
39+
40+
target_dir = dbt_project_file.parent / "target"
41+
if target_dir.exists() is True:
42+
shutil.rmtree(target_dir)
43+
assert target_dir.exists() is False
44+
45+
factory = hook.get_config_factory("generate")
46+
config = factory.create_config(
47+
project_dir=dbt_project_file.parent,
48+
profiles_dir=profiles_file.parent,
49+
compile=False,
50+
)
51+
success, results = hook.run_dbt_task(config)
52+
53+
assert success is True
54+
assert results is not None
55+
assert target_dir.exists() is True
56+
57+
index = target_dir / "index.html"
58+
assert index.exists() is True
59+
60+
manifest = target_dir / "manifest.json"
61+
assert manifest.exists() is False
62+
63+
catalog = target_dir / "catalog.json"
64+
assert catalog.exists() is True

0 commit comments

Comments
 (0)