Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions snuba/query/processors/physical/hexint_column_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,22 @@ def __init__(self, columns: Set[str], size: int = 16) -> None:
self._size = size

def _translate_literal(self, exp: Literal) -> Literal:
if not isinstance(exp.value, str):
raise ColumnTypeError("Invalid hexint", should_report=False)

if exp.value == "":
raise ColumnTypeError("Invalid hexint", should_report=False, skip_optimization=True)

try:
assert isinstance(exp.value, str)
# 128 bit integers in clickhouse need to be referenced as strings
if self._size == 32:
return Literal(alias=exp.alias, value=str(int(exp.value, 16)))
return Literal(alias=exp.alias, value=int(exp.value, 16))
except (AssertionError, ValueError):
translated = int(exp.value, 16)
except ValueError:
raise ColumnTypeError("Invalid hexint", should_report=False)

# 128 bit integers in clickhouse need to be referenced as strings
if self._size == 32:
return Literal(alias=exp.alias, value=str(translated))
return Literal(alias=exp.alias, value=translated)

def _process_expressions(self, exp: Expression) -> Expression:
if isinstance(exp, Column) and exp.column_name in self.columns:
hex = f.hex(column(exp.column_name))
Expand All @@ -60,12 +67,19 @@ def _process_expressions(self, exp: Expression) -> Expression:

class HexIntArrayColumnProcessor(BaseTypeConverter):
def _translate_literal(self, exp: Literal) -> Literal:
if not isinstance(exp.value, str):
raise ColumnTypeError("Invalid hexint", report=False)

if exp.value == "":
raise ColumnTypeError("Invalid hexint", report=False, skip_optimization=True)

try:
assert isinstance(exp.value, str)
return Literal(alias=exp.alias, value=int(exp.value, 16))
except (AssertionError, ValueError):
translated = int(exp.value, 16)
except ValueError:
raise ColumnTypeError("Invalid hexint", report=False)

return Literal(alias=exp.alias, value=translated)

def _process_expressions(self, exp: Expression) -> Expression:
if isinstance(exp, Column) and exp.column_name in self.columns:
return FunctionCall(
Expand Down
29 changes: 21 additions & 8 deletions snuba/query/processors/physical/type_converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Set
from typing import Optional, Set

from snuba.clickhouse.query import Query
from snuba.query.conditions import ConditionFunctions
Expand Down Expand Up @@ -156,14 +156,25 @@ def assert_literal(lit: Expression) -> Literal:
assert isinstance(lit, Literal)
return lit

def _translate_or_skip(lit: Literal) -> Optional[Literal]:
try:
return self._translate_literal(lit)
except ColumnTypeError as exc:
if getattr(exc, "extra_data", {}).get("skip_optimization"):
return None
raise

match = self.__condition_matcher.match(exp)
if match is not None:
translated_literal = _translate_or_skip(assert_literal(match.expression("literal")))
if translated_literal is None:
return exp
return FunctionCall(
exp.alias,
match.string("operator"),
(
self.__strip_column_alias(match.expression("col")),
self._translate_literal(assert_literal(match.expression("literal"))),
translated_literal,
),
)

Expand All @@ -178,15 +189,17 @@ def assert_literal(lit: Expression) -> Literal:
assert isinstance(param, Literal)

wrapper = tuple if collection_func.function_name == "tuple" else list
new_params = []
for lit in collection_func.parameters:
translated = _translate_or_skip(assert_literal(lit))
if translated is None:
return exp
new_params.append(translated)

new_collection_func = FunctionCall(
collection_func.alias,
collection_func.function_name,
parameters=wrapper(
[
self._translate_literal(assert_literal(lit))
for lit in collection_func.parameters
]
),
parameters=wrapper(new_params),
)
return FunctionCall(
exp.alias,
Expand Down
47 changes: 38 additions & 9 deletions tests/query/processors/test_hexint_column_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
binary_condition(
ConditionFunctions.IN,
Column(None, None, "column1"),
FunctionCall(
None, "tuple", (Literal(None, "a" * 16), Literal(None, "b" * 16))
),
FunctionCall(None, "tuple", (Literal(None, "a" * 16), Literal(None, "b" * 16))),
),
"in(column1, (12297829382473034410, 13527612320720337851))",
id="in_operator",
Expand All @@ -40,9 +38,7 @@
binary_condition(
ConditionFunctions.IN,
Column(None, None, "column1"),
FunctionCall(
None, "array", (Literal(None, "a" * 16), Literal(None, "b" * 16))
),
FunctionCall(None, "array", (Literal(None, "a" * 16), Literal(None, "b" * 16))),
),
"in(column1, [12297829382473034410, 13527612320720337851])",
id="array_in_operator",
Expand Down Expand Up @@ -77,9 +73,7 @@ def test_hexint_column_processor(unprocessed: Expression, formatted_value: str)
)
hex = f.hex(column("column1"))

HexIntColumnProcessor(set(["column1"])).process_query(
unprocessed_query, HTTPQuerySettings()
)
HexIntColumnProcessor(set(["column1"])).process_query(unprocessed_query, HTTPQuerySettings())
assert unprocessed_query.get_selected_columns() == [
SelectedExpression(
"column1",
Expand All @@ -104,3 +98,38 @@ def test_hexint_column_processor(unprocessed: Expression, formatted_value: str)
assert condition is not None
ret = condition.accept(ClickhouseExpressionFormatter())
assert ret == formatted_value


def test_hexint_processor_skips_empty_literal_optimization() -> None:
unprocessed_query = Query(
Table("transactions", ColumnSet([]), storage_key=StorageKey("dontmatter")),
selected_columns=[SelectedExpression("column1", Column(None, None, "column1"))],
condition=binary_condition(
ConditionFunctions.EQ,
FunctionCall(
None,
"cast",
(Column(None, None, "column1"), Literal(None, "String")),
),
Literal(None, ""),
),
)

HexIntColumnProcessor(set(["column1"])).process_query(unprocessed_query, HTTPQuerySettings())

condition = unprocessed_query.get_condition()
assert isinstance(condition, FunctionCall)
assert condition.function_name == ConditionFunctions.EQ

lhs, rhs = condition.parameters
assert isinstance(rhs, Literal)
assert rhs.value == ""

assert isinstance(lhs, FunctionCall)
assert lhs.function_name == "cast"
cast_arg, cast_target = lhs.parameters
assert isinstance(cast_target, Literal)
assert cast_target.value == "String"

assert isinstance(cast_arg, FunctionCall)
assert cast_arg.function_name == "lower"
Loading