@@ -42,6 +42,12 @@ trait SQLQueryValidator {
4242 // 1. Extract the SQL query (must be a literal)
4343 val sqlQuery = extractSQLString(c)(query)
4444
45+ // ✅ Check if already validated
46+ if (SQLQueryValidator .isCached(sqlQuery)) {
47+ debug(c)(s " ✅ Query already validated (cached): $sqlQuery" )
48+ return sqlQuery
49+ }
50+
4551 if (sys.props.get(" elastic.sql.debug" ).contains(" true" )) {
4652 c.info(c.enclosingPosition, s " Validating SQL: $sqlQuery" , force = false )
4753 }
@@ -75,6 +81,9 @@ trait SQLQueryValidator {
7581 debug(c)(" =" * 80 )
7682
7783 // 8. Return the validated request
84+ // ✅ Mark as validated
85+ SQLQueryValidator .markValidated(sqlQuery)
86+
7887 sqlQuery
7988 }
8089
@@ -167,17 +176,22 @@ trait SQLQueryValidator {
167176 // ============================================================
168177 // Reject SELECT * (incompatible with compile-time validation)
169178 // ============================================================
170- private def rejectSelectStar (c : blackbox.Context )(
179+ private def rejectSelectStar [ T : c. WeakTypeTag ] (c : blackbox.Context )(
171180 parsedQuery : SQLSearchRequest ,
172181 sqlQuery : String
173182 ): Unit = {
183+ import c .universe ._
174184
175185 // Check if any field is a wildcard (*)
176186 val hasWildcard = parsedQuery.select.fields.exists { field =>
177187 field.identifier.name == " *"
178188 }
179189
180190 if (hasWildcard) {
191+ val tpe = weakTypeOf[T ]
192+ val requiredFields = getRequiredFields(c)(tpe)
193+ val fieldNames = requiredFields.keys.mkString(" , " )
194+
181195 c.abort(
182196 c.enclosingPosition,
183197 s """ ❌ SELECT * is not allowed with compile-time validation.
@@ -190,11 +204,11 @@ trait SQLQueryValidator {
190204 | • Schema changes will break silently at runtime
191205 |
192206 |Solution:
193- | 1. Explicitly list all required fields:
194- | SELECT id, name, price FROM products
207+ | 1. Explicitly list all required fields for ${tpe.typeSymbol.name} :
208+ | SELECT $fieldNames FROM ...
195209 |
196210 | 2. Use the *Unchecked() variant for dynamic queries:
197- | searchAsUnchecked[Product ](SQLQuery("SELECT * FROM products "))
211+ | searchAsUnchecked[ ${tpe.typeSymbol.name} ](SQLQuery("SELECT * FROM ... "))
198212 |
199213 |Best Practice:
200214 | Always explicitly select only the fields you need.
@@ -278,12 +292,16 @@ trait SQLQueryValidator {
278292 */
279293 private def extractQueryFields (parsedQuery : SQLSearchRequest ): Set [String ] = {
280294 parsedQuery.select.fields.flatMap { field =>
281- val f = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name)
282- /* field.identifier.nestedElement match {
283- case Some(nested) => List(f, nested.innerHitsName)
284- case None => List(f)
285- }*/
286- List (f)
295+ val fieldName = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name)
296+
297+ // ✅ Manage nested fields (ex: "children.name" → "children", "children.name")
298+ val nestedParts = fieldName.split(" \\ ." ).toList
299+
300+ // Return all levels of nested fields
301+ // Ex: "children.address.city" → ["children", "children.address", "children.address.city"]
302+ nestedParts.indices.map { i =>
303+ nestedParts.take(i + 1 ).mkString(" ." )
304+ }
287305 }.toSet
288306 }
289307
@@ -408,74 +426,63 @@ trait SQLQueryValidator {
408426 ): Boolean = {
409427 import c .universe ._
410428
429+ val underlyingType = if (scalaType <:< typeOf[Option [_]]) {
430+ scalaType.typeArgs.headOption.getOrElse(scalaType)
431+ } else {
432+ scalaType
433+ }
434+
411435 sqlType match {
412436 case SQLTypes .TinyInt =>
413- scalaType =:= typeOf[Byte ] ||
414- scalaType =:= typeOf[Short ] ||
415- scalaType =:= typeOf[Int ] ||
416- scalaType =:= typeOf[Long ] ||
417- scalaType =:= typeOf[Option [Byte ]] ||
418- scalaType =:= typeOf[Option [Short ]] ||
419- scalaType =:= typeOf[Option [Int ]] ||
420- scalaType =:= typeOf[Option [Long ]]
437+ underlyingType =:= typeOf[Byte ] ||
438+ underlyingType =:= typeOf[Short ] ||
439+ underlyingType =:= typeOf[Int ] ||
440+ underlyingType =:= typeOf[Long ]
421441
422442 case SQLTypes .SmallInt =>
423- scalaType =:= typeOf[Short ] ||
424- scalaType =:= typeOf[Int ] ||
425- scalaType =:= typeOf[Long ] ||
426- scalaType =:= typeOf[Option [Short ]] ||
427- scalaType =:= typeOf[Option [Int ]] ||
428- scalaType =:= typeOf[Option [Long ]]
443+ underlyingType =:= typeOf[Short ] ||
444+ underlyingType =:= typeOf[Int ] ||
445+ underlyingType =:= typeOf[Long ]
429446
430447 case SQLTypes .Int =>
431- scalaType =:= typeOf[Int ] ||
432- scalaType =:= typeOf[Long ] ||
433- scalaType =:= typeOf[Option [Int ]] ||
434- scalaType =:= typeOf[Option [Long ]]
448+ underlyingType =:= typeOf[Int ] ||
449+ underlyingType =:= typeOf[Long ]
435450
436451 case SQLTypes .BigInt =>
437- scalaType =:= typeOf[Long ] ||
438- scalaType =:= typeOf[BigInt ] ||
439- scalaType =:= typeOf[Option [Long ]] ||
440- scalaType =:= typeOf[Option [BigInt ]]
452+ underlyingType =:= typeOf[Long ] ||
453+ underlyingType =:= typeOf[BigInt ]
441454
442455 case SQLTypes .Double | SQLTypes .Real =>
443- scalaType =:= typeOf[Double ] ||
444- scalaType =:= typeOf[Float ] ||
445- scalaType =:= typeOf[Option [Double ]] ||
446- scalaType =:= typeOf[Option [Float ]]
456+ underlyingType =:= typeOf[Double ] ||
457+ underlyingType =:= typeOf[Float ]
447458
448459 case SQLTypes .Char =>
449- scalaType =:= typeOf[String ] || // CHAR(n) → String
450- scalaType =:= typeOf[Char ] || // CHAR(1) → Char
451- scalaType =:= typeOf[Option [String ]] ||
452- scalaType =:= typeOf[Option [Char ]]
460+ underlyingType =:= typeOf[String ] || // CHAR(n) → String
461+ underlyingType =:= typeOf[Char ] // CHAR(1) → Char
453462
454463 case SQLTypes .Varchar =>
455- scalaType =:= typeOf[String ] ||
456- scalaType =:= typeOf[Option [String ]]
464+ underlyingType =:= typeOf[String ]
457465
458466 case SQLTypes .Boolean =>
459- scalaType =:= typeOf[Boolean ] ||
460- scalaType =:= typeOf[Option [Boolean ]]
467+ underlyingType =:= typeOf[Boolean ]
461468
462469 case SQLTypes .Time =>
463- scalaType .toString.contains(" Instant" ) ||
464- scalaType .toString.contains(" LocalTime" )
470+ underlyingType .toString.contains(" Instant" ) ||
471+ underlyingType .toString.contains(" LocalTime" )
465472
466473 case SQLTypes .Date =>
467- scalaType .toString.contains(" Date" ) ||
468- scalaType .toString.contains(" Instant" ) ||
469- scalaType .toString.contains(" LocalDate" )
474+ underlyingType .toString.contains(" Date" ) ||
475+ underlyingType .toString.contains(" Instant" ) ||
476+ underlyingType .toString.contains(" LocalDate" )
470477
471478 case SQLTypes .DateTime | SQLTypes .Timestamp =>
472- scalaType .toString.contains(" LocalDateTime" ) ||
473- scalaType .toString.contains(" ZonedDateTime" ) ||
474- scalaType .toString.contains(" Instant" )
479+ underlyingType .toString.contains(" LocalDateTime" ) ||
480+ underlyingType .toString.contains(" ZonedDateTime" ) ||
481+ underlyingType .toString.contains(" Instant" )
475482
476483 case SQLTypes .Struct =>
477- if (scalaType .typeSymbol.isClass && scalaType .typeSymbol.asClass.isCaseClass) {
478- // validateStructFields(c)(sqlField, scalaType )
484+ if (underlyingType .typeSymbol.isClass && underlyingType .typeSymbol.asClass.isCaseClass) {
485+ // TODO validateStructFields(c)(sqlField, underlyingType )
479486 true
480487 } else {
481488 false
@@ -555,4 +562,9 @@ trait SQLQueryValidator {
555562
556563object SQLQueryValidator {
557564 val DEBUG : Boolean = sys.props.get(" sql.macro.debug" ).contains(" true" )
565+ // ✅ Cache pour éviter les validations redondantes
566+ private val validationCache = scala.collection.mutable.Map [String , Boolean ]()
567+
568+ private def isCached (sql : String ): Boolean = validationCache.contains(sql.trim)
569+ private def markValidated (sql : String ): Unit = validationCache(sql.trim) = true
558570}
0 commit comments