|
2 | 2 |
|
3 | 3 | from pathlib import Path |
4 | 4 | from typing import Optional |
| 5 | +from zipfile import ZipFile |
5 | 6 |
|
6 | 7 | from airflow.hooks.S3_hook import S3Hook |
7 | 8 |
|
@@ -43,47 +44,77 @@ def get_dbt_profiles( |
43 | 44 | else: |
44 | 45 | local_profiles_file = Path(profiles_dir) / "profiles.yml" |
45 | 46 |
|
46 | | - self.log.info("Saving profiles file to: %s", local_profiles_file) |
47 | | - with open(local_profiles_file, "wb+") as f: |
48 | | - s3_object.download_fileobj(f) |
| 47 | + self.download_one_s3_object(local_profiles_file, s3_object) |
49 | 48 | return local_profiles_file |
50 | 49 |
|
| 50 | + def download_one_s3_object(self, target: Path, s3_object): |
| 51 | + """Download a single s3 object.""" |
| 52 | + self.log.info("Saving profiles file to: %s", target) |
| 53 | + |
| 54 | + with open(target, "wb+") as f: |
| 55 | + s3_object.download_fileobj(f) |
| 56 | + |
51 | 57 | def get_dbt_project( |
52 | 58 | self, s3_project_url: str, project_dir: Optional[str] = None |
53 | 59 | ) -> Path: |
54 | 60 | """Fetch all dbt project files from S3. |
55 | 61 |
|
56 | 62 | Fetches the dbt project files from the directory given by s3_project_url |
57 | | - and pulls them to project_dir. |
| 63 | + and pulls them to project_dir. However, if the URL points to a zip file, |
| 64 | + we assume it contains all the project files, and only download and unzip that |
| 65 | + instead. |
58 | 66 |
|
59 | 67 | Arguments: |
60 | | - s3_project_url: An S3 URL to a directory containing the dbt project files. |
61 | | - project_dir: An optional directory to download the S3 project files into. |
62 | | - If not provided, one will be created using the S3 URL. |
| 68 | + s3_project_url: An S3 URL to a directory containing the dbt project files |
| 69 | + or a zip file containing all project files. |
| 70 | + project_dir: An optional directory to download/unzip the S3 project files |
| 71 | + into. If not provided, one will be created using the S3 URL. |
63 | 72 |
|
64 | 73 | Returns: |
65 | 74 | A Path to the local directory containing the dbt project files. |
66 | 75 | """ |
67 | | - self.log.info("Downloading dbt project file from: %s", s3_project_url) |
| 76 | + self.log.info("Downloading dbt project files from: %s", s3_project_url) |
68 | 77 | bucket_name, key_prefix = self.parse_s3_url(s3_project_url) |
69 | | - if not key_prefix.endswith("/"): |
70 | | - key_prefix += "/" |
71 | | - s3_object_keys = self.list_keys(bucket_name=bucket_name, prefix=f"{key_prefix}") |
72 | 78 |
|
73 | 79 | if project_dir is None: |
74 | 80 | local_project_dir = Path(bucket_name) / key_prefix |
75 | 81 | else: |
76 | 82 | local_project_dir = Path(project_dir) |
77 | 83 |
|
78 | | - for s3_object_key in s3_object_keys: |
| 84 | + if key_prefix.endswith(".zip"): |
| 85 | + s3_object = self.get_key(key=key_prefix, bucket_name=bucket_name) |
| 86 | + target = local_project_dir / "dbt_project.zip" |
| 87 | + self.download_one_s3_object(target, s3_object) |
| 88 | + |
| 89 | + with ZipFile(target, "r") as zf: |
| 90 | + zf.extractall(local_project_dir) |
| 91 | + |
| 92 | + target.unlink() |
| 93 | + |
| 94 | + else: |
| 95 | + if not key_prefix.endswith("/"): |
| 96 | + key_prefix += "/" |
| 97 | + s3_object_keys = self.list_keys( |
| 98 | + bucket_name=bucket_name, prefix=f"{key_prefix}" |
| 99 | + ) |
| 100 | + |
| 101 | + self.download_many_s3_keys( |
| 102 | + bucket_name, s3_object_keys, local_project_dir, key_prefix |
| 103 | + ) |
| 104 | + |
| 105 | + return local_project_dir |
| 106 | + |
| 107 | + def download_many_s3_keys( |
| 108 | + self, bucket_name: str, s3_keys: list[str], target_dir: Path, prefix: str |
| 109 | + ): |
| 110 | + """Download multiple s3 keys.""" |
| 111 | + print(s3_keys) |
| 112 | + for s3_object_key in s3_keys: |
79 | 113 | s3_object = self.get_key(key=s3_object_key, bucket_name=bucket_name) |
80 | | - path_file = Path(s3_object_key).relative_to(f"{key_prefix}") |
81 | | - local_project_file = local_project_dir / path_file |
| 114 | + path_file = Path(s3_object_key).relative_to(prefix) |
| 115 | + local_project_file = target_dir / path_file |
82 | 116 | local_project_file.parent.mkdir(parents=True, exist_ok=True) |
83 | 117 |
|
84 | 118 | self.log.info("Saving %s to: %s", s3_object_key, local_project_file) |
85 | 119 |
|
86 | | - with open(local_project_file, "wb+") as f: |
87 | | - s3_object.download_fileobj(f) |
88 | | - |
89 | | - return local_project_dir |
| 120 | + self.download_one_s3_object(local_project_file, s3_object) |
0 commit comments