@@ -2,6 +2,7 @@ package dotty.tools.dotc.qualified_types
22
33import scala .collection .mutable
44import scala .collection .mutable .ArrayBuffer
5+ import scala .collection .mutable .ListBuffer
56
67import dotty .tools .dotc .ast .tpd .{
78 closureDef ,
@@ -25,13 +26,16 @@ import dotty.tools.dotc.core.Constants.Constant
2526import dotty .tools .dotc .core .Contexts .Context
2627import dotty .tools .dotc .core .Contexts .ctx
2728import dotty .tools .dotc .core .Decorators .i
29+ import dotty .tools .dotc .core .Flags
2830import dotty .tools .dotc .core .Hashable .Binders
2931import dotty .tools .dotc .core .Names .Designator
3032import dotty .tools .dotc .core .StdNames .nme
3133import dotty .tools .dotc .core .Symbols .{defn , NoSymbol , Symbol }
3234import dotty .tools .dotc .core .Types .{
35+ AppliedType ,
3336 CachedProxyType ,
3437 ConstantType ,
38+ LambdaType ,
3539 MethodType ,
3640 NamedType ,
3741 NoPrefix ,
@@ -40,14 +44,14 @@ import dotty.tools.dotc.core.Types.{
4044 TermParamRef ,
4145 TermRef ,
4246 Type ,
47+ TypeRef ,
4348 TypeVar ,
4449 ValueType
4550}
4651import dotty .tools .dotc .qualified_types .ENode .Op
4752import dotty .tools .dotc .reporting .trace
4853import dotty .tools .dotc .transform .TreeExtractors .BinaryOp
4954import dotty .tools .dotc .util .Spans .Span
50- import scala .collection .mutable .ListBuffer
5155
5256final class EGraph (rootCtx : Context ):
5357
@@ -92,7 +96,7 @@ final class EGraph(rootCtx: Context):
9296 private val builtinOps = Map (
9397 d.Int_== -> Op .Equal ,
9498 d.Boolean_== -> Op .Equal ,
95- d.Any_ == -> Op .Equal ,
99+ d.String_ == -> Op .Equal ,
96100 d.Boolean_&& -> Op .And ,
97101 d.Boolean_|| -> Op .Or ,
98102 d.Boolean_! -> Op .Not ,
@@ -108,9 +112,8 @@ final class EGraph(rootCtx: Context):
108112
109113 def equiv (node1 : ENode , node2 : ENode )(using Context ): Boolean =
110114 trace(i " EGraph.equiv " , Printers .qualifiedTypes):
111- val margin = ctx.base.indentTab * (ctx.base.indent)
115+ // val margin = ctx.base.indentTab * (ctx.base.indent)
112116 // println(s"$margin node1: $node1\n$margin node2: $node2")
113- // Check if the representents of both nodes are the same
114117 val repr1 = representent(node1)
115118 val repr2 = representent(node2)
116119 repr1 eq repr2
@@ -121,8 +124,8 @@ final class EGraph(rootCtx: Context):
121124 node match
122125 case ENode .Atom (tp) =>
123126 ()
124- case ENode .New (clazz ) =>
125- addUse(clazz, node )
127+ case ENode .Constructor (sym ) =>
128+ ( )
126129 case ENode .Select (qual, member) =>
127130 addUse(qual, node)
128131 case ENode .Apply (fn, args) =>
@@ -138,6 +141,7 @@ final class EGraph(rootCtx: Context):
138141 }
139142 ).asInstanceOf [node.type ]
140143
144+ // TODO(mbovel): Memoize this
141145 def toNode (tree : Tree , paramSyms : List [Symbol ] = Nil , paramTps : List [ENode .ArgRefType ] = Nil )(using
142146 Context
143147 ): Option [ENode ] =
@@ -165,16 +169,18 @@ final class EGraph(rootCtx: Context):
165169 tree match
166170 case Literal (_) | Ident (_) | This (_) if tree.tpe.isInstanceOf [SingletonType ] =>
167171 Some (ENode .Atom (mapType(tree.tpe).asInstanceOf [SingletonType ]))
168- case New (clazz) =>
169- for clazzNode <- toNode(clazz, paramSyms, paramTps) yield ENode .New (clazzNode)
172+ case Select (New (_), nme.CONSTRUCTOR ) =>
173+ constructorNode(tree.symbol)
174+ case tree : Select if isCaseClassApply(tree.symbol) =>
175+ constructorNode(tree.symbol.owner.linkedClass.primaryConstructor)
170176 case Select (qual, name) =>
171- for qualNode <- toNode(qual, paramSyms, paramTps) yield ENode . Select (qualNode, tree.symbol)
177+ for qualNode <- toNode(qual, paramSyms, paramTps) yield normalizeSelect (qualNode, tree.symbol)
172178 case BinaryOp (lhs, op, rhs) if builtinOps.contains(op) =>
173179 for
174180 lhsNode <- toNode(lhs, paramSyms, paramTps)
175181 rhsNode <- toNode(rhs, paramSyms, paramTps)
176182 yield normalizeOp(builtinOps(op), List (lhsNode, rhsNode))
177- case BinaryOp (lhs, d.Int_- , rhs) if lhs.tpe. isInstanceOf [ ValueType ] && rhs.tpe. isInstanceOf [ ValueType ] =>
183+ case BinaryOp (lhs, d.Int_- , rhs) =>
178184 for
179185 lhsNode <- toNode(lhs, paramSyms, paramTps)
180186 rhsNode <- toNode(rhs, paramSyms, paramTps)
@@ -192,7 +198,7 @@ final class EGraph(rootCtx: Context):
192198 case mt : MethodType =>
193199 assert(defDef.termParamss.size == 1 , " closures have a single parameter list, right?" )
194200 val myParamSyms : List [Symbol ] = defDef.termParamss.head.map(_.symbol)
195- val myParamTps : ListBuffer [ENode .ArgRefType ] = ListBuffer .empty
201+ val myParamTps : ListBuffer [ENode .ArgRefType ] = ListBuffer .empty
196202 val paramTpsSize = paramTps.size
197203 for myParamSym <- myParamSyms do
198204 val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList))
@@ -204,15 +210,38 @@ final class EGraph(rootCtx: Context):
204210 case _ =>
205211 None
206212
213+ // TODO(mbovel): Memoize this
214+ private def constructorNode (constr : Symbol )(using Context ): Option [ENode .Constructor ] =
215+ val clazz = constr.owner
216+ if hasStructuralEquality(clazz) then
217+ val isPrimaryConstructor = constr.denot.isPrimaryConstructor
218+ val fieldsRaw = clazz.denot.asClass.paramAccessors.filter(isPrimaryConstructor && _.isStableMember)
219+ val constrParams = constr.paramSymss.flatten.filter(_.isTerm)
220+ val fields = constrParams.map(p => fieldsRaw.find(_.name == p.name).getOrElse(NoSymbol ))
221+ Some (ENode .Constructor (constr)(fields))
222+ else
223+ None
224+
225+ private def hasStructuralEquality (clazz : Symbol )(using Context ): Boolean =
226+ val equalsMethod = clazz.info.decls.lookup(nme.equals_)
227+ val equalsNotOverriden = ! equalsMethod.exists || equalsMethod.is(Flags .Synthetic )
228+ clazz.isClass && clazz.is(Flags .Case ) && equalsNotOverriden
229+
230+ private def isCaseClassApply (meth : Symbol )(using Context ): Boolean =
231+ meth.name == nme.apply
232+ && meth.flags.is(Flags .Synthetic )
233+ && meth.owner.linkedClass.is(Flags .Case )
234+
207235 private def canonicalize (node : ENode ): ENode =
236+ // println(s"canonicalize $node")
208237 representent(unique(
209238 node match
210239 case ENode .Atom (tp) =>
211240 node
212- case ENode .New (clazz ) =>
213- ENode . New (representent(clazz))
241+ case ENode .Constructor (sym ) =>
242+ node
214243 case ENode .Select (qual, member) =>
215- ENode . Select (representent(qual), member)
244+ normalizeSelect (representent(qual), member)
216245 case ENode .Apply (fn, args) =>
217246 ENode .Apply (representent(fn), args.map(representent))
218247 case ENode .OpApply (op, args) =>
@@ -223,6 +252,33 @@ final class EGraph(rootCtx: Context):
223252 ENode .Lambda (paramTps, retTp, representent(body))
224253 ))
225254
255+ private def normalizeSelect (qual : ENode , member : Symbol ): ENode =
256+ getAppliedConstructor(qual) match
257+ case Some (constr) =>
258+ val memberIndex = constr.fields.indexOf(member)
259+
260+ if memberIndex >= 0 then
261+ val args = getTermArguments(qual)
262+ assert(args.size == constr.fields.size)
263+ args(memberIndex)
264+ else
265+ ENode .Select (qual, member)
266+ case None =>
267+ ENode .Select (qual, member)
268+
269+ private def getAppliedConstructor (node : ENode ): Option [ENode .Constructor ] =
270+ node match
271+ case ENode .Apply (fn, args) => getAppliedConstructor(fn)
272+ case ENode .TypeApply (fn, args) => getAppliedConstructor(fn)
273+ case node : ENode .Constructor => Some (node)
274+ case _ => None
275+
276+ private def getTermArguments (node : ENode ): List [ENode ] =
277+ node match
278+ case ENode .Apply (fn, args) => getTermArguments(fn) ::: args
279+ case ENode .TypeApply (fn, args) => getTermArguments(fn)
280+ case _ => Nil
281+
226282 private def normalizeOp (op : ENode .Op , args : List [ENode ]): ENode =
227283 op match
228284 case Op .Equal =>
@@ -316,12 +372,10 @@ final class EGraph(rootCtx: Context):
316372 (a, b) match
317373 case (ENode .Atom (_ : ConstantType ), _) => (a, b)
318374 case (_, ENode .Atom (_ : ConstantType )) => (b, a)
319- case (ENode . Atom ( _ : SkolemType ) , _) => (a, b)
320- case (_, ENode . Atom ( _ : SkolemType )) => (b, a)
375+ case (_ : ENode . Constructor , _) => (a, b)
376+ case (_, _ : ENode . Constructor ) => (b, a)
321377 case (_ : ENode .Atom , _) => (a, b)
322378 case (_, _ : ENode .Atom ) => (b, a)
323- case (_ : ENode .New , _) => (a, b)
324- case (_, _ : ENode .New ) => (b, a)
325379 case (_ : ENode .Select , _) => (a, b)
326380 case (_, _ : ENode .Select ) => (b, a)
327381 case (_ : ENode .Apply , _) => (a, b)
@@ -336,8 +390,6 @@ final class EGraph(rootCtx: Context):
336390 if aRepr eq bRepr then return
337391 assert(aRepr != bRepr, s " $aRepr and $bRepr are `equals` but not `eq` " )
338392
339- // TODO(mbovel): if both nodes are objects, recursively merge their arguments
340-
341393 // / Update represententOf and usedBy maps
342394 val (newRepr, oldRepr) = order(aRepr, bRepr)
343395 represententOf(oldRepr) = newRepr
@@ -371,8 +423,9 @@ final class EGraph(rootCtx: Context):
371423 node match
372424 case ENode .Atom (tp) =>
373425 singleton(tp)
374- case ENode .New (clazz) =>
375- New (toTree(clazz, paramRefs))
426+ case ENode .Constructor (sym) =>
427+ val tycon = sym.owner.info.typeConstructor
428+ New (tycon).select(TermRef (tycon, sym))
376429 case ENode .Select (qual, member) =>
377430 toTree(qual, paramRefs).select(member)
378431 case ENode .Apply (fn, args) =>
0 commit comments