Skip to content

Commit fba02c9

Browse files
committed
implements case when
1 parent ad71c92 commit fba02c9

File tree

8 files changed

+416
-34
lines changed

8 files changed

+416
-34
lines changed

es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,4 +1888,106 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
18881888
.replaceAll(",LocalDate", ", LocalDate")
18891889
.replaceAll("=DateTimeFormatter", " = DateTimeFormatter")
18901890
}
1891+
1892+
it should "handle case function as script field" in {
1893+
val select: ElasticSearchRequest =
1894+
SQLQuery(caseWhen)
1895+
val query = select.query
1896+
println(query)
1897+
query shouldBe
1898+
"""{
1899+
| "query": {
1900+
| "match_all": {}
1901+
| },
1902+
| "script_fields": {
1903+
| "c": {
1904+
| "script": {
1905+
| "lang": "painless",
1906+
| "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }"
1907+
| }
1908+
| }
1909+
| },
1910+
| "_source": {
1911+
| "includes": [
1912+
| "identifier"
1913+
| ]
1914+
| }
1915+
|}""".stripMargin
1916+
.replaceAll("\\s+", "")
1917+
.replaceAll("defv", " def v")
1918+
.replaceAll("defd", " def d")
1919+
.replaceAll("defe", " def e")
1920+
.replaceAll("defl", " def l")
1921+
.replaceAll("if\\(", "if (")
1922+
.replaceAll("\\{if", "{ if")
1923+
.replaceAll("=\\(", " = (")
1924+
.replaceAll("\\?", " ? ")
1925+
.replaceAll(":null", " : null")
1926+
.replaceAll("null:", "null : ")
1927+
.replaceAll("false:", "false : ")
1928+
.replaceAll("return", " return ")
1929+
.replaceAll("between\\(s,", "between(s, ")
1930+
.replaceAll(";", "; ")
1931+
.replaceAll(";if", "; if")
1932+
.replaceAll("==", " == ")
1933+
.replaceAll("!=", " != ")
1934+
.replaceAll("&&", " && ")
1935+
.replaceAll("\\|\\|", " || ")
1936+
.replaceAll(";\\s\\s", "; ")
1937+
.replaceAll(">", " > ")
1938+
.replaceAll("if \\(\\s*def", "if (def")
1939+
.replaceAll("ChronoUnit", " ChronoUnit")
1940+
}
1941+
1942+
it should "handle case with expression function as script field" in {
1943+
val select: ElasticSearchRequest =
1944+
SQLQuery(caseWhenExpr)
1945+
val query = select.query
1946+
println(query)
1947+
query shouldBe
1948+
"""{
1949+
| "query": {
1950+
| "match_all": {}
1951+
| },
1952+
| "script_fields": {
1953+
| "c": {
1954+
| "script": {
1955+
| "lang": "painless",
1956+
| "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS); def val0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); if (expr == val0) return val0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }"
1957+
| }
1958+
| }
1959+
| },
1960+
| "_source": {
1961+
| "includes": [
1962+
| "identifier"
1963+
| ]
1964+
| }
1965+
|}""".stripMargin
1966+
.replaceAll("\\s+", "")
1967+
.replaceAll("defv", " def v")
1968+
.replaceAll("defd", " def d")
1969+
.replaceAll("defe", " def e")
1970+
.replaceAll("defl", " def l")
1971+
.replaceAll("if\\(", "if (")
1972+
.replaceAll("\\{if", "{ if")
1973+
.replaceAll("=\\(", " = (")
1974+
.replaceAll("\\?", " ? ")
1975+
.replaceAll(":null", " : null")
1976+
.replaceAll("null:", "null : ")
1977+
.replaceAll("false:", "false : ")
1978+
.replaceAll("return", " return ")
1979+
.replaceAll("between\\(s,", "between(s, ")
1980+
.replaceAll(";", "; ")
1981+
.replaceAll(";if", "; if")
1982+
.replaceAll("==", " == ")
1983+
.replaceAll("!=", " != ")
1984+
.replaceAll("&&", " && ")
1985+
.replaceAll("\\|\\|", " || ")
1986+
.replaceAll(";\\s\\s", "; ")
1987+
.replaceAll(">", " > ")
1988+
.replaceAll("if \\(\\s*def", "if (def")
1989+
.replaceAll("ChronoUnit", " ChronoUnit")
1990+
.replaceAll("=ZonedDateTime", " = ZonedDateTime")
1991+
}
1992+
18911993
}

sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,4 +1877,106 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
18771877
.replaceAll(",LocalDate", ", LocalDate")
18781878
.replaceAll("=DateTimeFormatter", " = DateTimeFormatter")
18791879
}
1880+
1881+
it should "handle case function as script field" in {
1882+
val select: ElasticSearchRequest =
1883+
SQLQuery(caseWhen)
1884+
val query = select.query
1885+
println(query)
1886+
query shouldBe
1887+
"""{
1888+
| "query": {
1889+
| "match_all": {}
1890+
| },
1891+
| "script_fields": {
1892+
| "c": {
1893+
| "script": {
1894+
| "lang": "painless",
1895+
| "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }"
1896+
| }
1897+
| }
1898+
| },
1899+
| "_source": {
1900+
| "includes": [
1901+
| "identifier"
1902+
| ]
1903+
| }
1904+
|}""".stripMargin
1905+
.replaceAll("\\s+", "")
1906+
.replaceAll("defv", " def v")
1907+
.replaceAll("defd", " def d")
1908+
.replaceAll("defe", " def e")
1909+
.replaceAll("defl", " def l")
1910+
.replaceAll("if\\(", "if (")
1911+
.replaceAll("\\{if", "{ if")
1912+
.replaceAll("=\\(", " = (")
1913+
.replaceAll("\\?", " ? ")
1914+
.replaceAll(":null", " : null")
1915+
.replaceAll("null:", "null : ")
1916+
.replaceAll("false:", "false : ")
1917+
.replaceAll("return", " return ")
1918+
.replaceAll("between\\(s,", "between(s, ")
1919+
.replaceAll(";", "; ")
1920+
.replaceAll(";if", "; if")
1921+
.replaceAll("==", " == ")
1922+
.replaceAll("!=", " != ")
1923+
.replaceAll("&&", " && ")
1924+
.replaceAll("\\|\\|", " || ")
1925+
.replaceAll(";\\s\\s", "; ")
1926+
.replaceAll(">", " > ")
1927+
.replaceAll("if \\(\\s*def", "if (def")
1928+
.replaceAll("ChronoUnit", " ChronoUnit")
1929+
}
1930+
1931+
it should "handle case with expression function as script field" in {
1932+
val select: ElasticSearchRequest =
1933+
SQLQuery(caseWhenExpr)
1934+
val query = select.query
1935+
println(query)
1936+
query shouldBe
1937+
"""{
1938+
| "query": {
1939+
| "match_all": {}
1940+
| },
1941+
| "script_fields": {
1942+
| "c": {
1943+
| "script": {
1944+
| "lang": "painless",
1945+
| "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS); def val0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); if (expr == val0) return val0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }"
1946+
| }
1947+
| }
1948+
| },
1949+
| "_source": {
1950+
| "includes": [
1951+
| "identifier"
1952+
| ]
1953+
| }
1954+
|}""".stripMargin
1955+
.replaceAll("\\s+", "")
1956+
.replaceAll("defv", " def v")
1957+
.replaceAll("defd", " def d")
1958+
.replaceAll("defe", " def e")
1959+
.replaceAll("defl", " def l")
1960+
.replaceAll("if\\(", "if (")
1961+
.replaceAll("\\{if", "{ if")
1962+
.replaceAll("=\\(", " = (")
1963+
.replaceAll("\\?", " ? ")
1964+
.replaceAll(":null", " : null")
1965+
.replaceAll("null:", "null : ")
1966+
.replaceAll("false:", "false : ")
1967+
.replaceAll("return", " return ")
1968+
.replaceAll("between\\(s,", "between(s, ")
1969+
.replaceAll(";", "; ")
1970+
.replaceAll(";if", "; if")
1971+
.replaceAll("==", " == ")
1972+
.replaceAll("!=", " != ")
1973+
.replaceAll("&&", " && ")
1974+
.replaceAll("\\|\\|", " || ")
1975+
.replaceAll(";\\s\\s", "; ")
1976+
.replaceAll(">", " > ")
1977+
.replaceAll("if \\(\\s*def", "if (def")
1978+
.replaceAll("ChronoUnit", " ChronoUnit")
1979+
.replaceAll("=ZonedDateTime", " = ZonedDateTime")
1980+
}
1981+
18801982
}

sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ sealed trait SQLDelimiter extends SQLToken
44

55
sealed trait StartDelimiter extends SQLDelimiter
66
case object StartPredicate extends SQLExpr("(") with StartDelimiter
7+
case object StartCase extends SQLExpr("case") with StartDelimiter
8+
case object WhenCase extends SQLExpr("when") with StartDelimiter
79

810
sealed trait EndDelimiter extends SQLDelimiter
911
case object EndPredicate extends SQLExpr(")") with EndDelimiter
1012
case object Separator extends SQLExpr(",") with EndDelimiter
13+
case object EndCase extends SQLExpr("end") with EndDelimiter
14+
case object ThenCase extends SQLExpr("then") with EndDelimiter

sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
package app.softnetwork.elastic.sql
22

3-
import scala.util.Try
43
import scala.util.matching.Regex
54

65
sealed trait SQLFunction extends SQLRegex {
76
def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql
87
def applyType(in: SQLType): SQLType = out
9-
var expr: SQLToken = SQLNull
10-
def applyTo(expr: SQLToken): Unit = {
11-
this.expr = expr
8+
private[this] var _expr: SQLToken = SQLNull
9+
def expr_=(e: SQLToken): Unit = {
10+
_expr = e
1211
}
13-
override def nullable: Boolean = Try(expr.nullable).getOrElse(true)
12+
def expr: SQLToken = _expr
13+
override def nullable: Boolean = expr.nullable
1414
}
1515

1616
sealed trait SQLFunctionWithIdentifier extends SQLFunction {
@@ -72,11 +72,10 @@ trait SQLFunctionChain extends SQLFunction {
7272

7373
override def system: Boolean = functions.lastOption.exists(_.system)
7474

75-
override def applyTo(expr: SQLToken): Unit = {
76-
super.applyTo(expr)
77-
val orderedFunctions = functions.reverse
78-
orderedFunctions.foldLeft(expr) { (currentExpr, fun) =>
79-
fun.applyTo(currentExpr)
75+
def applyTo(expr: SQLToken): Unit = {
76+
this.expr = expr
77+
functions.reverse.foldLeft(expr) { (currentExpr, fun) =>
78+
fun.expr = currentExpr
8079
fun
8180
}
8281
}
@@ -662,8 +661,105 @@ case class SQLCast(value: PainlessScript, targetType: SQLType, as: Boolean = tru
662661

663662
override def toPainless(base: String, idx: Int): String =
664663
SQLTypeUtils.coerce(base, value.out, targetType, value.nullable)
665-
/*if (nullable)
666-
s"(def e$idx = $base; e$idx != null ? ${SQLTypeUtils.coerce(s"e$idx", value.out, out, nullable = false)}$painless : null)"
667-
else
668-
s"${SQLTypeUtils.coerce(base, value.out, targetType, nullable = value.nullable)}$painless"*/
664+
}
665+
666+
case class SQLCaseWhen(
667+
expression: Option[PainlessScript],
668+
conditions: List[(PainlessScript, PainlessScript)],
669+
default: Option[PainlessScript]
670+
) extends SQLTransformFunction[SQLAny, SQLAny] {
671+
override def inputType: SQLAny = SQLTypes.Any
672+
override def outputType: SQLAny = SQLTypes.Any
673+
674+
override def sql: String = {
675+
val exprPart = expression.map(e => s"$Case ${e.sql}").getOrElse(Case.sql)
676+
val whenThen = conditions
677+
.map { case (cond, res) => s"$When ${cond.sql} $Then ${res.sql}" }
678+
.mkString(" ")
679+
val elsePart = default.map(d => s" $Else ${d.sql}").getOrElse("")
680+
s"$exprPart $whenThen$elsePart $End"
681+
}
682+
683+
override def out: SQLType =
684+
SQLTypeUtils.leastCommonSuperType(
685+
conditions.map(_._2.out) ++ default.map(_.out).toList
686+
)
687+
688+
override def applyType(in: SQLType): SQLType = out
689+
690+
override def validate(): Either[String, Unit] = {
691+
if (conditions.isEmpty) Left("CASE WHEN requires at least one condition")
692+
else if (
693+
expression.isEmpty && conditions.exists { case (cond, _) => cond.out != SQLTypes.Boolean }
694+
)
695+
Left("CASE WHEN conditions must be of type BOOLEAN")
696+
else if (
697+
expression.isDefined && conditions.exists { case (cond, _) =>
698+
!SQLTypeUtils.matches(cond.out, expression.get.out)
699+
}
700+
)
701+
Left("CASE WHEN conditions must be of the same type as the expression")
702+
else Right(())
703+
}
704+
705+
override def painless: String = {
706+
val base =
707+
expression match {
708+
case Some(expr) =>
709+
s"def expr = ${SQLTypeUtils.coerce(expr, expr.out)}; "
710+
case _ => ""
711+
}
712+
val cases = conditions.zipWithIndex
713+
.map { case ((cond, res), zindex) =>
714+
expression match {
715+
case Some(expr) =>
716+
val c = SQLTypeUtils.coerce(cond, expr.out)
717+
if (cond.sql == res.sql) {
718+
s"def val$zindex = $c; if (expr == val$zindex) return val$zindex;"
719+
} else {
720+
val _res = {
721+
res match {
722+
case i: Identifier =>
723+
val name = i.name
724+
cond match {
725+
case e: Expression if e.identifier.name == name =>
726+
e.identifier.nullable = false
727+
e
728+
case i: Identifier if i.name == name =>
729+
i.nullable = false
730+
i
731+
case _ => res
732+
}
733+
case _ => res
734+
}
735+
}
736+
val r = SQLTypeUtils.coerce(_res, out)
737+
s"if (expr == $c) return $r;"
738+
}
739+
case None =>
740+
val c = SQLTypeUtils.coerce(cond, SQLTypes.Boolean)
741+
val r =
742+
cond match {
743+
case e: Expression =>
744+
val name = e.identifier.name
745+
res match {
746+
case i: Identifier if i.name == name => "left"
747+
case _ => SQLTypeUtils.coerce(res, out)
748+
}
749+
case _ => SQLTypeUtils.coerce(res, out)
750+
}
751+
s"if ($c) return $r;"
752+
}
753+
}
754+
.mkString(" ")
755+
val defaultCase = default
756+
.map(d => s"def dval = ${SQLTypeUtils.coerce(d, out)}; return dval;")
757+
.getOrElse("return null;")
758+
s"{ $base$cases $defaultCase }"
759+
}
760+
761+
override def toPainless(base: String, idx: Int): String = s"$base$painless"
762+
763+
override def nullable: Boolean =
764+
conditions.exists { case (_, res) => res.nullable } || default.forall(_.nullable)
669765
}

0 commit comments

Comments
 (0)