Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions sqlglot/generators/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from sqlglot.errors import UnsupportedError
from sqlglot.generator import unsupported_args
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.scope import build_scope
from sqlglot.parsers.exasol import DATE_UNITS

Expand Down Expand Up @@ -232,6 +233,33 @@ def _add_date_sql(self: ExasolGenerator, expression: DATE_ADD_OR_SUB) -> str:
return self.func(f"ADD_{unit}S", expression.this, offset_expr)


def _add_cte_column_aliases(expression: exp.Expr) -> exp.Expr:
"""
Exasol rejects unaliased non-column expressions inside CTE SELECT lists.
Inject synthetic aliases like ``_col_0`` for any projection that isn't
already aliased, a bare column reference, or a star (which Exasol expands
by itself and would be invalid wrapped in an alias).
"""
if not isinstance(expression, exp.Select):
return expression

if not isinstance(expression.parent, exp.CTE):
return expression

new_selects: list[exp.Expr] = []
counter = 0
for sel in expression.expressions:
if isinstance(sel, (exp.Alias, exp.Column, exp.Star)) or sel.find(exp.Star):
new_selects.append(sel)
continue

new_selects.append(exp.alias_(sel, exp.to_identifier(f"_col_{counter}", quoted=True)))
counter += 1

expression.set("expressions", new_selects)
return expression


def _group_by_all(expression: exp.Expr) -> exp.Expr:
if not isinstance(expression, exp.Select):
return expression
Expand Down Expand Up @@ -392,12 +420,14 @@ def datatype_sql(self, expression: exp.DataType) -> str:
exp.CommentColumnConstraint: lambda self, e: f"COMMENT IS {self.sql(e, 'this')}",
exp.Select: transforms.preprocess(
[
_add_cte_column_aliases,
_qualify_unscoped_star,
_add_local_prefix_for_aliases,
_group_by_all,
]
),
exp.SubstringIndex: _substring_index_sql,
exp.JSONObject: lambda self, e: self.jsonobject_sql(e),
exp.WeekOfYear: rename_func("WEEK"),
# https://docs.exasol.com/db/latest/sql_references/functions/alphabeticallistfunctions/to_date.htm
exp.Date: rename_func("TO_DATE"),
Expand Down Expand Up @@ -904,6 +934,61 @@ def jsonextract_sql(self, expression: exp.JSONExtract) -> str:

return sql

def jsonobject_sql(self, expression: exp.JSONObject) -> str:
pairs = expression.expressions or []
if not pairs:
return self.sql(exp.Literal.string("{}"))

concat_args: list[exp.Expression] = [exp.Literal.string("{")]
for i, pair in enumerate(pairs):
if not isinstance(pair, exp.JSONKeyValue):
return self.function_fallback_sql(expression)
key = pair.this
value = pair.args.get("expression")
if key is None or value is None:
return self.function_fallback_sql(expression)

prefix = ", " if i > 0 else ""
key_label = f'{prefix}"{key.name.replace(chr(34), chr(92) + chr(34))}": '
concat_args.append(exp.Literal.string(key_label))

if value.is_string:
wrapped: exp.Expression = exp.Concat(
expressions=[
exp.Literal.string('"'),
value,
exp.Literal.string('"'),
]
)
else:
typed_value = value if value.type else annotate_types(value, dialect=self.dialect)
if typed_value.is_type(*exp.DataType.TEXT_TYPES):
wrapped = exp.Case(
ifs=[
exp.If(
this=typed_value.is_(exp.Null()),
true=exp.Literal.string("null"),
)
],
default=exp.Concat(
expressions=[
exp.Literal.string('"'),
typed_value.copy(),
exp.Literal.string('"'),
]
),
)
else:
wrapped = exp.Coalesce(
this=exp.cast(typed_value, exp.DataType.build("VARCHAR(100)")),
expressions=[exp.Literal.string("null")],
)

concat_args.append(wrapped)

concat_args.append(exp.Literal.string("}"))
return self.sql(exp.Concat(expressions=concat_args))

@unsupported_args("flag")
def regexplike_sql(self, expression: exp.RegexpLike) -> str:
if not expression.args.get("full_match"):
Expand Down
115 changes: 115 additions & 0 deletions tests/dialects/test_exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,3 +950,118 @@ def test_group_by_all(self):
write="exasol",
unsupported_level=ErrorLevel.RAISE,
)

