Skip to content

Commit 1217cc8

Browse files
committed
update query validator to cache validated queries + update type compatibility check
1 parent 7de76bd commit 1217cc8

File tree

1 file changed

+66
-54
lines changed

1 file changed

+66
-54
lines changed

macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala

Lines changed: 66 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ trait SQLQueryValidator {
4242
// 1. Extract the SQL query (must be a literal)
4343
val sqlQuery = extractSQLString(c)(query)
4444

45+
// ✅ Check if already validated
46+
if (SQLQueryValidator.isCached(sqlQuery)) {
47+
debug(c)(s"✅ Query already validated (cached): $sqlQuery")
48+
return sqlQuery
49+
}
50+
4551
if (sys.props.get("elastic.sql.debug").contains("true")) {
4652
c.info(c.enclosingPosition, s"Validating SQL: $sqlQuery", force = false)
4753
}
@@ -75,6 +81,9 @@ trait SQLQueryValidator {
7581
debug(c)("=" * 80)
7682

7783
// 8. Return the validated request
84+
// ✅ Mark as validated
85+
SQLQueryValidator.markValidated(sqlQuery)
86+
7887
sqlQuery
7988
}
8089

@@ -167,17 +176,22 @@ trait SQLQueryValidator {
167176
// ============================================================
168177
// Reject SELECT * (incompatible with compile-time validation)
169178
// ============================================================
170-
private def rejectSelectStar(c: blackbox.Context)(
179+
private def rejectSelectStar[T: c.WeakTypeTag](c: blackbox.Context)(
171180
parsedQuery: SQLSearchRequest,
172181
sqlQuery: String
173182
): Unit = {
183+
import c.universe._
174184

175185
// Check if any field is a wildcard (*)
176186
val hasWildcard = parsedQuery.select.fields.exists { field =>
177187
field.identifier.name == "*"
178188
}
179189

180190
if (hasWildcard) {
191+
val tpe = weakTypeOf[T]
192+
val requiredFields = getRequiredFields(c)(tpe)
193+
val fieldNames = requiredFields.keys.mkString(", ")
194+
181195
c.abort(
182196
c.enclosingPosition,
183197
s"""❌ SELECT * is not allowed with compile-time validation.
@@ -190,11 +204,11 @@ trait SQLQueryValidator {
190204
| • Schema changes will break silently at runtime
191205
|
192206
|Solution:
193-
| 1. Explicitly list all required fields:
194-
| SELECT id, name, price FROM products
207+
| 1. Explicitly list all required fields for ${tpe.typeSymbol.name}:
208+
| SELECT $fieldNames FROM ...
195209
|
196210
| 2. Use the *Unchecked() variant for dynamic queries:
197-
| searchAsUnchecked[Product](SQLQuery("SELECT * FROM products"))
211+
| searchAsUnchecked[${tpe.typeSymbol.name}](SQLQuery("SELECT * FROM ..."))
198212
|
199213
|Best Practice:
200214
| Always explicitly select only the fields you need.
@@ -278,12 +292,16 @@ trait SQLQueryValidator {
278292
*/
279293
private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = {
280294
parsedQuery.select.fields.flatMap { field =>
281-
val f = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name)
282-
/*field.identifier.nestedElement match {
283-
case Some(nested) => List(f, nested.innerHitsName)
284-
case None => List(f)
285-
}*/
286-
List(f)
295+
val fieldName = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name)
296+
297+
// ✅ Manage nested fields (ex: "children.name" → "children", "children.name")
298+
val nestedParts = fieldName.split("\\.").toList
299+
300+
// Return all levels of nested fields
301+
// Ex: "children.address.city" → ["children", "children.address", "children.address.city"]
302+
nestedParts.indices.map { i =>
303+
nestedParts.take(i + 1).mkString(".")
304+
}
287305
}.toSet
288306
}
289307

@@ -408,74 +426,63 @@ trait SQLQueryValidator {
408426
): Boolean = {
409427
import c.universe._
410428

429+
val underlyingType = if (scalaType <:< typeOf[Option[_]]) {
430+
scalaType.typeArgs.headOption.getOrElse(scalaType)
431+
} else {
432+
scalaType
433+
}
434+
411435
sqlType match {
412436
case SQLTypes.TinyInt =>
413-
scalaType =:= typeOf[Byte] ||
414-
scalaType =:= typeOf[Short] ||
415-
scalaType =:= typeOf[Int] ||
416-
scalaType =:= typeOf[Long] ||
417-
scalaType =:= typeOf[Option[Byte]] ||
418-
scalaType =:= typeOf[Option[Short]] ||
419-
scalaType =:= typeOf[Option[Int]] ||
420-
scalaType =:= typeOf[Option[Long]]
437+
underlyingType =:= typeOf[Byte] ||
438+
underlyingType =:= typeOf[Short] ||
439+
underlyingType =:= typeOf[Int] ||
440+
underlyingType =:= typeOf[Long]
421441

422442
case SQLTypes.SmallInt =>
423-
scalaType =:= typeOf[Short] ||
424-
scalaType =:= typeOf[Int] ||
425-
scalaType =:= typeOf[Long] ||
426-
scalaType =:= typeOf[Option[Short]] ||
427-
scalaType =:= typeOf[Option[Int]] ||
428-
scalaType =:= typeOf[Option[Long]]
443+
underlyingType =:= typeOf[Short] ||
444+
underlyingType =:= typeOf[Int] ||
445+
underlyingType =:= typeOf[Long]
429446

430447
case SQLTypes.Int =>
431-
scalaType =:= typeOf[Int] ||
432-
scalaType =:= typeOf[Long] ||
433-
scalaType =:= typeOf[Option[Int]] ||
434-
scalaType =:= typeOf[Option[Long]]
448+
underlyingType =:= typeOf[Int] ||
449+
underlyingType =:= typeOf[Long]
435450

436451
case SQLTypes.BigInt =>
437-
scalaType =:= typeOf[Long] ||
438-
scalaType =:= typeOf[BigInt] ||
439-
scalaType =:= typeOf[Option[Long]] ||
440-
scalaType =:= typeOf[Option[BigInt]]
452+
underlyingType =:= typeOf[Long] ||
453+
underlyingType =:= typeOf[BigInt]
441454

442455
case SQLTypes.Double | SQLTypes.Real =>
443-
scalaType =:= typeOf[Double] ||
444-
scalaType =:= typeOf[Float] ||
445-
scalaType =:= typeOf[Option[Double]] ||
446-
scalaType =:= typeOf[Option[Float]]
456+
underlyingType =:= typeOf[Double] ||
457+
underlyingType =:= typeOf[Float]
447458

448459
case SQLTypes.Char =>
449-
scalaType =:= typeOf[String] || // CHAR(n) → String
450-
scalaType =:= typeOf[Char] || // CHAR(1) → Char
451-
scalaType =:= typeOf[Option[String]] ||
452-
scalaType =:= typeOf[Option[Char]]
460+
underlyingType =:= typeOf[String] || // CHAR(n) → String
461+
underlyingType =:= typeOf[Char] // CHAR(1) → Char
453462

454463
case SQLTypes.Varchar =>
455-
scalaType =:= typeOf[String] ||
456-
scalaType =:= typeOf[Option[String]]
464+
underlyingType =:= typeOf[String]
457465

458466
case SQLTypes.Boolean =>
459-
scalaType =:= typeOf[Boolean] ||
460-
scalaType =:= typeOf[Option[Boolean]]
467+
underlyingType =:= typeOf[Boolean]
461468

462469
case SQLTypes.Time =>
463-
scalaType.toString.contains("Instant") ||
464-
scalaType.toString.contains("LocalTime")
470+
underlyingType.toString.contains("Instant") ||
471+
underlyingType.toString.contains("LocalTime")
465472

466473
case SQLTypes.Date =>
467-
scalaType.toString.contains("Date") ||
468-
scalaType.toString.contains("Instant") ||
469-
scalaType.toString.contains("LocalDate")
474+
underlyingType.toString.contains("Date") ||
475+
underlyingType.toString.contains("Instant") ||
476+
underlyingType.toString.contains("LocalDate")
470477

471478
case SQLTypes.DateTime | SQLTypes.Timestamp =>
472-
scalaType.toString.contains("LocalDateTime") ||
473-
scalaType.toString.contains("ZonedDateTime") ||
474-
scalaType.toString.contains("Instant")
479+
underlyingType.toString.contains("LocalDateTime") ||
480+
underlyingType.toString.contains("ZonedDateTime") ||
481+
underlyingType.toString.contains("Instant")
475482

476483
case SQLTypes.Struct =>
477-
if (scalaType.typeSymbol.isClass && scalaType.typeSymbol.asClass.isCaseClass) {
478-
// validateStructFields(c)(sqlField, scalaType)
484+
if (underlyingType.typeSymbol.isClass && underlyingType.typeSymbol.asClass.isCaseClass) {
485+
// TODO validateStructFields(c)(sqlField, underlyingType)
479486
true
480487
} else {
481488
false
@@ -555,4 +562,9 @@ trait SQLQueryValidator {
555562

556563
object SQLQueryValidator {
557564
val DEBUG: Boolean = sys.props.get("sql.macro.debug").contains("true")
565+
// ✅ Cache pour éviter les validations redondantes
566+
private val validationCache = scala.collection.mutable.Map[String, Boolean]()
567+
568+
private def isCached(sql: String): Boolean = validationCache.contains(sql.trim)
569+
private def markValidated(sql: String): Unit = validationCache(sql.trim) = true
558570
}

0 commit comments

Comments
 (0)