Skip to content

Commit 070c990

Browse files
committed
Use dialect-specific constant for position overflow semantics
1 parent 8ea3299 commit 070c990

File tree

7 files changed

+52
-8
lines changed

7 files changed

+52
-8
lines changed

sqlglot/dialects/bigquery.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ def _build_datetime(args: t.List) -> exp.Func:
244244

245245
def _build_regexp_extract(
246246
expr_type: t.Type[E], default_group: t.Optional[exp.Expression] = None
247-
) -> t.Callable[[t.List], E]:
248-
def _builder(args: t.List) -> E:
247+
) -> t.Callable[[t.List, BigQuery], E]:
248+
def _builder(args: t.List, dialect: BigQuery) -> E:
249249
try:
250250
group = re.compile(args[1].name).groups == 1
251251
except re.error:
@@ -258,6 +258,11 @@ def _builder(args: t.List) -> E:
258258
position=seq_get(args, 2),
259259
occurrence=seq_get(args, 3),
260260
group=exp.Literal.number(1) if group else default_group,
261+
**(
262+
{"null_if_pos_overflow": dialect.REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL}
263+
if expr_type is exp.RegexpExtract
264+
else {}
265+
),
261266
)
262267

263268
return _builder

sqlglot/dialects/dialect.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,9 @@ class Dialect(metaclass=_Dialect):
638638
REGEXP_EXTRACT_DEFAULT_GROUP = 0
639639
"""The default value for the capturing group."""
640640

641+
REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL = True
642+
"""Whether REGEXP_EXTRACT returns NULL when the position arg exceeds the string length."""
643+
641644
SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = {
642645
exp.Except: True,
643646
exp.Intersect: True,
@@ -1965,11 +1968,21 @@ def _builder(args: t.List) -> exp.Expression:
19651968

19661969
def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
19671970
def _builder(args: t.List, dialect: Dialect) -> E:
1971+
# The "position" argument specifies the index of the string character to start matching from.
1972+
# `null_if_pos_overflow` reflects the dialect's behavior when position is greater than the string
1973+
# length. If true, returns NULL. If false, returns an empty string. `null_if_pos_overflow` is
1974+
# only needed for exp.RegexpExtract - exp.RegexpExtractAll always returns an empty array if
1975+
# position overflows.
19681976
return expr_type(
19691977
this=seq_get(args, 0),
19701978
expression=seq_get(args, 1),
19711979
group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP),
19721980
parameters=seq_get(args, 3),
1981+
**(
1982+
{"null_if_pos_overflow": dialect.REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL}
1983+
if expr_type is exp.RegexpExtract
1984+
else {}
1985+
),
19731986
)
19741987

19751988
return _builder

sqlglot/dialects/duckdb.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,11 +1448,13 @@ def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
14481448
params = expression.args.get("parameters")
14491449
position = expression.args.get("position")
14501450
occurrence = expression.args.get("occurrence")
1451+
null_if_pos_overflow = expression.args.get("null_if_pos_overflow")
1452+
14511453
if position and (not position.is_int or position.to_py() > 1):
1452-
# substring returns '' if position > len(string), but the '' shouldn't carry through to REGEXP_EXTRACT
1453-
this = exp.Nullif(
1454-
this=exp.Substring(this=this, start=position), expression=exp.Literal.string("")
1455-
)
1454+
this = exp.Substring(this=this, start=position)
1455+
1456+
if null_if_pos_overflow:
1457+
this = exp.Nullif(this=this, expression=exp.Literal.string(""))
14561458

14571459
# Do not render group if there is no following argument,
14581460
# and it's the default value for this dialect

sqlglot/dialects/redshift.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Redshift(Postgres):
4848
HEX_LOWERCASE = True
4949
HAS_DISTINCT_ARRAY_CONSTRUCTORS = True
5050
COALESCE_COMPARISON_NON_STANDARD = True
51+
REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL = False
5152

5253
# ref: https://docs.aws.amazon.com/redshift/latest/dg/r_FORMAT_strings.html
5354
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
@@ -69,6 +70,13 @@ class Parser(Postgres.Parser):
6970
"DATE_DIFF": _build_date_delta(exp.TsOrDsDiff),
7071
"GETDATE": exp.CurrentTimestamp.from_arg_list,
7172
"LISTAGG": exp.GroupConcat.from_arg_list,
73+
"REGEXP_SUBSTR": lambda args: exp.RegexpExtract(
74+
this=seq_get(args, 0),
75+
expression=seq_get(args, 1),
76+
position=seq_get(args, 2),
77+
occurrence=seq_get(args, 3),
78+
parameters=seq_get(args, 4),
79+
),
7280
"SPLIT_TO_ARRAY": lambda args: exp.StringToArray(
7381
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string(",")
7482
),
@@ -201,6 +209,7 @@ class Generator(Postgres.Generator):
201209
exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
202210
exp.GroupConcat: rename_func("LISTAGG"),
203211
exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))),
212+
exp.RegexpExtract: rename_func("REGEXP_SUBSTR"),
204213
exp.Select: transforms.preprocess(
205214
[
206215
transforms.eliminate_window_clause,

sqlglot/dialects/snowflake.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,15 +344,20 @@ def _transform_generate_date_array(expression: exp.Expression) -> exp.Expression
344344
return expression
345345

346346

347-
def _build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
348-
def _builder(args: t.List) -> E:
347+
def _build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Snowflake], E]:
348+
def _builder(args: t.List, dialect: Snowflake) -> E:
349349
return expr_type(
350350
this=seq_get(args, 0),
351351
expression=seq_get(args, 1),
352352
position=seq_get(args, 2),
353353
occurrence=seq_get(args, 3),
354354
parameters=seq_get(args, 4),
355355
group=seq_get(args, 5) or exp.Literal.number(0),
356+
**(
357+
{"null_if_pos_overflow": dialect.REGEXP_EXTRACT_POSITION_OVERFLOW_RETURNS_NULL}
358+
if expr_type is exp.RegexpExtract
359+
else {}
360+
),
356361
)
357362

358363
return _builder

sqlglot/expressions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7582,6 +7582,7 @@ class RegexpExtract(Func):
75827582
"occurrence": False,
75837583
"parameters": False,
75847584
"group": False,
7585+
"null_if_pos_overflow": False, # for transpilation target behavior
75857586
}
75867587

75877588

tests/dialects/test_redshift.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,3 +724,12 @@ def test_fetch_to_limit(self):
724724
"postgres": "SELECT * FROM t FETCH FIRST 1 ROWS ONLY",
725725
},
726726
)
727+
728+
def test_regexp_extract(self):
729+
self.validate_all(
730+
"SELECT REGEXP_SUBSTR(abc, 'pattern(group)', 2) FROM table",
731+
write={
732+
"redshift": '''SELECT REGEXP_SUBSTR(abc, 'pattern(group)', 2) FROM "table"''',
733+
"duckdb": '''SELECT REGEXP_EXTRACT(SUBSTRING(abc, 2), 'pattern(group)') FROM "table"''',
734+
},
735+
)

0 commit comments

Comments
 (0)