def test_group_by_alias_local(self):
# GROUP BY bare alias -> LOCAL prefix
self.validate_all(
"SELECT city, COUNT(*) AS cnt FROM t GROUP BY LOCAL.cnt",
read={"mysql": "SELECT city, COUNT(*) AS cnt FROM t GROUP BY cnt"},
write={"exasol": "SELECT city, COUNT(*) AS cnt FROM t GROUP BY LOCAL.cnt"},
)
# GROUP BY expression alias -> LOCAL prefix
self.validate_all(
"SELECT YEAR(TO_DATE(a_date)) AS a_year FROM t GROUP BY LOCAL.a_year",
read={"mysql": "SELECT YEAR(a_date) AS a_year FROM t GROUP BY a_year"},
write={"exasol": "SELECT YEAR(TO_DATE(a_date)) AS a_year FROM t GROUP BY LOCAL.a_year"},
)
# GROUP BY non-alias column -> unchanged
self.validate_all(
"SELECT city, COUNT(*) FROM t GROUP BY city",
read={"mysql": "SELECT city, COUNT(*) FROM t GROUP BY city"},
write={"exasol": "SELECT city, COUNT(*) FROM t GROUP BY city"},
)
# HAVING alias -> LOCAL prefix
self.validate_all(
"SELECT COUNT(*) AS cnt FROM t HAVING LOCAL.cnt > 1",
read={"mysql": "SELECT COUNT(*) AS cnt FROM t HAVING cnt > 1"},
write={"exasol": "SELECT COUNT(*) AS cnt FROM t HAVING LOCAL.cnt > 1"},
)

def test_cte_literal_auto_alias(self):
# Integer literal in CTE gets synthetic alias
self.validate_all(
'WITH cte AS (SELECT 12345 AS "_col_0") SELECT * FROM cte',
read={"mysql": "WITH cte AS (SELECT 12345) SELECT * FROM cte"},
write={"exasol": 'WITH cte AS (SELECT 12345 AS "_col_0") SELECT * FROM cte'},
)
# Mixed literals get sequential aliases
self.validate_all(
'WITH cte AS (SELECT 12345 AS "_col_0", \'value\' AS "_col_1") SELECT * FROM cte',
read={"mysql": "WITH cte AS (SELECT 12345, 'value') SELECT * FROM cte"},
write={
"exasol": 'WITH cte AS (SELECT 12345 AS "_col_0", \'value\' AS "_col_1") SELECT * FROM cte'
},
)
# Existing alias preserved
self.validate_identity(
"WITH cte AS (SELECT 12345 AS id, 'value' AS name) SELECT * FROM cte"
)
# Function call / arithmetic gets alias
self.validate_all(
'WITH cte AS (SELECT LOWER(name) AS "_col_0", 1 + 2 AS "_col_1" FROM t) SELECT * FROM cte',
read={"mysql": "WITH cte AS (SELECT LOWER(name), 1 + 2 FROM t) SELECT * FROM cte"},
write={
"exasol": 'WITH cte AS (SELECT LOWER(name) AS "_col_0", 1 + 2 AS "_col_1" FROM t) SELECT * FROM cte'
},
)
# Bare column: no alias injected
self.validate_identity("WITH cte AS (SELECT col FROM t) SELECT * FROM cte")
# Bare star inside CTE: no alias injected (wrapping star would be invalid SQL)
self.validate_identity("WITH cte AS (SELECT * FROM t) SELECT * FROM cte")
# Qualified star inside CTE: no alias injected
self.validate_identity("WITH cte AS (SELECT t.* FROM t) SELECT * FROM cte")
# Nested CTE: outer star still passes through unwrapped
self.validate_identity(
"WITH outer_cte AS (WITH inner_cte AS (SELECT 1 AS x) SELECT * FROM inner_cte) SELECT * FROM outer_cte"
)
# Non-CTE subquery: no alias injected
self.validate_identity("SELECT * FROM (SELECT 1) AS t")

def test_json_object(self):
from sqlglot import parse_one
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.qualify import qualify

# Empty JSON_OBJECT -> '{}' literal
self.validate_all(
"SELECT '{}'",
read={"mysql": "SELECT JSON_OBJECT()"},
write={"exasol": "SELECT '{}'"},
)

# String-typed column: NULL-safe quoted value via CASE
ast = parse_one("SELECT JSON_OBJECT('name', str_name) AS j FROM t", read="mysql")
schema = {"t": {"str_name": "VARCHAR"}}
annotated = annotate_types(
qualify(ast, schema=schema, dialect="mysql"),
schema=schema,
dialect="mysql",
)
result = annotated.sql("exasol")
self.assertIn("CASE WHEN", result)
self.assertIn("IS NULL THEN 'null'", result)
self.assertIn("'\"name\": '", result)

# Numeric-typed column: COALESCE(CAST(..) AS VARCHAR(100)), 'null')
ast = parse_one("SELECT JSON_OBJECT('id', int_id) AS j FROM t", read="mysql")
schema = {"t": {"int_id": "INT"}}
annotated = annotate_types(
qualify(ast, schema=schema, dialect="mysql"),
schema=schema,
dialect="mysql",
)
result = annotated.sql("exasol")
self.assertIn("COALESCE(CAST(", result)
self.assertIn("AS VARCHAR(100)), 'null')", result)

# Multi-pair with comma separators
ast = parse_one("SELECT JSON_OBJECT('id', int_id, 'name', str_name) FROM t", read="mysql")
schema = {"t": {"int_id": "INT", "str_name": "VARCHAR"}}
annotated = annotate_types(
qualify(ast, schema=schema, dialect="mysql"),
schema=schema,
dialect="mysql",
)
result = annotated.sql("exasol")
self.assertIn("'\"id\": '", result)
self.assertIn("', \"name\": '", result)
Loading