Skip to content

Commit 2c38043

Browse files
authored
Merge pull request #89 from oracle/dev/python_models
Python Models
2 parents 6c18965 + 5fc6504 commit 2c38043

File tree

20 files changed

+319
-50
lines changed

20 files changed

+319
-50
lines changed

.github/workflows/oracle-xe-adapter-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
- name: Install dbt-oracle with core dependencies
4949
run: |
5050
python -m pip install --upgrade pip
51-
pip install pytest dbt-tests-adapter==1.5.0
51+
pip install pytest dbt-tests-adapter==1.5.1
5252
pip install -r requirements.txt
5353
pip install -e .
5454

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,4 @@ doc/build.gitbak
146146
.venv1.3/
147147
.venv1.4/
148148
.venv1.5/
149+
dbt_adbs_py_test_project

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Configuration variables
2-
VERSION=1.5.0
2+
VERSION=1.5.1
33
PROJ_DIR?=$(shell pwd)
44
VENV_DIR?=${PROJ_DIR}/.bldenv
55
BUILD_DIR=${PROJ_DIR}/build

dbt/adapters/oracle/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17-
version = "1.5.0"
17+
version = "1.5.1"

dbt/adapters/oracle/connections.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from dbt.adapters.base import Credentials
2626
from dbt.adapters.sql import SQLConnectionManager
2727
from dbt.contracts.connection import AdapterResponse
28+
from dbt.events.functions import fire_event
29+
from dbt.events.types import ConnectionUsed, SQLQuery, SQLCommit, SQLQueryStatus
2830
from dbt.events import AdapterLogger
31+
from dbt.events.contextvars import get_node_info
32+
from dbt.utils import cast_to_str
2933

3034
from dbt.version import __version__ as dbt_version
3135
from dbt.adapters.oracle.connection_helper import oracledb, SQLNET_ORA_CONFIG
@@ -105,6 +109,9 @@ class OracleAdapterCredentials(Credentials):
105109
retry_count: Optional[int] = 1
106110
retry_delay: Optional[int] = 3
107111

112+
# Fetch an auth token to run Python UDF
113+
oml_auth_token_uri: Optional[str] = None
114+
108115

109116
_ALIASES = {
110117
'dbname': 'database',
@@ -129,7 +136,7 @@ def _connection_keys(self) -> Tuple[str]:
129136
'service', 'connection_string',
130137
'shardingkey', 'supershardingkey',
131138
'cclass', 'purity', 'retry_count',
132-
'retry_delay'
139+
'retry_delay', 'oml_auth_token_uri'
133140
)
134141

135142
@classmethod
@@ -293,20 +300,36 @@ def add_query(
293300
if auto_begin and connection.transaction_open is False:
294301
self.begin()
295302

296-
logger.debug('Using {} connection "{}".'
297-
.format(self.TYPE, connection.name))
303+
fire_event(
304+
ConnectionUsed(
305+
conn_type=self.TYPE,
306+
conn_name=cast_to_str(connection.name),
307+
node_info=get_node_info(),
308+
)
309+
)
298310

299311
with self.exception_handler(sql):
300312
if abridge_sql_log:
301313
log_sql = '{}...'.format(sql[:512])
302314
else:
303315
log_sql = sql
304316

305-
logger.debug(f'On {connection.name}: f{log_sql}')
317+
fire_event(
318+
SQLQuery(
319+
conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info()
320+
)
321+
)
322+
306323
pre = time.time()
307324
cursor = connection.handle.cursor()
308325
cursor.execute(sql, bindings)
309-
logger.debug(f"SQL status: {self.get_status(cursor)} in {(time.time() - pre)} seconds")
326+
fire_event(
327+
SQLQueryStatus(
328+
status=str(self.get_response(cursor)),
329+
elapsed=round((time.time() - pre)),
330+
node_info=get_node_info(),
331+
)
332+
)
310333
return connection, cursor
311334

312335
def add_begin_query(self):
@@ -317,3 +340,10 @@ def add_begin_query(self):
317340
@classmethod
318341
def data_type_code_to_name(cls, type_code) -> str:
319342
return DATATYPES[type_code.name]
343+
344+
def commit(self):
345+
connection = self.get_thread_connection()
346+
fire_event(SQLCommit(conn_name=connection.name, node_info=get_node_info()))
347+
self.add_commit_query()
348+
connection.transaction_open = False
349+
return connection

dbt/adapters/oracle/impl.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17+
import datetime
1718
from typing import (
1819
Optional, List, Set
1920
)
@@ -24,6 +25,7 @@
2425
Dict)
2526

