Skip to content

Commit 2100d82

Browse files
committed
get rid of ParametrizedFunction, fix extract function, fix cast with interval function
1 parent 9906da0 commit 2100d82

File tree

5 files changed

+175
-33
lines changed

5 files changed

+175
-33
lines changed

es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1953,7 +1953,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
19531953
| "c": {
19541954
| "script": {
19551955
| "lang": "painless",
1956-
| "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? ((e0).atStartOfDay(ZoneId.of('Z')).minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }"
1956+
| "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? (e0.minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }"
19571957
| }
19581958
| }
19591959
| },
@@ -1991,4 +1991,74 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
19911991
.replaceAll("=e", " = e")
19921992
}
19931993

1994+
it should "handle extract function as script field" in {
1995+
val select: ElasticSearchRequest =
1996+
SQLQuery(extract)
1997+
val query = select.query
1998+
println(query)
1999+
query shouldBe
2000+
"""{
2001+
| "query": {
2002+
| "match_all": {}
2003+
| },
2004+
| "script_fields": {
2005+
| "day": {
2006+
| "script": {
2007+
| "lang": "painless",
2008+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.DAYS) : null)"
2009+
| }
2010+
| },
2011+
| "month": {
2012+
| "script": {
2013+
| "lang": "painless",
2014+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MONTHS) : null)"
2015+
| }
2016+
| },
2017+
| "year": {
2018+
| "script": {
2019+
| "lang": "painless",
2020+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.YEARS) : null)"
2021+
| }
2022+
| },
2023+
| "hour": {
2024+
| "script": {
2025+
| "lang": "painless",
2026+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.HOURS) : null)"
2027+
| }
2028+
| },
2029+
| "minute": {
2030+
| "script": {
2031+
| "lang": "painless",
2032+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MINUTES) : null)"
2033+
| }
2034+
| },
2035+
| "second": {
2036+
| "script": {
2037+
| "lang": "painless",
2038+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.SECONDS) : null)"
2039+
| }
2040+
| }
2041+
| },
2042+
| "_source": true
2043+
|}""".stripMargin
2044+
.replaceAll("\\s+", "")
2045+
.replaceAll("defe", "def e")
2046+
.replaceAll("if\\(", "if (")
2047+
.replaceAll("=\\(", " = (")
2048+
.replaceAll("\\?", " ? ")
2049+
.replaceAll(":null", " : null")
2050+
.replaceAll("null:", "null : ")
2051+
.replaceAll("return", " return ")
2052+
.replaceAll("between\\(s,", "between(s, ")
2053+
.replaceAll(";", "; ")
2054+
.replaceAll(";if", "; if")
2055+
.replaceAll("==", " == ")
2056+
.replaceAll("!=", " != ")
2057+
.replaceAll("&&", " && ")
2058+
.replaceAll("\\|\\|", " || ")
2059+
.replaceAll(";\\s\\s", "; ")
2060+
.replaceAll(">", " > ")
2061+
.replaceAll("if \\(\\s*def", "if (def")
2062+
}
2063+
19942064
}

sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1942,7 +1942,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
19421942
| "c": {
19431943
| "script": {
19441944
| "lang": "painless",
1945-
| "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? ((e0).atStartOfDay(ZoneId.of('Z')).minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }"
1945+
| "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? (e0.minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }"
19461946
| }
19471947
| }
19481948
| },
@@ -1980,4 +1980,74 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
19801980
.replaceAll("=e", " = e")
19811981
}
19821982

