Skip to content

Commit 918dc7c

Browse files
committed
fix painless script for nullif function for temporal
1 parent 0aaaf71 commit 918dc7c

File tree

5 files changed

+20
-21
lines changed

5 files changed

+20
-21
lines changed

es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,7 +1958,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
19581958
| "c": {
19591959
| "script": {
19601960
| "lang": "painless",
1961-
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); def param2 = LocalDate.parse(\"2025-09-11\", DateTimeFormatter.ofPattern(\"yyyy-MM-dd\")).minus(2, ChronoUnit.DAYS); def param3 = param1 == param2 ? null : param1; def param4 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); param3 != null ? param3 : param4"
1961+
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); def param2 = LocalDate.parse(\"2025-09-11\", DateTimeFormatter.ofPattern(\"yyyy-MM-dd\")).minus(2, ChronoUnit.DAYS); def param3 = param1 == null || param1.isEqual(param2) ? null : param1; def param4 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); param3 != null ? param3 : param4"
19621962
| }
19631963
| }
19641964
| },
@@ -2016,7 +2016,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
20162016
| "c": {
20172017
| "script": {
20182018
| "lang": "painless",
2019-
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); def param2 = LocalDate.parse(\"2025-09-11\", DateTimeFormatter.ofPattern(\"yyyy-MM-dd\")); def param3 = param1 == param2 ? null : param1; def param4 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(2, ChronoUnit.HOURS); try { param3 != null ? param3 : param4 } catch (Exception e) { return null; }"
2019+
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); def param2 = LocalDate.parse(\"2025-09-11\", DateTimeFormatter.ofPattern(\"yyyy-MM-dd\")); def param3 = param1 == null || param1.isEqual(param2) ? null : param1; def param4 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(2, ChronoUnit.HOURS); try { param3 != null ? param3 : param4 } catch (Exception e) { return null; }"
20202020
| }
20212021
| },
20222022
| "c2": {

sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,7 +1962,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
19621962
| "c": {
19631963
| "script": {
19641964
| "lang": "painless",
1965-
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); def param2 = LocalDate.parse(\"2025-09-11\", DateTimeFormatter.ofPattern(\"yyyy-MM-dd\")).minus(2, ChronoUnit.DAYS); def param3 = param1 == param2 ? null : param1; def param4 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); param3 != null ? param3 : param4"
1965+
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); def param2 = LocalDate.parse(\"2025-09-11\", DateTimeFormatter.ofPattern(\"yyyy-MM-dd\")).minus(2, ChronoUnit.DAYS); def param3 = param1 == null || param1.isEqual(param2) ? null : param1; def param4 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); param3 != null ? param3 : param4"
19661966
| }
19671967
| }
19681968
| },
@@ -2020,7 +2020,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers {
20202020
| "c": {
20212021
| "script": {
20222022
| "lang": "painless",
2023-
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); def param2 = LocalDate.parse(\"2025-09-11\", DateTimeFormatter.ofPattern(\"yyyy-MM-dd\")); def param3 = param1 == param2 ? null : param1; def param4 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(2, ChronoUnit.HOURS); try { param3 != null ? param3 : param4 } catch (Exception e) { return null; }"
2023+
| "source": "def param1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value.toLocalDate()); def param2 = LocalDate.parse(\"2025-09-11\", DateTimeFormatter.ofPattern(\"yyyy-MM-dd\")); def param3 = param1 == null || param1.isEqual(param2) ? null : param1; def param4 = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(2, ChronoUnit.HOURS); try { param3 != null ? param3 : param4 } catch (Exception e) { return null; }"
20242024
| }
20252025
| },
20262026
| "c2": {

sql/src/main/scala/app/softnetwork/elastic/sql/function/cond/package.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@ import app.softnetwork.elastic.sql.{
2424
PainlessScript,
2525
TokenRegex
2626
}
27-
import app.softnetwork.elastic.sql.`type`.{SQLAny, SQLBool, SQLType, SQLTypeUtils, SQLTypes}
27+
import app.softnetwork.elastic.sql.`type`.{
28+
SQLAny,
29+
SQLBool,
30+
SQLTemporal,
31+
SQLType,
32+
SQLTypeUtils,
33+
SQLTypes
34+
}
2835
import app.softnetwork.elastic.sql.parser.Validator
2936
import app.softnetwork.elastic.sql.query.{CriteriaWithConditionalFunction, Expression}
3037

@@ -117,8 +124,6 @@ package object cond {
117124
// Reprend l’idée de SQLValues mais pour n’importe quel token
118125
override def baseType: SQLType = SQLTypeUtils.leastCommonSuperType(argTypes)
119126

120-
override def applyType(in: SQLType): SQLType = baseType
121-
122127
override def validate(): Either[String, Unit] = {
123128
if (values.isEmpty) Left("COALESCE requires at least one argument")
124129
else Right(())
@@ -153,8 +158,6 @@ package object cond {
153158

154159
override def baseType: SQLType = SQLTypeUtils.leastCommonSuperType(argTypes)
155160

156-
override def applyType(in: SQLType): SQLType = baseType
157-
158161
private[this] def checkIfExpressionNullable(expr: PainlessScript): Boolean = expr match {
159162
case f: FunctionChain if f.functions.nonEmpty => true
160163
case _ => false
@@ -170,7 +173,12 @@ package object cond {
170173
callArgs match {
171174
case List(arg0, arg1) =>
172175
val expr =
173-
s"${arg0.trim} == ${arg1.trim} ? null : $arg0"
176+
out match {
177+
case SQLTypes.Varchar =>
178+
s"$arg0 == null || $arg0.compareTo($arg1) == 0 ? null : $arg0"
179+
case _: SQLTemporal => s"$arg0 == null || $arg0.isEqual($arg1) ? null : $arg0"
180+
case _ => s"$arg0 == $arg1 ? null : $arg0"
181+
}
174182
context match {
175183
case Some(ctx) =>
176184
ctx.addParam(LiteralParam(expr)) match {
@@ -214,10 +222,7 @@ package object cond {
214222
s"$exprPart $whenThen$elsePart $END"
215223
}
216224

217-
override def baseType: SQLType =
218-
SQLTypeUtils.leastCommonSuperType(argTypes)
219-
220-
override def applyType(in: SQLType): SQLType = baseType
225+
override def baseType: SQLType = SQLTypeUtils.leastCommonSuperType(argTypes)
221226

222227
override def validate(): Either[String, Unit] = {
223228
if (conditions.isEmpty) Left("CASE WHEN requires at least one condition")

sql/src/main/scala/app/softnetwork/elastic/sql/function/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ package object function {
145145
override def in: SQLType = inputType
146146
override def baseType: SQLType = outputType
147147

148-
override def applyType(in: SQLType): SQLType = outputType
148+
override def applyType(in: SQLType): SQLType = baseType
149149

150150
override def sql: String =
151151
s"${fun.map(_.sql).getOrElse("")}(${args.map(_.sql).mkString(argsSeparator)})"

sql/src/main/scala/app/softnetwork/elastic/sql/function/time/package.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,11 @@ package object time {
6262
case _ => None
6363
}
6464

65-
//private[this] var _out: SQLType = outputType
66-
67-
//override def out: SQLType = _out
68-
6965
override def applyType(in: SQLType): SQLType = {
7066
interval.checkType(in) match {
7167
case Left(_) => baseType
7268
case Right(_) => cast(in)
7369
}
74-
//_out = interval.checkType(in).getOrElse(out)
75-
//_out
7670
}
7771

7872
override def validate(): Either[String, Unit] = interval.checkType(out) match {

0 commit comments

Comments
 (0)