Skip to content

Commit 45c7a27

Browse files
authored
Merge pull request #92 from oracle/dev/async_python_models
Optimized Python Models
2 parents 2c38043 + 6c6cd59 commit 45c7a27

File tree

14 files changed

+265
-57
lines changed

14 files changed

+265
-57
lines changed

.github/workflows/oracle-xe-adapter-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
- name: Install dbt-oracle with core dependencies
4949
run: |
5050
python -m pip install --upgrade pip
51-
pip install pytest dbt-tests-adapter==1.5.1
51+
pip install pytest dbt-tests-adapter==1.5.2
5252
pip install -r requirements.txt
5353
pip install -e .
5454

dbt/adapters/oracle/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17-
version = "1.5.1"
17+
version = "1.5.2"

dbt/adapters/oracle/connections.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class OracleAdapterCredentials(Credentials):
109109
retry_count: Optional[int] = 1
110110
retry_delay: Optional[int] = 3
111111

112-
# Fetch an auth token to run Python UDF
113-
oml_auth_token_uri: Optional[str] = None
112+
# Base URL for ADB-S OML REST API
113+
oml_cloud_service_url: Optional[str] = None
114114

115115

116116
_ALIASES = {
@@ -136,7 +136,7 @@ def _connection_keys(self) -> Tuple[str]:
136136
'service', 'connection_string',
137137
'shardingkey', 'supershardingkey',
138138
'cclass', 'purity', 'retry_count',
139-
'retry_delay', 'oml_auth_token_uri'
139+
'retry_delay', 'oml_cloud_service_url'
140140
)
141141

142142
@classmethod

dbt/adapters/oracle/impl.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from dbt.utils import filter_null_values
4343

4444
from dbt.adapters.oracle.keyword_catalog import KEYWORDS
45+
from dbt.adapters.oracle.python_submissions import OracleADBSPythonJob
46+
from dbt.adapters.oracle.connections import AdapterResponse
4547

4648
logger = AdapterLogger("oracle")
4749

@@ -367,49 +369,27 @@ def get_oml_auth_token(self) -> str:
367369

