Skip to content

Commit 543ad55

Browse files
committed
feat: Provide env_vars to update environment variables
Environment variables are used by dbt to template configuration files. Although users can now modify the environment of Airflow workers, there is no interface for them to do it in airflow-dbt-python. This commit adds an env_vars argument that can take a dictionary of environment variables to set during dbt execution. This can be useful for folks who wish to keep their profiles.yml templates.
1 parent c2907b4 commit 543ad55

File tree

10 files changed

+276
-36
lines changed

10 files changed

+276
-36
lines changed

airflow_dbt_python/hooks/dbt.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def run_dbt_task(
185185
delete_before_upload: bool = False,
186186
replace_on_upload: bool = False,
187187
artifacts: Optional[Iterable[str]] = None,
188+
env_vars: Optional[Dict[str, Any]] = None,
188189
**kwargs,
189190
) -> DbtTaskResult:
190191
"""Run a dbt task with a given configuration and return the results.
@@ -208,6 +209,7 @@ def run_dbt_task(
208209
upload_dbt_project=upload_dbt_project,
209210
delete_before_upload=delete_before_upload,
210211
replace_on_upload=replace_on_upload,
212+
env_vars=env_vars,
211213
) as dbt_dir:
212214
config.dbt_task.pre_init_hook(config)
213215
self.ensure_profiles(config.profiles_dir)
@@ -272,6 +274,7 @@ def dbt_directory(
272274
upload_dbt_project: bool = False,
273275
delete_before_upload: bool = False,
274276
replace_on_upload: bool = False,
277+
env_vars: Optional[Dict[str, Any]] = None,
275278
) -> Iterator[str]:
276279
"""Provides a temporary directory to execute dbt.
277280
@@ -284,43 +287,46 @@ def dbt_directory(
284287
Yields:
285288
The temporary directory's name.
286289
"""
290+
from airflow_dbt_python.utils.env import update_environment
291+
287292
store_profiles_dir = config.profiles_dir
288293
store_project_dir = config.project_dir
289294

290-
with TemporaryDirectory(prefix="airflow_tmp") as tmp_dir:
291-
self.log.info("Initializing temporary directory: %s", tmp_dir)
292-
293-
try:
294-
project_dir, profiles_dir = self.prepare_directory(
295-
tmp_dir,
296-
store_project_dir,
297-
store_profiles_dir,
298-
)
299-
except Exception as e:
300-
raise AirflowException(
301-
"Failed to prepare temporary directory for dbt execution"
302-
) from e
303-
304-
config.project_dir = project_dir
305-
config.profiles_dir = profiles_dir
306-
307-
if getattr(config, "state", None) is not None:
308-
state = Path(getattr(config, "state", ""))
309-
# Since we are running in a temporary directory, we need to make
310-
# state paths relative to this temporary directory.
311-
if not state.is_absolute():
312-
setattr(config, "state", str(Path(tmp_dir) / state))
313-
314-
yield tmp_dir
315-
316-
if upload_dbt_project is True:
317-
self.log.info("Uploading dbt project to: %s", store_project_dir)
318-
self.upload_dbt_project(
319-
tmp_dir,
320-
store_project_dir,
321-
replace=replace_on_upload,
322-
delete_before=delete_before_upload,
323-
)
295+
with update_environment(env_vars):
296+
with TemporaryDirectory(prefix="airflow_tmp") as tmp_dir:
297+
self.log.info("Initializing temporary directory: %s", tmp_dir)
298+
299+
try:
300+
project_dir, profiles_dir = self.prepare_directory(
301+
tmp_dir,
302+
store_project_dir,
303+
store_profiles_dir,
304+
)
305+
except Exception as e:
306+
raise AirflowException(
307+
"Failed to prepare temporary directory for dbt execution"
308+
) from e
309+
310+
config.project_dir = project_dir
311+
config.profiles_dir = profiles_dir
312+
313+
if getattr(config, "state", None) is not None:
314+
state = Path(getattr(config, "state", ""))
315+
# Since we are running in a temporary directory, we need to make
316+
# state paths relative to this temporary directory.
317+
if not state.is_absolute():
318+
setattr(config, "state", str(Path(tmp_dir) / state))
319+
320+
yield tmp_dir
321+
322+
if upload_dbt_project is True:
323+
self.log.info("Uploading dbt project to: %s", store_project_dir)
324+
self.upload_dbt_project(
325+
tmp_dir,
326+
store_project_dir,
327+
replace=replace_on_upload,
328+
delete_before=delete_before_upload,
329+
)
324330

325331
config.profiles_dir = store_profiles_dir
326332
config.project_dir = store_project_dir

airflow_dbt_python/operators/dbt.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from dataclasses import asdict, is_dataclass
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Optional, Union
8+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
99

1010
from airflow.exceptions import AirflowException
1111
from airflow.models.baseoperator import BaseOperator
@@ -24,6 +24,7 @@
2424
"target",
2525
"state",
2626
"vars",
27+
"env_vars",
2728
]
2829

2930

@@ -95,6 +96,7 @@ def __init__(
9596
upload_dbt_project: bool = False,
9697
delete_before_upload: bool = False,
9798
replace_on_upload: bool = False,
99+
env_vars: Optional[Dict[str, Any]] = None,
98100
**kwargs,
99101
) -> None:
100102
super().__init__(**kwargs)
@@ -145,6 +147,7 @@ def __init__(
145147
self.upload_dbt_project = upload_dbt_project
146148
self.delete_before_upload = delete_before_upload
147149
self.replace_on_upload = replace_on_upload
150+
self.env_vars = env_vars
148151

149152
self._dbt_hook: Optional[DbtHook] = None
150153

airflow_dbt_python/utils/env.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Provides utilities to interact with environment variables."""
2+
import copy
3+
import os
4+
from contextlib import contextmanager
5+
from typing import Any, Dict, Optional
6+
7+
8+
@contextmanager
9+
def update_environment(env_vars: Optional[Dict[str, Any]] = None):
10+
"""Update current environment with env_vars and restore afterwards."""
11+
if not env_vars:
12+
# Nothing to update or restore afterwards, so we return early
13+
yield os.environ
14+
return
15+
16+
restore_env = copy.deepcopy(os.environ)
17+
os.environ.update({k: str(v) for k, v in env_vars.items()})
18+
19+
try:
20+
yield os.environ
21+
finally:
22+
os.environ = restore_env

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,21 @@ def profiles_file(tmp_path_factory, database):
185185
return p
186186

187187

188+
@pytest.fixture(scope="session")
189+
def profiles_file_with_env(tmp_path_factory):
190+
"""Create a profiles.yml file that relies on env variables for db connection."""
191+
p = tmp_path_factory.mktemp(".dbt_with_env") / "profiles.yml"
192+
profiles_content = PROFILES.format(
193+
host="\"{{ env_var('DBT_HOST') }}\"",
194+
user="\"{{ env_var('DBT_USER') }}\"",
195+
port="\"{{ env_var('DBT_PORT') | int }}\"",
196+
password="\"{{ env_var('DBT_ENV_SECRET_PASSWORD') }}\"",
197+
dbname="\"{{ env_var('DBT_DBNAME') }}\"",
198+
)
199+
p.write_text(profiles_content)
200+
return p
201+
202+
188203
@pytest.fixture(scope="session")
189204
def airflow_conns(database):
190205
"""Create Airflow connections for testing.

tests/hooks/dbt/test_dbt_hook_base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Unit test module for the dbt hook base class."""
2+
import os
23
from pathlib import Path
34

45
import pytest
@@ -304,6 +305,28 @@ def test_dbt_directory_with_no_state(
304305
assert getattr(config, "state", None) is None
305306

306307

308+
def test_dbt_directory_with_env_vars(hook, profiles_file_with_env, dbt_project_file):
309+
"""Test dbt_directory sets environment variables."""
310+
config = RunTaskConfig(
311+
project_dir=dbt_project_file.parent,
312+
profiles_dir=profiles_file_with_env.parent,
313+
state="target/",
314+
)
315+
316+
assert "TEST_ENVAR0" not in os.environ
317+
assert "TEST_ENVAR1" not in os.environ
318+
319+
env_vars = {"TEST_ENVAR0": 1, "TEST_ENVAR1": "abc"}
320+
321+
with hook.dbt_directory(config, env_vars=env_vars) as tmp_dir:
322+
assert Path(tmp_dir).exists()
323+
assert os.environ.get("TEST_ENVAR0") == "1"
324+
assert os.environ.get("TEST_ENVAR1") == "abc"
325+
326+
assert "TEST_ENVAR0" not in os.environ
327+
assert "TEST_ENVAR1" not in os.environ
328+
329+
307330
@no_s3_remote
308331
def test_dbt_base_dbt_directory_changed_to_s3(
309332
dbt_project_file, profiles_file, s3_bucket, s3_hook, hook

tests/hooks/dbt/test_dbt_run.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Unit test module for running dbt run with the DbtHook."""
2+
import os
23
from pathlib import Path
34

45
import pytest
@@ -26,6 +27,36 @@ def test_dbt_run_task(hook, profiles_file, dbt_project_file, model_files):
2627
assert run_result.node.unique_id == f"model.test.model_{index}"
2728

2829

30+
def test_dbt_run_task_with_env_vars_profile(
31+
hook, profiles_file_with_env, dbt_project_file, model_files, database
32+
):
33+
"""Test a dbt run task that can render env variables in profile."""
34+
env = {
35+
"DBT_HOST": database.host,
36+
"DBT_USER": database.user,
37+
"DBT_PORT": str(database.port),
38+
"DBT_ENV_SECRET_PASSWORD": database.password,
39+
"DBT_DBNAME": database.dbname,
40+
}
41+
42+
result = hook.run_dbt_task(
43+
"run",
44+
project_dir=dbt_project_file.parent,
45+
profiles_dir=profiles_file_with_env.parent,
46+
select=[str(m.stem) for m in model_files],
47+
env_vars=env,
48+
)
49+
50+
assert result.success is True
51+
52+
assert len(result.run_results) == 3
53+
54+
# Start from 2 as model_1 is ephemeral, and ephemeral models are not built.
55+
for index, run_result in enumerate(result.run_results, start=2):
56+
assert run_result.status == RunStatus.Success
57+
assert run_result.node.unique_id == f"model.test.model_{index}"
58+
59+
2960
def test_dbt_run_task_one_file(hook, profiles_file, dbt_project_file, model_files):
3061
"""Test a dbt run task for only one file."""
3162
result = hook.run_dbt_task(

tests/hooks/dbt/test_dbt_seed.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,35 @@ def test_dbt_seed_task(profiles_file, dbt_project_file, seed_files):
1818

1919
assert result.success is True
2020

21+
assert result.run_results is not None
22+
assert len(result.run_results) == 2
23+
for index, run_result in enumerate(result.run_results, start=1):
24+
assert run_result.status == RunStatus.Success
25+
assert run_result.node.unique_id == f"seed.test.seed_{index}"
26+
27+
28+
def test_dbt_seed_task_with_env_vars_profile(
29+
hook, profiles_file_with_env, dbt_project_file, model_files, database
30+
):
31+
"""Test a dbt seed task that can render env variables in profile."""
32+
env = {
33+
"DBT_HOST": database.host,
34+
"DBT_USER": database.user,
35+
"DBT_PORT": str(database.port),
36+
"DBT_ENV_SECRET_PASSWORD": database.password,
37+
"DBT_DBNAME": database.dbname,
38+
}
39+
40+
result = hook.run_dbt_task(
41+
"run",
42+
project_dir=dbt_project_file.parent,
43+
profiles_dir=profiles_file_with_env.parent,
44+
select=[str(m.stem) for m in model_files],
45+
env_vars=env,
46+
)
47+
48+
assert result.success is True
49+
2150
assert len(result.run_results) == 2
2251
for index, run_result in enumerate(result.run_results, start=1):
2352
assert run_result.status == RunStatus.Success

tests/operators/test_dbt_run.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,45 @@ def test_dbt_run_models(profiles_file, dbt_project_file, model_files, logs_dir):
112112
)
113113

114114

115+
def test_dbt_run_models_with_env_vars(
116+
profiles_file_with_env, dbt_project_file, model_files, logs_dir, database
117+
):
118+
"""Test execution of DbtRunOperator with all models using env vars in profile."""
119+
env = {
120+
"DBT_HOST": database.host,
121+
"DBT_USER": database.user,
122+
"DBT_PORT": str(database.port),
123+
"DBT_ENV_SECRET_PASSWORD": database.password,
124+
"DBT_DBNAME": database.dbname,
125+
}
126+
127+
op = DbtRunOperator(
128+
task_id="dbt_task",
129+
project_dir=dbt_project_file.parent,
130+
profiles_dir=profiles_file_with_env.parent,
131+
models=[str(m.stem) for m in model_files],
132+
do_xcom_push=True,
133+
debug=True,
134+
env_vars=env,
135+
)
136+
137+
execution_results = op.execute({})
138+
run_result = execution_results["results"][0]
139+
140+
assert run_result["status"] == RunStatus.Success
141+
142+
log_file = logs_dir / "dbt.log"
143+
assert log_file.exists()
144+
145+
with open(log_file) as f:
146+
logs = f.read()
147+
148+
assert (
149+
"OK created view model public.model_4" in logs
150+
or "OK created sql view model public.model_4" in logs
151+
)
152+
153+
115154
def test_dbt_run_models_full_refresh(profiles_file, dbt_project_file, model_files):
116155
"""Test dbt run operator with all model files and full-refresh."""
117156
op = DbtRunOperator(

tests/utils/test_configs.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Unit test module for dbt task configuration utilities."""
2+
import os
3+
from unittest.mock import patch
4+
25
import pytest
3-
from dbt.exceptions import DbtProfileError
6+
from dbt.exceptions import DbtProfileError, EnvVarMissingError
47
from dbt.task.build import BuildTask
58
from dbt.task.compile import CompileTask
69
from dbt.task.debug import DebugTask
@@ -252,6 +255,40 @@ def test_base_config_create_dbt_profile(hook, profiles_file, dbt_project_file):
252255
assert target["type"] == "postgres"
253256

254257

258+
def test_base_config_create_dbt_profile_with_env_vars(
259+
profiles_file_with_env, dbt_project_file, database
260+
):
261+
"""Test the create_dbt_profile with a profiles file that contains env vars."""
262+
config = BaseConfig(
263+
project_dir=dbt_project_file.parent,
264+
profiles_dir=profiles_file_with_env.parent,
265+
)
266+
267+
with pytest.raises(EnvVarMissingError):
268+
# No environment set yet, we should fail.
269+
profile = config.create_dbt_profile()
270+
271+
env = {
272+
"DBT_HOST": database.host,
273+
"DBT_USER": database.user,
274+
"DBT_PORT": str(database.port),
275+
"DBT_ENV_SECRET_PASSWORD": database.password,
276+
"DBT_DBNAME": database.dbname,
277+
}
278+
279+
with patch.dict(os.environ, env):
280+
profile = config.create_dbt_profile()
281+
assert profile.credentials.password == database.password
282+
283+
target = profile.to_target_dict()
284+
assert target["name"] == "test"
285+
assert target["type"] == "postgres"
286+
assert target["host"] == database.host
287+
assert target["user"] == database.user
288+
assert target["port"] == database.port
289+
assert target["dbname"] == database.dbname
290+
291+
255292
def test_base_config_create_dbt_profile_with_extra_target(
256293
hook, profiles_file, dbt_project_file, airflow_conns
257294
):

0 commit comments

Comments
 (0)