@@ -47,6 +47,7 @@ import dotty.tools.dotc.qualified_types.ENode.Op
4747import dotty .tools .dotc .reporting .trace
4848import dotty .tools .dotc .transform .TreeExtractors .BinaryOp
4949import dotty .tools .dotc .util .Spans .Span
50+ import scala .collection .mutable .ListBuffer
5051
5152final class EGraph (rootCtx : Context ):
5253
@@ -72,23 +73,23 @@ final class EGraph(rootCtx: Context):
7273 /** Map used for hash-consing nodes, keys and values are the same */
7374 private val index = mutable.Map .empty[ENode , ENode ]
7475
75- val trueNode : ENode .Atom = ENode .Atom (ConstantType (Constant (true ))(using rootCtx))
76+ final val trueNode : ENode .Atom = ENode .Atom (ConstantType (Constant (true ))(using rootCtx))
7677 index(trueNode) = trueNode
7778
78- val falseNode : ENode .Atom = ENode .Atom (ConstantType (Constant (false ))(using rootCtx))
79+ final val falseNode : ENode .Atom = ENode .Atom (ConstantType (Constant (false ))(using rootCtx))
7980 index(falseNode) = falseNode
8081
81- val minusOneIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (- 1 ))(using rootCtx))
82+ final val minusOneIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (- 1 ))(using rootCtx))
8283 index(minusOneIntNode) = minusOneIntNode
8384
84- val zeroIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (0 ))(using rootCtx))
85+ final val zeroIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (0 ))(using rootCtx))
8586 index(zeroIntNode) = zeroIntNode
8687
87- val oneIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (1 ))(using rootCtx))
88+ final val oneIntNode : ENode .Atom = ENode .Atom (ConstantType (Constant (1 ))(using rootCtx))
8889 index(oneIntNode) = oneIntNode
8990
90- val d = defn(using rootCtx) // Need a stable path to match on `defn` members
91- val builtinOps = Map (
91+ private val d = defn(using rootCtx) // Need a stable path to match on `defn` members
92+ private val builtinOps = Map (
9293 d.Int_== -> Op .Equal ,
9394 d.Boolean_== -> Op .Equal ,
9495 d.Any_== -> Op .Equal ,
@@ -137,16 +138,16 @@ final class EGraph(rootCtx: Context):
137138 }
138139 ).asInstanceOf [node.type ]
139140
140- def toNode (tree : Tree , paramSyms : List [Symbol ] = Nil , paramNodes : List [ENode .ArgRefType ] = Nil )(using
141+ def toNode (tree : Tree , paramSyms : List [Symbol ] = Nil , paramTps : List [ENode .ArgRefType ] = Nil )(using
141142 Context
142143 ): Option [ENode ] =
143144 trace(i " EGraph.toNode $tree" , Printers .qualifiedTypes):
144- computeToNode(tree, paramSyms, paramNodes ).map(node => representent(unique(node)))
145+ computeToNode(tree, paramSyms, paramTps ).map(node => representent(unique(node)))
145146
146147 private def computeToNode (
147148 tree : Tree ,
148149 paramSyms : List [Symbol ] = Nil ,
149- paramNodes : List [ENode .ArgRefType ] = Nil
150+ paramTps : List [ENode .ArgRefType ] = Nil
150151 )(using currentCtx : Context ): Option [ENode ] =
151152 trace(i " ENode.computeToNode $tree" , Printers .qualifiedTypes):
152153 def normalizeType (tp : Type ): Type =
@@ -159,48 +160,45 @@ final class EGraph(rootCtx: Context):
159160 case tp => tp
160161
161162 def mapType (tp : Type ): Type =
162- normalizeType(tp.subst(paramSyms, paramNodes ))
163+ normalizeType(tp.subst(paramSyms, paramTps ))
163164
164165 tree match
165166 case Literal (_) | Ident (_) | This (_) if tree.tpe.isInstanceOf [SingletonType ] =>
166167 Some (ENode .Atom (mapType(tree.tpe).asInstanceOf [SingletonType ]))
167168 case New (clazz) =>
168- for clazzNode <- toNode(clazz, paramSyms, paramNodes ) yield ENode .New (clazzNode)
169+ for clazzNode <- toNode(clazz, paramSyms, paramTps ) yield ENode .New (clazzNode)
169170 case Select (qual, name) =>
170- for qualNode <- toNode(qual, paramSyms, paramNodes ) yield ENode .Select (qualNode, tree.symbol)
171+ for qualNode <- toNode(qual, paramSyms, paramTps ) yield ENode .Select (qualNode, tree.symbol)
171172 case BinaryOp (lhs, op, rhs) if builtinOps.contains(op) =>
172173 for
173- lhsNode <- toNode(lhs, paramSyms, paramNodes )
174- rhsNode <- toNode(rhs, paramSyms, paramNodes )
174+ lhsNode <- toNode(lhs, paramSyms, paramTps )
175+ rhsNode <- toNode(rhs, paramSyms, paramTps )
175176 yield normalizeOp(builtinOps(op), List (lhsNode, rhsNode))
176177 case BinaryOp (lhs, d.Int_- , rhs) if lhs.tpe.isInstanceOf [ValueType ] && rhs.tpe.isInstanceOf [ValueType ] =>
177178 for
178- lhsNode <- toNode(lhs, paramSyms, paramNodes )
179- rhsNode <- toNode(rhs, paramSyms, paramNodes )
179+ lhsNode <- toNode(lhs, paramSyms, paramTps )
180+ rhsNode <- toNode(rhs, paramSyms, paramTps )
180181 yield normalizeOp(Op .IntSum , List (lhsNode, normalizeOp(Op .IntProduct , List (minusOneIntNode, rhsNode))))
181182 case Apply (fun, args) =>
182183 for
183- funNode <- toNode(fun, paramSyms, paramNodes )
184- argsNodes <- args.map(toNode(_, paramSyms, paramNodes )).sequence
184+ funNode <- toNode(fun, paramSyms, paramTps )
185+ argsNodes <- args.map(toNode(_, paramSyms, paramTps )).sequence
185186 yield ENode .Apply (funNode, argsNodes)
186187 case TypeApply (fun, args) =>
187- for funNode <- toNode(fun, paramSyms, paramNodes )
188+ for funNode <- toNode(fun, paramSyms, paramTps )
188189 yield ENode .TypeApply (funNode, args.map(tp => mapType(tp.tpe)))
189190 case closureDef(defDef) =>
190191 defDef.symbol.info.dealias match
191192 case mt : MethodType =>
192193 assert(defDef.termParamss.size == 1 , " closures have a single parameter list, right?" )
193- val params = defDef.termParamss.head
194- val myParamSyms = params.map(_.symbol)
195-
196- val myParamTps : ArrayBuffer [Type ] = ArrayBuffer .empty
197- ???
198-
199- val myRetTp = ???
200-
201- val myParamNodes = myParamTps.zipWithIndex.map((tp, i) => ENode .ArgRefType (i, tp)).toList
202-
203- for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamNodes ::: paramNodes)
194+ val myParamSyms : List [Symbol ] = defDef.termParamss.head.map(_.symbol)
195+ val myParamTps : ListBuffer [ENode .ArgRefType ] = ListBuffer .empty
196+ val paramTpsSize = paramTps.size
197+ for myParamSym <- myParamSyms do
198+ val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList))
199+ myParamTps += ENode .ArgRefType (paramTpsSize + myParamTps.size, underlying)
200+ val myRetTp = mapType(defDef.tpt.tpe.subst(myParamSyms, myParamTps.toList))
201+ for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamTps.toList ::: paramTps)
204202 yield ENode .Lambda (myParamTps.toList, myRetTp, body)
205203 case _ => None
206204 case _ =>
@@ -222,15 +220,15 @@ final class EGraph(rootCtx: Context):
222220 case ENode .TypeApply (fn, args) =>
223221 ENode .TypeApply (representent(fn), args)
224222 case ENode .Lambda (paramTps, retTp, body) =>
225-
226223 ENode .Lambda (paramTps, retTp, representent(body))
227224 ))
228225
229226 private def normalizeOp (op : ENode .Op , args : List [ENode ]): ENode =
230227 op match
231228 case Op .Equal =>
232229 assert(args.size == 2 , s " Expected 2 arguments for equality, got $args" )
233- if args(0 ) eq args(1 ) then trueNode
230+ if args(0 ) eq args(1 ) then
231+ trueNode
234232 else ENode .OpApply (op, args.sortBy(_.hashCode()))
235233 case Op .And =>
236234 assert(args.size == 2 , s " Expected 2 arguments for conjunction, got $args" )
0 commit comments