From 1d6881a91967b665921375dbc21859caaaaecc6b Mon Sep 17 00:00:00 2001 From: Ilyas Mohamed Date: Tue, 17 Oct 2023 11:57:29 +0100 Subject: [PATCH] Add functionality to be able to push dbt artifacts using remote --- airflow_dbt_python/hooks/dbt.py | 38 +++++++++++++++++++++++++++++ airflow_dbt_python/hooks/remote.py | 36 ++++++++++++++++++++++++++- airflow_dbt_python/operators/dbt.py | 9 +++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/airflow_dbt_python/hooks/dbt.py b/airflow_dbt_python/hooks/dbt.py index dee971ab..be7b6f9d 100644 --- a/airflow_dbt_python/hooks/dbt.py +++ b/airflow_dbt_python/hooks/dbt.py @@ -200,6 +200,30 @@ def upload_dbt_project( project_dir, destination, replace=replace, delete_before=delete_before ) + def upload_dbt_artifacts( + self, + project_dir: URLLike, + artifacts: Iterable[str], + destination: URLLike, + replace: bool = False, + delete_before: bool = False, + ) -> None: + """Push dbt artifacts from a given project_dir. + + This operation is delegated to a DbtRemoteHook. An optional connection id is + supported for remotes that require it. + """ + scheme = urlparse(str(destination)).scheme + remote = self.get_remote(scheme, self.project_conn_id) + + return remote.upload_dbt_artifacts( + project_dir, + artifacts, + destination, + replace=replace, + delete_before=delete_before, + ) + def run_dbt_task( self, command: str, @@ -207,6 +231,8 @@ def run_dbt_task( delete_before_upload: bool = False, replace_on_upload: bool = False, artifacts: Optional[Iterable[str]] = None, + upload_dbt_artifacts: Optional[Iterable[str]] = None, + upload_dbt_artifacts_destination: Optional[URLLike] = None, env_vars: Optional[Dict[str, Any]] = None, **kwargs, ) -> DbtTaskResult: @@ -281,6 +307,18 @@ def run_dbt_task( saved_artifacts[artifact] = json_artifact + if ( + upload_dbt_artifacts is not None + and upload_dbt_artifacts_destination is not None + ): + self.upload_dbt_artifacts( + dbt_dir, + upload_dbt_artifacts, + upload_dbt_artifacts_destination, + replace_on_upload, + delete_before_upload, + ) + return DbtTaskResult(success, results, saved_artifacts) def get_dbt_task_config(self, command: str, **config_kwargs) -> BaseConfig: diff --git a/airflow_dbt_python/hooks/remote.py b/airflow_dbt_python/hooks/remote.py index 88f7de7c..d2d6b733 100644 --- a/airflow_dbt_python/hooks/remote.py +++ b/airflow_dbt_python/hooks/remote.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Type +from typing import Iterable, Optional, Type from airflow.utils.log.logging_mixin import LoggingMixin @@ -137,6 +137,40 @@ def upload_dbt_project( if destination_url.is_archive(): source_url.unlink() + def upload_dbt_artifacts( + self, + source: URLLike, + artifacts: Iterable[str], + destination: URLLike, + replace: bool = False, + delete_before: bool = False, + ): + """Upload specific dbt artifacts from a given source. + + Args: + source: URLLike to a directory containing a dbt project. + artifacts: A list of artifacts to upload. + destination: URLLike to a directory where the dbt artifacts will be stored. + replace: Flag to indicate whether to replace existing files. + delete_before: Flag to indicate wheter to clear any existing files before + uploading the dbt project. + """ + source_url = URL(source) / "target" + destination_url = URL(destination) + + for artifact in artifacts: + artifact_path = source_url / artifact + self.log.info( + "Uploading dbt artifact from %s to %s", artifact_path, destination + ) + + self.upload( + source=artifact_path, + destination=destination_url / artifact, + replace=replace, + delete_before=delete_before, + ) + def get_remote(scheme: str, conn_id: Optional[str] = None) -> DbtRemoteHook: """Get a DbtRemoteHook as long as the scheme is supported. diff --git a/airflow_dbt_python/operators/dbt.py b/airflow_dbt_python/operators/dbt.py index 85d4a3bd..0add91d3 100644 --- a/airflow_dbt_python/operators/dbt.py +++ b/airflow_dbt_python/operators/dbt.py @@ -11,6 +11,8 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.xcom import XCOM_RETURN_KEY +from airflow_dbt_python.utils.url import URLLike + if TYPE_CHECKING: from dbt.contracts.results import RunResult @@ -47,6 +49,9 @@ class DbtBaseOperator(BaseOperator): log_cache_events: Flag to enable logging of cache events. s3_conn_id: An s3 Airflow connection ID to use when pulling dbt files from s3. do_xcom_push_artifacts: A list of dbt artifacts to XCom push. + upload_dbt_artifacts: A list of dbt artifacts to upload. + upload_dbt_artifacts_destination: Destination to upload dbt artifacts to. + This value also determines the remote to use when uploading the artifacts. """ template_fields = base_template_fields @@ -93,6 +98,8 @@ def __init__( profiles_conn_id: Optional[str] = None, project_conn_id: Optional[str] = None, do_xcom_push_artifacts: Optional[list[str]] = None, + upload_artifacts: Optional[list[str]] = None, + upload_artifacts_destination: Optional[URLLike] = None, upload_dbt_project: bool = False, delete_before_upload: bool = False, replace_on_upload: bool = False, @@ -144,6 +151,8 @@ def __init__( self.profiles_conn_id = profiles_conn_id self.project_conn_id = project_conn_id self.do_xcom_push_artifacts = do_xcom_push_artifacts + self.upload_dbt_artifacts = upload_artifacts + self.upload_dbt_artifacts_destination = upload_artifacts_destination self.upload_dbt_project = upload_dbt_project self.delete_before_upload = delete_before_upload self.replace_on_upload = replace_on_upload