2627
import agate
28+
import requests
2729

2830
import dbt.exceptions
2931
from dbt.adapters.base.relation import BaseRelation, InformationSchema
@@ -345,3 +347,69 @@ def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]])
345347
rendered_column_constraints.append(" ".join(rendered_column_constraint))
346348

347349
return rendered_column_constraints
350+
351+
def get_oml_auth_token(self) -> str:
352+
if self.config.credentials.oml_auth_token_uri is None:
353+
raise dbt.exceptions.DbtRuntimeError("oml_auth_token_uri should be set to run dbt-py models")
354+
data = {
355+
"grant_type": "password",
356+
"username": self.config.credentials.user,
357+
"password": self.config.credentials.password
358+
}
359+
try:
360+
r = requests.post(url=self.config.credentials.oml_auth_token_uri,
361+
json=data)
362+
r.raise_for_status()
363+
except requests.exceptions.RequestException:
364+
raise dbt.exceptions.DbtRuntimeError("Error getting OML OAuth2.0 token")
365+
else:
366+
return r.json()["accessToken"]
367+
368+
def submit_python_job(self, parsed_model: dict, compiled_code: str):
369+
"""Submit user defined Python function
370+
371+
The function pyqEval when used in Oracle Autonomous Database,
372+
calls a user-defined Python function.
373+
374+
pyqEval(PAR_LST, OUT_FMT, SRC_NAME, SRC_OWNER, ENV_NAME)
375+
376+
- PAR_LST -> Parameter List
377+
- OUT_FMT -> JSON clob of the columns
378+
- ENV_NAME -> Name of conda environment
379+
380+
381+
"""
382+
identifier = parsed_model["alias"]
383+
oml_oauth_access_token = self.get_oml_auth_token()
384+
py_q_script_name = f"{identifier}_dbt_py_script"
385+
py_q_eval_output_fmt = '{"result":"number"}'
386+
py_q_eval_result_table = f"o$pt_dbt_pyqeval_{identifier}_tmp_{datetime.datetime.utcnow().strftime('%H%M%S')}"
387+
388+
conda_env_name = parsed_model["config"].get("conda_env_name")
389+
if conda_env_name:
390+
logger.info("Custom python environment is %s", conda_env_name)
391+
py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table}
392+
AS SELECT * FROM TABLE(pyqEval(par_lst => NULL,
393+
out_fmt => ''{py_q_eval_output_fmt}'',
394+
scr_name => ''{py_q_script_name}'',
395+
scr_owner => NULL,
396+
env_name => ''{conda_env_name}''))"""
397+
else:
398+
py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table}
399+
AS SELECT * FROM TABLE(pyqEval(par_lst => NULL,
400+
out_fmt => ''{py_q_eval_output_fmt}'',
401+
scr_name => ''{py_q_script_name}'',
402+
scr_owner => NULL))"""
403+
404+
py_exec_main_sql = f"""
405+
BEGIN
406+
sys.pyqSetAuthToken('{oml_oauth_access_token}');
407+
sys.pyqScriptCreate('{py_q_script_name}', '{compiled_code.strip()}', FALSE, TRUE);
408+
EXECUTE IMMEDIATE '{py_q_eval_sql}';
409+
EXECUTE IMMEDIATE 'DROP TABLE {py_q_eval_result_table}';
410+
sys.pyqScriptDrop('{py_q_script_name}');
411+
END;
412+
"""
413+
response, _ = self.execute(sql=py_exec_main_sql)
414+
logger.info(response)
415+
return response

dbt/include/oracle/macros/adapters.sql

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -136,29 +136,33 @@
136136
{%- endmacro %}
137137

138138

