Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions airflow_dbt_python/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,39 @@ def upload_dbt_project(
project_dir, destination, replace=replace, delete_before=delete_before
)

def upload_dbt_artifacts(
self,
project_dir: URLLike,
artifacts: Iterable[str],
destination: URLLike,
replace: bool = False,
delete_before: bool = False,
) -> None:
"""Push dbt artifacts from a given project_dir.

This operation is delegated to a DbtRemoteHook. An optional connection id is
supported for remotes that require it.
"""
scheme = urlparse(str(destination)).scheme
remote = self.get_remote(scheme, self.project_conn_id)

return remote.upload_dbt_artifacts(
project_dir,
artifacts,
destination,
replace=replace,
delete_before=delete_before,
)

def run_dbt_task(
self,
command: str,
upload_dbt_project: bool = False,
delete_before_upload: bool = False,
replace_on_upload: bool = False,
artifacts: Optional[Iterable[str]] = None,
upload_dbt_artifacts: Optional[Iterable[str]] = None,
upload_dbt_artifacts_destination: Optional[URLLike] = None,
env_vars: Optional[Dict[str, Any]] = None,
**kwargs,
) -> DbtTaskResult:
Expand Down Expand Up @@ -281,6 +307,18 @@ def run_dbt_task(

saved_artifacts[artifact] = json_artifact

if (
upload_dbt_artifacts is not None
and upload_dbt_artifacts_destination is not None
):
self.upload_dbt_artifacts(
dbt_dir,
upload_dbt_artifacts,
upload_dbt_artifacts_destination,
replace_on_upload,
delete_before_upload,
)

return DbtTaskResult(success, results, saved_artifacts)

def get_dbt_task_config(self, command: str, **config_kwargs) -> BaseConfig:
Expand Down
36 changes: 35 additions & 1 deletion airflow_dbt_python/hooks/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Type
from typing import Iterable, Optional, Type

from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -137,6 +137,40 @@ def upload_dbt_project(
if destination_url.is_archive():
source_url.unlink()

def upload_dbt_artifacts(
self,
source: URLLike,
artifacts: Iterable[str],
destination: URLLike,
replace: bool = False,
delete_before: bool = False,
):
"""Upload specific dbt artifacts from a given source.

Args:
source: URLLike to a directory containing a dbt project.
artifacts: A list of artifacts to upload.
destination: URLLike to a directory where the dbt artifacts will be stored.
replace: Flag to indicate whether to replace existing files.
delete_before: Flag to indicate wheter to clear any existing files before
uploading the dbt project.
"""
source_url = URL(source) / "target"
destination_url = URL(destination)

for artifact in artifacts:
artifact_path = source_url / artifact
self.log.info(
"Uploading dbt artifact from %s to %s", artifact_path, destination
)

self.upload(
source=artifact_path,
destination=destination_url / artifact,
replace=replace,
delete_before=delete_before,
)


def get_remote(scheme: str, conn_id: Optional[str] = None) -> DbtRemoteHook:
"""Get a DbtRemoteHook as long as the scheme is supported.
Expand Down
9 changes: 9 additions & 0 deletions airflow_dbt_python/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.xcom import XCOM_RETURN_KEY

from airflow_dbt_python.utils.url import URLLike

if TYPE_CHECKING:
from dbt.contracts.results import RunResult

Expand Down Expand Up @@ -47,6 +49,9 @@ class DbtBaseOperator(BaseOperator):
log_cache_events: Flag to enable logging of cache events.
s3_conn_id: An s3 Airflow connection ID to use when pulling dbt files from s3.
do_xcom_push_artifacts: A list of dbt artifacts to XCom push.
upload_dbt_artifacts: A list of dbt artifacts to upload.
upload_dbt_artifacts_destination: Destination to upload dbt artifacts to.
This value also determines the remote to use when uploading the artifacts.
"""

template_fields = base_template_fields
Expand Down Expand Up @@ -93,6 +98,8 @@ def __init__(
profiles_conn_id: Optional[str] = None,
project_conn_id: Optional[str] = None,
do_xcom_push_artifacts: Optional[list[str]] = None,
upload_artifacts: Optional[list[str]] = None,
upload_artifacts_destination: Optional[URLLike] = None,
upload_dbt_project: bool = False,
delete_before_upload: bool = False,
replace_on_upload: bool = False,
Expand Down Expand Up @@ -144,6 +151,8 @@ def __init__(
self.profiles_conn_id = profiles_conn_id
self.project_conn_id = project_conn_id
self.do_xcom_push_artifacts = do_xcom_push_artifacts
self.upload_dbt_artifacts = upload_artifacts
self.upload_dbt_artifacts_destination = upload_artifacts_destination
self.upload_dbt_project = upload_dbt_project
self.delete_before_upload = delete_before_upload
self.replace_on_upload = replace_on_upload
Expand Down