368370
def submit_python_job(self, parsed_model: dict, compiled_code: str):
369371
"""Submit user defined Python function
370-
371-
The function pyqEval when used in Oracle Autonomous Database,
372-
calls a user-defined Python function.
373-
374-
pyqEval(PAR_LST, OUT_FMT, SRC_NAME, SRC_OWNER, ENV_NAME)
375-
376-
- PAR_LST -> Parameter List
377-
- OUT_FMT -> JSON clob of the columns
378-
- ENV_NAME -> Name of conda environment
372+
https://docs.oracle.com/en/database/oracle/machine-learning/oml4py/1/mlepe/op-py-scripts-v1-do-eval-scriptname-post.html
379373
380374
381375
"""
382376
identifier = parsed_model["alias"]
383-
oml_oauth_access_token = self.get_oml_auth_token()
384377
py_q_script_name = f"{identifier}_dbt_py_script"
385-
py_q_eval_output_fmt = '{"result":"number"}'
386-
py_q_eval_result_table = f"o$pt_dbt_pyqeval_{identifier}_tmp_{datetime.datetime.utcnow().strftime('%H%M%S')}"
387-
388-
conda_env_name = parsed_model["config"].get("conda_env_name")
389-
if conda_env_name:
390-
logger.info("Custom python environment is %s", conda_env_name)
391-
py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table}
392-
AS SELECT * FROM TABLE(pyqEval(par_lst => NULL,
393-
out_fmt => ''{py_q_eval_output_fmt}'',
394-
scr_name => ''{py_q_script_name}'',
395-
scr_owner => NULL,
396-
env_name => ''{conda_env_name}''))"""
397-
else:
398-
py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table}
399-
AS SELECT * FROM TABLE(pyqEval(par_lst => NULL,
400-
out_fmt => ''{py_q_eval_output_fmt}'',
401-
scr_name => ''{py_q_script_name}'',
402-
scr_owner => NULL))"""
403-
404-
py_exec_main_sql = f"""
405-
BEGIN
406-
sys.pyqSetAuthToken('{oml_oauth_access_token}');
407-
sys.pyqScriptCreate('{py_q_script_name}', '{compiled_code.strip()}', FALSE, TRUE);
408-
EXECUTE IMMEDIATE '{py_q_eval_sql}';
409-
EXECUTE IMMEDIATE 'DROP TABLE {py_q_eval_result_table}';
410-
sys.pyqScriptDrop('{py_q_script_name}');
411-
END;
378+
py_q_create_script = f"""
379+
BEGIN
380+
sys.pyqScriptCreate('{py_q_script_name}', '{compiled_code.strip()}', FALSE, TRUE);
381+
END;
412382
"""
413-
response, _ = self.execute(sql=py_exec_main_sql)
383+
response, _ = self.execute(sql=py_q_create_script)
384+
python_job = OracleADBSPythonJob(parsed_model=parsed_model,
385+
credential=self.config.credentials)
386+
python_job()
387+
py_q_drop_script = f"""
388+
BEGIN
389+
sys.pyqScriptDrop('{py_q_script_name}');
390+
END;
391+
"""
392+
393+
response, _ = self.execute(sql=py_q_drop_script)
414394
logger.info(response)
415395
return response
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
"""
2+
Copyright (c) 2023, Oracle and/or its affiliates.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import datetime
17+
import http
18+
import json
19+
from typing import Dict
20+
21+
import requests
22+
import time
23+
24+
import dbt.exceptions
25+
from dbt.adapters.oracle import OracleAdapterCredentials
26+
from dbt.events import AdapterLogger
27+
from dbt.ui import red, green
28+
29+
# ADB-S OML Rest API minimum timeout is 1800 seconds
30+
DEFAULT_TIMEOUT_IN_SECONDS = 1800
31+
DEFAULT_DELAY_BETWEEN_POLL_IN_SECONDS = 2
32+
33+
OMLUSERS_OAUTH_API = "/omlusers/api/oauth2/v1/token"
34+
OML_DO_EVAL_API = "/oml/api/py-scripts/v1/do-eval/{script_name}"
35+
36+
logger = AdapterLogger("oracle")
37+
38+
39+
class OracleOML4PYClient:
40+
41+
def __init__(self, oml_cloud_service_url, username, password):
42+
self.base_url = oml_cloud_service_url
43+
self._username = username
44+
self._password = password
45+
self.token = None
46+
self.token_expires_at = None
47+
self.token_url = self.base_url + OMLUSERS_OAUTH_API
48+
self._session = requests.Session()
49+
50+
@property
51+
def session(self):
52+
return self._session
53+
54+
def get_token(self):
55+
"""Get access_token or refresh_token"""
56+
# If access token is about to expire then refresh the token
57+
if self.token_expires_at and self.token_expires_at - datetime.datetime.utcnow() < datetime.timedelta(minutes=1):
58+
return self._get_token(grant_type="refresh_token")
59+
elif self.token: # Token is valid
60+
return self.token
61+
else: # Generate a new token
62+
return self._get_token(grant_type="password")
63+
64+
def _get_token(self, grant_type="password"):
65+
"""Gets access_token or refresh_token using /broker/pdbcs/private/v1/token"""
66+
data = {"grant_type": grant_type}
67+
if grant_type == "password":
68+
data["username"] = self._username
69+
data["password"] = self._password
70+
else:
71+
data["token"] = self.token
72+
73+
r = self.session.post(
74+
url=self.token_url,
75+
json=data,
76+
headers={
77+
"Accept": "application/json",
78+
"Content-type": "application/json",
79+
},
80+
)
81+
r.raise_for_status()
82+
response = r.json()
83+
self.token = response["accessToken"]
84+
self.token_expires_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=response["expiresIn"])
85+
return self.token
86+
87+
@property
88+
def default_headers(self):
89+
"""Default headers added to every request"""
90+
return {
91+
"Content-type": "application/json",
92+
"Accept": "application/json",
93+
"Authorization": f"Bearer {self.get_token()}",
94+
}
95+
96+
def request(self, method: str, path: str,
97+
raise_for_status: bool = False,
98+
**kwargs) -> requests.Response:
99+
"""
100+
Description:
101+
Perform a desired action (GET, PUT, POST) on a certain resource
102+
103+
Args:
104+
method (str) -> HTTP verb like GET, PUT, POST, etc
105+
path (str) -> path to the resource e.g. /job/{job_id}
106+
raise_for_status (bool) -> True if HTTPError should be raised in case of an error else False
107+
108+
Returns:
109+
object of type request.Response
110+
111+
Raises:
112+
requests.HTTPError() in case of en error, if raise_for_status is True
113+
114+
"""
115+
url = path if path.startswith(self.base_url) else self.base_url + path
116+
self.session.headers.update(self.default_headers)
117+
r = self.session.request(method=method, url=url, **kwargs)
118+
try:
119+
r.raise_for_status()
120+
except requests.HTTPError:
121+
if raise_for_status:
122+
raise
123+
return r
124+
125+
126+
class OracleADBSPythonJob:
127+
"""Callable to submit Python Script to ADB-S
128+
129+
"""
130+
131+
def __init__(self,
132+
parsed_model: Dict,
133+
credential: OracleAdapterCredentials) -> None:
134+
self.identifier = parsed_model["alias"]
135+
self.py_q_script_name = f"{self.identifier}_dbt_py_script"
136+
self.conda_env_name = parsed_model["config"].get("conda_env_name")
137+
self.timeout = parsed_model["config"].get("timeout", DEFAULT_TIMEOUT_IN_SECONDS)
138+
self.async_flag = parsed_model["config"].get("async_flag", False)
139+
self.service = parsed_model["config"].get("service", "HIGH")
140+
self.oml4py_client = OracleOML4PYClient(oml_cloud_service_url=credential.oml_cloud_service_url,
141+
username=credential.user,
142+
password=credential.password)
143+
144+
def schedule_async_job_and_wait_for_completion(self, data):
145+
logger.info(f"Running Python aysnc job using {data}")
146+
try:
147+
r = self.oml4py_client.request(method="POST",
148+
path=OML_DO_EVAL_API.format(script_name=self.py_q_script_name),
149+
data=json.dumps(data),
150+
raise_for_status=False)
151+
if r.status_code in (http.HTTPStatus.BAD_REQUEST, http.HTTPStatus.INTERNAL_SERVER_ERROR):
152+
logger.error(red(r.json()))
153+
r.raise_for_status()
154+
except requests.exceptions.RequestException as e:
155+
logger.error(red(f"Error {e} scheduling async Python job for model {self.identifier}"))
156+
raise dbt.exceptions.DbtRuntimeError(f"Error scheduling Python model {self.identifier}")
157+
158+
job_location = r.headers["location"]
159+
logger.info(f"Started async job {job_location}")
160+
start_time = time.time()
161+
162+
while time.time() - start_time < self.timeout:
163+
logger.debug(f"Checking Job status for : {job_location}")
164+
try:
165+
job_status = self.oml4py_client.request(method="GET",
166+
path=job_location,
167+
raise_for_status=False)
168+
job_status_code = job_status.status_code
169+
logger.debug(f"Job status code is: {job_status_code}")
170+
if job_status_code == http.HTTPStatus.FOUND:
171+
logger.info(green(f"Job {job_location} completed"))
172+
job_result = self.oml4py_client.request(method="GET",
173+
path=f"{job_location}/result",
174+
raise_for_status=False)
175+
job_result_json = job_result.json()
176+
if 'errorMessage' in job_result_json:
177+
logger.error(red(f"FAILURE - Python model {self.identifier} Job failure is: {job_result_json}"))
178+
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
179+
job_result.raise_for_status()
180+
logger.info(green(f"SUCCESS - Python model {self.identifier} Job result is: {job_result_json}"))
181+
return
182+
elif job_status_code == http.HTTPStatus.INTERNAL_SERVER_ERROR:
183+
logger.error(red(f"FAILURE - Job status is: {job_status.json()}"))
184+
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
185+
else:
186+
logger.debug(f"Python model {self.identifier} job status is: {job_status.json()}")
187+
job_status.raise_for_status()
188+
189+
except requests.exceptions.RequestException as e:
190+
logger.error(red(f"Error {e} checking status of Python job {job_location} for model {self.identifier}"))
191+
raise dbt.exceptions.DbtRuntimeError(f"Error checking status for job {job_location}")
192+
193+
time.sleep(DEFAULT_DELAY_BETWEEN_POLL_IN_SECONDS)
194+
logger.error(red(f"Timeout error for Python model {self.identifier}"))
195+
raise dbt.exceptions.DbtRuntimeError(f"Timeout error for Python model {self.identifier}")
196+
197+
def __call__(self, *args, **kwargs):
198+
data = {
199+
"service": self.service
200+
}
201+
if self.async_flag:
202+
data["asyncFlag"] = self.async_flag
203+
data["timeout"] = self.timeout
204+
if self.conda_env_name:
205+
data["envName"] = self.conda_env_name
206+
207+
if self.async_flag:
208+
self.schedule_async_job_and_wait_for_completion(data=data)
209+
else: # Run in blocking mode
210+
logger.info(f"Running Python model {self.identifier} with args {data}")
211+
try:
212+
r = self.oml4py_client.request(method="POST",
213+
path=OML_DO_EVAL_API.format(script_name=self.py_q_script_name),
214+
data=json.dumps(data),
215+
raise_for_status=False)
216+
job_result = r.json()
217+
if 'errorMessage' in job_result:
218+
logger.error(red(f"FAILURE - Python model {self.identifier} Job failure is: {job_result}"))
219+
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
220+
r.raise_for_status()
221+
logger.info(green(f"SUCCESS - Python model {self.identifier} Job result is: {job_result}"))
222+
except requests.exceptions.RequestException as e:
223+
logger.error(red(f"Error {e} running Python model {self.identifier}"))
224+
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
225+