1983+
it should "handle extract function as script field" in {
1984+
val select: ElasticSearchRequest =
1985+
SQLQuery(extract)
1986+
val query = select.query
1987+
println(query)
1988+
query shouldBe
1989+
"""{
1990+
| "query": {
1991+
| "match_all": {}
1992+
| },
1993+
| "script_fields": {
1994+
| "day": {
1995+
| "script": {
1996+
| "lang": "painless",
1997+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.DAYS) : null)"
1998+
| }
1999+
| },
2000+
| "month": {
2001+
| "script": {
2002+
| "lang": "painless",
2003+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MONTHS) : null)"
2004+
| }
2005+
| },
2006+
| "year": {
2007+
| "script": {
2008+
| "lang": "painless",
2009+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.YEARS) : null)"
2010+
| }
2011+
| },
2012+
| "hour": {
2013+
| "script": {
2014+
| "lang": "painless",
2015+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.HOURS) : null)"
2016+
| }
2017+
| },
2018+
| "minute": {
2019+
| "script": {
2020+
| "lang": "painless",
2021+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MINUTES) : null)"
2022+
| }
2023+
| },
2024+
| "second": {
2025+
| "script": {
2026+
| "lang": "painless",
2027+
| "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.SECONDS) : null)"
2028+
| }
2029+
| }
2030+
| },
2031+
| "_source": true
2032+
|}""".stripMargin
2033+
.replaceAll("\\s+", "")
2034+
.replaceAll("defe", "def e")
2035+
.replaceAll("if\\(", "if (")
2036+
.replaceAll("=\\(", " = (")
2037+
.replaceAll("\\?", " ? ")
2038+
.replaceAll(":null", " : null")
2039+
.replaceAll("null:", "null : ")
2040+
.replaceAll("return", " return ")
2041+
.replaceAll("between\\(s,", "between(s, ")
2042+
.replaceAll(";", "; ")
2043+
.replaceAll(";if", "; if")
2044+
.replaceAll("==", " == ")
2045+
.replaceAll("!=", " != ")
2046+
.replaceAll("&&", " && ")
2047+
.replaceAll("\\|\\|", " || ")
2048+
.replaceAll(";\\s\\s", "; ")
2049+
.replaceAll(">", " > ")
2050+
.replaceAll("if \\(\\s*def", "if (def")
2051+
}
2052+
19832053
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,6 @@ sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType]
137137
override def applyType(in: SQLType): SQLType = in
138138
}
139139

140-
sealed trait ParametrizedFunction extends SQLFunction {
141-
def params: Seq[String]
142-
override def toSQL(base: String): String = {
143-
params match {
144-
case Nil => s"$sql($base)"
145-
case _ =>
146-
val paramsStr = params.mkString(", ")
147-
s"$sql($paramsStr)($base)"
148-
}
149-
}
150-
}
151-
152140
sealed trait AggregateFunction extends SQLFunction
153141
case object Count extends SQLExpr("count") with AggregateFunction
154142
case object Min extends SQLExpr("min") with AggregateFunction
@@ -367,36 +355,35 @@ case class DateTrunc(identifier: SQLIdentifier, unit: TimeUnit)
367355
case class Extract(unit: TimeUnit, override val sql: String = "extract")
368356
extends SQLExpr(sql)
369357
with DateTimeFunction
370-
with SQLTransformFunction[SQLTemporal, SQLNumeric]
371-
with ParametrizedFunction {
358+
with SQLTransformFunction[SQLTemporal, SQLNumeric] {
372359
override def inputType: SQLTemporal = SQLTypes.Temporal
373360
override def outputType: SQLNumeric = SQLTypes.Numeric
374-
override def params: Seq[String] = Seq(unit.sql)
361+
override def toSQL(base: String): String = s"$sql(${unit.sql} from $base)"
375362
override def painless: String = s".get(${unit.painless})"
376363
}
377364

378365
object YEAR extends Extract(Year, Year.sql) {
379-
override def params: Seq[String] = Seq.empty
366+
override def toSQL(base: String): String = s"$sql($base)"
380367
}
381368

382369
object MONTH extends Extract(Month, Month.sql) {
383-
override def params: Seq[String] = Seq.empty
370+
override def toSQL(base: String): String = s"$sql($base)"
384371
}
385372

386373
object DAY extends Extract(Day, Day.sql) {
387-
override def params: Seq[String] = Seq.empty
374+
override def toSQL(base: String): String = s"$sql($base)"
388375
}
389376

390377
object HOUR extends Extract(Hour, Hour.sql) {
391-
override def params: Seq[String] = Seq.empty
378+
override def toSQL(base: String): String = s"$sql($base)"
392379
}
393380

394381
object MINUTE extends Extract(Minute, Minute.sql) {
395-
override def params: Seq[String] = Seq.empty
382+
override def toSQL(base: String): String = s"$sql($base)"
396383
}
397384

398385
object SECOND extends Extract(Second, Second.sql) {
399-
override def params: Seq[String] = Seq.empty
386+
override def toSQL(base: String): String = s"$sql($base)"
400387
}
401388

402389
case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit)
@@ -732,7 +719,8 @@ case class SQLCaseWhen(
732719
s"def val$idx = $c; if (expr == val$idx) return ${SQLTypeUtils.coerce(i.toPainless(s"val$idx"), i.out, out, nullable = false)};"
733720
else {
734721
cond.asInstanceOf[Identifier].nullable = false
735-
s"def e$idx = ${i.checkNotNull}; def val$idx = e$idx != null ? ${SQLTypeUtils.coerce(cond.asInstanceOf[Identifier].toPainless(s"e$idx"), cond.out, out, nullable = false)} : null; if (expr == val$idx) return ${SQLTypeUtils
722+
s"def e$idx = ${i.checkNotNull}; def val$idx = e$idx != null ? ${SQLTypeUtils
723+
.coerce(cond.asInstanceOf[Identifier].toPainless(s"e$idx"), cond.out, out, nullable = false)} : null; if (expr == val$idx) return ${SQLTypeUtils
736724
.coerce(i.toPainless(s"e$idx"), i.out, out, nullable = false)};"
737725
}
738726
case _ =>

sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,10 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser =>
173173
DateTrunc(i, u)
174174
}
175175

176-
def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] =
177-
"(?i)extract".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ =>
178-
Extract(u)
176+
def extract_identifier: PackratParser[SQLIdentifier] =
177+
"(?i)extract".r ~ start ~ time_unit ~ "(?i)from".r ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ end ^^ {
178+
case _ ~ _ ~ u ~ _ ~ i ~ _ =>
179+
i.copy(functions = Extract(u) +: i.functions)
179180
}
180181

181182
def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] =
@@ -197,7 +198,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser =>
197198
Second.regex ^^ (_ => SECOND)
198199

199200
def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] =
200-
extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second
201+
extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second
201202

202203
def date_add: PackratParser[DateFunction with SQLFunctionWithIdentifier] =
203204
"(?i)date_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ {
@@ -311,6 +312,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser =>
311312
// les plus spécifiques en premier
312313
identifierWithTransformation | // transformations appliquées à un identifier
313314
date_diff_identifier | // date_diff(...) retournant un identifier-like
315+
extract_identifier |
314316
identifierWithSystemFunction | // CURRENT_DATE, NOW, etc. (+/- interval)
315317
identifierWithArithmeticFunction | // foo - interval ...
316318
identifierWithTemporalFunction | // chaîne de fonctions appliquées à un identifier
@@ -483,10 +485,10 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser =>
483485
string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | int_type
484486

485487
private[this] def castFunctionWithIdentifier: PackratParser[SQLIdentifier] =
486-
"(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ arithmeticFunction.? ^^ {
488+
"(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ arithmeticFunction.? ^^ {
487489
case _ ~ _ ~ i ~ as ~ t ~ _ ~ a =>
488490
i.copy(functions =
489-
(SQLCast(i, targetType = t, as = as.isDefined) +: i.functions) ++ a.toList
491+
a.toList ++ (SQLCast(i, targetType = t, as = as.isDefined) +: i.functions)
490492
)
491493
}
492494

@@ -552,7 +554,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser =>
552554
def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) }
553555

554556
def field: PackratParser[Field] =
555-
(identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | case_when_identifier | identifier) ~ alias.? ^^ {
557+
(identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | case_when_identifier | identifier) ~ alias.? ^^ {
556558
case i ~ a =>
557559
SQLField(i, a)
558560
}
@@ -612,7 +614,7 @@ trait SQLWhereParser {
612614
private def diff: PackratParser[SQLComparisonOperator] = Diff.sql ^^ (_ => Diff)
613615

614616
private def any_identifier: PackratParser[SQLIdentifier] =
615-
identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier
617+
identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier
616618

617619
private def equality: PackratParser[SQLExpression] =
618620
not.? ~ any_identifier ~ (eq | ne | diff) ~ (boolean | literal | double | long | any_identifier) ^^ {

sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ object Queries {
147147
"select case when lastUpdated > now - interval 7 day then lastUpdated when isnotnull(lastSeen) then lastSeen + interval 2 day else createdAt end as c, identifier from Table"
148148
val caseWhenExpr: String =
149149
"select case current_date - interval 7 day when cast(lastUpdated as date) - interval 3 day then lastUpdated when lastSeen then lastSeen + interval 2 day else createdAt end as c, identifier from Table"
150+
151+
val extract: String =
152+
"select extract(day from createdAt) as day, extract(month from createdAt) as month, extract(year from createdAt) as year, extract(hour from createdAt) as hour, extract(minute from createdAt) as minute, extract(second from createdAt) as second from Table"
150153
}
151154

152155
/** Created by smanciot on 15/02/17.
@@ -579,8 +582,17 @@ class SQLParserSpec extends AnyFlatSpec with Matchers {
579582

580583
it should "parse case when with expression" in {
581584
val result = SQLParser(caseWhenExpr)
585+
result.toOption
586+
.flatMap(_.left.toOption.map(_.sql))
587+
.getOrElse("")
588+
.equalsIgnoreCase(caseWhenExpr) shouldBe true
589+
}
590+
591+
it should "parse extract function" in {
592+
val result = SQLParser(extract)
582593
result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===(
583-
caseWhenExpr
594+
extract
584595
)
585596
}
597+
586598
}

0 commit comments

Comments
 (0)