Skip to content

Commit 31aac54

Browse files
committed
Moved MatchReducer
1 parent 88334c0 commit 31aac54

File tree

2 files changed

+384
-384
lines changed

2 files changed

+384
-384
lines changed
Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
object MatchReducer:
2+
import printing.*, Texts.*
3+
enum MatchResult extends Showable:
4+
case Reduced(tp: Type)
5+
case Disjoint
6+
case ReducedAndDisjoint
7+
case Stuck
8+
case NoInstance(fails: List[(Name, TypeBounds)])
9+
10+
def toText(p: Printer): Text = this match
11+
case Reduced(tp) => "Reduced(" ~ p.toText(tp) ~ ")"
12+
case Disjoint => "Disjoint"
13+
case ReducedAndDisjoint => "ReducedAndDisjoint"
14+
case Stuck => "Stuck"
15+
case NoInstance(fails) => "NoInstance(" ~ Text(fails.map(p.toText(_) ~ p.toText(_)), ", ") ~ ")"
16+
17+
/** A type comparer for reducing match types.
18+
* TODO: Not sure this needs to be a type comparer. Can we make it a
19+
* separate class?
20+
*/
21+
class MatchReducer(initctx: Context) extends TypeComparer(initctx) {
22+
import MatchReducer.*
23+
24+
init(initctx)
25+
26+
override def matchReducer = this
27+
28+
def matchCases(scrut: Type, cases: List[MatchTypeCaseSpec])(using Context): Type = {
29+
// a reference for the type parameters poisoned during matching
30+
// for use during the reduction step
31+
var poisoned: Set[TypeParamRef] = Set.empty
32+
33+
def paramInstances(canApprox: Boolean) = new TypeAccumulator[Array[Type]]:
34+
def apply(insts: Array[Type], t: Type) = t match
35+
case param @ TypeParamRef(b, n) if b eq caseLambda =>
36+
insts(n) =
37+
if canApprox then
38+
approximation(param, fromBelow = variance >= 0, Int.MaxValue).simplified
39+
else constraint.entry(param) match
40+
case entry: TypeBounds =>
41+
val lo = fullLowerBound(param)
42+
val hi = fullUpperBound(param)
43+
if !poisoned(param) && isSubType(hi, lo) then lo.simplified else Range(lo, hi)
44+
case inst =>
45+
assert(inst.exists, i"param = $param\nconstraint = $constraint")
46+
if !poisoned(param) then inst.simplified else Range(inst, inst)
47+
insts
48+
case _ =>
49+
foldOver(insts, t)
50+
51+
def instantiateParams(insts: Array[Type]) = new ApproximatingTypeMap {
52+
variance = 0
53+
54+
override def range(lo: Type, hi: Type): Type =
55+
if variance == 0 && (lo eq hi) then
56+
// override the default `lo eq hi` test, which removes the Range
57+
// which leads to a Reduced result, instead of NoInstance
58+
Range(lower(lo), upper(hi))
59+
else super.range(lo, hi)
60+
61+
def apply(t: Type) = t match {
62+
case t @ TypeParamRef(b, n) if b `eq` caseLambda => insts(n)
63+
case t: LazyRef => apply(t.ref)
64+
case _ => mapOver(t)
65+
}
66+
}
67+
68+
def instantiateParamsSpec(insts: Array[Type], caseLambda: HKTypeLambda) = new TypeMap {
69+
variance = 0
70+
71+
def apply(t: Type) = t match {
72+
case t @ TypeParamRef(b, n) if b `eq` caseLambda => insts(n)
73+
case t: LazyRef => apply(t.ref)
74+
case _ => mapOver(t)
75+
}
76+
}
77+
78+
/** Match a single case. */
79+
def matchCase(cas: MatchTypeCaseSpec): MatchResult = trace(i"$scrut match ${MatchTypeTrace.caseText(cas)}", matchTypes, show = true) {
80+
cas match
81+
case cas: MatchTypeCaseSpec.SubTypeTest => matchSubTypeTest(cas)
82+
case cas: MatchTypeCaseSpec.SpeccedPatMat => matchSpeccedPatMat(cas)
83+
case cas: MatchTypeCaseSpec.LegacyPatMat => matchLegacyPatMat(cas)
84+
case cas: MatchTypeCaseSpec.MissingCaptures => matchMissingCaptures(cas)
85+
}
86+
87+
def matchSubTypeTest(spec: MatchTypeCaseSpec.SubTypeTest): MatchResult =
88+
val disjoint = provablyDisjoint(scrut, spec.pattern)
89+
if necessarySubType(scrut, spec.pattern) then
90+
if disjoint then
91+
MatchResult.ReducedAndDisjoint
92+
else
93+
MatchResult.Reduced(spec.body)
94+
else if disjoint then
95+
MatchResult.Disjoint
96+
else
97+
MatchResult.Stuck
98+
end matchSubTypeTest
99+
100+
// See https://docs.scala-lang.org/sips/match-types-spec.html#matching
101+
def matchSpeccedPatMat(spec: MatchTypeCaseSpec.SpeccedPatMat): MatchResult =
102+
val instances = Array.fill[Type](spec.captureCount)(NoType)
103+
val noInstances = mutable.ListBuffer.empty[(TypeName, TypeBounds)]
104+
105+
def rec(pattern: MatchTypeCasePattern, scrut: Type, variance: Int, scrutIsWidenedAbstract: Boolean): Boolean =
106+
pattern match
107+
case MatchTypeCasePattern.Capture(num, /* isWildcard = */ true) =>
108+
// instantiate the wildcard in a way that the subtype test always succeeds
109+
instances(num) = variance match
110+
case 1 => scrut.hiBound // actually important if we are not in a class type constructor
111+
case -1 => scrut.loBound
112+
case 0 => scrut
113+
!instances(num).isError
114+
115+
case MatchTypeCasePattern.Capture(num, /* isWildcard = */ false) =>
116+
def failNotSpecific(bounds: TypeBounds): TypeBounds =
117+
noInstances += spec.origMatchCase.paramNames(num) -> bounds
118+
bounds
119+
120+
instances(num) = scrut match
121+
case scrut: TypeBounds =>
122+
if scrutIsWidenedAbstract then
123+
failNotSpecific(scrut)
124+
else
125+
variance match
126+
case 1 => scrut.hi
127+
case -1 => scrut.lo
128+
case 0 => failNotSpecific(scrut)
129+
case _ =>
130+
if scrutIsWidenedAbstract && variance != 0 then
131+
// fail as not specific
132+
// the Nothing and Any bounds are used so that they are not displayed; not for themselves in particular
133+
if variance > 0 then failNotSpecific(TypeBounds(defn.NothingType, scrut))
134+
else failNotSpecific(TypeBounds(scrut, defn.AnyType))
135+
else
136+
scrut
137+
!instances(num).isError
138+
139+
case MatchTypeCasePattern.TypeTest(tpe) =>
140+
// The actual type test is handled by `scrut <:< instantiatedPat`
141+
true
142+
143+
case MatchTypeCasePattern.BaseTypeTest(classType, argPatterns, needsConcreteScrut) =>
144+
val cls = classType.classSymbol.asClass
145+
scrut.baseType(cls) match
146+
case base @ AppliedType(baseTycon, baseArgs) =>
147+
// #19445 Don't check the prefix of baseTycon here; it is handled by `scrut <:< instantiatedPat`.
148+
val innerScrutIsWidenedAbstract =
149+
scrutIsWidenedAbstract
150+
|| (needsConcreteScrut && !isConcrete(scrut)) // no point in checking concreteness if it does not need to be concrete
151+
matchArgs(argPatterns, baseArgs, classType.typeParams, innerScrutIsWidenedAbstract)
152+
case _ =>
153+
false
154+
155+
case MatchTypeCasePattern.AbstractTypeConstructor(tycon, argPatterns) =>
156+
scrut.dealias match
157+
case scrutDealias @ AppliedType(scrutTycon, args) if scrutTycon =:= tycon =>
158+
matchArgs(argPatterns, args, tycon.typeParams, scrutIsWidenedAbstract)
159+
case _ =>
160+
false
161+
162+
case MatchTypeCasePattern.CompileTimeS(argPattern) =>
163+
natValue(scrut) match
164+
case Some(scrutValue) if scrutValue > 0 =>
165+
rec(argPattern, ConstantType(Constant(scrutValue - 1)), variance, scrutIsWidenedAbstract)
166+
case _ =>
167+
false
168+
169+
case MatchTypeCasePattern.TypeMemberExtractor(typeMemberName, capture) =>
170+
/** Try to remove references to `skolem` from a type in accordance with the spec.
171+
*
172+
* References to `skolem` occuring are avoided by following aliases and
173+
* singletons.
174+
*
175+
* If any reference to `skolem` remains in the result type,
176+
* `refersToSkolem` is set to true.
177+
*/
178+
class DropSkolemMap(skolem: SkolemType) extends TypeMap:
179+
var refersToSkolem = false
180+
def apply(tp: Type): Type =
181+
if refersToSkolem then
182+
return tp
183+
tp match
184+
case `skolem` =>
185+
refersToSkolem = true
186+
tp
187+
case tp: NamedType =>
188+
val pre1 = apply(tp.prefix)
189+
if refersToSkolem then
190+
tp match
191+
case tp: TermRef => tp.info.widenExpr.dealias match
192+
case info: SingletonType =>
193+
refersToSkolem = false
194+
apply(info)
195+
case _ =>
196+
tp.derivedSelect(pre1)
197+
case tp: TypeRef => tp.info match
198+
case info: AliasingBounds =>
199+
refersToSkolem = false
200+
apply(info.alias)
201+
case _ =>
202+
tp.derivedSelect(pre1)
203+
else
204+
tp.derivedSelect(pre1)
205+
case tp: LazyRef =>
206+
// By default, TypeMap maps LazyRefs lazily. We need to
207+
// force it for `refersToSkolem` to be correctly set.
208+
apply(tp.ref)
209+
case _ =>
210+
mapOver(tp)
211+
end DropSkolemMap
212+
/** Try to remove references to `skolem` from `u` in accordance with the spec.
213+
*
214+
* If any reference to `skolem` remains in the result type, return
215+
* NoType instead.
216+
*/
217+
def dropSkolem(u: Type, skolem: SkolemType): Type =
218+
val dmap = DropSkolemMap(skolem)
219+
val res = dmap(u)
220+
if dmap.refersToSkolem then NoType else res
221+
222+
val stableScrut: SingletonType = scrut match
223+
case scrut: SingletonType => scrut
224+
case _ => SkolemType(scrut)
225+
226+
stableScrut.member(typeMemberName) match
227+
case denot: SingleDenotation if denot.exists =>
228+
val info = stableScrut match
229+
case skolem: SkolemType =>
230+
/* If it is a skolem type, we cannot have class selections nor
231+
* abstract type selections. If it is an alias, we try to remove
232+
* any reference to the skolem from the right-hand-side. If that
233+
* succeeds, we take the result, otherwise we fail as not-specific.
234+
*/
235+
236+
def adaptToTriggerNotSpecific(info: Type): Type = info match
237+
case info: TypeBounds => info
238+
case _ => RealTypeBounds(info, info)
239+
240+
denot.info match
241+
case denotInfo: AliasingBounds =>
242+
val alias = denotInfo.alias
243+
dropSkolem(alias, skolem).orElse(adaptToTriggerNotSpecific(alias))
244+
case ClassInfo(prefix, cls, _, _, _) =>
245+
// for clean error messages
246+
adaptToTriggerNotSpecific(prefix.select(cls))
247+
case denotInfo =>
248+
adaptToTriggerNotSpecific(denotInfo)
249+
250+
case _ =>
251+
// The scrutinee type is truly stable. We select the type member directly on it.
252+
stableScrut.select(typeMemberName)
253+
end info
254+
255+
rec(capture, info, variance = 0, scrutIsWidenedAbstract)
256+
257+
case _ =>
258+
// The type member was not found; no match
259+
false
260+
end rec
261+
262+
def matchArgs(argPatterns: List[MatchTypeCasePattern], args: List[Type], tparams: List[TypeParamInfo], scrutIsWidenedAbstract: Boolean): Boolean =
263+
if argPatterns.isEmpty then
264+
true
265+
else
266+
rec(argPatterns.head, args.head, tparams.head.paramVarianceSign, scrutIsWidenedAbstract)
267+
&& matchArgs(argPatterns.tail, args.tail, tparams.tail, scrutIsWidenedAbstract)
268+
269+
// This might not be needed
270+
val constrainedCaseLambda = constrained(spec.origMatchCase, ast.tpd.EmptyTree)._1.asInstanceOf[HKTypeLambda]
271+
272+
val disjoint =
273+
val defn.MatchCase(origPattern, _) = constrainedCaseLambda.resultType: @unchecked
274+
provablyDisjoint(scrut, origPattern)
275+
276+
def tryDisjoint: MatchResult =
277+
if disjoint then
278+
MatchResult.Disjoint
279+
else
280+
MatchResult.Stuck
281+
282+
if rec(spec.pattern, scrut, variance = 1, scrutIsWidenedAbstract = false) then
283+
if noInstances.nonEmpty then
284+
MatchResult.NoInstance(noInstances.toList)
285+
else
286+
val defn.MatchCase(instantiatedPat, reduced) =
287+
instantiateParamsSpec(instances, constrainedCaseLambda)(constrainedCaseLambda.resultType): @unchecked
288+
if scrut <:< instantiatedPat then
289+
if disjoint then
290+
MatchResult.ReducedAndDisjoint
291+
else
292+
MatchResult.Reduced(reduced)
293+
else
294+
tryDisjoint
295+
else
296+
tryDisjoint
297+
end matchSpeccedPatMat
298+
299+
def matchLegacyPatMat(spec: MatchTypeCaseSpec.LegacyPatMat): MatchResult =
300+
val caseLambda = constrained(spec.origMatchCase, ast.tpd.EmptyTree)._1.asInstanceOf[HKTypeLambda]
301+
this.caseLambda = caseLambda
302+
303+
val defn.MatchCase(pat, body) = caseLambda.resultType: @unchecked
304+
305+
def matches(canWidenAbstract: Boolean): Boolean =
306+
val saved = this.canWidenAbstract
307+
val savedPoisoned = this.poisoned
308+
this.canWidenAbstract = canWidenAbstract
309+
this.poisoned = Set.empty
310+
try necessarySubType(scrut, pat)
311+
finally
312+
poisoned = this.poisoned
313+
this.poisoned = savedPoisoned
314+
this.canWidenAbstract = saved
315+
316+
val disjoint = provablyDisjoint(scrut, pat)
317+
318+
def redux(canApprox: Boolean): MatchResult =
319+
val instances = paramInstances(canApprox)(Array.fill(caseLambda.paramNames.length)(NoType), pat)
320+
instantiateParams(instances)(body) match
321+
case Range(lo, hi) =>
322+
MatchResult.NoInstance {
323+
caseLambda.paramNames.zip(instances).collect {
324+
case (name, Range(lo, hi)) => (name, TypeBounds(lo, hi))
325+
}
326+
}
327+
case redux =>
328+
if disjoint then
329+
MatchResult.ReducedAndDisjoint
330+
else
331+
MatchResult.Reduced(redux)
332+
333+
if matches(canWidenAbstract = false) then
334+
redux(canApprox = true)
335+
else if matches(canWidenAbstract = true) then
336+
redux(canApprox = false)
337+
else if (disjoint)
338+
// We found a proof that `scrut` and `pat` are incompatible.
339+
// The search continues.
340+
MatchResult.Disjoint
341+
else
342+
MatchResult.Stuck
343+
end matchLegacyPatMat
344+
345+
def matchMissingCaptures(spec: MatchTypeCaseSpec.MissingCaptures): MatchResult =
346+
MatchResult.Stuck
347+
348+
def recur(remaining: List[MatchTypeCaseSpec]): Type = remaining match
349+
case (cas: MatchTypeCaseSpec.LegacyPatMat) :: _ if sourceVersion.isAtLeast(SourceVersion.`3.4`) =>
350+
val errorText = MatchTypeTrace.illegalPatternText(scrut, cas)
351+
ErrorType(reporting.MatchTypeLegacyPattern(errorText))
352+
case cas :: remaining1 =>
353+
matchCase(cas) match
354+
case MatchResult.Disjoint =>
355+
recur(remaining1)
356+
case MatchResult.Stuck =>
357+
MatchTypeTrace.stuck(scrut, cas, remaining1)
358+
NoType
359+
case MatchResult.NoInstance(fails) =>
360+
MatchTypeTrace.noInstance(scrut, cas, fails)
361+
NoType
362+
case MatchResult.Reduced(tp) =>
363+
tp.simplified
364+
case MatchResult.ReducedAndDisjoint =>
365+
// Empty types break the basic assumption that if a scrutinee and a
366+
// pattern are disjoint it's OK to reduce passed that pattern. Indeed,
367+
// empty types viewed as a set of value is always a subset of any other
368+
// types. As a result, if a scrutinee both matches a pattern and is
369+
// probably disjoint from it, we prevent reduction.
370+
// See `tests/neg/6570.scala` and `6570-1.scala` for examples that
371+
// exploit emptiness to break match type soundness.
372+
MatchTypeTrace.emptyScrutinee(scrut)
373+
NoType
374+
case Nil =>
375+
/* TODO warn ? then re-enable warn/12974.scala:26
376+
val noCasesText = MatchTypeTrace.noMatchesText(scrut, cases)
377+
report.warning(reporting.MatchTypeNoCases(noCasesText), pos = ???)
378+
*/
379+
MatchTypeTrace.noMatches(scrut, cases)
380+
NoType
381+
382+
inFrozenConstraint(recur(cases))
383+
}
384+
}

0 commit comments

Comments
 (0)