dbt/include/oracle/macros/adapters.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
{{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }}
2626
{% endmacro %}
2727

28-
{% macro oracle__get_empty_subquery_sql(select_sql) %}
28+
{% macro oracle__get_empty_subquery_sql(select_sql, select_sql_header=none) %}
2929
select * from (
3030
{{ select_sql }}
3131
) dbt_sbq_tmp
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
def model(dbt, session):
22
# Must be either table or incremental (view is not currently supported)
33
dbt.config(materialized="table")
4+
dbt.config(async_flag=True)
5+
dbt.config(timeout=900) # In seconds
6+
dbt.config(service="HIGH") # LOW, MEDIUM, HIGH
47
# oml.core.DataFrame representing a datasource
58
s_df = dbt.ref("sales_cost")
69
return s_df

dbt_adbs_test_project/profiles.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dbt_test:
1111
service: "{{ env_var('DBT_ORACLE_SERVICE') }}"
1212
#database: "{{ env_var('DBT_ORACLE_DATABASE') }}"
1313
schema: "{{ env_var('DBT_ORACLE_SCHEMA') }}"
14-
oml_auth_token_uri: "{{ env_var('DBT_ORACLE_OML_AUTH_TOKEN_API')}}"
14+
oml_cloud_service_url: "{{ env_var('DBT_ORACLE_OML_CLOUD_SERVICE_URL')}}"
1515
retry_count: 1
1616
retry_delay: 5
1717
shardingkey:

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
dbt-core==1.5.1
1+
dbt-core==1.5.2
22
cx_Oracle==8.3.0
3-
oracledb==1.3.1
3+
oracledb==1.3.2
44

requirements_dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ tox
66
coverage
77
twine
88
pytest
9-
dbt-tests-adapter==1.5.1
9+
dbt-tests-adapter==1.5.2

0 commit comments

Comments
 (0)