139-
{% macro oracle__create_table_as(temporary, relation, sql) -%}
140-
{%- set sql_header = config.get('sql_header', none) -%}
141-
{%- set parallel = config.get('parallel', none) -%}
142-
{%- set compression_clause = config.get('table_compression_clause', none) -%}
143-
{%- set contract_config = config.get('contract') -%}
144-
145-
{{ sql_header if sql_header is not none }}
146-
147-
create {% if temporary -%}
148-
global temporary
149-
{%- endif %} table {{ relation.include(schema=(not temporary)) }}
150-
{%- if contract_config.enforced -%}
151-
{{ get_assert_columns_equivalent(sql) }}
152-
{{ get_table_columns_and_constraints() }}
153-
{%- set sql = get_select_subquery(sql) %}
154-
{% endif %}
155-
{% if temporary -%} on commit preserve rows {%- endif %}
156-
{% if not temporary -%}
157-
{% if parallel %} parallel {{ parallel }}{% endif %}
158-
{% if compression_clause %} {{ compression_clause }} {% endif %}
159-
{%- endif %}
160-
as
161-
{{ sql }}
139+
{% macro oracle__create_table_as(temporary, relation, sql, language='sql') -%}
140+
{%- if language == 'sql' -%}
141+
{%- set sql_header = config.get('sql_header', none) -%}
142+
{%- set parallel = config.get('parallel', none) -%}
143+
{%- set compression_clause = config.get('table_compression_clause', none) -%}
144+
{%- set contract_config = config.get('contract') -%}
145+
{{ sql_header if sql_header is not none }}
146+
create {% if temporary -%}
147+
global temporary
148+
{%- endif %} table {{ relation.include(schema=(not temporary)) }}
149+
{%- if contract_config.enforced -%}
150+
{{ get_assert_columns_equivalent(sql) }}
151+
{{ get_table_columns_and_constraints() }}
152+
{%- set sql = get_select_subquery(sql) %}
153+
{% endif %}
154+
{% if temporary -%} on commit preserve rows {%- endif %}
155+
{% if not temporary -%}
156+
{% if parallel %} parallel {{ parallel }}{% endif %}
157+
{% if compression_clause %} {{ compression_clause }} {% endif %}
158+
{%- endif %}
159+
as
160+
{{ sql }}
161+
{%- elif language == 'python' -%}
162+
{{ py_write_table(compiled_code=compiled_code, target_relation=relation, temporary=temporary) }}
163+
{%- else -%}
164+
{% do exceptions.raise_compiler_error("oracle__create_table_as macro didn't get supported language, it got %s" % language) %}
165+
{%- endif -%}
162166

163167
{%- endmacro %}
164168
{% macro oracle__create_view_as(relation, sql) -%}

dbt/include/oracle/macros/materializations/incremental/incremental.sql

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
#}
17-
{% materialization incremental, adapter='oracle' %}
17+
{% materialization incremental, adapter='oracle', supported_languages=['sql', 'python'] %}
1818

1919
{% set unique_key = config.get('unique_key') %}
2020
{% set full_refresh_mode = flags.FULL_REFRESH %}
21-
21+
{%- set language = model['language'] -%}
2222
{% set target_relation = this.incorporate(type='table') %}
2323
{% set existing_relation = load_relation(this) %}
2424
{% set tmp_relation = make_temp_relation(this) %}
@@ -32,7 +32,7 @@
3232

3333
{% set to_drop = [] %}
3434
{% if existing_relation is none %}
35-
{% set build_sql = create_table_as(False, target_relation, sql) %}
35+
{% set build_sql = create_table_as(False, target_relation, sql, language) %}
3636
{% elif existing_relation.is_view or full_refresh_mode %}
3737
{#-- Make sure the backup doesn't exist so we don't encounter issues with the rename below #}
3838
{% set backup_identifier = existing_relation.identifier ~ "__dbt_backup" %}
@@ -43,12 +43,16 @@
4343
{% else %}
4444
{% do adapter.rename_relation(existing_relation, backup_relation) %}
4545
{% endif %}
46-
{% set build_sql = create_table_as(False, target_relation, sql) %}
46+
{% set build_sql = create_table_as(False, target_relation, sql, language) %}
4747
{% do to_drop.append(backup_relation) %}
4848
{% else %}
4949
{% set tmp_relation = make_temp_relation(target_relation) %}
5050
{% do to_drop.append(tmp_relation) %}
51-
{% do run_query(create_table_as(True, tmp_relation, sql)) %}
51+
{% call statement("make_tmp_relation", language=language) %}
52+
{{create_table_as(True, tmp_relation, sql, language)}}
53+
{% endcall %}
54+
{#-- After this language should be SQL --#}
55+
{% set language = 'sql' %}
5256
{% do adapter.expand_target_column_types(
5357
from_relation=tmp_relation,
5458
to_relation=target_relation) %}
@@ -66,7 +70,7 @@
6670

6771
{% endif %}
6872

69-
{% call statement("main") %}
73+
{% call statement("main", language=language) %}
7074
{{ build_sql }}
7175
{% endcall %}
7276

0 commit comments

Comments
 (0)