From 972abbadfa02e6591eeda7bc7ad35583514a2f80 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 4 Jun 2025 13:33:43 +0000 Subject: [PATCH 01/20] Add syntax for qualified types Co-Authored-By: Quentin Bernet <28290641+Sporarum@users.noreply.github.com> --- .../src/dotty/tools/dotc/ast/Desugar.scala | 28 +++++- compiler/src/dotty/tools/dotc/ast/untpd.scala | 15 +++ .../src/dotty/tools/dotc/config/Feature.scala | 6 ++ .../src/dotty/tools/dotc/core/StdNames.scala | 1 + .../dotty/tools/dotc/parsing/Parsers.scala | 42 +++++++- .../tools/dotc/printing/RefinedPrinter.scala | 4 + .../dotty/tools/dotc/typer/ImportInfo.scala | 2 +- library/src/scala/annotation/qualified.scala | 4 + .../runtime/stdLibPatches/language.scala | 4 + project/Build.scala | 2 + project/MiMaFilters.scala | 3 + tests/printing/qualified-types.check | 98 +++++++++++++++++++ tests/printing/qualified-types.flags | 1 + tests/printing/qualified-types.scala | 59 +++++++++++ .../stdlibExperimentalDefinitions.scala | 3 + 15 files changed, 265 insertions(+), 7 deletions(-) create mode 100644 library/src/scala/annotation/qualified.scala create mode 100644 tests/printing/qualified-types.check create mode 100644 tests/printing/qualified-types.flags create mode 100644 tests/printing/qualified-types.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 8471b06b7e97..5b28ea2b49bd 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -8,7 +8,7 @@ import Symbols.*, StdNames.*, Trees.*, ContextOps.* import Decorators.* import Annotations.Annotation import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, DefaultGetterName, WildcardParamName} -import typer.{Namer, Checking} +import typer.{Namer, Checking, ErrorReporting} import util.{Property, SourceFile, SourcePosition, SrcPos, Chars} import config.{Feature, Config} import config.Feature.{sourceVersion, migrateTo3, enabled} @@ -213,9 +213,10 @@ object desugar { def valDef(vdef0: ValDef)(using Context): Tree = val vdef @ ValDef(_, tpt, rhs) = vdef0 val valName = normalizeName(vdef, tpt).asTermName + val tpt1 = qualifiedType(tpt, valName) var mods1 = vdef.mods - val vdef1 = cpy.ValDef(vdef)(name = valName).withMods(mods1) + val vdef1 = cpy.ValDef(vdef)(name = valName, tpt = tpt1).withMods(mods1) if isSetterNeeded(vdef) then val setterParam = makeSyntheticParameter(tpt = SetterParamTree().watching(vdef)) @@ -2349,6 +2350,10 @@ object desugar { case PatDef(mods, pats, tpt, rhs) => val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt)) flatTree(pats1 map (makePatDef(tree, mods, _, rhs))) + case QualifiedTypeTree(parent, None, qualifier) => + ErrorReporting.errorTree(parent, em"missing parameter name in qualified type", tree.srcPos) + case QualifiedTypeTree(parent, Some(paramName), qualifier) => + qualifiedType(parent, paramName, qualifier, tree.span) case ext: ExtMethods => Block(List(ext), syntheticUnitLiteral.withSpan(ext.span)) case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt) @@ -2527,4 +2532,23 @@ object desugar { collect(tree) buf.toList } + + /** If `tree` is a `QualifiedTypeTree`, then desugars it using `paramName` as + * the qualified parameter name. Otherwise, returns `tree` unchanged. + */ + def qualifiedType(tree: Tree, paramName: TermName)(using Context): Tree = tree match + case QualifiedTypeTree(parent, None, qualifier) => qualifiedType(parent, paramName, qualifier, tree.span) + case _ => tree + + /** Returns the annotated type used to represent the qualified type with the + * given components: + * `parent @qualified[parent]((paramName: parent) => qualifier)`. + */ + def qualifiedType(parent: Tree, paramName: TermName, qualifier: Tree, span: Span)(using Context): Tree = + val param = makeParameter(paramName, parent, EmptyModifiers) // paramName: parent + val predicate = WildcardFunction(List(param), qualifier) // (paramName: parent) => qualifier + val qualifiedAnnot = scalaAnnotationDot(nme.qualified) + val annot = Apply(TypeApply(qualifiedAnnot, List(parent)), predicate).withSpan(span) // @qualified[parent](predicate) + Annotated(parent, annot).withSpan(span) // parent @qualified[parent](predicate) + } diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index bad70cb3a01c..a323600a49b7 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -156,6 +156,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { */ case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree + /** `{ x: parent with qualifier }` if `paramName == Some(x)`, + * `parent with qualifier` otherwise. + * + * Only relevant under `qualifiedTypes`. + */ + case class QualifiedTypeTree(parent: Tree, paramName: Option[TermName], qualifier: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree + /** A type tree appearing somewhere in the untyped DefDef of a lambda, it will be typed using `tpFun`. * * @param isResult Is this the result type of the lambda? This is handled specially in `Namer#valOrDefDefSig`. @@ -735,6 +742,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case tree: CapturesAndResult if (refs eq tree.refs) && (parent eq tree.parent) => tree case _ => finalize(tree, untpd.CapturesAndResult(refs, parent)) + def QualifiedTypeTree(tree: Tree)(parent: Tree, paramName: Option[TermName], qualifier: Tree)(using Context): Tree = tree match + case tree: QualifiedTypeTree if (parent eq tree.parent) && (paramName eq tree.paramName) && (qualifier eq tree.qualifier) => tree + case _ => finalize(tree, untpd.QualifiedTypeTree(parent, paramName, qualifier)(using tree.source)) + def TypedSplice(tree: Tree)(splice: tpd.Tree)(using Context): ProxyTree = tree match { case tree: TypedSplice if splice `eq` tree.splice => tree case _ => finalize(tree, untpd.TypedSplice(splice)(using ctx)) @@ -798,6 +809,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { cpy.MacroTree(tree)(transform(expr)) case CapturesAndResult(refs, parent) => cpy.CapturesAndResult(tree)(transform(refs), transform(parent)) + case QualifiedTypeTree(parent, paramName, qualifier) => + cpy.QualifiedTypeTree(tree)(transform(parent), paramName, transform(qualifier)) case _ => super.transformMoreCases(tree) } @@ -857,6 +870,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { this(x, expr) case CapturesAndResult(refs, parent) => this(this(x, refs), parent) + case QualifiedTypeTree(parent, paramName, qualifier) => + this(this(x, parent), qualifier) case _ => super.foldMoreCases(x, tree) } diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 5aa7ed84f72d..3c34d4213727 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -34,6 +34,7 @@ object Feature: val pureFunctions = experimental("pureFunctions") val captureChecking = experimental("captureChecking") val separationChecking = experimental("separationChecking") + val qualifiedTypes = experimental("qualifiedTypes") val into = experimental("into") val modularity = experimental("modularity") val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions") @@ -67,6 +68,7 @@ object Feature: (pureFunctions, "Enable pure functions for capture checking"), (captureChecking, "Enable experimental capture checking"), (separationChecking, "Enable experimental separation checking (requires captureChecking)"), + (qualifiedTypes, "Enable experimental qualified types"), (into, "Allow into modifier on parameter types"), (modularity, "Enable experimental modularity features"), (packageObjectValues, "Enable experimental package objects as values"), @@ -153,6 +155,10 @@ object Feature: if ctx.run != null then ctx.run.nn.ccEnabledSomewhere else enabledBySetting(captureChecking) + /** Is qualifiedTypes enabled for this compilation unit? */ + def qualifiedTypesEnabled(using Context) = + enabledBySetting(qualifiedTypes) + def sourceVersionSetting(using Context): SourceVersion = SourceVersion.valueOf(ctx.settings.source.value) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 3213be389c9d..b5ec2de8568f 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -590,6 +590,7 @@ object StdNames { val productElementName: N = "productElementName" val productIterator: N = "productIterator" val productPrefix: N = "productPrefix" + val qualified : N = "qualified" val quotes : N = "quotes" val raw_ : N = "raw" val reachCapability: N = "reachCapability" diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 147a650370ac..760cbf99d588 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -448,6 +448,13 @@ object Parsers { finally inMatchPattern = saved } + private var inQualifiedType = false + private def fromWithinQualifiedType[T](body: => T): T = + val saved = inQualifiedType + inQualifiedType = true + try body + finally inQualifiedType = saved + private var staged = StageKind.None def withinStaged[T](kind: StageKind)(op: => T): T = { val saved = staged @@ -1941,12 +1948,22 @@ object Parsers { t } - /** WithType ::= AnnotType {`with' AnnotType} (deprecated) - */ + /** With qualifiedTypes enabled: + * WithType ::= AnnotType [`with' PostfixExpr] + * + * Otherwise: + * WithType ::= AnnotType {`with' AnnotType} (deprecated) + */ def withType(): Tree = withTypeRest(annotType()) def withTypeRest(t: Tree): Tree = - if in.token == WITH then + if in.featureEnabled(Feature.qualifiedTypes) && in.token == WITH then + if inQualifiedType then t + else + in.nextToken() + val qualifier = postfixExpr() + QualifiedTypeTree(t, None, qualifier).withSpan(Span(t.span.start, qualifier.span.end)) + else if in.token == WITH then val withOffset = in.offset in.nextToken() if in.token == LBRACE || in.token == INDENT then @@ -2098,6 +2115,7 @@ object Parsers { * | ‘(’ ArgTypes ‘)’ * | ‘(’ NamesAndTypes ‘)’ * | Refinement + * | QualifiedType -- under qualifiedTypes * | TypeSplice -- deprecated syntax (since 3.0.0) * | SimpleType1 TypeArgs * | SimpleType1 `#' id @@ -2108,7 +2126,10 @@ object Parsers { makeTupleOrParens(inParensWithCommas(argTypes(namedOK = false, wildOK = true, tupleOK = true))) } else if in.token == LBRACE then - atSpan(in.offset) { RefinedTypeTree(EmptyTree, refinement(indentOK = false)) } + if in.featureEnabled(Feature.qualifiedTypes) && in.lookahead.token == IDENTIFIER then + qualifiedType() + else + atSpan(in.offset) { RefinedTypeTree(EmptyTree, refinement(indentOK = false)) } else if (isSplice) splice(isType = true) else @@ -2272,6 +2293,19 @@ object Parsers { else inBraces(refineStatSeq()) + /** QualifiedType ::= `{` Ident `:` Type `with` Block `}` + */ + def qualifiedType(): Tree = + val startOffset = in.offset + accept(LBRACE) + val id = ident() + accept(COLONfollow) + val tp = fromWithinQualifiedType(typ()) + accept(WITH) + val qualifier = block(simplify = true) + accept(RBRACE) + QualifiedTypeTree(tp, Some(id), qualifier).withSpan(Span(startOffset, qualifier.span.end)) + /** TypeBounds ::= [`>:' TypeBound ] [`<:' TypeBound ] * TypeBound ::= Type * | CaptureSet -- under captureChecking diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index b8de2c7a9115..07349f631d4f 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -843,6 +843,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { prefix ~~ idx.toString ~~ "|" ~~ tpeText ~~ "|" ~~ argsText ~~ "|" ~~ contentText ~~ postfix case CapturesAndResult(refs, parent) => changePrec(GlobalPrec)("^{" ~ Text(refs.map(toText), ", ") ~ "}" ~ toText(parent)) + case QualifiedTypeTree(parent, paramName, predicate) => + paramName match + case Some(name) => "{" ~ toText(name) ~ ": " ~ toText(parent) ~ " with " ~ toText(predicate) ~ "}" + case None => toText(parent) ~ " with " ~ toText(predicate) case ContextBoundTypeTree(tycon, pname, ownName) => toText(pname) ~ " : " ~ toText(tycon) ~ (" as " ~ toText(ownName) `provided` !ownName.isEmpty) case _ => diff --git a/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala b/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala index e8698baa46ac..1915155f5aca 100644 --- a/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala +++ b/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala @@ -206,7 +206,7 @@ class ImportInfo(symf: Context ?=> Symbol, /** Does this import clause or a preceding import clause enable `feature`? * - * @param feature a possibly quailified name, e.g. + * @param feature a possibly qualified name, e.g. * strictEquality * experimental.genericNumberLiterals * diff --git a/library/src/scala/annotation/qualified.scala b/library/src/scala/annotation/qualified.scala new file mode 100644 index 000000000000..2fae020be762 --- /dev/null +++ b/library/src/scala/annotation/qualified.scala @@ -0,0 +1,4 @@ +package scala.annotation + +/** Annotation for qualified types. */ +@experimental class qualified[T](predicate: T => Boolean) extends StaticAnnotation diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 34227259de0d..2ad830ecfd9e 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -100,6 +100,10 @@ object language: @compileTimeOnly("`separationChecking` can only be used at compile time in import statements") object separationChecking + /** Experimental support for qualified types */ + @compileTimeOnly("`qualifiedTypes` is only be used at compile time") + object qualifiedTypes + /** Experimental support for automatic conversions of arguments, without requiring * a language import `import scala.language.implicitConversions`. * diff --git a/project/Build.scala b/project/Build.scala index cddf225f3475..05af00f45a7d 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -1119,6 +1119,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/annotation/MacroAnnotation.scala"), file(s"${baseDirectory.value}/src/scala/annotation/alpha.scala"), file(s"${baseDirectory.value}/src/scala/annotation/publicInBinary.scala"), + file(s"${baseDirectory.value}/src/scala/annotation/qualified.scala"), file(s"${baseDirectory.value}/src/scala/annotation/init.scala"), file(s"${baseDirectory.value}/src/scala/annotation/unroll.scala"), file(s"${baseDirectory.value}/src/scala/annotation/targetName.scala"), @@ -1258,6 +1259,7 @@ object Build { file(s"${baseDirectory.value}/src/scala/annotation/MacroAnnotation.scala"), file(s"${baseDirectory.value}/src/scala/annotation/alpha.scala"), file(s"${baseDirectory.value}/src/scala/annotation/publicInBinary.scala"), + file(s"${baseDirectory.value}/src/scala/annotation/qualified.scala"), file(s"${baseDirectory.value}/src/scala/annotation/init.scala"), file(s"${baseDirectory.value}/src/scala/annotation/unroll.scala"), file(s"${baseDirectory.value}/src/scala/annotation/targetName.scala"), diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index 42ba23cac480..f316165e120b 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -32,6 +32,9 @@ object MiMaFilters { ProblemFilters.exclude[MissingClassProblem]("scala.caps.Classifier"), ProblemFilters.exclude[MissingClassProblem]("scala.caps.SharedCapability"), ProblemFilters.exclude[MissingClassProblem]("scala.caps.Control"), + + ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language#experimental.qualifiedTypes"), + ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$experimental$qualifiedTypes$"), ), // Additions since last LTS diff --git a/tests/printing/qualified-types.check b/tests/printing/qualified-types.check new file mode 100644 index 000000000000..11e32e759298 --- /dev/null +++ b/tests/printing/qualified-types.check @@ -0,0 +1,98 @@ +[[syntax trees at end of typer]] // tests/printing/qualified-types.scala +package example { + class Foo() extends Object() { + val x: Int @qualified[Int]((x: Int) => x > 0) = 1 + } + trait A() extends Object {} + final lazy module val qualified-types$package: example.qualified-types$package + = new example.qualified-types$package() + final module class qualified-types$package() extends Object() { + this: example.qualified-types$package.type => + type Neg = Int @qualified[Int]((x: Int) => x < 0) + type Pos = Int @qualified[Int]((x: Int) => x > 0) + type Pos2 = Int @qualified[Int]((x: Int) => x > 0) + type Pos3 = Int @qualified[Int]((x: Int) => x > 0) + type Pos4 = Int @qualified[Int]((x: Int) => x > 0) + type Pos5 = + Int @qualified[Int]((x: Int) => + { + val res: Boolean = x > 0 + res:Boolean + } + ) + type Nested = + Int @qualified[Int]((x: Int) => + { + val y: Int @qualified[Int]((z: Int) => z > 0) = ??? + x > y + } + ) + type Intersection = Int & Int @qualified[Int]((x: Int) => x > 0) + type ValRefinement = + Object + { + val x: Int @qualified[Int]((x: Int) => x > 0) + } + def id[T >: Nothing <: Any](x: T): T = x + def test(): Unit = + { + val x: example.Pos = 1 + val x2: Int @qualified[Int]((x: Int) => x > 0) = 1 + val x3: Int @qualified[Int]((x: Int) => x > 0) = 1 + val x4: Int @qualified[Int]((x: Int) => x > 0) = 1 + val x5: Int @qualified[Int]((x5: Int) => x > 0) = 1 + val x6: Int = + example.id[Int @qualified[Int]((x: Int) => x < 0)](1) + + example.id[example.Neg](-1) + () + } + def bar(x: Int @qualified[Int]((x: Int) => x > 0)): Nothing = ??? + def secondGreater1(x: Int, y: Int)(z: Int @qualified[Int]((w: Int) => x > y) + ): Nothing = ??? + def secondGreater2(x: Int, y: Int)(z: Int @qualified[Int]((z: Int) => x > y) + ): Nothing = ??? + final lazy module given val given_A: example.given_A = new example.given_A() + final module class given_A() extends Object(), example.A { + this: example.given_A.type => + val b: Boolean = false + example.id[Boolean](true) + } + type T1 = + Object + { + val x: Int + } + type T2 = + Object + { + val x: Int + } + type T3 = + Object + { + type T = Int + } + type T4 = + Object + { + def x: Int + } + type T5 = + Object + { + def x: Int + def x_=(x$1: Int): _root_.scala.Unit + } + type T6 = + Object + { + val x: Int + } + type T7 = + Object + { + val x: Int + } + } +} + diff --git a/tests/printing/qualified-types.flags b/tests/printing/qualified-types.flags new file mode 100644 index 000000000000..0a6f384436cd --- /dev/null +++ b/tests/printing/qualified-types.flags @@ -0,0 +1 @@ +-language:experimental.qualifiedTypes diff --git a/tests/printing/qualified-types.scala b/tests/printing/qualified-types.scala new file mode 100644 index 000000000000..f09f260d4962 --- /dev/null +++ b/tests/printing/qualified-types.scala @@ -0,0 +1,59 @@ +package example + +type Neg = {x: Int with x < 0} +type Pos = {x: Int with x > 0} +type Pos2 = {x: Int + with x > 0 +} +type Pos3 = {x: Int with + x > 0 +} +type Pos4 = + {x: Int with x > 0} +type Pos5 = {x: Int with + val res = x > 0 + res +} + +type Nested = {x: Int with { val y: {z: Int with z > 0} = ??? ; x > y }} +type Intersection = Int & {x: Int with x > 0} +type ValRefinement = {val x: Int with x > 0} + +def id[T](x: T): T = x + +def test() = + val x: Pos = 1 + val x2: {x: Int with x > 0} = 1 + val x3: { + x: Int with x > 0 + } = 1 + val x4: {x: Int with + x > 0 + } = 1 + val x5: Int with x > 0 = 1 + val x6: Int = id[{x: Int with x < 0}](1) + id[Neg](-1) + +def bar(x: Int with x > 0) = ??? +def secondGreater1(x: Int, y: Int)(z: {w: Int with x > y}) = ??? +def secondGreater2(x: Int, y: Int)(z: Int with x > y) = ??? + +class Foo: + val x: Int with x > 0 = 1 + +trait A +// Not a qualified type: +given A with + val b = false + id(true) + +// Also not qualified types: +type T1 = {val x: Int} +type T2 = { + val x: Int +} +type T3 = {type T = Int} +type T4 = {def x: Int} +type T5 = {var x: Int} +type T6 = Object {val x: Int} +type T7 = Object: + val x: Int diff --git a/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala b/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala index c249721f6a6d..45736281f57f 100644 --- a/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala +++ b/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala @@ -49,6 +49,9 @@ val experimentalDefinitionInLibrary = Set( "scala.caps.package$package$.Exclusive", "scala.caps.package$package$.Shared", + // Experimental feature: qualified types + "scala.annotation.qualified", + //// New feature: Macro annotations "scala.annotation.MacroAnnotation", From 9d80b83f2e9e8e0567319629a92918d3eb39f130 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 4 Jun 2025 14:51:05 +0000 Subject: [PATCH 02/20] Allow nested qualified types with implicit argument name --- .../src/dotty/tools/dotc/ast/Desugar.scala | 31 ++++++++++++---- tests/printing/qualified-types.check | 36 ++++++++++++++++++- tests/printing/qualified-types.scala | 16 ++++++++- 3 files changed, 75 insertions(+), 8 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 5b28ea2b49bd..e26ceda30ba6 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -213,7 +213,7 @@ object desugar { def valDef(vdef0: ValDef)(using Context): Tree = val vdef @ ValDef(_, tpt, rhs) = vdef0 val valName = normalizeName(vdef, tpt).asTermName - val tpt1 = qualifiedType(tpt, valName) + val tpt1 = desugarQualifiedTypes(tpt, valName) var mods1 = vdef.mods val vdef1 = cpy.ValDef(vdef)(name = valName, tpt = tpt1).withMods(mods1) @@ -2533,12 +2533,31 @@ object desugar { buf.toList } - /** If `tree` is a `QualifiedTypeTree`, then desugars it using `paramName` as - * the qualified parameter name. Otherwise, returns `tree` unchanged. + /** Desugar subtrees that are `QualifiedTypeTree`s using `outerParamName` as + * the qualified parameter name. */ - def qualifiedType(tree: Tree, paramName: TermName)(using Context): Tree = tree match - case QualifiedTypeTree(parent, None, qualifier) => qualifiedType(parent, paramName, qualifier, tree.span) - case _ => tree + private def desugarQualifiedTypes(tpt: Tree, outerParamName: TermName)(using Context): Tree = + def transform(tree: Tree): Tree = + tree match + case QualifiedTypeTree(parent, None, qualifier) => + qualifiedType(transform(parent), outerParamName, qualifier, tree.span) + case QualifiedTypeTree(parent, paramName, qualifier) => + cpy.QualifiedTypeTree(tree)(transform(parent), paramName, qualifier) + case TypeApply(fn, args) => + cpy.TypeApply(tree)(transform(fn), args) + case AppliedTypeTree(fn, args) => + cpy.AppliedTypeTree(tree)(transform(fn), args) + case InfixOp(left, op, right) => + cpy.InfixOp(tree)(transform(left), op, transform(right)) + case Parens(arg) => + cpy.Parens(tree)(transform(arg)) + case _ => + tree + + if Feature.qualifiedTypesEnabled then + transform(tpt) + else + tpt /** Returns the annotated type used to represent the qualified type with the * given components: diff --git a/tests/printing/qualified-types.check b/tests/printing/qualified-types.check index 11e32e759298..6cdfb1a2cdfa 100644 --- a/tests/printing/qualified-types.check +++ b/tests/printing/qualified-types.check @@ -42,10 +42,44 @@ package example { val x4: Int @qualified[Int]((x: Int) => x > 0) = 1 val x5: Int @qualified[Int]((x5: Int) => x > 0) = 1 val x6: Int = - example.id[Int @qualified[Int]((x: Int) => x < 0)](1) + + example.id[Int @qualified[Int]((x: Int) => x < 0)](-1) + example.id[example.Neg](-1) () } + def implicitArgumentName(): Unit = + { + val x0: + Int @qualified[Int]((x0: Int) => x0 > 0) | + String @qualified[String]((x0: String) => x0 == "foo") + = ??? + val x1: Int @qualified[Int]((x1: Int) => x1 > 0) = ??? + val x2: Int @qualified[Int]((x2: Int) => x2 > 0) = ??? + val x3: + Int @qualified[Int]((x3: Int) => x3 > 0) & + Int @qualified[Int]((x3: Int) => x3 < 10) + = ??? + val x4: + Int @qualified[Int]((x4: Int) => x4 > 0) & + Int @qualified[Int]((x4: Int) => x4 < 10) + = ??? + val x5: Int & String @qualified[String]((x5: String) => false) = ??? + val x6: + (Int @qualified[Int]((x6: Int) => x6 > 0) & Int) @qualified[ + Int @qualified[Int]((x6: Int) => x6 > 0) & Int](( + x6: Int @qualified[Int]((x6: Int) => x6 > 0) & Int) => x5 < 10) + = ??? + val x7: + Int @qualified[Int]((x7: Int) => x7 > 0) @qualified[ + Int @qualified[Int]((x7: Int) => x7 > 0)](( + x7: Int @qualified[Int]((x7: Int) => x7 > 0)) => x6 < 10) + = ??? + val x8: + Int @qualified[Int]((x8: Int) => x8 > 0) @qualified[ + Int @qualified[Int]((x8: Int) => x8 > 0)](( + x8: Int @qualified[Int]((x8: Int) => x8 > 0)) => x7 < 10) + = ??? + () + } def bar(x: Int @qualified[Int]((x: Int) => x > 0)): Nothing = ??? def secondGreater1(x: Int, y: Int)(z: Int @qualified[Int]((w: Int) => x > y) ): Nothing = ??? diff --git a/tests/printing/qualified-types.scala b/tests/printing/qualified-types.scala index f09f260d4962..4dc28463612e 100644 --- a/tests/printing/qualified-types.scala +++ b/tests/printing/qualified-types.scala @@ -31,7 +31,21 @@ def test() = x > 0 } = 1 val x5: Int with x > 0 = 1 - val x6: Int = id[{x: Int with x < 0}](1) + id[Neg](-1) + val x6: Int = id[{x: Int with x < 0}](-1) + id[Neg](-1) + +// `val x: Int with x > 0` is desugared to `val x: {x: Int with x > 0}`: if the +// name of a qualifier argument is not specified, it is assumed to be the same +// as the parent `val` definition. +def implicitArgumentName() = + val x0: (Int with x0 > 0) | (String with x0 == "foo") = ??? + val x1: Int with x1 > 0 = ??? + val x2: (Int with x2 > 0) = ??? + val x3: (Int with x3 > 0) & (Int with x3 < 10) = ??? + val x4: (Int with x4 > 0) & Int with x4 < 10 = ??? + val x5: Int & String with false = ??? + val x6: ((Int with x6 > 0) & Int) with x5 < 10 = ??? + val x7: (Int with x7 > 0) with x6 < 10 = ??? + val x8: ((Int with x8 > 0) with x7 < 10) = ??? def bar(x: Int with x > 0) = ??? def secondGreater1(x: Int, y: Int)(z: {w: Int with x > y}) = ??? From 23266d9f1ef07d28cd27c0ea3dac996adcb0e244 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 4 Jun 2025 15:51:39 +0000 Subject: [PATCH 03/20] Allow qualified types without argument name --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 6 ++---- tests/printing/qualified-types.check | 1 + tests/printing/qualified-types.scala | 2 ++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index e26ceda30ba6..fadde0376e12 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -2350,10 +2350,8 @@ object desugar { case PatDef(mods, pats, tpt, rhs) => val pats1 = if (tpt.isEmpty) pats else pats map (Typed(_, tpt)) flatTree(pats1 map (makePatDef(tree, mods, _, rhs))) - case QualifiedTypeTree(parent, None, qualifier) => - ErrorReporting.errorTree(parent, em"missing parameter name in qualified type", tree.srcPos) - case QualifiedTypeTree(parent, Some(paramName), qualifier) => - qualifiedType(parent, paramName, qualifier, tree.span) + case QualifiedTypeTree(parent, paramName, qualifier) => + qualifiedType(parent, paramName.getOrElse(nme.WILDCARD), qualifier, tree.span) case ext: ExtMethods => Block(List(ext), syntheticUnitLiteral.withSpan(ext.span)) case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt) diff --git a/tests/printing/qualified-types.check b/tests/printing/qualified-types.check index 6cdfb1a2cdfa..67a55918b31a 100644 --- a/tests/printing/qualified-types.check +++ b/tests/printing/qualified-types.check @@ -20,6 +20,7 @@ package example { res:Boolean } ) + type UninhabitedInt = Int @qualified[Int]((_: Int) => false) type Nested = Int @qualified[Int]((x: Int) => { diff --git a/tests/printing/qualified-types.scala b/tests/printing/qualified-types.scala index 4dc28463612e..e1b46b43eab9 100644 --- a/tests/printing/qualified-types.scala +++ b/tests/printing/qualified-types.scala @@ -15,6 +15,8 @@ type Pos5 = {x: Int with res } +type UninhabitedInt = Int with false + type Nested = {x: Int with { val y: {z: Int with z > 0} = ??? ; x > y }} type Intersection = Int & {x: Int with x > 0} type ValRefinement = {val x: Int with x > 0} From 8e2ab02cf8e95079cd81c1b3ffa4a09ba0613f2e Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 4 Jun 2025 16:41:48 +0000 Subject: [PATCH 04/20] Allow qualified types with implicit argument name in patterns --- .../src/dotty/tools/dotc/ast/Desugar.scala | 16 ++++++++++++ .../src/dotty/tools/dotc/typer/Typer.scala | 3 ++- tests/printing/qualified-types.check | 26 ++++++++++++++++++- tests/printing/qualified-types.scala | 11 ++++++++ 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index fadde0376e12..b34867918397 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -233,6 +233,14 @@ object desugar { else vdef1 end valDef + def caseDef(cdef: CaseDef)(using Context): CaseDef = + if Feature.qualifiedTypesEnabled then + val CaseDef(pat, guard, body) = cdef + val pat1 = DesugarQualifiedTypesInPatternMap().transform(pat) + cpy.CaseDef(cdef)(pat1, guard, body) + else + cdef + def mapParamss(paramss: List[ParamClause]) (mapTypeParam: TypeDef => TypeDef) (mapTermParam: ValDef => ValDef)(using Context): List[ParamClause] = @@ -2557,6 +2565,14 @@ object desugar { else tpt + private class DesugarQualifiedTypesInPatternMap extends UntypedTreeMap: + override def transform(tree: Tree)(using Context): Tree = + tree match + case Typed(ident @ Ident(name: TermName), tpt) => + cpy.Typed(tree)(ident, desugarQualifiedTypes(tpt, name)) + case _ => + super.transform(tree) + /** Returns the annotated type used to represent the qualified type with the * given components: * `parent @qualified[parent]((paramName: parent) => qualifier)`. diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 7432a0900ac5..2868dc45a193 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2343,9 +2343,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer } /** Type a case. */ - def typedCase(tree: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(using Context): CaseDef = { + def typedCase(tree0: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(using Context): CaseDef = { val originalCtx = ctx val gadtCtx: Context = ctx.fresh.setFreshGADTBounds + val tree = desugar.caseDef(tree0) def caseRest(pat: Tree)(using Context) = { val pt1 = instantiateMatchTypeProto(pat, pt) match { diff --git a/tests/printing/qualified-types.check b/tests/printing/qualified-types.check index 67a55918b31a..080a7cd2cbd7 100644 --- a/tests/printing/qualified-types.check +++ b/tests/printing/qualified-types.check @@ -79,7 +79,31 @@ package example { Int @qualified[Int]((x8: Int) => x8 > 0)](( x8: Int @qualified[Int]((x8: Int) => x8 > 0)) => x7 < 10) = ??? - () + val x9: Any = 42 + x9 match + { + case y @ _:Int @qualified[Int]((y: Int) => y > 0) => + println( + _root_.scala.StringContext.apply([""," is positive" : String]*). + s([y : Any]*) + ) + case _ => + () + } + Tuple2.apply[Int, Int](42, 42) match + { + case + Tuple2.unapply[Int, Int]( + y @ _:Int @qualified[Int]((y: Int) => y > 0), + z @ _:Int @qualified[Int]((z: Int) => z > 0)) + => + println( + _root_.scala.StringContext.apply( + [""," and "," are both positive" : String]*).s([y,z : Any]*) + ) + case _ => + () + } } def bar(x: Int @qualified[Int]((x: Int) => x > 0)): Nothing = ??? def secondGreater1(x: Int, y: Int)(z: Int @qualified[Int]((w: Int) => x > y) diff --git a/tests/printing/qualified-types.scala b/tests/printing/qualified-types.scala index e1b46b43eab9..bfa4dad9d0b2 100644 --- a/tests/printing/qualified-types.scala +++ b/tests/printing/qualified-types.scala @@ -49,6 +49,17 @@ def implicitArgumentName() = val x7: (Int with x7 > 0) with x6 < 10 = ??? val x8: ((Int with x8 > 0) with x7 < 10) = ??? + val x9: Any = 42 + x9 match + case y: Int with y > 0 => + println(s"$y is positive") + case _ => () + + (42, 42) match + case (y: Int with y > 0, z: Int with z > 0) => + println(s"$y and $z are both positive") + case _ => () + def bar(x: Int with x > 0) = ??? def secondGreater1(x: Int, y: Int)(z: {w: Int with x > y}) = ??? def secondGreater2(x: Int, y: Int)(z: Int with x > y) = ??? From 77ac0939faad98e8d6999cc3d8bba17f9364c00c Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 4 Jun 2025 17:31:50 +0000 Subject: [PATCH 05/20] Lower precedence of qualified types' `with` --- .../dotty/tools/dotc/parsing/Parsers.scala | 52 ++++++++++++------- tests/printing/qualified-types.check | 9 ++-- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 760cbf99d588..246a20f545b8 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1691,6 +1691,7 @@ object Parsers { * | TypTypeParamClause ‘=>>’ Type * | FunParamClause ‘=>>’ Type * | MatchType + * | QualifiedType2 -- under qualifiedTypes * | InfixType * FunType ::= (MonoFunType | PolyFunType) * MonoFunType ::= FunTypeArgs (‘=>’ | ‘?=>’) Type @@ -1701,6 +1702,11 @@ object Parsers { * | `(' [ FunArgType {`,' FunArgType } ] `)' * | '(' [ TypedFunParam {',' TypedFunParam } ')' * MatchType ::= InfixType `match` <<< TypeCaseClauses >>> + * QualifiedType2 ::= InfixType `with` PostfixExprf + * IntoType ::= [‘into’] IntoTargetType + * | ‘( IntoType ‘)’ + * IntoTargetType ::= Type + * | FunTypeArgs (‘=>’ | ‘?=>’) IntoType */ def typ(inContextBound: Boolean = false): Tree = val start = in.offset @@ -1760,6 +1766,8 @@ object Parsers { functionRest(t :: Nil) case MATCH => matchType(t) + case WITH if in.featureEnabled(Feature.qualifiedTypes) => + qualifiedTypeShort(t) case FORSOME => syntaxError(ExistentialTypesNoLongerSupported()) t @@ -1894,6 +1902,7 @@ object Parsers { def funParamClauses(): List[List[ValDef]] = if in.token == LPAREN then funParamClause() :: funParamClauses() else Nil + /** InfixType ::= RefinedType {id [nl] RefinedType} * | RefinedType `^` -- under captureChecking */ @@ -1948,22 +1957,12 @@ object Parsers { t } - /** With qualifiedTypes enabled: - * WithType ::= AnnotType [`with' PostfixExpr] - * - * Otherwise: - * WithType ::= AnnotType {`with' AnnotType} (deprecated) - */ + /** WithType ::= AnnotType {`with' AnnotType} (deprecated) + */ def withType(): Tree = withTypeRest(annotType()) def withTypeRest(t: Tree): Tree = - if in.featureEnabled(Feature.qualifiedTypes) && in.token == WITH then - if inQualifiedType then t - else - in.nextToken() - val qualifier = postfixExpr() - QualifiedTypeTree(t, None, qualifier).withSpan(Span(t.span.start, qualifier.span.end)) - else if in.token == WITH then + if in.token == WITH && !in.featureEnabled(Feature.qualifiedTypes) then val withOffset = in.offset in.nextToken() if in.token == LBRACE || in.token == INDENT then @@ -2306,6 +2305,17 @@ object Parsers { accept(RBRACE) QualifiedTypeTree(tp, Some(id), qualifier).withSpan(Span(startOffset, qualifier.span.end)) + /** `with` PostfixExpr + */ + def qualifiedTypeShort(t: Tree): Tree = + if inQualifiedType then + t + else + accept(WITH) + val qualifier = postfixExpr() + QualifiedTypeTree(t, None, qualifier).withSpan(Span(t.span.start, qualifier.span.end)) + + /** TypeBounds ::= [`>:' TypeBound ] [`<:' TypeBound ] * TypeBound ::= Type * | CaptureSet -- under captureChecking @@ -2382,7 +2392,12 @@ object Parsers { def typeDependingOn(location: Location): Tree = if location.inParens then typ() - else if location.inPattern then rejectWildcardType(refinedType()) + else if location.inPattern then + val t = rejectWildcardType(refinedType()) + if in.featureEnabled(Feature.qualifiedTypes) && in.token == WITH then + qualifiedTypeShort(t) + else + t else infixType() /* ----------- EXPRESSIONS ------------------------------------------------ */ @@ -3300,10 +3315,11 @@ object Parsers { if (isIdent(nme.raw.BAR)) { in.nextToken(); pattern1(location) :: patternAlts(location) } else Nil - /** Pattern1 ::= PatVar `:` RefinedType - * | [‘-’] integerLiteral `:` RefinedType - * | [‘-’] floatingPointLiteral `:` RefinedType - * | Pattern2 + /** Pattern1 ::= PatVar `:` QualifiedType3 + * | [‘-’] integerLiteral `:` QualifiedType3 + * | [‘-’] floatingPointLiteral `:` QualifiedType3 + * | Pattern2 + * QualifiedType3 ::= RefinedType [`with` PostfixExpr] */ def pattern1(location: Location = Location.InPattern): Tree = val p = pattern2(location) diff --git a/tests/printing/qualified-types.check b/tests/printing/qualified-types.check index 080a7cd2cbd7..29ac7b28ea21 100644 --- a/tests/printing/qualified-types.check +++ b/tests/printing/qualified-types.check @@ -60,10 +60,13 @@ package example { Int @qualified[Int]((x3: Int) => x3 < 10) = ??? val x4: - Int @qualified[Int]((x4: Int) => x4 > 0) & - Int @qualified[Int]((x4: Int) => x4 < 10) + (Int @qualified[Int]((x4: Int) => x4 > 0) & Int) @qualified[ + Int @qualified[Int]((x4: Int) => x4 > 0) & Int](( + x4: Int @qualified[Int]((x4: Int) => x4 > 0) & Int) => x4 < 10) = ??? - val x5: Int & String @qualified[String]((x5: String) => false) = ??? + val x5: + (Int & String) @qualified[Int & String]((x5: Int & String) => false) + = ??? val x6: (Int @qualified[Int]((x6: Int) => x6 > 0) & Int) @qualified[ Int @qualified[Int]((x6: Int) => x6 > 0) & Int](( From 53c9685a2f6aaae57828611c457c0ea653098b46 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 11 Jun 2025 16:09:27 +0000 Subject: [PATCH 06/20] Basic subtyping and solver for qualified types --- compiler/src/dotty/tools/dotc/ast/tpd.scala | 13 ++ .../dotty/tools/dotc/config/Printers.scala | 1 + .../dotty/tools/dotc/core/Definitions.scala | 1 + .../dotty/tools/dotc/core/TypeComparer.scala | 3 + .../src/dotty/tools/dotc/core/Types.scala | 17 ++- .../dotc/qualified_types/QualifiedType.scala | 38 +++++ .../dotc/qualified_types/QualifiedTypes.scala | 103 +++++++++++++ .../qualified_types/QualifierComparer.scala | 102 +++++++++++++ .../qualified_types/QualifierEvaluator.scala | 109 +++++++++++++ .../qualified_types/QualifierNormalizer.scala | 143 ++++++++++++++++++ .../qualified_types/QualifierSolver.scala | 82 ++++++++++ .../tools/dotc/transform/BetaReduce.scala | 12 +- .../dotty/tools/dotc/transform/Erasure.scala | 2 +- .../tools/dotc/transform/InlinePatterns.scala | 7 +- .../dotty/tools/dotc/typer/ConstFold.scala | 11 -- .../src/dotty/tools/dotc/typer/Typer.scala | 8 +- .../dotty/tools/dotc/CompilationTests.scala | 2 + library/src/scala/annotation/qualified.scala | 2 +- .../qualified-types/adapt_neg.scala | 22 +++ .../subtyping_singletons_neg.scala | 8 + .../subtyping_unfolding_neg.scala | 18 +++ .../qualified-types/syntax_unnamed_neg.scala | 7 + .../qualified-types/adapt.scala | 22 +++ .../qualified-types/avoidance.scala | 7 + .../qualified-types/sized_lists.scala | 18 +++ .../subtyping_normalization.scala | 29 ++++ .../subtyping_singletons.scala | 10 ++ .../qualified-types/subtyping_unfolding.scala | 54 +++++++ .../qualified-types/syntax_basics.scala | 55 +++++++ 29 files changed, 878 insertions(+), 28 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala create mode 100644 tests/neg-custom-args/qualified-types/adapt_neg.scala create mode 100644 tests/neg-custom-args/qualified-types/subtyping_singletons_neg.scala create mode 100644 tests/neg-custom-args/qualified-types/subtyping_unfolding_neg.scala create mode 100644 tests/neg-custom-args/qualified-types/syntax_unnamed_neg.scala create mode 100644 tests/pos-custom-args/qualified-types/adapt.scala create mode 100644 tests/pos-custom-args/qualified-types/avoidance.scala create mode 100644 tests/pos-custom-args/qualified-types/sized_lists.scala create mode 100644 tests/pos-custom-args/qualified-types/subtyping_normalization.scala create mode 100644 tests/pos-custom-args/qualified-types/subtyping_singletons.scala create mode 100644 tests/pos-custom-args/qualified-types/subtyping_unfolding.scala create mode 100644 tests/pos-custom-args/qualified-types/syntax_basics.scala diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 909387bbb809..380605d75acb 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1439,6 +1439,19 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def unapply(ts: List[Tree]): Option[List[Tree]] = if ts.nonEmpty && ts.head.isType then Some(ts) else None + + /** An extractor for trees that are constant values. */ + object ConstantTree: + def unapply(tree: Tree)(using Context): Option[Constant] = + tree match + case Inlined(_, Nil, expr) => unapply(expr) + case Typed(expr, _) => unapply(expr) + case Literal(c) if c.tag == Constants.NullTag => Some(c) + case _ => + tree.tpe.widenTermRefExpr.normalized.simplified match + case ConstantType(c) => Some(c) + case _ => None + /** Split argument clauses into a leading type argument clause if it exists and * remaining clauses */ diff --git a/compiler/src/dotty/tools/dotc/config/Printers.scala b/compiler/src/dotty/tools/dotc/config/Printers.scala index 4c66e1cdf833..2ace7b1f402a 100644 --- a/compiler/src/dotty/tools/dotc/config/Printers.scala +++ b/compiler/src/dotty/tools/dotc/config/Printers.scala @@ -51,6 +51,7 @@ object Printers { val overload = noPrinter val patmatch = noPrinter val pickling = noPrinter + val qualifiedTypes = noPrinter val quotePickling = noPrinter val plugins = noPrinter val recheckr = noPrinter diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 372c6994e655..088c34fda66e 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1052,6 +1052,7 @@ class Definitions { @tu lazy val DeprecatedAnnot: ClassSymbol = requiredClass("scala.deprecated") @tu lazy val DeprecatedOverridingAnnot: ClassSymbol = requiredClass("scala.deprecatedOverriding") @tu lazy val DeprecatedInheritanceAnnot: ClassSymbol = requiredClass("scala.deprecatedInheritance") + @tu lazy val QualifiedAnnot: ClassSymbol = requiredClass("scala.annotation.qualified") @tu lazy val ImplicitAmbiguousAnnot: ClassSymbol = requiredClass("scala.annotation.implicitAmbiguous") @tu lazy val ImplicitNotFoundAnnot: ClassSymbol = requiredClass("scala.annotation.implicitNotFound") @tu lazy val InferredDepFunAnnot: ClassSymbol = requiredClass("scala.caps.internal.inferredDepFun") diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 77fdc24a01cc..91d6c6f6de19 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -28,6 +28,7 @@ import NameKinds.WildcardParamName import MatchTypes.isConcrete import reporting.Message.Note import scala.util.boundary, boundary.break +import qualified_types.{QualifiedType, QualifiedTypes} /** Provides methods to compare types. */ @@ -887,6 +888,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling println(i"assertion failed while compare captured $tp1 <:< $tp2") throw ex compareCapturing || fourthTry + case QualifiedType(parent2, qualifier2) => + QualifiedTypes.typeImplies(tp1, qualifier2) && recur(tp1, parent2) case tp2: AnnotatedType if tp2.isRefining => (tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) && recur(tp1, tp2.parent) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 6bf19c5a27a1..6636da640485 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -44,6 +44,7 @@ import CaptureSet.IdentityCaptRefMap import Capabilities.* import transform.Recheck.currentRechecker +import qualified_types.QualifiedType import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -58,7 +59,7 @@ object Types extends TypeUtils { * The principal subclasses and sub-objects are as follows: * * ```none - * Type -+- ProxyType --+- NamedType ----+--- TypeRef + * Type -+- TypeProxy --+- NamedType ----+--- TypeRef * | | \ * | +- SingletonType-+-+- TermRef * | | | @@ -193,9 +194,10 @@ object Types extends TypeUtils { /** Is this type a (possibly refined, applied, aliased or annotated) type reference * to the given type symbol? - * @sym The symbol to compare to. It must be a class symbol or abstract type. + * @param sym The symbol to compare to. It must be a class symbol or abstract type. * It makes no sense for it to be an alias type because isRef would always * return false in that case. + * @param skipRefined If true, skip refinements, annotated types and applied types. */ def isRef(sym: Symbol, skipRefined: Boolean = true)(using Context): Boolean = this match { case this1: TypeRef => @@ -213,7 +215,7 @@ object Types extends TypeUtils { else this1.underlying.isRef(sym, skipRefined) case this1: TypeVar => this1.instanceOpt.isRef(sym, skipRefined) - case this1: AnnotatedType => + case this1: AnnotatedType if (!this1.isRefining || skipRefined) => this1.parent.isRef(sym, skipRefined) case _ => false } @@ -1616,6 +1618,7 @@ object Types extends TypeUtils { def apply(tp: Type) = /*trace(i"deskolemize($tp) at $variance", show = true)*/ tp match { case tp: SkolemType => range(defn.NothingType, atVariance(1)(apply(tp.info))) + case QualifiedType(_, _) => tp case _ => mapOver(tp) } } @@ -2150,7 +2153,7 @@ object Types extends TypeUtils { /** Is `this` isomorphic to `that`, assuming pairs of matching binders `bs`? * It is assumed that `this.ne(that)`. */ - protected def iso(that: Any, bs: BinderPairs): Boolean = this.equals(that) + def iso(that: Any, bs: BinderPairs): Boolean = this.equals(that) /** Equality used for hash-consing; uses `eq` on all recursive invocations, * except where a BindingType is involved. The latter demand a deep isomorphism check. @@ -3595,7 +3598,7 @@ object Types extends TypeUtils { case _ => false } - override protected def iso(that: Any, bs: BinderPairs) = that match + override def iso(that: Any, bs: BinderPairs) = that match case that: AndType => tp1.equals(that.tp1, bs) && tp2.equals(that.tp2, bs) case _ => false } @@ -3749,7 +3752,7 @@ object Types extends TypeUtils { case _ => false } - override protected def iso(that: Any, bs: BinderPairs) = that match + override def iso(that: Any, bs: BinderPairs) = that match case that: OrType => tp1.equals(that.tp1, bs) && tp2.equals(that.tp2, bs) && isSoft == that.isSoft case _ => false } @@ -5081,7 +5084,7 @@ object Types extends TypeUtils { * anymore, or NoType if the variable can still be further constrained or a provisional * instance type in the constraint can be retracted. */ - private[core] def permanentInst = inst + def permanentInst = inst private[core] def setPermanentInst(tp: Type): Unit = inst = tp if tp.exists && owningState != null then diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala new file mode 100644 index 000000000000..71eec80753fb --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala @@ -0,0 +1,38 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Annotations.Annotation +import dotty.tools.dotc.core.Contexts.{ctx, Context} +import dotty.tools.dotc.core.Types.{AnnotatedType, Type} + +/** A qualified type is internally represented as a type annotated with a + * `@qualified` annotation. + */ +object QualifiedType: + /** Extractor for qualified types. + * + * @param tp + * the type to deconstruct + * @return + * a pair containing the parent type and the qualifier tree (a lambda) on + * success, [[None]] otherwise + */ + def unapply(tp: Type)(using Context): Option[(Type, tpd.Tree)] = + tp match + case AnnotatedType(parent, annot) if annot.symbol == ctx.definitions.QualifiedAnnot => + Some((parent, annot.argument(0).get)) + case _ => + None + + /** Factory method to create a qualified type. + * + * @param parent + * the parent type + * @param qualifier + * the qualifier tree (a lambda) + * @return + * a qualified type + */ + def apply(parent: Type, qualifier: tpd.Tree)(using Context): Type = + val annotTp = ctx.definitions.QualifiedAnnot.typeRef.appliedTo(parent) + AnnotatedType(parent, Annotation(tpd.New(annotTp, List(qualifier)))) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala new file mode 100644 index 000000000000..5ca445bfc303 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala @@ -0,0 +1,103 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.{ + Apply, + Block, + EmptyTree, + Ident, + If, + Lambda, + Literal, + New, + Select, + SeqLiteral, + This, + Throw, + Tree, + TypeApply, + Typed, + given +} +import dotty.tools.dotc.core.Atoms +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.{ctx, Context} +import dotty.tools.dotc.core.Decorators.{i, em, toTermName} +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.Symbols.{defn, Symbol} +import dotty.tools.dotc.core.Types.{AndType, ConstantType, SkolemType, ErrorType, MethodType, OrType, TermRef, Type, TypeProxy} + +import dotty.tools.dotc.report +import dotty.tools.dotc.reporting.trace +import dotty.tools.dotc.config.Printers + +object QualifiedTypes: + /** Does the type `tp1` imply the qualifier `qualifier2`? + * + * Used by [[dotty.tools.dotc.core.TypeComparer]] to compare qualified types. + * + * Note: the logic here is similar to [[Type#derivesAnnotWith]] but + * additionally handle comparisons with [[SingletonType]]s. + */ + def typeImplies(tp1: Type, qualifier2: Tree)(using Context): Boolean = + trace(i"typeImplies $tp1 --> $qualifier2", Printers.qualifiedTypes): + tp1 match + case QualifiedType(parent1, qualifier1) => + QualifierSolver().implies(qualifier1, qualifier2) + case tp1: (ConstantType | TermRef) => + QualifierSolver().implies(equalToPredicate(tpd.singleton(tp1)), qualifier2) + || typeImplies(tp1.underlying, qualifier2) + case tp1: TypeProxy => + typeImplies(tp1.underlying, qualifier2) + case AndType(tp11, tp12) => + typeImplies(tp11, qualifier2) || typeImplies(tp12, qualifier2) + case OrType(tp11, tp12) => + typeImplies(tp11, qualifier2) && typeImplies(tp12, qualifier2) + case _ => + false + // QualifierSolver().implies(truePredicate(), qualifier2) + + /** Try to adapt the tree to the given type `pt` + * + * Returns [[EmptyTree]] if `pt` does not contain qualifiers or if the tree + * cannot be adapted, or the adapted tree otherwise. + * + * Used by [[dotty.tools.dotc.core.Typer]]. + */ + def adapt(tree: Tree, pt: Type)(using Context): Tree = + trace(i"adapt $tree to $pt", Printers.qualifiedTypes): + if containsQualifier(pt) && isSimple(tree) then + val selfifiedTp = QualifiedType(tree.tpe, equalToPredicate(tree)) + if selfifiedTp <:< pt then tree.cast(selfifiedTp) else EmptyTree + else + EmptyTree + + def isSimple(tree: Tree)(using Context): Boolean = + tree match + case Apply(fn, args) => isSimple(fn) && args.forall(isSimple) + case TypeApply(fn, args) => isSimple(fn) + case SeqLiteral(elems, _) => elems.forall(isSimple) + case Typed(expr, _) => isSimple(expr) + case Block(Nil, expr) => isSimple(expr) + case _ => tpd.isIdempotentExpr(tree) + + def containsQualifier(tp: Type)(using Context): Boolean = + tp match + case QualifiedType(_, _) => true + case tp: TypeProxy => containsQualifier(tp.underlying) + case AndType(tp1, tp2) => containsQualifier(tp1) || containsQualifier(tp2) + case OrType(tp1, tp2) => containsQualifier(tp1) || containsQualifier(tp2) + case _ => false + + + private def equalToPredicate(tree: Tree)(using Context): Tree = + Lambda( + MethodType(List("v".toTermName))(_ => List(tree.tpe), _ => defn.BooleanType), + (args) => Ident(args(0).symbol.termRef).equal(tree) + ) + + private def truePredicate()(using Context): Tree = + Lambda( + MethodType(List("v".toTermName))(_ => List(defn.AnyType), _ => defn.BooleanType), + (args) => Literal(Constant(true)) + ) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala new file mode 100644 index 000000000000..97270de2cdfe --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala @@ -0,0 +1,102 @@ +package dotty.tools.dotc.qualified_types + +import scala.util.hashing.MurmurHash3 as hashing + +import dotty.tools.dotc.ast.tpd.{closureDef, Apply, Block, DefDef, Ident, Literal, New, Select, Tree, TreeOps, TypeApply, Typed, TypeTree} +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.Types.{MethodType, TermRef, Type, TypeVar} +import dotty.tools.dotc.core.Symbols.defn + +import dotty.tools.dotc.reporting.trace +import dotty.tools.dotc.config.Printers + +private abstract class QualifierComparer: + private def typeIso(tp1: Type, tp2: Type) = + val tp1stripped = stripPermanentTypeVar(tp1) + val tp2stripped = stripPermanentTypeVar(tp2) + tp1stripped.equals(tp2stripped) + + /** Structural equality for trees. + * + * This implementation is _not_ alpha-equivalence aware, which + * 1. allows it not to rely on a [[Context]] and, + * 2. allows the corresponding [[hash]] method to reuse [[Type#hashCode]] + * instead of defining an other hash code for types. + */ + def iso(tree1: Tree, tree2: Tree): Boolean = + (tree1, tree2) match + case (Literal(_) | Ident(_), _) => + typeIso(tree1.tpe, tree2.tpe) + case (Select(qual1, name1), Select(qual2, name2)) => + name1 == name2 && iso(qual1, qual2) + case (Apply(fun1, args1), Apply(fun2, args2)) => + iso(fun1, fun2) && args1.corresponds(args2)(iso) + case (TypeApply(fun1, args1), TypeApply(fun2, args2)) => + iso(fun1, fun2) && args1.corresponds(args2)((arg1, arg2) => typeIso(arg1.tpe, arg2.tpe)) + case (tpt1: TypeTree, tpt2: TypeTree) => + typeIso(tpt1.tpe, tpt2.tpe) + case (Typed(expr1, tpt1), Typed(expr2, tpt2)) => + iso(expr1, expr2) && typeIso(tpt1.tpe, tpt2.tpe) + case (New(tpt1), New(tpt2)) => + typeIso(tpt1.tpe, tpt2.tpe) + case (Block(stats1, expr1), Block(stats2, expr2)) => + stats1.corresponds(stats2)(iso) && iso(expr1, expr2) + case _ => + tree1.equals(tree2) + + protected def stripPermanentTypeVar(tp: Type): Type = + tp match + case tp: TypeVar if tp.isPermanentlyInstantiated => tp.permanentInst + case tp => tp + +private[qualified_types] object QualifierStructuralComparer extends QualifierComparer: + /** A hash code for trees that corresponds to `iso(tree1, tree2)`. */ + def hash(tree: Tree): Int = + tree match + case Literal(_) | Ident(_) => + hashType(tree.tpe) + case Select(qual, name) => + hashing.mix(name.hashCode, hash(qual)) + case Apply(fun, args) => + hashing.mix(hash(fun), hashList(args)) + case TypeApply(fun, args) => + hashing.mix(hash(fun), hashList(args)) + case tpt: TypeTree => + hashType(tpt.tpe) + case Typed(expr, tpt) => + hashing.mix(hash(expr), hashType(tpt.tpe)) + case New(tpt1) => + hashType(tpt1.tpe) + case Block(stats, expr) => + hashing.mix(hashList(stats), hash(expr)) + case _ => + tree.hashCode + + private def hashList(trees: List[Tree]): Int = + trees.map(hash).foldLeft(0)(hashing.mix) + + private def hashType(tp: Type): Int = + stripPermanentTypeVar(tp).hashCode + + /** A box for trees that implements structural equality using [[iso]] and + * [[hash]]. This enables using trees as keys in hash maps. + */ + final class TreeBox(val tree: Tree) extends AnyVal: + override def equals(that: Any): Boolean = that match + case that: TreeBox => iso(tree, that.tree) + case _ => false + + override def hashCode: Int = hash(tree) + +private[qualified_types] final class QualifierAlphaComparer(using Context) extends QualifierComparer: + override def iso(tree1: Tree, tree2: Tree): Boolean = + trace(i"iso $tree1 ; $tree2"): + (tree1, tree2) match + case (closureDef(def1), closureDef(def2)) => + val def2substituted = def2.rhs.subst(def2.symbol.paramSymss.flatten, def1.symbol.paramSymss.flatten) + val def2normalized = QualifierNormalizer.normalize(def2substituted) + iso(def1.rhs, def2normalized) + case _ => + super.iso(tree1, tree2) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala new file mode 100644 index 000000000000..3c260a5c86e2 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala @@ -0,0 +1,109 @@ +package dotty.tools.dotc.qualified_types + +import scala.annotation.tailrec + +import dotty.tools.dotc.ast.tpd.{ + Apply, + Block, + ConstantTree, + isIdempotentExpr, + EmptyTree, + Literal, + Ident, + Match, + Select, + This, + Tree, + TreeMap, + ValDef, + given +} +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Mode.Type +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} +import dotty.tools.dotc.core.SymDenotations.given +import dotty.tools.dotc.core.Types.{ConstantType, NoPrefix, TermRef} +import dotty.tools.dotc.inlines.InlineReducer +import dotty.tools.dotc.transform.TreeExtractors.BinaryOp +import dotty.tools.dotc.transform.patmat.{Empty as EmptySpace, SpaceEngine} +import dotty.tools.dotc.typer.Typer +import scala.util.boundary +import scala.util.boundary.break + + +import dotty.tools.dotc.reporting.trace +import dotty.tools.dotc.config.Printers + +private[qualified_types] object QualifierEvaluator: + /** Reduces a tree by constant folding, simplification and unfolding of simple + * references. + * + * This is more aggressive than [[dotty.tools.dotc.transform.BetaReduce]] and + * [[dotty.tools.dotc.typer.ConstFold]] (which is used under the hood by + * `BetaReduce` through [[dotty.tools.dotc.ast.tpd.cpy]]), as it also unfolds + * non-constant expressions. + */ + def evaluate(tree: Tree, args: Map[Symbol, Tree] = Map.empty)(using Context): Tree = + trace(i"evaluate $tree", Printers.qualifiedTypes): + QualifierEvaluator(args).transform(tree) + +private class QualifierEvaluator(args: Map[Symbol, Tree]) extends TreeMap: + import QualifierEvaluator.* + + override def transform(tree: Tree)(using Context): Tree = + unfold(reduce(tree)) + + private def reduce(tree: Tree)(using Context): Tree = + tree match + case tree: Apply => + val treeTransformed = super.transform(tree) + constFold(treeTransformed).orElse(reduceBinaryOp(treeTransformed)).orElse(treeTransformed) + case tree: Select => + val treeTransformed = super.transform(tree) + constFold(treeTransformed).orElse(treeTransformed) + case Block(Nil, expr) => + transform(expr) + case tree => + super.transform(tree) + + private def constFold(tree: Tree)(using Context): Tree = + tree match + case ConstantTree(c: Constant) => Literal(c) + case _ => EmptyTree + + private def reduceBinaryOp(tree: Tree)(using Context): Tree = + val d = defn // Need a stable path to match on `defn` members + tree match + case BinaryOp(a, d.Int_== | d.Any_== | d.Boolean_==, b) => + val aNormalized = QualifierNormalizer.normalize(a) + val bNormalized = QualifierNormalizer.normalize(b) + if QualifierAlphaComparer().iso(aNormalized, bNormalized) then + Literal(Constant(true)) + else + EmptyTree + case _ => + EmptyTree + + private def unfold(tree: Tree)(using Context): Tree = + args.get(tree.symbol) match + case Some(tree2) => + return transform(tree2) + case None => () + + tree match + case tree: Ident => + trace(s"unfold $tree", Printers.qualifiedTypes): + tree.symbol.defTree match + case valDef: ValDef + if !valDef.rhs.isEmpty + && !valDef.symbol.is(Flags.Lazy) + && QualifiedTypes.isSimple(valDef.rhs) => + transform(valDef.rhs) + case _ => + tree + case _ => + tree diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala new file mode 100644 index 000000000000..fa8669965c46 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala @@ -0,0 +1,143 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.dotc.ast.tpd.{singleton, Apply, Block, Literal, Select, Tree, TreeMap, given} +import dotty.tools.dotc.core.Atoms +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Symbols.{defn, Symbol} +import dotty.tools.dotc.core.Types.{ConstantType, TermRef} +import dotty.tools.dotc.config.Printers + +import dotty.tools.dotc.reporting.trace +import dotty.tools.dotc.config.Printers + +private[qualified_types] object QualifierNormalizer: + def normalize(tree: Tree)(using Context): Tree = + trace(i"normalize $tree", Printers.qualifiedTypes): + QualifierNormalizer().transform(tree) + +/** A [[TreeMap]] that normalizes trees by applying algebraic simplifications + * and by ordering operands. + * + * Entry point: [[QualifierNormalizer.normalize]]. + */ +private class QualifierNormalizer extends TreeMap: + override def transform(tree: Tree)(using Context): Tree = + val d = defn // Need a stable path to match on `defn` members + tree match + case Apply(method, _) => + method.symbol match + case d.Int_+ => normalizeIntSum(tree) + case d.Int_* => normalizeIntProduct(tree) + case d.Int_== => normalizeEquality(tree) + case d.Any_== => normalizeEquality(tree) + case d.Boolean_== => normalizeEquality(tree) + case _ => super.transform(tree) + case _ => super.transform(tree) + + /** Normalizes a tree representing an integer sum. + * + * The normalization consists in: + * - Grouping summands which have the same non-constant factors, such that + * `3x + x` becomes `4x` for example. + * - Sorting the summands, so that for example `x + y` and `y + x` are + * normalized to the same tree. + * - Simplifying `0 + x` to `x`. + * - Normalizing each summand using [[normalizeIntProduct]]. + */ + private def normalizeIntSum(tree: Tree)(using Context): Tree = + val (summands, const) = decomposeIntSum(tree) + makeIntSum(summands, const) + + /** Decomposes a tree representing an integer sum into a list of non-constant + * summands `s_i` and a constant `c`. The summands are grouped and sorted as + * described in [[normalizeIntSum]]. + */ + private def decomposeIntSum(tree: Tree)(using Context): (List[Tree], Int) = + val groups: Map[List[QualifierStructuralComparer.TreeBox], Int] = + getAllArguments(tree, defn.Int_+) + .map(decomposeIntProduct) + .groupMapReduce(_._1.map(QualifierStructuralComparer.TreeBox.apply))(_._2)(_ + _) + val const = groups.getOrElse(Nil, 0) + val summands = + groups + .filter((args, c) => c != 0 && !args.isEmpty) + .toList + .sortBy((pair: (List[QualifierStructuralComparer.TreeBox], Int)) => pair.hashCode()) + .map((args, c) => makeIntProduct(args.map(_.tree), c)) + (summands, const) + + /** Constructs a tree representing an integer sum from a list of non-constant + * summands `summands` and a constant `const`. + */ + private def makeIntSum(summands: List[Tree], const: Int)(using Context): Tree = + if summands.isEmpty then + Literal(Constant(const)) + else + val summandsTree = summands.reduce(_.select(defn.Int_+).appliedTo(_)) + if const == 0 then summandsTree + else Literal(Constant(const)).select(defn.Int_+).appliedTo(summandsTree) + + /** Normalizes a tree representing an integer product. + * + * The normalization consists in: + * - Sorting the factors, so that for example `x * y` and `y * x` are + * normalized to the same tree. + * - Simplifying `0 * x` to `0`. + * - Simplifying `1 * x` to `x`. + */ + private def normalizeIntProduct(tree: Tree)(using Context): Tree = + val (factors, const) = decomposeIntProduct(tree) + makeIntProduct(factors, const) + + /** Decomposes a tree representing an integer product into a sorted list of + * non-constant factors `f_i` and a constant `c`. + */ + private def decomposeIntProduct(tree: Tree)(using Context): (List[Tree], Int) = + val (consts, factors) = + getAllArguments(tree, defn.Int_*) + .map(transform) + .partitionMap: + case Literal(Constant(n: Int)) => Left(n) + case arg => Right(arg) + (factors.sortBy(QualifierStructuralComparer.hash), consts.product) + + /** Constructs a tree representing an integer product from a sorted list of + * non-constant factors `factors` and a constant `const`. + */ + private def makeIntProduct(factors: List[Tree], const: Int)(using Context): Tree = + if const == 0 then + Literal(Constant(0)) + else if factors.isEmpty then + Literal(Constant(const)) + else + val factorsTree = factors.reduce(_.select(defn.Int_*).appliedTo(_)) + if const == 1 then factorsTree + else Literal(Constant(const)).select(defn.Int_*).appliedTo(factorsTree) + + /** Recursively collects all arguments of an n-ary operation. + * + * For example, given the tree `(a + (b * c)) + (d + e)`, the method returns + * the list `[a, b * c, d, e]` when called with the `+` operator. + */ + private def getAllArguments(tree: Tree, op: Symbol)(using Context): List[Tree] = + tree match + case Apply(method @ Select(qual, _), List(arg)) if method.symbol == op => + getAllArguments(qual, op) ::: getAllArguments(arg, op) + case Block(Nil, expr) => + getAllArguments(expr, op) + case _ => + List(transform(tree)) + + private def normalizeEquality(tree: Tree)(using Context): Tree = + tree match + case Apply(select @ Select(lhs, name), List(rhs)) => + val lhsNormalized = transform(lhs) + val rhsNormalized = transform(rhs) + if QualifierStructuralComparer.hash(lhsNormalized) > QualifierStructuralComparer.hash(rhsNormalized) then + cpy.Apply(tree)(cpy.Select(select)(rhsNormalized, name), List(lhsNormalized)) + else + cpy.Apply(tree)(cpy.Select(select)(lhsNormalized, name), List(rhsNormalized)) + case _ => + throw new IllegalArgumentException("Unexpected tree passed to normalizeEquality: " + tree) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala new file mode 100644 index 000000000000..4845f975d212 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala @@ -0,0 +1,82 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.tpd.{closureDef, singleton, Apply, Ident, Literal, Select, Tree, given} +import dotty.tools.dotc.config.Printers +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} +import dotty.tools.dotc.core.Types.{TermRef} +import dotty.tools.dotc.transform.BetaReduce + +import dotty.tools.dotc.reporting.trace +import dotty.tools.dotc.config.Printers + +class QualifierSolver(using Context): + private val litTrue = Literal(Constant(true)) + private val litFalse = Literal(Constant(false)) + + val d = defn // Need a stable path to match on `defn` members + + def implies(tree1: Tree, tree2: Tree) = + trace(i"implies $tree1 -> $tree2", Printers.qualifiedTypes): + (tree1, tree2) match + case (closureDef(defDef1), closureDef(defDef2)) => + val tree1ArgSym = defDef1.symbol.paramSymss.head.head + val tree2ArgSym = defDef2.symbol.paramSymss.head.head + val rhs = defDef1.rhs + val lhs = defDef2.rhs + if tree1ArgSym.info frozen_<:< tree2ArgSym.info then + impliesRec(rhs, lhs.subst(List(tree2ArgSym), List(tree1ArgSym))) + else if tree2ArgSym.info frozen_<:< tree1ArgSym.info then + impliesRec(rhs.subst(List(tree1ArgSym), List(tree2ArgSym)), lhs) + else + false + case _ => + throw IllegalArgumentException("Qualifiers must be closures") + + private def impliesRec(tree1: Tree, tree2: Tree): Boolean = + // tree1 = lhs || rhs + tree1 match + case Apply(select @ Select(lhs, name), List(rhs)) => + select.symbol match + case d.Boolean_|| => + return impliesRec(lhs, tree2) && impliesRec(rhs, tree2) + case _ => () + case _ => () + + // tree2 = lhs && rhs, or tree2 = lhs || rhs + tree2 match + case Apply(select @ Select(lhs, name), List(rhs)) => + select.symbol match + case d.Boolean_&& => + return impliesRec(tree1, lhs) && impliesRec(tree1, rhs) + case d.Boolean_|| => + return impliesRec(tree1, lhs) || impliesRec(tree1, rhs) + case _ => () + case _ => () + + // tree1 = lhs && rhs + tree1 match + case Apply(select @ Select(lhs, name), List(rhs)) => + select.symbol match + case d.Boolean_&& => + return impliesRec(lhs, tree2) || impliesRec(rhs, tree2) + case _ => () + case _ => () + + val tree1Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree1)) + val tree2Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree2)) + + tree2Normalized match + case Literal(Constant(true)) => + return true + case _ => () + + tree1Normalized match + case Literal(Constant(false)) => + return true + case _ => () + + QualifierAlphaComparer().iso(tree1Normalized, tree2Normalized) diff --git a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala index 16219055b8c0..30095b99b1c5 100644 --- a/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala @@ -175,4 +175,14 @@ object BetaReduce: Some(expansion1) else None end reduceApplication -end BetaReduce \ No newline at end of file + + def reduceApplication(ddef: DefDef, argss: List[List[Tree]])(using Context): Option[Tree] = + val bindings = new ListBuffer[DefTree]() + reduceApplication(ddef, argss, bindings) match + case Some(expansion1) => + val bindings1 = bindings.result() + Some(seq(bindings1, expansion1)) + case None => + None + +end BetaReduce diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index 9a8f5596471f..884b323654e0 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -602,7 +602,7 @@ object Erasure { } override def promote(tree: untpd.Tree)(using Context): tree.ThisTree[Type] = { - assert(tree.hasType) + assert(tree.hasType, i"promote called on tree without type: ${tree.show}") val erasedTp = erasedType(tree) report.log(s"promoting ${tree.show}: ${erasedTp.showWithUnderlying()}") tree.withType(erasedTp) diff --git a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala index d2a72e10fcfc..aa470a0e9ea8 100644 --- a/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala +++ b/compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala @@ -59,12 +59,7 @@ class InlinePatterns extends MiniPhase: case Block(TypeDef(_, template: Template) :: Nil, Apply(Select(New(_),_), Nil)) if template.constr.rhs.isEmpty => template.body match case List(ddef @ DefDef(`name`, _, _, _)) => - val bindings = new ListBuffer[DefTree]() - BetaReduce.reduceApplication(ddef, argss, bindings) match - case Some(expansion1) => - val bindings1 = bindings.result() - seq(bindings1, expansion1) - case None => tree + BetaReduce.reduceApplication(ddef, argss).getOrElse(tree) case _ => tree case _ => tree diff --git a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala index bd726afe5bba..76c36bfe5c2c 100644 --- a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala +++ b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala @@ -61,17 +61,6 @@ object ConstFold: tree.withFoldedType(Constant(targ.tpe)) case _ => tree - private object ConstantTree: - def unapply(tree: Tree)(using Context): Option[Constant] = - tree match - case Inlined(_, Nil, expr) => unapply(expr) - case Typed(expr, _) => unapply(expr) - case Literal(c) if c.tag == Constants.NullTag => Some(c) - case _ => - tree.tpe.widenTermRefExpr.normalized.simplified match - case ConstantType(c) => Some(c) - case _ => None - extension [T <: Tree](tree: T)(using Context) private def withFoldedType(c: Constant | Null): T = if c == null then tree else tree.withType(ConstantType(c)).asInstanceOf[T] diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 2868dc45a193..f94aedfa3aa8 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -48,6 +48,7 @@ import reporting.* import Nullables.* import NullOpsDecorator.* import cc.{CheckCaptures, isRetainsLike} +import qualified_types.QualifiedTypes import config.Config import config.MigrationVersion import transform.CheckUnused.OriginalName @@ -4862,7 +4863,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer return readapt(tree.cast(captured)) // drop type if prototype is Unit - if pt.isRef(defn.UnitClass) then + if pt.isRef(defn.UnitClass, false) then // local adaptation makes sure every adapted tree conforms to its pt // so will take the code path that decides on inlining val tree1 = adapt(tree, WildcardType, locked) @@ -4914,6 +4915,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _ => case _ => + // Try to adapt to a qualified type + val adapted = QualifiedTypes.adapt(tree, pt) + if !adapted.isEmpty then + return readapt(adapted) + def recover(failure: SearchFailureType) = if canDefineFurther(wtp) || canDefineFurther(pt) then readapt(tree) else diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 2af2f834db9a..43705415945e 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -38,6 +38,7 @@ class CompilationTests { compileFile("tests/pos-special/sourcepath/outer/nested/Test4.scala", defaultOptions.and("-sourcepath", "tests/pos-special/sourcepath")), compileFilesInDir("tests/pos-scala2", defaultOptions.and("-source", "3.0-migration")), compileFilesInDir("tests/pos-custom-args/captures", defaultOptions.and("-language:experimental.captureChecking", "-language:experimental.separationChecking", "-source", "3.8")), + compileFilesInDir("tests/pos-custom-args/qualified-types", defaultOptions.and("-language:experimental.qualifiedTypes")), compileFile("tests/pos-special/utf8encoded.scala", defaultOptions.and("-encoding", "UTF8")), compileFile("tests/pos-special/utf16encoded.scala", defaultOptions.and("-encoding", "UTF16")), compileDir("tests/pos-special/i18589", defaultOptions.and("-Wsafe-init").without("-Ycheck:all")), @@ -151,6 +152,7 @@ class CompilationTests { compileFilesInDir("tests/neg", defaultOptions, FileFilter.exclude(TestSources.negScala2LibraryTastyExcludelisted)), compileFilesInDir("tests/neg-deep-subtype", allowDeepSubtypes), compileFilesInDir("tests/neg-custom-args/captures", defaultOptions.and("-language:experimental.captureChecking", "-language:experimental.separationChecking", "-source", "3.8")), + compileFilesInDir("tests/neg-custom-args/qualified-types", defaultOptions.and("-language:experimental.qualifiedTypes")), compileFile("tests/neg-custom-args/sourcepath/outer/nested/Test1.scala", defaultOptions.and("-sourcepath", "tests/neg-custom-args/sourcepath")), compileDir("tests/neg-custom-args/sourcepath2/hi", defaultOptions.and("-sourcepath", "tests/neg-custom-args/sourcepath2", "-Werror")), compileList("duplicate source", List( diff --git a/library/src/scala/annotation/qualified.scala b/library/src/scala/annotation/qualified.scala index 2fae020be762..0c0b6532dd43 100644 --- a/library/src/scala/annotation/qualified.scala +++ b/library/src/scala/annotation/qualified.scala @@ -1,4 +1,4 @@ package scala.annotation /** Annotation for qualified types. */ -@experimental class qualified[T](predicate: T => Boolean) extends StaticAnnotation +@experimental class qualified[T](predicate: T => Boolean) extends RefiningAnnotation diff --git a/tests/neg-custom-args/qualified-types/adapt_neg.scala b/tests/neg-custom-args/qualified-types/adapt_neg.scala new file mode 100644 index 000000000000..70fe964cebbc --- /dev/null +++ b/tests/neg-custom-args/qualified-types/adapt_neg.scala @@ -0,0 +1,22 @@ +def f(x: Int): Int = ??? +case class IntBox(x: Int) +case class Box[T](x: T) + +def test: Unit = + val x: Int = ??? + val y: Int = ??? + def g(x: Int): Int = ??? + + val v1: {v: Int with v == 1} = 2 // error + val v2: {v: Int with v == x} = y // error + val v3: {v: Int with v == x + 1} = x + 2 // error + val v4: {v: Int with v == f(x)} = g(x) // error + val v5: {v: Int with v == g(x)} = f(x) // error + //val v6: {v: Int with v == IntBox(x)} = IntBox(x) // Not implemented + //val v7: {v: Int with v == Box(x)} = Box(x) // Not implemented + val v8: {v: Int with v == x + f(x)} = x + g(x) // error + val v9: {v: Int with v == x + g(x)} = x + f(x) // error + val v10: {v: Int with v == f(x + 1)} = f(x + 2) // error + val v11: {v: Int with v == g(x + 1)} = g(x + 2) // error + //val v12: {v: Int with v == IntBox(x + 1)} = IntBox(x + 1) // Not implemented + //val v13: {v: Int with v == Box(x + 1)} = Box(x + 1) // Not implemented diff --git a/tests/neg-custom-args/qualified-types/subtyping_singletons_neg.scala b/tests/neg-custom-args/qualified-types/subtyping_singletons_neg.scala new file mode 100644 index 000000000000..421cca0789b2 --- /dev/null +++ b/tests/neg-custom-args/qualified-types/subtyping_singletons_neg.scala @@ -0,0 +1,8 @@ +def f(x: Int): Int = ??? + +def test: Unit = + val x: Int = ??? + val y: Int = ??? + summon[2 <:< {v: Int with v == 1}] // error + summon[x.type <:< {v: Int with v == 1}] // error + //summon[y.type <:< {v: Int with v == x}] // FIXME diff --git a/tests/neg-custom-args/qualified-types/subtyping_unfolding_neg.scala b/tests/neg-custom-args/qualified-types/subtyping_unfolding_neg.scala new file mode 100644 index 000000000000..2af09733af41 --- /dev/null +++ b/tests/neg-custom-args/qualified-types/subtyping_unfolding_neg.scala @@ -0,0 +1,18 @@ +def tp[T](): T = ??? + +abstract class C: + type T + val x: T + +def test: Unit = + val x: Int = ??? + val z: Int = ??? + val c1: C = ??? + val c2: C = ??? + + summon[{v: Int with v == x} <:< {v: Int with v == z}] // error + summon[{v: C with v == c1} <:< {v: C with v == c2}] // error + + // summon[{v: Int with v == (??? : Int)} <:< {v: Int with v == (??? : Int)}] // TODO(mbovel): should not compare some impure applications? + + summon[{v: Int with v == tp[c1.T]()} <:< {v: Int with v == tp[c2.T]()}] // error diff --git a/tests/neg-custom-args/qualified-types/syntax_unnamed_neg.scala b/tests/neg-custom-args/qualified-types/syntax_unnamed_neg.scala new file mode 100644 index 000000000000..9191cf5db2db --- /dev/null +++ b/tests/neg-custom-args/qualified-types/syntax_unnamed_neg.scala @@ -0,0 +1,7 @@ +case class Box[T](x: T) +def id[T](x: T): T = x + +abstract class Test: + val v1: Box[Int with v1 > 0] // error: Cyclic reference + val v2: {v: Int with id[Int with v2 > 0](???) > 0} // error: Cyclic reference + val v3: {v: Int with (??? : Int with v3 == 2) > 0} // error: Cyclic reference diff --git a/tests/pos-custom-args/qualified-types/adapt.scala b/tests/pos-custom-args/qualified-types/adapt.scala new file mode 100644 index 000000000000..12fc30e9c80c --- /dev/null +++ b/tests/pos-custom-args/qualified-types/adapt.scala @@ -0,0 +1,22 @@ +def f(x: Int): Int = ??? +case class IntBox(x: Int) +case class Box[T](x: T) + + +def f(x: Int, y: Int): {r: Int with r == x + y} = x + y + +def test: Unit = + val x: Int = ??? + def g(x: Int): Int = ??? + + val v1: {v: Int with v == x + 1} = x + 1 + val v2: {v: Int with v == f(x)} = f(x) + val v3: {v: Int with v == g(x)} = g(x) + //val v6: {v: Int with v == IntBox(x)} = IntBox(x) // Not implemented + //val v7: {v: Int with v == Box(x)} = Box(x) // Not implemented + val v4: {v: Int with v == x + f(x)} = x + f(x) + val v5: {v: Int with v == x + g(x)} = x + g(x) + val v6: {v: Int with v == f(x + 1)} = f(x + 1) + val v7: {v: Int with v == g(x + 1)} = g(x + 1) + //val v12: {v: Int with v == IntBox(x + 1)} = IntBox(x + 1) // Not implemented + //val v13: {v: Int with v == Box(x + 1)} = Box(x + 1) // Not implemented diff --git a/tests/pos-custom-args/qualified-types/avoidance.scala b/tests/pos-custom-args/qualified-types/avoidance.scala new file mode 100644 index 000000000000..b595557924cf --- /dev/null +++ b/tests/pos-custom-args/qualified-types/avoidance.scala @@ -0,0 +1,7 @@ +def Test = () + /* + val x = + val y = 1 + y: {v: Int with v == y} + */ + // TODO(mbovel): proper avoidance for qualified types diff --git a/tests/pos-custom-args/qualified-types/sized_lists.scala b/tests/pos-custom-args/qualified-types/sized_lists.scala new file mode 100644 index 000000000000..88b60db58563 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/sized_lists.scala @@ -0,0 +1,18 @@ + + +def size(v: Vec): Int = ??? +type Vec + + +def vec(s: Int): {v: Vec with size(v) == s} = ??? +def concat(v1: Vec, v2: Vec): {v: Vec with size(v) == size(v1) + size(v2)} = ??? +def sum(v1: Vec, v2: Vec with size(v1) == size(v2)): {v: Vec with size(v) == size(v1)} = ??? + +@main def Test = + + val v3: {v: Vec with size(v) == 3} = vec(3) + val v4: {v: Vec with size(v) == 4} = vec(4) + /* + val v7: {v: Vec with size(v) == 7} = concat(v3, v4) + */ + // TODO(mbovel): need constraints of referred term refs diff --git a/tests/pos-custom-args/qualified-types/subtyping_normalization.scala b/tests/pos-custom-args/qualified-types/subtyping_normalization.scala new file mode 100644 index 000000000000..986e1f618e9b --- /dev/null +++ b/tests/pos-custom-args/qualified-types/subtyping_normalization.scala @@ -0,0 +1,29 @@ +def f(x: Int): Int = ??? +def id[T](x: T): T = x +def opaqueSize[T](l: List[T]): Int = ??? + +def test: Unit = + val x: Int = ??? + val y: Int = ??? + val z: Int = ??? + + summon[{v: Int with v == 2 + (x * y * y * z)} <:< {v: Int with v == (x * y * z * y) + 2}] + summon[{v: Int with v == x + 1} <:< {v: Int with v == 1 + x}] + summon[{v: Int with v == y + x} <:< {v: Int with v == x + y}] + summon[{v: Int with v == x + 2} <:< {v: Int with v == 1 + x + 1}] + summon[{v: Int with v == x + 2} <:< {v: Int with v == 1 + (x + 1)}] + summon[{v: Int with v == x + 2 * y} <:< {v: Int with v == y + x + y}] + summon[{v: Int with v == x + 2 * y} <:< {v: Int with v == y + (x + y)}] + summon[{v: Int with v == x + 3 * y} <:< {v: Int with v == 2 * y + x + y}] + summon[{v: Int with v == x + 3 * y} <:< {v: Int with v == 2 * y + (x + y)}] + summon[{v: Int with v == 0} <:< {v: Int with v == 1 - 1}] + // summon[{v: Int with v == 0} <:< {v: Int with v == x - x}] // TODO(mbovel): handle subtraction + summon[{v: Int with v == 0} <:< {v: Int with v == x + (x * -1)}] + // summon[{v: Int with v == x} <:< {v: Int with v == 1 + x - 1}] // TODO(mbovel): handle subtraction + summon[{v: Int with v == 4 * (x + 1)} <:< {v: Int with v == 2 * (x + 1) + 2 * (1 + x)}] + summon[{v: Int with v == 4 * (x / 2)} <:< {v: Int with v == 2 * (x / 2) + 2 * (x / 2)}] + + summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(1 + x)}] + summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(x + 1)}] + + summon[{v: List[Int] with opaqueSize(v) == 2 * x} <:< {v: List[Int] with opaqueSize(v) == x + x}] diff --git a/tests/pos-custom-args/qualified-types/subtyping_singletons.scala b/tests/pos-custom-args/qualified-types/subtyping_singletons.scala new file mode 100644 index 000000000000..984674da225c --- /dev/null +++ b/tests/pos-custom-args/qualified-types/subtyping_singletons.scala @@ -0,0 +1,10 @@ +type Pos = {v: Int with v > 0} + +def test: Unit = + val x: Int = ??? + val one: Int = 1 + summon[1 <:< {v: Int with v == 1}] + summon[1 <:< {v: Int with v > 0}] + summon[1 <:< Pos] + summon[x.type <:< {v: Int with v == x}] + summon[one.type <:< {v: Int with v == 1}] diff --git a/tests/pos-custom-args/qualified-types/subtyping_unfolding.scala b/tests/pos-custom-args/qualified-types/subtyping_unfolding.scala new file mode 100644 index 000000000000..a488cb9ef1aa --- /dev/null +++ b/tests/pos-custom-args/qualified-types/subtyping_unfolding.scala @@ -0,0 +1,54 @@ +def id[T](x: T): T = x + +abstract class C: + type T + val x: T + +def test: Unit = + val x: Int = ??? + val x2: Int = x + val x3: Int = x2 + val y: Int = x + 1 + val c1: C = ??? + val c2: C = ??? + + summon[{v: Int with v == x} <:< {v: Int with v == x2}] + summon[{v: Int with v == x2} <:< {v: Int with v == x}] + summon[{v: Int with v == x} <:< {v: Int with v == x3}] + summon[{v: Int with v == x3} <:< {v: Int with v == x}] + summon[{v: Int with v == x2} <:< {v: Int with v == x3}] + summon[{v: Int with v == x3} <:< {v: Int with v == x2}] + + summon[{v: Int with v == y} <:< {v: Int with v == x + 1}] + summon[{v: Int with v == x + 1} <:< {v: Int with v == y}] + + summon[{v: Int with v == id(x)} <:< {v: Int with v == id(x2)}] + summon[{v: Int with v == id(x2)} <:< {v: Int with v == id(x)}] + //summon[{v: Int with v == id(y)} <:< {v: Int with v == id(x + 1)}] // TODO(mbovel): needs normaliazion of type hashes + //summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(y)}] // TODO(mbovel): needs normaliazion of type hashes + + summon[{v: Int with v == y + 2} <:< {v: Int with v == x + 1 + 2}] + summon[{v: Int with v == x + 1 + 2} <:< {v: Int with v == y + 2}] + + summon[{v: Int with v == id[c1.T](c1.x)} <:< {v: Int with v == id[c1.T](c1.x)}] + + def innerScope() = + summon[{v: Int with v == x} <:< {v: Int with v == x2}] + summon[{v: Int with v == x2} <:< {v: Int with v == x}] + summon[{v: Int with v == x} <:< {v: Int with v == x3}] + summon[{v: Int with v == x3} <:< {v: Int with v == x}] + summon[{v: Int with v == x2} <:< {v: Int with v == x3}] + summon[{v: Int with v == x3} <:< {v: Int with v == x2}] + + summon[{v: Int with v == y} <:< {v: Int with v == x + 1}] + summon[{v: Int with v == x + 1} <:< {v: Int with v == y}] + + summon[{v: Int with v == id(x)} <:< {v: Int with v == id(x2)}] + summon[{v: Int with v == id(x2)} <:< {v: Int with v == id(x)}] + //summon[{v: Int with v == id(y)} <:< {v: Int with v == id(x + 1)}] // TODO(mbovel): needs normaliazion of type hashes + //summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(y)}] // TODO(mbovel): needs normaliazion of type hashes + + summon[{v: Int with v == y + 2} <:< {v: Int with v == x + 1 + 2}] + summon[{v: Int with v == x + 1 + 2} <:< {v: Int with v == y + 2}] + + summon[{v: Int with v == id[c1.T](c1.x)} <:< {v: Int with v == id[c1.T](c1.x)}] diff --git a/tests/pos-custom-args/qualified-types/syntax_basics.scala b/tests/pos-custom-args/qualified-types/syntax_basics.scala new file mode 100644 index 000000000000..8395c5bf670c --- /dev/null +++ b/tests/pos-custom-args/qualified-types/syntax_basics.scala @@ -0,0 +1,55 @@ +abstract class BoolExprs: + val c: Boolean + def f(b: Boolean): Boolean + val v01: {b: Boolean with true} + val v02: {b: Boolean with false} + val v03: {b: Boolean with !b} + val v04: {b: Boolean with b && b} + val v05: {b: Boolean with b || b} + val v06: {b: Boolean with c} + val v07: {b: Boolean with !c} + val v08: {b: Boolean with b && c} + val v09: {b: Boolean with b || c} + val v10: {b: Boolean with f(b)} + val w01 = v01 + val w02 = v02 + val w03 = v03 + val w04 = v04 + val w05 = v05 + val w06 = v06 + val w07 = v07 + val w08 = v08 + val w09 = v09 + val w10 = v10 + +abstract class IntExprs: + val c: Int + def f(n: Int): Int + val v01: {n: Int with n == 0} + val v02: {n: Int with -n == 0} + val v03: {n: Int with n != 0} + val v04: {n: Int with n > 0} + val v05: {n: Int with n >= 0} + val v06: {n: Int with n < 0} + val v07: {n: Int with n <= 0} + val v08: {n: Int with n == c} + val v09: {n: Int with n != c} + val v10: {n: Int with n > c} + val v11: {n: Int with n >= c} + val v12: {n: Int with n < c} + val v13: {n: Int with n <= c} + val v14: {n: Int with n == f(n)} + val w01 = v01 + val w02 = v02 + val w03 = v03 + val w04 = v04 + val w05 = v05 + val w06 = v06 + val w07 = v07 + val w08 = v08 + val w09 = v09 + val w10 = v10 + val w11 = v11 + val w12 = v12 + val w13 = v13 + val w14 = v14 From 8169c9317ade2466b0cf00b8d3e565ee2263b48c Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 11 Jun 2025 17:12:08 +0000 Subject: [PATCH 07/20] Runtime checks for qualified types Co-Authored-By: Valentin Schneeberger <23651312+Valentin889@users.noreply.github.com> Co-Authored-By: Quentin Bernet <28290641+Sporarum@users.noreply.github.com> --- .../dotc/qualified_types/QualifiedTypes.scala | 50 ++- .../tools/dotc/transform/TypeTestsCasts.scala | 10 +- .../dotty/tools/dotc/CompilationTests.scala | 1 + .../runtimeChecked_dependent_neg.scala | 8 + .../qualified-types/evalOnce.check | 4 + .../qualified-types/evalOnce.scala | 15 + .../qualified-types/isInstanceOf.check | 320 ++++++++++++++++++ .../qualified-types/isInstanceOf.scala | 36 ++ .../qualified-types/pattern_matching.check | 8 + .../qualified-types/pattern_matching.scala | 31 ++ .../pattern_matching_alternative.check | 2 + .../pattern_matching_alternative.scala | 14 + .../qualified-types/runtimeChecked.scala | 18 + .../runtimeChecked_dependent.scala | 14 + 14 files changed, 522 insertions(+), 9 deletions(-) create mode 100644 tests/neg-custom-args/qualified-types/runtimeChecked_dependent_neg.scala create mode 100644 tests/run-custom-args/qualified-types/evalOnce.check create mode 100644 tests/run-custom-args/qualified-types/evalOnce.scala create mode 100644 tests/run-custom-args/qualified-types/isInstanceOf.check create mode 100644 tests/run-custom-args/qualified-types/isInstanceOf.scala create mode 100644 tests/run-custom-args/qualified-types/pattern_matching.check create mode 100644 tests/run-custom-args/qualified-types/pattern_matching.scala create mode 100644 tests/run-custom-args/qualified-types/pattern_matching_alternative.check create mode 100644 tests/run-custom-args/qualified-types/pattern_matching_alternative.scala create mode 100644 tests/run-custom-args/qualified-types/runtimeChecked.scala create mode 100644 tests/run-custom-args/qualified-types/runtimeChecked_dependent.scala diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala index 5ca445bfc303..ceddec843393 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala @@ -19,17 +19,26 @@ import dotty.tools.dotc.ast.tpd.{ Typed, given } +import dotty.tools.dotc.config.Printers import dotty.tools.dotc.core.Atoms import dotty.tools.dotc.core.Constants.Constant import dotty.tools.dotc.core.Contexts.{ctx, Context} -import dotty.tools.dotc.core.Decorators.{i, em, toTermName} +import dotty.tools.dotc.core.Decorators.{em, i, toTermName} import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.Symbols.{defn, Symbol} -import dotty.tools.dotc.core.Types.{AndType, ConstantType, SkolemType, ErrorType, MethodType, OrType, TermRef, Type, TypeProxy} - +import dotty.tools.dotc.core.Types.{ + AndType, + ConstantType, + ErrorType, + MethodType, + OrType, + SkolemType, + TermRef, + Type, + TypeProxy +} import dotty.tools.dotc.report import dotty.tools.dotc.reporting.trace -import dotty.tools.dotc.config.Printers object QualifiedTypes: /** Does the type `tp1` imply the qualifier `qualifier2`? @@ -66,9 +75,22 @@ object QualifiedTypes: */ def adapt(tree: Tree, pt: Type)(using Context): Tree = trace(i"adapt $tree to $pt", Printers.qualifiedTypes): - if containsQualifier(pt) && isSimple(tree) then - val selfifiedTp = QualifiedType(tree.tpe, equalToPredicate(tree)) - if selfifiedTp <:< pt then tree.cast(selfifiedTp) else EmptyTree + if containsQualifier(pt) then + if tree.tpe.hasAnnotation(defn.RuntimeCheckedAnnot) then + if checkContainsSkolem(pt) then + tpd.evalOnce(tree): e => + If( + e.isInstance(pt), + e.asInstance(pt), + Throw(New(defn.IllegalArgumentExceptionType, List())) + ) + else + tree.withType(ErrorType(em"")) + else if isSimple(tree) then + val selfifiedTp = QualifiedType(tree.tpe, equalToPredicate(tree)) + if selfifiedTp <:< pt then tree.cast(selfifiedTp) else EmptyTree + else + EmptyTree else EmptyTree @@ -79,7 +101,7 @@ object QualifiedTypes: case SeqLiteral(elems, _) => elems.forall(isSimple) case Typed(expr, _) => isSimple(expr) case Block(Nil, expr) => isSimple(expr) - case _ => tpd.isIdempotentExpr(tree) + case _ => tpd.isIdempotentExpr(tree) def containsQualifier(tp: Type)(using Context): Boolean = tp match @@ -89,6 +111,18 @@ object QualifiedTypes: case OrType(tp1, tp2) => containsQualifier(tp1) || containsQualifier(tp2) case _ => false + def checkContainsSkolem(tp: Type)(using Context): Boolean = + var res = true + tp.foreachPart: + case QualifiedType(_, qualifier) => + qualifier.foreachSubTree: subTree => + subTree.tpe.foreachPart: + case tp: SkolemType => + report.error(em"The qualified type $qualifier cannot be checked at runtime", qualifier.srcPos) + res = false + case _ => () + case _ => () + res private def equalToPredicate(tree: Tree)(using Context): Tree = Lambda( diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala index a8c8ec8ce1d8..8fd3ed4373f4 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala @@ -18,6 +18,8 @@ import config.Printers.{ transforms => debug } import patmat.Typ import dotty.tools.dotc.util.SrcPos +import qualified_types.QualifiedType + /** This transform normalizes type tests and type casts, * also replacing type tests with singleton argument type with reference equality check * Any remaining type tests @@ -323,7 +325,7 @@ object TypeTestsCasts { * The transform happens before erasure of `testType`, thus cannot be merged * with `transformIsInstanceOf`, which depends on erased type of `testType`. */ - def transformTypeTest(expr: Tree, testType: Type, flagUnrelated: Boolean): Tree = testType.dealias match { + def transformTypeTest(expr: Tree, testType: Type, flagUnrelated: Boolean): Tree = testType.dealiasKeepRefiningAnnots match { case tref: TermRef if tref.symbol == defn.EmptyTupleModule => ref(defn.RuntimeTuples_isInstanceOfEmptyTuple).appliedTo(expr) case _: SingletonType => @@ -352,6 +354,12 @@ object TypeTestsCasts { ref(defn.RuntimeTuples_isInstanceOfNonEmptyTuple).appliedTo(expr) case AppliedType(tref: TypeRef, _) if tref.symbol == defn.PairClass => ref(defn.RuntimeTuples_isInstanceOfNonEmptyTuple).appliedTo(expr) + case QualifiedType(parent, closureDef(qualifierDef)) => + evalOnce(expr): e => + // e.isInstanceOf[baseType] && qualifier(e.asInstanceOf[baseType]) + val arg = e.asInstance(parent) + val qualifierTest = BetaReduce.reduceApplication(qualifierDef, List(List(arg))).get + transformTypeTest(e, parent, flagUnrelated).and(qualifierTest) case _ => val testWidened = testType.widen defn.untestableClasses.find(testWidened.isRef(_)) match diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 43705415945e..bb478561b1e9 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -176,6 +176,7 @@ class CompilationTests { compileFilesInDir("tests/run", defaultOptions.and("-Wsafe-init")), compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes), compileFilesInDir("tests/run-custom-args/captures", allowDeepSubtypes.and("-language:experimental.captureChecking", "-language:experimental.separationChecking", "-source", "3.8")), + compileFilesInDir("tests/run-custom-args/qualified-types", defaultOptions.and("-language:experimental.qualifiedTypes")), // Run tests for legacy lazy vals. compileFilesInDir("tests/run", defaultOptions.and("-Wsafe-init", "-Ylegacy-lazy-vals", "-Ycheck-constraint-deps"), FileFilter.include(TestSources.runLazyValsAllowlist)), ).checkRuns() diff --git a/tests/neg-custom-args/qualified-types/runtimeChecked_dependent_neg.scala b/tests/neg-custom-args/qualified-types/runtimeChecked_dependent_neg.scala new file mode 100644 index 000000000000..f8fcc5c576f0 --- /dev/null +++ b/tests/neg-custom-args/qualified-types/runtimeChecked_dependent_neg.scala @@ -0,0 +1,8 @@ +def foo(x: Int, y: {v: Int with v > x}): y.type = y + +def getInt(): Int = + println("getInt called") + 42 + +@main def Test = + val res = foo(getInt(), 2.runtimeChecked) // error diff --git a/tests/run-custom-args/qualified-types/evalOnce.check b/tests/run-custom-args/qualified-types/evalOnce.check new file mode 100644 index 000000000000..520fc9973c80 --- /dev/null +++ b/tests/run-custom-args/qualified-types/evalOnce.check @@ -0,0 +1,4 @@ +Hello +succeed +65 +Hello diff --git a/tests/run-custom-args/qualified-types/evalOnce.scala b/tests/run-custom-args/qualified-types/evalOnce.scala new file mode 100644 index 000000000000..fd050e1f5399 --- /dev/null +++ b/tests/run-custom-args/qualified-types/evalOnce.scala @@ -0,0 +1,15 @@ +@main def Test: Unit = + + ({println ("Hello"); 4}).isInstanceOf[{ y : Int with y > 0}] + + class Person(val name: String, var age: Int) + def incAge(p: Person): Int = + p.age += 1 + p.age + + val p = new Person("Alice", 64) + + if incAge(p).isInstanceOf[{v:Int with v == 65}] then println("succeed") + println(p.age) + + ({println ("Hello"); 4}).isInstanceOf[Int & { y : Int with y > 0}] diff --git a/tests/run-custom-args/qualified-types/isInstanceOf.check b/tests/run-custom-args/qualified-types/isInstanceOf.check new file mode 100644 index 000000000000..bba4b7cd4694 --- /dev/null +++ b/tests/run-custom-args/qualified-types/isInstanceOf.check @@ -0,0 +1,320 @@ +call id +-1 is instance of Pos: false +call id +-1 is instance of {v: Int with v == 2}: false +call id +-1 is instance of NonEmptyString: false +call id +-1 is instance of PoliteString: false +call id +-1 is instance of Pos & Int: false +call id +-1 is instance of Pos | Int: true +call id +-1 is instance of Pos & String: false +call id +-1 is instance of Pos | String: false +call id +-1 is instance of {v: Int with v == 2} & Int: false +call id +-1 is instance of {v: Int with v == 2} | Int: true +call id +-1 is instance of {v: Int with v == 2} & String: false +call id +-1 is instance of {v: Int with v == 2} | String: false +call id +-1 is instance of NonEmptyString & String: false +call id +-1 is instance of NonEmptyString | String: false +call id +-1 is instance of NonEmptyString & Int: false +call id +-1 is instance of NonEmptyString | Int: true +call id +-1 is instance of PoliteString & Int: false +call id +-1 is instance of PoliteString | Int: true +call id +-1 is instance of PoliteString & String: false +call id +-1 is instance of PoliteString | String: false +call id +1 is instance of Pos: true +call id +1 is instance of {v: Int with v == 2}: false +call id +1 is instance of NonEmptyString: false +call id +1 is instance of PoliteString: false +call id +1 is instance of Pos & Int: true +call id +1 is instance of Pos | Int: true +call id +1 is instance of Pos & String: false +call id +1 is instance of Pos | String: true +call id +1 is instance of {v: Int with v == 2} & Int: false +call id +1 is instance of {v: Int with v == 2} | Int: true +call id +1 is instance of {v: Int with v == 2} & String: false +call id +1 is instance of {v: Int with v == 2} | String: false +call id +1 is instance of NonEmptyString & String: false +call id +1 is instance of NonEmptyString | String: false +call id +1 is instance of NonEmptyString & Int: false +call id +1 is instance of NonEmptyString | Int: true +call id +1 is instance of PoliteString & Int: false +call id +1 is instance of PoliteString | Int: true +call id +1 is instance of PoliteString & String: false +call id +1 is instance of PoliteString | String: false +call id +2 is instance of Pos: true +call id +2 is instance of {v: Int with v == 2}: true +call id +2 is instance of NonEmptyString: false +call id +2 is instance of PoliteString: false +call id +2 is instance of Pos & Int: true +call id +2 is instance of Pos | Int: true +call id +2 is instance of Pos & String: false +call id +2 is instance of Pos | String: true +call id +2 is instance of {v: Int with v == 2} & Int: true +call id +2 is instance of {v: Int with v == 2} | Int: true +call id +2 is instance of {v: Int with v == 2} & String: false +call id +2 is instance of {v: Int with v == 2} | String: true +call id +2 is instance of NonEmptyString & String: false +call id +2 is instance of NonEmptyString | String: false +call id +2 is instance of NonEmptyString & Int: false +call id +2 is instance of NonEmptyString | Int: true +call id +2 is instance of PoliteString & Int: false +call id +2 is instance of PoliteString | Int: true +call id +2 is instance of PoliteString & String: false +call id +2 is instance of PoliteString | String: false +call id +"" is instance of Pos: false +call id +"" is instance of {v: Int with v == 2}: false +call id +"" is instance of NonEmptyString: false +call id +"" is instance of PoliteString: false +call id +"" is instance of Pos & Int: false +call id +"" is instance of Pos | Int: false +call id +"" is instance of Pos & String: false +call id +"" is instance of Pos | String: true +call id +"" is instance of {v: Int with v == 2} & Int: false +call id +"" is instance of {v: Int with v == 2} | Int: false +call id +"" is instance of {v: Int with v == 2} & String: false +call id +"" is instance of {v: Int with v == 2} | String: true +call id +"" is instance of NonEmptyString & String: false +call id +"" is instance of NonEmptyString | String: true +call id +"" is instance of NonEmptyString & Int: false +call id +"" is instance of NonEmptyString | Int: false +call id +"" is instance of PoliteString & Int: false +call id +"" is instance of PoliteString | Int: false +call id +"" is instance of PoliteString & String: false +call id +"" is instance of PoliteString | String: true +call id +"Do it please" is instance of Pos: false +call id +"Do it please" is instance of {v: Int with v == 2}: false +call id +"Do it please" is instance of NonEmptyString: true +call id +"Do it please" is instance of PoliteString: true +call id +"Do it please" is instance of Pos & Int: false +call id +"Do it please" is instance of Pos | Int: false +call id +"Do it please" is instance of Pos & String: false +call id +"Do it please" is instance of Pos | String: true +call id +"Do it please" is instance of {v: Int with v == 2} & Int: false +call id +"Do it please" is instance of {v: Int with v == 2} | Int: false +call id +"Do it please" is instance of {v: Int with v == 2} & String: false +call id +"Do it please" is instance of {v: Int with v == 2} | String: true +call id +"Do it please" is instance of NonEmptyString & String: true +call id +"Do it please" is instance of NonEmptyString | String: true +call id +"Do it please" is instance of NonEmptyString & Int: false +call id +"Do it please" is instance of NonEmptyString | Int: true +call id +"Do it please" is instance of PoliteString & Int: false +call id +"Do it please" is instance of PoliteString | Int: true +call id +"Do it please" is instance of PoliteString & String: true +call id +"Do it please" is instance of PoliteString | String: true +call id +"do it already" is instance of Pos: false +call id +"do it already" is instance of {v: Int with v == 2}: false +call id +"do it already" is instance of NonEmptyString: true +call id +"do it already" is instance of PoliteString: false +call id +"do it already" is instance of Pos & Int: false +call id +"do it already" is instance of Pos | Int: false +call id +"do it already" is instance of Pos & String: false +call id +"do it already" is instance of Pos | String: true +call id +"do it already" is instance of {v: Int with v == 2} & Int: false +call id +"do it already" is instance of {v: Int with v == 2} | Int: false +call id +"do it already" is instance of {v: Int with v == 2} & String: false +call id +"do it already" is instance of {v: Int with v == 2} | String: true +call id +"do it already" is instance of NonEmptyString & String: true +call id +"do it already" is instance of NonEmptyString | String: true +call id +"do it already" is instance of NonEmptyString & Int: false +call id +"do it already" is instance of NonEmptyString | Int: true +call id +"do it already" is instance of PoliteString & Int: false +call id +"do it already" is instance of PoliteString | Int: false +call id +"do it already" is instance of PoliteString & String: false +call id +"do it already" is instance of PoliteString | String: true +call id +false is instance of Pos: false +call id +false is instance of {v: Int with v == 2}: false +call id +false is instance of NonEmptyString: false +call id +false is instance of PoliteString: false +call id +false is instance of Pos & Int: false +call id +false is instance of Pos | Int: false +call id +false is instance of Pos & String: false +call id +false is instance of Pos | String: false +call id +false is instance of {v: Int with v == 2} & Int: false +call id +false is instance of {v: Int with v == 2} | Int: false +call id +false is instance of {v: Int with v == 2} & String: false +call id +false is instance of {v: Int with v == 2} | String: false +call id +false is instance of NonEmptyString & String: false +call id +false is instance of NonEmptyString | String: false +call id +false is instance of NonEmptyString & Int: false +call id +false is instance of NonEmptyString | Int: false +call id +false is instance of PoliteString & Int: false +call id +false is instance of PoliteString | Int: false +call id +false is instance of PoliteString & String: false +call id +false is instance of PoliteString | String: false +call id +null is instance of Pos: false +call id +null is instance of {v: Int with v == 2}: false +call id +null is instance of NonEmptyString: false +call id +null is instance of PoliteString: false +call id +null is instance of Pos & Int: false +call id +null is instance of Pos | Int: false +call id +null is instance of Pos & String: false +call id +null is instance of Pos | String: false +call id +null is instance of {v: Int with v == 2} & Int: false +call id +null is instance of {v: Int with v == 2} | Int: false +call id +null is instance of {v: Int with v == 2} & String: false +call id +null is instance of {v: Int with v == 2} | String: false +call id +null is instance of NonEmptyString & String: false +call id +null is instance of NonEmptyString | String: false +call id +null is instance of NonEmptyString & Int: false +call id +null is instance of NonEmptyString | Int: false +call id +null is instance of PoliteString & Int: false +call id +null is instance of PoliteString | Int: false +call id +null is instance of PoliteString & String: false +call id +null is instance of PoliteString | String: false diff --git a/tests/run-custom-args/qualified-types/isInstanceOf.scala b/tests/run-custom-args/qualified-types/isInstanceOf.scala new file mode 100644 index 000000000000..a893a5b7a3f0 --- /dev/null +++ b/tests/run-custom-args/qualified-types/isInstanceOf.scala @@ -0,0 +1,36 @@ +type Pos = {x: Int with x > 0} + +type NonEmptyString = {s: String with !s.isEmpty} +type PoliteString = {s: NonEmptyString with s.head.isUpper && s.takeRight(6) == "please"} + +def id[T](x: T): T = + println("call id") + x + +@main +def Test = + for v <- List[Any](-1, 1, 2, "", "Do it please", "do it already", false, null) do + val vStr = + if v.isInstanceOf[String] then s""""$v"""" + else if v == null then "null" + else v.toString + println(s"$vStr is instance of Pos: ${id(v).isInstanceOf[Pos]}") + println(s"$vStr is instance of {v: Int with v == 2}: ${id(v).isInstanceOf[{v: Int with v == 2}]}") + println(s"$vStr is instance of NonEmptyString: ${id(v).isInstanceOf[NonEmptyString]}") + println(s"$vStr is instance of PoliteString: ${id(v).isInstanceOf[PoliteString]}") + println(s"$vStr is instance of Pos & Int: ${id(v).isInstanceOf[Pos & Int]}") + println(s"$vStr is instance of Pos | Int: ${id(v).isInstanceOf[Pos | Int]}") + println(s"$vStr is instance of Pos & String: ${id(v).isInstanceOf[Pos & String]}") + println(s"$vStr is instance of Pos | String: ${id(v).isInstanceOf[Pos | String]}") + println(s"$vStr is instance of {v: Int with v == 2} & Int: ${id(v).isInstanceOf[{v: Int with v == 2} & Int]}") + println(s"$vStr is instance of {v: Int with v == 2} | Int: ${id(v).isInstanceOf[{v: Int with v == 2} | Int]}") + println(s"$vStr is instance of {v: Int with v == 2} & String: ${id(v).isInstanceOf[{v: Int with v == 2} & String]}") + println(s"$vStr is instance of {v: Int with v == 2} | String: ${id(v).isInstanceOf[{v: Int with v == 2} | String]}") + println(s"$vStr is instance of NonEmptyString & String: ${id(v).isInstanceOf[NonEmptyString & String]}") + println(s"$vStr is instance of NonEmptyString | String: ${id(v).isInstanceOf[NonEmptyString | String]}") + println(s"$vStr is instance of NonEmptyString & Int: ${id(v).isInstanceOf[NonEmptyString & Int]}") + println(s"$vStr is instance of NonEmptyString | Int: ${id(v).isInstanceOf[NonEmptyString | Int]}") + println(s"$vStr is instance of PoliteString & Int: ${id(v).isInstanceOf[PoliteString & Int]}") + println(s"$vStr is instance of PoliteString | Int: ${id(v).isInstanceOf[PoliteString | Int]}") + println(s"$vStr is instance of PoliteString & String: ${id(v).isInstanceOf[PoliteString & String]}") + println(s"$vStr is instance of PoliteString | String: ${id(v).isInstanceOf[PoliteString | String]}") diff --git a/tests/run-custom-args/qualified-types/pattern_matching.check b/tests/run-custom-args/qualified-types/pattern_matching.check new file mode 100644 index 000000000000..612472cd9501 --- /dev/null +++ b/tests/run-custom-args/qualified-types/pattern_matching.check @@ -0,0 +1,8 @@ +-1 is none of the above +1 is Pos +2 is {v: Int with v == 2} +"" is none of the above +"Do it please" is PoliteString +"do it already" is NonEmptyString +false is none of the above +null is none of the above diff --git a/tests/run-custom-args/qualified-types/pattern_matching.scala b/tests/run-custom-args/qualified-types/pattern_matching.scala new file mode 100644 index 000000000000..43dc87b9dbe9 --- /dev/null +++ b/tests/run-custom-args/qualified-types/pattern_matching.scala @@ -0,0 +1,31 @@ +type Pos = {x: Int with x > 0} + +type NonEmptyString = {s: String with !s.isEmpty} +type PoliteString = {s: NonEmptyString with s.head.isUpper && s.takeRight(6) == "please"} + +def id[T](x: T): T = + println("call id") + x + +def rec(x: NonEmptyString): List[Char] = + val rest = + x.tail match + case xs: NonEmptyString => rec(xs) + case _ => Nil + + x.head :: rest + +@main def Test = + for v <- List[Any](-1, 1, 2, "", "Do it please", "do it already", false, null) do + val vStr = + if v.isInstanceOf[String] then s""""$v"""" + else if v == null then "null" + else v.toString + + v match + case _: {v: Int with v == 2} => println(s"$vStr is {v: Int with v == 2}") + case _: Pos => println(s"$vStr is Pos") + case _: PoliteString => println(s"$vStr is PoliteString") + case _: NonEmptyString => println(s"$vStr is NonEmptyString") + case _ => println(s"$vStr is none of the above") + diff --git a/tests/run-custom-args/qualified-types/pattern_matching_alternative.check b/tests/run-custom-args/qualified-types/pattern_matching_alternative.check new file mode 100644 index 000000000000..fddaa098404b --- /dev/null +++ b/tests/run-custom-args/qualified-types/pattern_matching_alternative.check @@ -0,0 +1,2 @@ +Hello is an Int or a non-empty String +Hello is an Int or a non-empty String diff --git a/tests/run-custom-args/qualified-types/pattern_matching_alternative.scala b/tests/run-custom-args/qualified-types/pattern_matching_alternative.scala new file mode 100644 index 000000000000..5f3feed1ed4b --- /dev/null +++ b/tests/run-custom-args/qualified-types/pattern_matching_alternative.scala @@ -0,0 +1,14 @@ +@main def Test = + val x: Any = "Hello" + + x match + case (_: Int) | (_: {s: String with s.nonEmpty}) => + println(s"$x is an Int or a non-empty String") + case _ => + () + + x match + case _: (Int | {s: String with s.nonEmpty}) => + println(s"$x is an Int or a non-empty String") + case _ => + () diff --git a/tests/run-custom-args/qualified-types/runtimeChecked.scala b/tests/run-custom-args/qualified-types/runtimeChecked.scala new file mode 100644 index 000000000000..a709a0089c09 --- /dev/null +++ b/tests/run-custom-args/qualified-types/runtimeChecked.scala @@ -0,0 +1,18 @@ + +def getInt(): Int = 1 + +@main def Test = + val x: {v: Int with v == 1} = getInt().runtimeChecked + + assertThrows[IllegalArgumentException]: + val v1: {v: Int with v == 2} = x.runtimeChecked + +def assertThrows[T <: Throwable](block: => Unit): Unit = + try + block + catch + case e: Throwable if e.isInstanceOf[T] => + return + case _ => + throw new AssertionError("Unexpected exception") + throw new AssertionError("Expected exception not thrown") diff --git a/tests/run-custom-args/qualified-types/runtimeChecked_dependent.scala b/tests/run-custom-args/qualified-types/runtimeChecked_dependent.scala new file mode 100644 index 000000000000..26ed21311a89 --- /dev/null +++ b/tests/run-custom-args/qualified-types/runtimeChecked_dependent.scala @@ -0,0 +1,14 @@ +def foo(x: Int, y: {v: Int with v > x}): y.type = y + +def getInt(): Int = 1 + +type Pos = {v: Int with v > 0} +type Neg = {v: Int with v < 0} + +import scala.reflect.TypeTest + +@main def Test = + val v1= foo(1, 2.runtimeChecked) + + val p: Int = 1 + val v2 = foo(p, 2.runtimeChecked) From 3fd81dc4e2b8ef2af3c0c3cf869db8ae087c61db Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 12 Jun 2025 09:46:13 +0000 Subject: [PATCH 08/20] Add E-Graph-based rewriting to the QualifierSolver --- .../qualified_types/QualifierEGraph.scala | 263 ++++++++++++++++++ .../qualified_types/QualifierNormalizer.scala | 15 - .../qualified_types/QualifierSolver.scala | 62 ++++- .../subtyping_equalities.scala | 31 +++ 4 files changed, 341 insertions(+), 30 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierEGraph.scala create mode 100644 tests/pos-custom-args/qualified-types/subtyping_equalities.scala diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierEGraph.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierEGraph.scala new file mode 100644 index 000000000000..66074d845011 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierEGraph.scala @@ -0,0 +1,263 @@ +package dotty.tools.dotc.qualified_types + +import scala.collection.mutable + +import dotty.tools.dotc.ast.tpd.{ + Apply, + ConstantTree, + Ident, + Literal, + New, + Select, + Tree, + TreeMap, + TreeOps, + TypeApply, + TypeTree +} +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Names.Designator +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} +import dotty.tools.dotc.core.Types.{ConstantType, NoPrefix, SingletonType, TermRef, Type} +import dotty.tools.dotc.transform.TreeExtractors.BinaryOp +import dotty.tools.dotc.util.Spans.Span + +private enum ENode: + case Const(value: Constant) + case Ref(tp: TermRef) + case Object(clazz: Symbol, args: List[ENode]) + case Select(qual: ENode, member: Symbol) + case App(fn: ENode, args: List[ENode]) + case TypeApp(fn: ENode, args: List[Type]) + + override def toString(): String = + this match + case Const(value) => value.toString + case Ref(tp) => termRefToString(tp) + case Object(clazz, args) => s"#$clazz(${args.mkString(", ")})" + case Select(qual, member) => s"$qual..$member" + case App(fn, args) => s"$fn(${args.mkString(", ")})" + case TypeApp(fn, args) => s"$fn[${args.mkString(", ")}]" + + private def designatorToString(d: Designator): String = + d match + case d: Symbol => d.lastKnownDenotation.name.toString + case _ => d.toString + + private def termRefToString(tp: Type): String = + tp match + case tp: TermRef => + val pre = if tp.prefix == NoPrefix then "" else termRefToString(tp.prefix) + "." + pre + designatorToString(tp.designator) + case _ => + tp.toString + +final class QualifierEGraph: + private val represententOf = mutable.Map.empty[ENode, ENode] + + private def representent(node: ENode): ENode = + represententOf.get(node) match + case None => node + case Some(repr) => + val res = representent(repr) // avoid tailrec optimization + res + + /** Map from child nodes to their parent nodes */ + private val usedBy = mutable.Map.empty[ENode, mutable.Set[ENode]] + + private def uses(node: ENode): mutable.Set[ENode] = + usedBy.getOrElseUpdate(node, mutable.Set.empty) + + /** Map used for hash-consing nodes, keys and values are the same */ + private val index = mutable.Map.empty[ENode, ENode] + + private val worklist = mutable.Queue.empty[ENode] + + final def union(tree1: Tree, tree2: Tree)(using Context): Unit = + for node1 <- toNode(tree1); node2 <- toNode(tree2) do + merge(node1, node2) + + private def unique(node: ENode): node.type = + index.getOrElseUpdate( + node, { + node match + case ENode.Const(value) => + () + case ENode.Ref(tp) => + () + case ENode.Object(clazz, args) => + args.foreach(uses(_) += node) + case ENode.Select(qual, member) => + uses(qual) += node + case ENode.App(fn, args) => + uses(fn) += node + args.foreach(uses(_) += node) + case ENode.TypeApp(fn, args) => + uses(fn) += node + node + } + ).asInstanceOf[node.type] + + private val toNodeCache = mutable.WeakHashMap.empty[Tree, Option[ENode]] + + private def toNode(tree: Tree)(using Context): Option[ENode] = + toNodeCache.getOrElseUpdate(tree, computeToNode(tree).map(n => representent(unique(n)))) + + private def computeToNode(tree: Tree)(using Context): Option[ENode] = + tree match + case ConstantTree(constant) => + Some(ENode.Const(constant)) + case Ident(_) => + tree.tpe match + case tp: TermRef => Some(ENode.Ref(tp)) + case _ => None + case Apply(Select(clazz, nme.CONSTRUCTOR), args) if isCaseClass(clazz.symbol) => + for argsNodes <- args.map(toNode).sequence yield ENode.Object(clazz.symbol, argsNodes) + case Select(qual, name) if isCaseClassField(tree.symbol) => + for qualNode <- toNode(qual) yield qualNode match + case ENode.Object(_, args) => args(caseClassFieldIndex(tree.symbol)) + case qualNode => ENode.Select(qualNode, tree.symbol) + case Apply(fun, args) => + for funNode <- toNode(fun); argsNodes <- args.map(toNode).sequence yield ENode.App(funNode, argsNodes) + case TypeApply(fun, args) => + for funNode <- toNode(fun) yield ENode.TypeApp(funNode, args.map(_.tpe)) + case _ => + return None + + private object RefTypeTree: + def unapply(tree: Tree): Option[TermRef] = + tree.tpe match + case tp: TermRef => Some(tp) + case _ => None + + private def isCaseClass(sym: Symbol): Boolean = + // TODO(mbovel) + false + + private def isCaseClassField(sym: Symbol): Boolean = + // TODO(mbovel) + false + + private def caseClassFieldIndex(sym: Symbol): Int = + // TODO(mbovel) + ??? + + private def canonicalize(node: ENode): ENode = + representent(unique( + node match + case ENode.Const(value) => + node + case ENode.Ref(tp) => + node + case ENode.Object(clazz, args) => + val argsNodes = args.map(representent) + ENode.Object(clazz, argsNodes) + case ENode.Select(qual, member) => + representent(qual) match + case ENode.Object(_, args) => + args(caseClassFieldIndex(member)) + case qualRepr => + ENode.Select(qualRepr, member) + case ENode.App(fn, args) => + val fnNode = representent(fn) + val argsNodes = args.map(representent) + ENode.App(fnNode, argsNodes) + case ENode.TypeApp(fn, args) => + val fnNode = representent(fn) + ENode.TypeApp(fnNode, args) + )) + + private def order(a: ENode, b: ENode): (ENode, ENode) = + (a, b) match + case (_: ENode.Const, _) => (a, b) + case (_, _: ENode.Const) => (b, a) + case (_: ENode.Ref, _) => (a, b) + case (_, _: ENode.Ref) => (b, a) + case (_: ENode.Object, _) => (a, b) + case (_, _: ENode.Object) => (b, a) + case (_: ENode.Select, _) => (a, b) + case (_, _: ENode.Select) => (b, a) + case (_: ENode.App, _) => (a, b) + case (_, _: ENode.App) => (b, a) + case _ => (a, b) + + private def merge(a: ENode, b: ENode): Unit = + val aRepr = representent(a) + val bRepr = representent(b) + if aRepr eq bRepr then return + + // If both nodes are objects, recursively merge their arguments + (aRepr, bRepr) match + case (ENode.Object(clazzA, argsA), ENode.Object(clazzB, argsB)) if clazzA == clazzB => + argsA.zip(argsB).foreach(merge) + case _ => () + + /// Update represententOf and usedBy maps + val (newRepr, oldRepr) = order(aRepr, bRepr) + represententOf(oldRepr) = newRepr + uses(newRepr) ++= uses(oldRepr) + val oldUses = uses(oldRepr) + usedBy.remove(oldRepr) + + // Enqueue all nodes that use the oldRepr for repair + worklist.enqueueAll(oldUses) + + def repair(): Unit = + while !worklist.isEmpty do + val head = worklist.dequeue() + val headRepr = representent(head) + val headCanonical = canonicalize(head) + if headRepr ne headCanonical then + merge(headRepr, headCanonical) + + // Rewrite equivalent nodes in the tree to their canonical form + def rewrite(tree: Tree)(using Context): Tree = + Rewriter().transform(tree) + + private class Rewriter extends TreeMap: + override def transform(tree: Tree)(using Context): Tree = + toNode(tree) match + case Some(n) => toTree(representent(n)) + case None => + val d = defn + tree match + case BinaryOp(a, d.Int_== | d.Any_== | d.Boolean_==, b) => + (toNode(a), toNode(b)) match + case (Some(aNode), Some(bNode)) => + if representent(aNode) eq representent(bNode) then Literal(Constant(true)) + else super.transform(tree) + case _ => + super.transform(tree) + case _ => + super.transform(tree) + + private def toTree(node: ENode)(using Context): Tree = + node match + case ENode.Const(value) => + Literal(value) + case ENode.Ref(tp) => + Ident(tp) + case ENode.Object(clazz, args) => + New(clazz.typeRef, args.map(toTree)) + case ENode.Select(qual, member) => + toTree(qual).select(member) + case ENode.App(fn, args) => + Apply(toTree(fn), args.map(toTree)) + case ENode.TypeApp(fn, args) => + TypeApply(toTree(fn), args.map(TypeTree(_, false))) + + extension [T](xs: List[Option[T]]) + def sequence: Option[List[T]] = + var result = List.newBuilder[T] + var current = xs + while current.nonEmpty do + current.head match + case Some(x) => + result += x + current = current.tail + case None => + return None + Some(result.result()) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala index fa8669965c46..7a0560a38006 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala @@ -30,9 +30,6 @@ private class QualifierNormalizer extends TreeMap: method.symbol match case d.Int_+ => normalizeIntSum(tree) case d.Int_* => normalizeIntProduct(tree) - case d.Int_== => normalizeEquality(tree) - case d.Any_== => normalizeEquality(tree) - case d.Boolean_== => normalizeEquality(tree) case _ => super.transform(tree) case _ => super.transform(tree) @@ -129,15 +126,3 @@ private class QualifierNormalizer extends TreeMap: getAllArguments(expr, op) case _ => List(transform(tree)) - - private def normalizeEquality(tree: Tree)(using Context): Tree = - tree match - case Apply(select @ Select(lhs, name), List(rhs)) => - val lhsNormalized = transform(lhs) - val rhsNormalized = transform(rhs) - if QualifierStructuralComparer.hash(lhsNormalized) > QualifierStructuralComparer.hash(rhsNormalized) then - cpy.Apply(tree)(cpy.Select(select)(rhsNormalized, name), List(lhsNormalized)) - else - cpy.Apply(tree)(cpy.Select(select)(lhsNormalized, name), List(rhsNormalized)) - case _ => - throw new IllegalArgumentException("Unexpected tree passed to normalizeEquality: " + tree) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala index 4845f975d212..bc0be46dbff6 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala @@ -28,21 +28,21 @@ class QualifierSolver(using Context): val rhs = defDef1.rhs val lhs = defDef2.rhs if tree1ArgSym.info frozen_<:< tree2ArgSym.info then - impliesRec(rhs, lhs.subst(List(tree2ArgSym), List(tree1ArgSym))) + impliesRec1(rhs, lhs.subst(List(tree2ArgSym), List(tree1ArgSym))) else if tree2ArgSym.info frozen_<:< tree1ArgSym.info then - impliesRec(rhs.subst(List(tree1ArgSym), List(tree2ArgSym)), lhs) + impliesRec1(rhs.subst(List(tree1ArgSym), List(tree2ArgSym)), lhs) else false case _ => throw IllegalArgumentException("Qualifiers must be closures") - private def impliesRec(tree1: Tree, tree2: Tree): Boolean = + private def impliesRec1(tree1: Tree, tree2: Tree): Boolean = // tree1 = lhs || rhs tree1 match case Apply(select @ Select(lhs, name), List(rhs)) => select.symbol match case d.Boolean_|| => - return impliesRec(lhs, tree2) && impliesRec(rhs, tree2) + return impliesRec1(lhs, tree2) && impliesRec1(rhs, tree2) case _ => () case _ => () @@ -51,32 +51,64 @@ class QualifierSolver(using Context): case Apply(select @ Select(lhs, name), List(rhs)) => select.symbol match case d.Boolean_&& => - return impliesRec(tree1, lhs) && impliesRec(tree1, rhs) + return impliesRec1(tree1, lhs) && impliesRec1(tree1, rhs) case d.Boolean_|| => - return impliesRec(tree1, lhs) || impliesRec(tree1, rhs) + return impliesRec1(tree1, lhs) || impliesRec1(tree1, rhs) case _ => () case _ => () + val tree1Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree1)) + val tree2Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree2)) + + val eqs = topLevelEqualities(tree1Normalized) + if !eqs.isEmpty then + val (tree1Rewritten, tree2Rewritten) = rewriteEquivalences(tree1Normalized, tree2Normalized, eqs) + return impliesRec2(QualifierNormalizer.normalize(tree1Rewritten), QualifierNormalizer.normalize(tree2Rewritten)) + + impliesRec2(tree1Normalized, tree2Normalized) + + def impliesRec2(tree1: Tree, tree2: Tree): Boolean = // tree1 = lhs && rhs tree1 match case Apply(select @ Select(lhs, name), List(rhs)) => select.symbol match case d.Boolean_&& => - return impliesRec(lhs, tree2) || impliesRec(rhs, tree2) + return impliesRec2(lhs, tree2) || impliesRec2(rhs, tree2) case _ => () case _ => () - val tree1Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree1)) - val tree2Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree2)) - - tree2Normalized match - case Literal(Constant(true)) => + tree1 match + case Literal(Constant(false)) => return true case _ => () - tree1Normalized match - case Literal(Constant(false)) => + tree2 match + case Literal(Constant(true)) => return true case _ => () - QualifierAlphaComparer().iso(tree1Normalized, tree2Normalized) + QualifierAlphaComparer().iso(tree1, tree2) + + private def topLevelEqualities(tree: Tree): List[(Tree, Tree)] = + trace(i"topLevelEqualities $tree", Printers.qualifiedTypes): + topLevelEqualitiesImpl(tree) + + private def topLevelEqualitiesImpl(tree: Tree): List[(Tree, Tree)] = + val d = defn + tree match + case Apply(select @ Select(lhs, name), List(rhs)) => + select.symbol match + case d.Int_== | d.Any_== | d.Boolean_== => List((lhs, rhs)) + case d.Boolean_&& => topLevelEqualitiesImpl(lhs) ++ topLevelEqualitiesImpl(rhs) + case _ => Nil + case _ => + Nil + + private def rewriteEquivalences(tree1: Tree, tree2: Tree, eqs: List[(Tree, Tree)]): (Tree, Tree) = + trace(i"rewriteEquivalences $tree1, $tree2, $eqs", Printers.qualifiedTypes): + val egraph = QualifierEGraph() + for (lhs, rhs) <- eqs do + egraph.union(lhs, rhs) + egraph.repair() + (egraph.rewrite(tree1), egraph.rewrite(tree2)) + diff --git a/tests/pos-custom-args/qualified-types/subtyping_equalities.scala b/tests/pos-custom-args/qualified-types/subtyping_equalities.scala new file mode 100644 index 000000000000..d34cad165f0a --- /dev/null +++ b/tests/pos-custom-args/qualified-types/subtyping_equalities.scala @@ -0,0 +1,31 @@ +def f(x: Int): Int = ??? +def g(x: Int): Int = ??? +def f2(x: Int, y: Int): Int = ??? +def g2(x: Int, y: Int): Int = ??? + +def test: Unit = + val a: Int = ??? + val b: Int = ??? + val c: Int = ??? + val d: Int = ??? + + // Equality is reflexive, symmetric and transitive + summon[{v: Int with v == v} <:< {v: Int with true}] + summon[{v: Int with v == a} <:< {v: Int with v == a}] + summon[{v: Int with v == a} <:< {v: Int with a == v}] + summon[{v: Int with a == b} <:< {v: Int with b == a}] + summon[{v: Int with v == a && a > 3} <:< {v: Int with v > 3}] + summon[{v: Int with v == a && a == b} <:< {v: Int with v == b}] + summon[{v: Int with a == b && b == c} <:< {v: Int with a == c}] + summon[{v: Int with a == b && c == b} <:< {v: Int with a == c}] + summon[{v: Int with a == b && c == d && b == d} <:< {v: Int with b == d}] + summon[{v: Int with a == b && c == d && b == d} <:< {v: Int with a == c}] + + // Equality is congruent over functions + summon[{v: Int with a == b} <:< {v: Int with f(a) == f(b)}] + summon[{v: Int with a == b} <:< {v: Int with f(f(a)) == f(f(b))}] + summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with c == d}] + // the two first equalities in the premises are just used to test the behavior + // of the e-graph when `f(a)` and `f(b)` are inserted before `a == b`. + summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with f(a) == f(b)}] + summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with f(f(a)) == f(f(b))}] From 7757e22257aff5b8d0a2b37d557c340cf48f5875 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 12 Jun 2025 14:50:45 +0000 Subject: [PATCH 09/20] Normalize TermRefs in QualifierComparer --- .../dotc/qualified_types/QualifierComparer.scala | 16 +++++++++++++--- .../qualified-types/subtyping_comparisons.scala | 13 +++++++++++++ .../qualified-types/subtyping_unfolding.scala | 8 ++++---- 3 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 tests/pos-custom-args/qualified-types/subtyping_comparisons.scala diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala index 97270de2cdfe..ce85eacd0888 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala @@ -3,17 +3,17 @@ package dotty.tools.dotc.qualified_types import scala.util.hashing.MurmurHash3 as hashing import dotty.tools.dotc.ast.tpd.{closureDef, Apply, Block, DefDef, Ident, Literal, New, Select, Tree, TreeOps, TypeApply, Typed, TypeTree} -import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Contexts.{ctx, Context} import dotty.tools.dotc.core.Decorators.i import dotty.tools.dotc.core.Symbols.Symbol -import dotty.tools.dotc.core.Types.{MethodType, TermRef, Type, TypeVar} +import dotty.tools.dotc.core.Types.{MethodType, NamedType, TermRef, Type, TypeVar} import dotty.tools.dotc.core.Symbols.defn import dotty.tools.dotc.reporting.trace import dotty.tools.dotc.config.Printers private abstract class QualifierComparer: - private def typeIso(tp1: Type, tp2: Type) = + protected def typeIso(tp1: Type, tp2: Type) = val tp1stripped = stripPermanentTypeVar(tp1) val tp2stripped = stripPermanentTypeVar(tp2) tp1stripped.equals(tp2stripped) @@ -91,6 +91,16 @@ private[qualified_types] object QualifierStructuralComparer extends QualifierCom override def hashCode: Int = hash(tree) private[qualified_types] final class QualifierAlphaComparer(using Context) extends QualifierComparer: + override protected def typeIso(tp1: Type, tp2: Type): Boolean = + def normalizeType(tp: Type): Type = + tp match + case tp: TypeVar if tp.isPermanentlyInstantiated => tp.permanentInst + case tp: NamedType => + if tp.symbol.isStatic then tp.symbol.termRef + else normalizeType(tp.prefix).select(tp.symbol) + case tp => tp + super.typeIso(normalizeType(tp1), normalizeType(tp2)) + override def iso(tree1: Tree, tree2: Tree): Boolean = trace(i"iso $tree1 ; $tree2"): (tree1, tree2) match diff --git a/tests/pos-custom-args/qualified-types/subtyping_comparisons.scala b/tests/pos-custom-args/qualified-types/subtyping_comparisons.scala new file mode 100644 index 000000000000..5bed71babd11 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/subtyping_comparisons.scala @@ -0,0 +1,13 @@ +def tp[T](): Boolean = ??? + +class Outer: + class Inner: + class D + summon[{v: Boolean with tp[Inner.this.D]()} =:= {v: Boolean with tp[D]()}] + +object OuterO: + object InnerO: + class D + summon[{v: Boolean with tp[InnerO.this.D]()} =:= {v: Boolean with tp[D]()}] + summon[{v: Boolean with tp[InnerO.D]()} =:= {v: Boolean with tp[D]()}] + summon[{v: Boolean with tp[OuterO.InnerO.D]()} =:= {v: Boolean with tp[D]()}] diff --git a/tests/pos-custom-args/qualified-types/subtyping_unfolding.scala b/tests/pos-custom-args/qualified-types/subtyping_unfolding.scala index a488cb9ef1aa..99ced156cdf5 100644 --- a/tests/pos-custom-args/qualified-types/subtyping_unfolding.scala +++ b/tests/pos-custom-args/qualified-types/subtyping_unfolding.scala @@ -24,8 +24,8 @@ def test: Unit = summon[{v: Int with v == id(x)} <:< {v: Int with v == id(x2)}] summon[{v: Int with v == id(x2)} <:< {v: Int with v == id(x)}] - //summon[{v: Int with v == id(y)} <:< {v: Int with v == id(x + 1)}] // TODO(mbovel): needs normaliazion of type hashes - //summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(y)}] // TODO(mbovel): needs normaliazion of type hashes + summon[{v: Int with v == id(y)} <:< {v: Int with v == id(x + 1)}] + summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(y)}] summon[{v: Int with v == y + 2} <:< {v: Int with v == x + 1 + 2}] summon[{v: Int with v == x + 1 + 2} <:< {v: Int with v == y + 2}] @@ -45,8 +45,8 @@ def test: Unit = summon[{v: Int with v == id(x)} <:< {v: Int with v == id(x2)}] summon[{v: Int with v == id(x2)} <:< {v: Int with v == id(x)}] - //summon[{v: Int with v == id(y)} <:< {v: Int with v == id(x + 1)}] // TODO(mbovel): needs normaliazion of type hashes - //summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(y)}] // TODO(mbovel): needs normaliazion of type hashes + summon[{v: Int with v == id(y)} <:< {v: Int with v == id(x + 1)}] + summon[{v: Int with v == id(x + 1)} <:< {v: Int with v == id(y)}] summon[{v: Int with v == y + 2} <:< {v: Int with v == x + 1 + 2}] summon[{v: Int with v == x + 1 + 2} <:< {v: Int with v == y + 2}] From b38c5b1313eb84857d642ec76ce2a62f20966d99 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 25 Jun 2025 12:23:26 +0000 Subject: [PATCH 10/20] Move normalization and comparison logic to the E-Graph --- .../tools/dotc/qualified_types/EGraph.scala | 400 ++++++++++++++++++ .../tools/dotc/qualified_types/ENode.scala | 104 +++++ .../qualified_types/QualifierComparer.scala | 112 ----- .../qualified_types/QualifierEGraph.scala | 263 ------------ .../qualified_types/QualifierEvaluator.scala | 15 +- .../qualified_types/QualifierNormalizer.scala | 128 ------ .../qualified_types/QualifierSolver.scala | 65 +-- .../subtyping_comparisons.scala | 9 +- .../subtyping_equalities.scala | 4 + .../subtyping_normalization.scala | 4 +- 10 files changed, 533 insertions(+), 571 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/ENode.scala delete mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala delete mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierEGraph.scala delete mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala diff --git a/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala new file mode 100644 index 000000000000..7263e732589c --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala @@ -0,0 +1,400 @@ +package dotty.tools.dotc.qualified_types + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import dotty.tools.dotc.ast.tpd.{ + closureDef, + singleton, + Apply, + ConstantTree, + Ident, + Lambda, + Literal, + New, + Select, + This, + Tree, + TreeMap, + TreeOps, + TypeApply, + TypeTree +} +import dotty.tools.dotc.config.Printers +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Contexts.ctx +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Hashable.Binders +import dotty.tools.dotc.core.Names.Designator +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} +import dotty.tools.dotc.core.Types.{ + CachedProxyType, + ConstantType, + MethodType, + NamedType, + NoPrefix, + SingletonType, + SkolemType, + TermParamRef, + TermRef, + Type, + TypeVar, + ValueType +} +import dotty.tools.dotc.qualified_types.ENode.Op +import dotty.tools.dotc.reporting.trace +import dotty.tools.dotc.transform.TreeExtractors.BinaryOp +import dotty.tools.dotc.util.Spans.Span + +final class EGraph(rootCtx: Context): + + private val represententOf = mutable.Map.empty[ENode, ENode] + + private def representent(node: ENode): ENode = + represententOf.get(node) match + case None => node + case Some(repr) => + assert(repr ne node, s"Node $node has itself as representent") + representent(repr) + + /** Map from child nodes to their parent nodes */ + private val usedBy = mutable.Map.empty[ENode, mutable.Set[ENode]] + + private def uses(node: ENode): mutable.Set[ENode] = + usedBy.getOrElseUpdate(node, mutable.Set.empty) + + private def addUse(node: ENode, parent: ENode): Unit = + require(!represententOf.contains(node), s"Reference $node is not normalized") + uses(node) += parent + + /** Map used for hash-consing nodes, keys and values are the same */ + private val index = mutable.Map.empty[ENode, ENode] + + val trueNode: ENode.Atom = ENode.Atom(ConstantType(Constant(true))(using rootCtx)) + index(trueNode) = trueNode + + val falseNode: ENode.Atom = ENode.Atom(ConstantType(Constant(false))(using rootCtx)) + index(falseNode) = falseNode + + val minusOneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(-1))(using rootCtx)) + index(minusOneIntNode) = minusOneIntNode + + val zeroIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(0))(using rootCtx)) + index(zeroIntNode) = zeroIntNode + + val oneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(1))(using rootCtx)) + index(oneIntNode) = oneIntNode + + val d = defn(using rootCtx) // Need a stable path to match on `defn` members + val builtinOps = Map( + d.Int_== -> Op.Equal, + d.Boolean_== -> Op.Equal, + d.Any_== -> Op.Equal, + d.Boolean_&& -> Op.And, + d.Boolean_|| -> Op.Or, + d.Boolean_! -> Op.Not, + d.Int_+ -> Op.IntSum, + d.Int_* -> Op.IntProduct + ) + + private val worklist = mutable.Queue.empty[ENode] + + override def toString(): String = + val represententsString = represententOf.map((node, repr) => s" $node -> $repr").mkString("\n") + s"EGraph:\n$represententsString\n" + + def equiv(node1: ENode, node2: ENode)(using Context): Boolean = + trace(i"EGraph.equiv", Printers.qualifiedTypes): + val margin = ctx.base.indentTab * (ctx.base.indent) + // println(s"$margin node1: $node1\n$margin node2: $node2") + // Check if the representents of both nodes are the same + val repr1 = representent(node1) + val repr2 = representent(node2) + repr1 eq repr2 + + private def unique(node: ENode): node.type = + index.getOrElseUpdate( + node, { + node match + case ENode.Atom(tp) => + () + case ENode.New(clazz) => + addUse(clazz, node) + case ENode.Select(qual, member) => + addUse(qual, node) + case ENode.Apply(fn, args) => + addUse(fn, node) + args.foreach(addUse(_, node)) + case ENode.OpApply(op, args) => + args.foreach(addUse(_, node)) + case ENode.TypeApply(fn, args) => + addUse(fn, node) + case ENode.Lambda(paramTps, retTp, body) => + addUse(body, node) + node + } + ).asInstanceOf[node.type] + + def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramNodes: List[ENode.ArgRefType] = Nil)(using + Context + ): Option[ENode] = + trace(i"EGraph.toNode $tree", Printers.qualifiedTypes): + computeToNode(tree, paramSyms, paramNodes).map(node => representent(unique(node))) + + private def computeToNode( + tree: Tree, + paramSyms: List[Symbol] = Nil, + paramNodes: List[ENode.ArgRefType] = Nil + )(using currentCtx: Context): Option[ENode] = + trace(i"ENode.computeToNode $tree", Printers.qualifiedTypes): + def normalizeType(tp: Type): Type = + tp match + case tp: TypeVar if tp.isPermanentlyInstantiated => + tp.permanentInst + case tp: NamedType => + if tp.symbol.isStatic then tp.symbol.termRef + else normalizeType(tp.prefix).select(tp.symbol) + case tp => tp + + def mapType(tp: Type): Type = + normalizeType(tp.subst(paramSyms, paramNodes)) + + tree match + case Literal(_) | Ident(_) | This(_) if tree.tpe.isInstanceOf[SingletonType] => + Some(ENode.Atom(mapType(tree.tpe).asInstanceOf[SingletonType])) + case New(clazz) => + for clazzNode <- toNode(clazz, paramSyms, paramNodes) yield ENode.New(clazzNode) + case Select(qual, name) => + for qualNode <- toNode(qual, paramSyms, paramNodes) yield ENode.Select(qualNode, tree.symbol) + case BinaryOp(lhs, op, rhs) if builtinOps.contains(op) => + for + lhsNode <- toNode(lhs, paramSyms, paramNodes) + rhsNode <- toNode(rhs, paramSyms, paramNodes) + yield normalizeOp(builtinOps(op), List(lhsNode, rhsNode)) + case BinaryOp(lhs, d.Int_-, rhs) if lhs.tpe.isInstanceOf[ValueType] && rhs.tpe.isInstanceOf[ValueType] => + for + lhsNode <- toNode(lhs, paramSyms, paramNodes) + rhsNode <- toNode(rhs, paramSyms, paramNodes) + yield normalizeOp(Op.IntSum, List(lhsNode, normalizeOp(Op.IntProduct, List(minusOneIntNode, rhsNode)))) + case Apply(fun, args) => + for + funNode <- toNode(fun, paramSyms, paramNodes) + argsNodes <- args.map(toNode(_, paramSyms, paramNodes)).sequence + yield ENode.Apply(funNode, argsNodes) + case TypeApply(fun, args) => + for funNode <- toNode(fun, paramSyms, paramNodes) + yield ENode.TypeApply(funNode, args.map(tp => mapType(tp.tpe))) + case closureDef(defDef) => + defDef.symbol.info.dealias match + case mt: MethodType => + assert(defDef.termParamss.size == 1, "closures have a single parameter list, right?") + val params = defDef.termParamss.head + val myParamSyms = params.map(_.symbol) + + val myParamTps: ArrayBuffer[Type] = ArrayBuffer.empty + ??? + + val myRetTp = ??? + + val myParamNodes = myParamTps.zipWithIndex.map((tp, i) => ENode.ArgRefType(i, tp)).toList + + for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamNodes ::: paramNodes) + yield ENode.Lambda(myParamTps.toList, myRetTp, body) + case _ => None + case _ => + None + + private def canonicalize(node: ENode): ENode = + representent(unique( + node match + case ENode.Atom(tp) => + node + case ENode.New(clazz) => + ENode.New(representent(clazz)) + case ENode.Select(qual, member) => + ENode.Select(representent(qual), member) + case ENode.Apply(fn, args) => + ENode.Apply(representent(fn), args.map(representent)) + case ENode.OpApply(op, args) => + normalizeOp(op, args.map(representent)) + case ENode.TypeApply(fn, args) => + ENode.TypeApply(representent(fn), args) + case ENode.Lambda(paramTps, retTp, body) => + + ENode.Lambda(paramTps, retTp, representent(body)) + )) + + private def normalizeOp(op: ENode.Op, args: List[ENode]): ENode = + op match + case Op.Equal => + assert(args.size == 2, s"Expected 2 arguments for equality, got $args") + if args(0) eq args(1) then trueNode + else ENode.OpApply(op, args.sortBy(_.hashCode())) + case Op.And => + assert(args.size == 2, s"Expected 2 arguments for conjunction, got $args") + if (args(0) eq falseNode) || (args(1) eq falseNode) then falseNode + else if args(0) eq trueNode then args(1) + else if args(1) eq trueNode then args(0) + else ENode.OpApply(op, args) + case Op.Or => + assert(args.size == 2, s"Expected 2 arguments for disjunction, got $args") + if (args(0) eq trueNode) || (args(1) eq trueNode) then trueNode + else if args(0) eq falseNode then args(1) + else if args(1) eq falseNode then args(0) + else ENode.OpApply(op, args) + case Op.IntProduct => + val (consts, nonConsts) = decomposeIntProduct(args) + makeIntProduct(consts, nonConsts) + case Op.IntSum => + val (const, nonConsts) = decomposeIntSum(args) + makeIntSum(const, nonConsts) + case _ => + ENode.OpApply(op, args) + + private def decomposeIntProduct(args: List[ENode]): (Int, List[ENode]) = + val factors = + args.flatMap: + case ENode.OpApply(Op.IntProduct, innerFactors) => innerFactors + case arg => List(arg) + val (consts, nonConsts) = + factors.partitionMap: + case ENode.Atom(ConstantType(Constant(c: Int))) => Left(c) + case factor => Right(factor) + (consts.product, nonConsts.sortBy(_.hashCode())) + + private def makeIntProduct(const: Int, nonConsts: List[ENode]): ENode = + if const == 0 then + zeroIntNode + else if const == 1 then + if nonConsts.isEmpty then oneIntNode + else if nonConsts.size == 1 then nonConsts.head + else ENode.OpApply(Op.IntProduct, nonConsts) + else + val constNode = unique(ENode.Atom(ConstantType(Constant(const))(using rootCtx))) + nonConsts match + case Nil => + constNode + case List(ENode.OpApply(Op.IntSum, summands)) => + ENode.OpApply(Op.IntSum, summands.map(summand => normalizeOp(Op.IntProduct, List(constNode, summand)))) + case _ => + ENode.OpApply(Op.IntProduct, constNode :: nonConsts) + + private def decomposeIntSum(args: List[ENode]): (Int, List[ENode]) = + val summands: List[ENode] = + args.flatMap: + case ENode.OpApply(Op.IntSum, innerSummands) => innerSummands + case arg => List(arg) + val decomposed: List[(Int, List[ENode])] = + summands.map: + case ENode.OpApply(Op.IntProduct, args) => + args match + case ENode.Atom(ConstantType(Constant(const: Int))) :: nonConsts => (const, nonConsts) + case nonConsts => (1, nonConsts) + case ENode.Atom(ConstantType(Constant(const: Int))) => (const, Nil) + case other => (1, List(other)) + val grouped = decomposed.groupMapReduce(_._2)(_._1)(_ + _) + val const = grouped.getOrElse(Nil, 0) + val nonConsts = + grouped + .toList + .filter((nonConsts, const) => const != 0 && !nonConsts.isEmpty) + .sortBy((nonConsts, const) => nonConsts.hashCode()) + .map((nonConsts, const) => makeIntProduct(const, nonConsts)) + (const, nonConsts) + + private def makeIntSum(const: Int, nonConsts: List[ENode]): ENode = + if const == 0 then + if nonConsts.isEmpty then zeroIntNode + else if nonConsts.size == 1 then nonConsts.head + else ENode.OpApply(Op.IntSum, nonConsts) + else + val constNode = unique(ENode.Atom(ConstantType(Constant(const))(using rootCtx))) + if nonConsts.isEmpty then constNode + else ENode.OpApply(Op.IntSum, constNode :: nonConsts) + + private def order(a: ENode, b: ENode): (ENode, ENode) = + (a, b) match + case (ENode.Atom(_: ConstantType), _) => (a, b) + case (_, ENode.Atom(_: ConstantType)) => (b, a) + case (ENode.Atom(_: SkolemType), _) => (a, b) + case (_, ENode.Atom(_: SkolemType)) => (b, a) + case (_: ENode.Atom, _) => (a, b) + case (_, _: ENode.Atom) => (b, a) + case (_: ENode.New, _) => (a, b) + case (_, _: ENode.New) => (b, a) + case (_: ENode.Select, _) => (a, b) + case (_, _: ENode.Select) => (b, a) + case (_: ENode.Apply, _) => (a, b) + case (_, _: ENode.Apply) => (b, a) + case (_: ENode.TypeApply, _) => (a, b) + case (_, _: ENode.TypeApply) => (b, a) + case _ => (a, b) + + def merge(a: ENode, b: ENode): Unit = + val aRepr = representent(a) + val bRepr = representent(b) + if aRepr eq bRepr then return + assert(aRepr != bRepr, s"$aRepr and $bRepr are `equals` but not `eq`") + + // TODO(mbovel): if both nodes are objects, recursively merge their arguments + + /// Update represententOf and usedBy maps + val (newRepr, oldRepr) = order(aRepr, bRepr) + represententOf(oldRepr) = newRepr + uses(newRepr) ++= uses(oldRepr) + val oldUses = uses(oldRepr) + usedBy.remove(oldRepr) + + // Propagate truth values over disjunctions, conjunctions and equalities + oldRepr match + case ENode.OpApply(Op.And, args) if newRepr eq trueNode => + args.foreach(merge(_, trueNode)) + case ENode.OpApply(Op.Or, args) if newRepr eq falseNode => + args.foreach(merge(_, falseNode)) + case ENode.OpApply(Op.Equal, args) if newRepr eq trueNode => + args.foreach(arg => merge(args(0), args(1))) + case _ => + () + + // Enqueue all nodes that use the oldRepr for repair + worklist.enqueueAll(oldUses) + + def repair(): Unit = + while !worklist.isEmpty do + val head = worklist.dequeue() + val headRepr = representent(head) + val headCanonical = canonicalize(head) + if headRepr ne headCanonical then + merge(headRepr, headCanonical) + + private def toTree(node: ENode, paramRefs: List[Tree])(using Context): Tree = + node match + case ENode.Atom(tp) => + singleton(tp) + case ENode.New(clazz) => + New(toTree(clazz, paramRefs)) + case ENode.Select(qual, member) => + toTree(qual, paramRefs).select(member) + case ENode.Apply(fn, args) => + Apply(toTree(fn, paramRefs), args.map(toTree(_, paramRefs))) + case ENode.OpApply(op, args) => + ??? + case ENode.TypeApply(fn, args) => + TypeApply(toTree(fn, paramRefs), args.map(TypeTree(_, false))) + case ENode.Lambda(paramTps, retTp, body) => + ??? + + extension [T](xs: List[Option[T]]) + private def sequence: Option[List[T]] = + var result = List.newBuilder[T] + var current = xs + while current.nonEmpty do + current.head match + case Some(x) => + result += x + current = current.tail + case None => + return None + Some(result.result()) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala new file mode 100644 index 000000000000..acbd4b1e5ca2 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala @@ -0,0 +1,104 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Hashable.Binders +import dotty.tools.dotc.core.Names.Designator +import dotty.tools.dotc.core.Names.Name +import dotty.tools.dotc.core.StdNames.nme +import dotty.tools.dotc.core.StdNames.tpnme +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.Types.{ + CachedProxyType, + ConstantType, + MethodType, + NamedType, + NoPrefix, + SingletonType, + TermRef, + ThisType, + Type +} + +enum ENode: + import ENode.* + + case Atom(tp: SingletonType) + case New(clazz: ENode) + case Select(qual: ENode, member: Symbol) + case Apply(fn: ENode, args: List[ENode]) + case OpApply(fn: ENode.Op, args: List[ENode]) + case TypeApply(fn: ENode, args: List[Type]) + case Lambda(paramTps: List[Type], retTp: Type, body: ENode) + + override def toString(): String = + this match + case Atom(tp) => typeToString(tp) + case New(clazz) => s"new $clazz" + case Select(qual, member) => s"$qual.${designatorToString(member)}" + case Apply(fn, args) => s"$fn(${args.mkString(", ")})" + case OpApply(op, args) => s"(${args.mkString(op.operatorString())})" + case TypeApply(fn, args) => s"$fn[${args.map(typeToString).mkString(", ")}]" + case Lambda(paramTps, retTp, body) => + s"(${paramTps.map(typeToString).mkString(", ")}): ${{ typeToString(retTp) }} => $body" + +object ENode: + private def typeToString(tp: Type): String = + tp match + case tp: NamedType => + val prefixString = if isEmptyPrefix(tp.prefix) then "" else typeToString(tp.prefix) + "." + prefixString + designatorToString(tp.designator) + case tp: ConstantType => + tp.value.value.toString + case tp: ThisType => + typeToString(tp.tref) + ".this" + case _ => + tp.toString + + private def isEmptyPrefix(tp: Type): Boolean = + tp match + case tp: NoPrefix.type => + true + case tp: ThisType => + tp.tref.designator match + case d: Symbol => d.lastKnownDenotation.name.toTermName == nme.EMPTY_PACKAGE + case _ => false + case _ => false + + private def designatorToString(d: Designator): String = + d match + case d: Symbol => d.lastKnownDenotation.name.toString + case _ => d.toString + + enum Op: + case IntSum + case IntProduct + case LongSum + case LongProduct + case Equal + case Not + case And + case Or + case LessThan + + def operatorString(): String = + this match + case IntSum => "+" + case IntProduct => "*" + case LongSum => "+" + case LongProduct => "*" + case Equal => "==" + case Not => "!" + case And => "&&" + case Or => "||" + case LessThan => "<" + + /** Reference to the argument of an [[ENode.Lambda]]. + * + * @param index  + * Debruijn index of the argument, starting from 0 + * @param underyling + * Underlying type of the argument + */ + final case class ArgRefType(index: Int, underlying: Type) extends CachedProxyType, SingletonType: + override def underlying(using Context): Type = underlying + override def computeHash(bs: Binders): Int = doHash(bs, index, underlying) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala deleted file mode 100644 index ce85eacd0888..000000000000 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierComparer.scala +++ /dev/null @@ -1,112 +0,0 @@ -package dotty.tools.dotc.qualified_types - -import scala.util.hashing.MurmurHash3 as hashing - -import dotty.tools.dotc.ast.tpd.{closureDef, Apply, Block, DefDef, Ident, Literal, New, Select, Tree, TreeOps, TypeApply, Typed, TypeTree} -import dotty.tools.dotc.core.Contexts.{ctx, Context} -import dotty.tools.dotc.core.Decorators.i -import dotty.tools.dotc.core.Symbols.Symbol -import dotty.tools.dotc.core.Types.{MethodType, NamedType, TermRef, Type, TypeVar} -import dotty.tools.dotc.core.Symbols.defn - -import dotty.tools.dotc.reporting.trace -import dotty.tools.dotc.config.Printers - -private abstract class QualifierComparer: - protected def typeIso(tp1: Type, tp2: Type) = - val tp1stripped = stripPermanentTypeVar(tp1) - val tp2stripped = stripPermanentTypeVar(tp2) - tp1stripped.equals(tp2stripped) - - /** Structural equality for trees. - * - * This implementation is _not_ alpha-equivalence aware, which - * 1. allows it not to rely on a [[Context]] and, - * 2. allows the corresponding [[hash]] method to reuse [[Type#hashCode]] - * instead of defining an other hash code for types. - */ - def iso(tree1: Tree, tree2: Tree): Boolean = - (tree1, tree2) match - case (Literal(_) | Ident(_), _) => - typeIso(tree1.tpe, tree2.tpe) - case (Select(qual1, name1), Select(qual2, name2)) => - name1 == name2 && iso(qual1, qual2) - case (Apply(fun1, args1), Apply(fun2, args2)) => - iso(fun1, fun2) && args1.corresponds(args2)(iso) - case (TypeApply(fun1, args1), TypeApply(fun2, args2)) => - iso(fun1, fun2) && args1.corresponds(args2)((arg1, arg2) => typeIso(arg1.tpe, arg2.tpe)) - case (tpt1: TypeTree, tpt2: TypeTree) => - typeIso(tpt1.tpe, tpt2.tpe) - case (Typed(expr1, tpt1), Typed(expr2, tpt2)) => - iso(expr1, expr2) && typeIso(tpt1.tpe, tpt2.tpe) - case (New(tpt1), New(tpt2)) => - typeIso(tpt1.tpe, tpt2.tpe) - case (Block(stats1, expr1), Block(stats2, expr2)) => - stats1.corresponds(stats2)(iso) && iso(expr1, expr2) - case _ => - tree1.equals(tree2) - - protected def stripPermanentTypeVar(tp: Type): Type = - tp match - case tp: TypeVar if tp.isPermanentlyInstantiated => tp.permanentInst - case tp => tp - -private[qualified_types] object QualifierStructuralComparer extends QualifierComparer: - /** A hash code for trees that corresponds to `iso(tree1, tree2)`. */ - def hash(tree: Tree): Int = - tree match - case Literal(_) | Ident(_) => - hashType(tree.tpe) - case Select(qual, name) => - hashing.mix(name.hashCode, hash(qual)) - case Apply(fun, args) => - hashing.mix(hash(fun), hashList(args)) - case TypeApply(fun, args) => - hashing.mix(hash(fun), hashList(args)) - case tpt: TypeTree => - hashType(tpt.tpe) - case Typed(expr, tpt) => - hashing.mix(hash(expr), hashType(tpt.tpe)) - case New(tpt1) => - hashType(tpt1.tpe) - case Block(stats, expr) => - hashing.mix(hashList(stats), hash(expr)) - case _ => - tree.hashCode - - private def hashList(trees: List[Tree]): Int = - trees.map(hash).foldLeft(0)(hashing.mix) - - private def hashType(tp: Type): Int = - stripPermanentTypeVar(tp).hashCode - - /** A box for trees that implements structural equality using [[iso]] and - * [[hash]]. This enables using trees as keys in hash maps. - */ - final class TreeBox(val tree: Tree) extends AnyVal: - override def equals(that: Any): Boolean = that match - case that: TreeBox => iso(tree, that.tree) - case _ => false - - override def hashCode: Int = hash(tree) - -private[qualified_types] final class QualifierAlphaComparer(using Context) extends QualifierComparer: - override protected def typeIso(tp1: Type, tp2: Type): Boolean = - def normalizeType(tp: Type): Type = - tp match - case tp: TypeVar if tp.isPermanentlyInstantiated => tp.permanentInst - case tp: NamedType => - if tp.symbol.isStatic then tp.symbol.termRef - else normalizeType(tp.prefix).select(tp.symbol) - case tp => tp - super.typeIso(normalizeType(tp1), normalizeType(tp2)) - - override def iso(tree1: Tree, tree2: Tree): Boolean = - trace(i"iso $tree1 ; $tree2"): - (tree1, tree2) match - case (closureDef(def1), closureDef(def2)) => - val def2substituted = def2.rhs.subst(def2.symbol.paramSymss.flatten, def1.symbol.paramSymss.flatten) - val def2normalized = QualifierNormalizer.normalize(def2substituted) - iso(def1.rhs, def2normalized) - case _ => - super.iso(tree1, tree2) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierEGraph.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierEGraph.scala deleted file mode 100644 index 66074d845011..000000000000 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierEGraph.scala +++ /dev/null @@ -1,263 +0,0 @@ -package dotty.tools.dotc.qualified_types - -import scala.collection.mutable - -import dotty.tools.dotc.ast.tpd.{ - Apply, - ConstantTree, - Ident, - Literal, - New, - Select, - Tree, - TreeMap, - TreeOps, - TypeApply, - TypeTree -} -import dotty.tools.dotc.core.Constants.Constant -import dotty.tools.dotc.core.Contexts.Context -import dotty.tools.dotc.core.Decorators.i -import dotty.tools.dotc.core.Names.Designator -import dotty.tools.dotc.core.StdNames.nme -import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} -import dotty.tools.dotc.core.Types.{ConstantType, NoPrefix, SingletonType, TermRef, Type} -import dotty.tools.dotc.transform.TreeExtractors.BinaryOp -import dotty.tools.dotc.util.Spans.Span - -private enum ENode: - case Const(value: Constant) - case Ref(tp: TermRef) - case Object(clazz: Symbol, args: List[ENode]) - case Select(qual: ENode, member: Symbol) - case App(fn: ENode, args: List[ENode]) - case TypeApp(fn: ENode, args: List[Type]) - - override def toString(): String = - this match - case Const(value) => value.toString - case Ref(tp) => termRefToString(tp) - case Object(clazz, args) => s"#$clazz(${args.mkString(", ")})" - case Select(qual, member) => s"$qual..$member" - case App(fn, args) => s"$fn(${args.mkString(", ")})" - case TypeApp(fn, args) => s"$fn[${args.mkString(", ")}]" - - private def designatorToString(d: Designator): String = - d match - case d: Symbol => d.lastKnownDenotation.name.toString - case _ => d.toString - - private def termRefToString(tp: Type): String = - tp match - case tp: TermRef => - val pre = if tp.prefix == NoPrefix then "" else termRefToString(tp.prefix) + "." - pre + designatorToString(tp.designator) - case _ => - tp.toString - -final class QualifierEGraph: - private val represententOf = mutable.Map.empty[ENode, ENode] - - private def representent(node: ENode): ENode = - represententOf.get(node) match - case None => node - case Some(repr) => - val res = representent(repr) // avoid tailrec optimization - res - - /** Map from child nodes to their parent nodes */ - private val usedBy = mutable.Map.empty[ENode, mutable.Set[ENode]] - - private def uses(node: ENode): mutable.Set[ENode] = - usedBy.getOrElseUpdate(node, mutable.Set.empty) - - /** Map used for hash-consing nodes, keys and values are the same */ - private val index = mutable.Map.empty[ENode, ENode] - - private val worklist = mutable.Queue.empty[ENode] - - final def union(tree1: Tree, tree2: Tree)(using Context): Unit = - for node1 <- toNode(tree1); node2 <- toNode(tree2) do - merge(node1, node2) - - private def unique(node: ENode): node.type = - index.getOrElseUpdate( - node, { - node match - case ENode.Const(value) => - () - case ENode.Ref(tp) => - () - case ENode.Object(clazz, args) => - args.foreach(uses(_) += node) - case ENode.Select(qual, member) => - uses(qual) += node - case ENode.App(fn, args) => - uses(fn) += node - args.foreach(uses(_) += node) - case ENode.TypeApp(fn, args) => - uses(fn) += node - node - } - ).asInstanceOf[node.type] - - private val toNodeCache = mutable.WeakHashMap.empty[Tree, Option[ENode]] - - private def toNode(tree: Tree)(using Context): Option[ENode] = - toNodeCache.getOrElseUpdate(tree, computeToNode(tree).map(n => representent(unique(n)))) - - private def computeToNode(tree: Tree)(using Context): Option[ENode] = - tree match - case ConstantTree(constant) => - Some(ENode.Const(constant)) - case Ident(_) => - tree.tpe match - case tp: TermRef => Some(ENode.Ref(tp)) - case _ => None - case Apply(Select(clazz, nme.CONSTRUCTOR), args) if isCaseClass(clazz.symbol) => - for argsNodes <- args.map(toNode).sequence yield ENode.Object(clazz.symbol, argsNodes) - case Select(qual, name) if isCaseClassField(tree.symbol) => - for qualNode <- toNode(qual) yield qualNode match - case ENode.Object(_, args) => args(caseClassFieldIndex(tree.symbol)) - case qualNode => ENode.Select(qualNode, tree.symbol) - case Apply(fun, args) => - for funNode <- toNode(fun); argsNodes <- args.map(toNode).sequence yield ENode.App(funNode, argsNodes) - case TypeApply(fun, args) => - for funNode <- toNode(fun) yield ENode.TypeApp(funNode, args.map(_.tpe)) - case _ => - return None - - private object RefTypeTree: - def unapply(tree: Tree): Option[TermRef] = - tree.tpe match - case tp: TermRef => Some(tp) - case _ => None - - private def isCaseClass(sym: Symbol): Boolean = - // TODO(mbovel) - false - - private def isCaseClassField(sym: Symbol): Boolean = - // TODO(mbovel) - false - - private def caseClassFieldIndex(sym: Symbol): Int = - // TODO(mbovel) - ??? - - private def canonicalize(node: ENode): ENode = - representent(unique( - node match - case ENode.Const(value) => - node - case ENode.Ref(tp) => - node - case ENode.Object(clazz, args) => - val argsNodes = args.map(representent) - ENode.Object(clazz, argsNodes) - case ENode.Select(qual, member) => - representent(qual) match - case ENode.Object(_, args) => - args(caseClassFieldIndex(member)) - case qualRepr => - ENode.Select(qualRepr, member) - case ENode.App(fn, args) => - val fnNode = representent(fn) - val argsNodes = args.map(representent) - ENode.App(fnNode, argsNodes) - case ENode.TypeApp(fn, args) => - val fnNode = representent(fn) - ENode.TypeApp(fnNode, args) - )) - - private def order(a: ENode, b: ENode): (ENode, ENode) = - (a, b) match - case (_: ENode.Const, _) => (a, b) - case (_, _: ENode.Const) => (b, a) - case (_: ENode.Ref, _) => (a, b) - case (_, _: ENode.Ref) => (b, a) - case (_: ENode.Object, _) => (a, b) - case (_, _: ENode.Object) => (b, a) - case (_: ENode.Select, _) => (a, b) - case (_, _: ENode.Select) => (b, a) - case (_: ENode.App, _) => (a, b) - case (_, _: ENode.App) => (b, a) - case _ => (a, b) - - private def merge(a: ENode, b: ENode): Unit = - val aRepr = representent(a) - val bRepr = representent(b) - if aRepr eq bRepr then return - - // If both nodes are objects, recursively merge their arguments - (aRepr, bRepr) match - case (ENode.Object(clazzA, argsA), ENode.Object(clazzB, argsB)) if clazzA == clazzB => - argsA.zip(argsB).foreach(merge) - case _ => () - - /// Update represententOf and usedBy maps - val (newRepr, oldRepr) = order(aRepr, bRepr) - represententOf(oldRepr) = newRepr - uses(newRepr) ++= uses(oldRepr) - val oldUses = uses(oldRepr) - usedBy.remove(oldRepr) - - // Enqueue all nodes that use the oldRepr for repair - worklist.enqueueAll(oldUses) - - def repair(): Unit = - while !worklist.isEmpty do - val head = worklist.dequeue() - val headRepr = representent(head) - val headCanonical = canonicalize(head) - if headRepr ne headCanonical then - merge(headRepr, headCanonical) - - // Rewrite equivalent nodes in the tree to their canonical form - def rewrite(tree: Tree)(using Context): Tree = - Rewriter().transform(tree) - - private class Rewriter extends TreeMap: - override def transform(tree: Tree)(using Context): Tree = - toNode(tree) match - case Some(n) => toTree(representent(n)) - case None => - val d = defn - tree match - case BinaryOp(a, d.Int_== | d.Any_== | d.Boolean_==, b) => - (toNode(a), toNode(b)) match - case (Some(aNode), Some(bNode)) => - if representent(aNode) eq representent(bNode) then Literal(Constant(true)) - else super.transform(tree) - case _ => - super.transform(tree) - case _ => - super.transform(tree) - - private def toTree(node: ENode)(using Context): Tree = - node match - case ENode.Const(value) => - Literal(value) - case ENode.Ref(tp) => - Ident(tp) - case ENode.Object(clazz, args) => - New(clazz.typeRef, args.map(toTree)) - case ENode.Select(qual, member) => - toTree(qual).select(member) - case ENode.App(fn, args) => - Apply(toTree(fn), args.map(toTree)) - case ENode.TypeApp(fn, args) => - TypeApply(toTree(fn), args.map(TypeTree(_, false))) - - extension [T](xs: List[Option[T]]) - def sequence: Option[List[T]] = - var result = List.newBuilder[T] - var current = xs - while current.nonEmpty do - current.head match - case Some(x) => - result += x - current = current.tail - case None => - return None - Some(result.result()) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala index 3c260a5c86e2..ce20a38aee4a 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala @@ -61,7 +61,7 @@ private class QualifierEvaluator(args: Map[Symbol, Tree]) extends TreeMap: tree match case tree: Apply => val treeTransformed = super.transform(tree) - constFold(treeTransformed).orElse(reduceBinaryOp(treeTransformed)).orElse(treeTransformed) + constFold(treeTransformed).orElse(treeTransformed) case tree: Select => val treeTransformed = super.transform(tree) constFold(treeTransformed).orElse(treeTransformed) @@ -75,19 +75,6 @@ private class QualifierEvaluator(args: Map[Symbol, Tree]) extends TreeMap: case ConstantTree(c: Constant) => Literal(c) case _ => EmptyTree - private def reduceBinaryOp(tree: Tree)(using Context): Tree = - val d = defn // Need a stable path to match on `defn` members - tree match - case BinaryOp(a, d.Int_== | d.Any_== | d.Boolean_==, b) => - val aNormalized = QualifierNormalizer.normalize(a) - val bNormalized = QualifierNormalizer.normalize(b) - if QualifierAlphaComparer().iso(aNormalized, bNormalized) then - Literal(Constant(true)) - else - EmptyTree - case _ => - EmptyTree - private def unfold(tree: Tree)(using Context): Tree = args.get(tree.symbol) match case Some(tree2) => diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala deleted file mode 100644 index 7a0560a38006..000000000000 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierNormalizer.scala +++ /dev/null @@ -1,128 +0,0 @@ -package dotty.tools.dotc.qualified_types - -import dotty.tools.dotc.ast.tpd.{singleton, Apply, Block, Literal, Select, Tree, TreeMap, given} -import dotty.tools.dotc.core.Atoms -import dotty.tools.dotc.core.Constants.Constant -import dotty.tools.dotc.core.Contexts.Context -import dotty.tools.dotc.core.Decorators.i -import dotty.tools.dotc.core.Symbols.{defn, Symbol} -import dotty.tools.dotc.core.Types.{ConstantType, TermRef} -import dotty.tools.dotc.config.Printers - -import dotty.tools.dotc.reporting.trace -import dotty.tools.dotc.config.Printers - -private[qualified_types] object QualifierNormalizer: - def normalize(tree: Tree)(using Context): Tree = - trace(i"normalize $tree", Printers.qualifiedTypes): - QualifierNormalizer().transform(tree) - -/** A [[TreeMap]] that normalizes trees by applying algebraic simplifications - * and by ordering operands. - * - * Entry point: [[QualifierNormalizer.normalize]]. - */ -private class QualifierNormalizer extends TreeMap: - override def transform(tree: Tree)(using Context): Tree = - val d = defn // Need a stable path to match on `defn` members - tree match - case Apply(method, _) => - method.symbol match - case d.Int_+ => normalizeIntSum(tree) - case d.Int_* => normalizeIntProduct(tree) - case _ => super.transform(tree) - case _ => super.transform(tree) - - /** Normalizes a tree representing an integer sum. - * - * The normalization consists in: - * - Grouping summands which have the same non-constant factors, such that - * `3x + x` becomes `4x` for example. - * - Sorting the summands, so that for example `x + y` and `y + x` are - * normalized to the same tree. - * - Simplifying `0 + x` to `x`. - * - Normalizing each summand using [[normalizeIntProduct]]. - */ - private def normalizeIntSum(tree: Tree)(using Context): Tree = - val (summands, const) = decomposeIntSum(tree) - makeIntSum(summands, const) - - /** Decomposes a tree representing an integer sum into a list of non-constant - * summands `s_i` and a constant `c`. The summands are grouped and sorted as - * described in [[normalizeIntSum]]. - */ - private def decomposeIntSum(tree: Tree)(using Context): (List[Tree], Int) = - val groups: Map[List[QualifierStructuralComparer.TreeBox], Int] = - getAllArguments(tree, defn.Int_+) - .map(decomposeIntProduct) - .groupMapReduce(_._1.map(QualifierStructuralComparer.TreeBox.apply))(_._2)(_ + _) - val const = groups.getOrElse(Nil, 0) - val summands = - groups - .filter((args, c) => c != 0 && !args.isEmpty) - .toList - .sortBy((pair: (List[QualifierStructuralComparer.TreeBox], Int)) => pair.hashCode()) - .map((args, c) => makeIntProduct(args.map(_.tree), c)) - (summands, const) - - /** Constructs a tree representing an integer sum from a list of non-constant - * summands `summands` and a constant `const`. - */ - private def makeIntSum(summands: List[Tree], const: Int)(using Context): Tree = - if summands.isEmpty then - Literal(Constant(const)) - else - val summandsTree = summands.reduce(_.select(defn.Int_+).appliedTo(_)) - if const == 0 then summandsTree - else Literal(Constant(const)).select(defn.Int_+).appliedTo(summandsTree) - - /** Normalizes a tree representing an integer product. - * - * The normalization consists in: - * - Sorting the factors, so that for example `x * y` and `y * x` are - * normalized to the same tree. - * - Simplifying `0 * x` to `0`. - * - Simplifying `1 * x` to `x`. - */ - private def normalizeIntProduct(tree: Tree)(using Context): Tree = - val (factors, const) = decomposeIntProduct(tree) - makeIntProduct(factors, const) - - /** Decomposes a tree representing an integer product into a sorted list of - * non-constant factors `f_i` and a constant `c`. - */ - private def decomposeIntProduct(tree: Tree)(using Context): (List[Tree], Int) = - val (consts, factors) = - getAllArguments(tree, defn.Int_*) - .map(transform) - .partitionMap: - case Literal(Constant(n: Int)) => Left(n) - case arg => Right(arg) - (factors.sortBy(QualifierStructuralComparer.hash), consts.product) - - /** Constructs a tree representing an integer product from a sorted list of - * non-constant factors `factors` and a constant `const`. - */ - private def makeIntProduct(factors: List[Tree], const: Int)(using Context): Tree = - if const == 0 then - Literal(Constant(0)) - else if factors.isEmpty then - Literal(Constant(const)) - else - val factorsTree = factors.reduce(_.select(defn.Int_*).appliedTo(_)) - if const == 1 then factorsTree - else Literal(Constant(const)).select(defn.Int_*).appliedTo(factorsTree) - - /** Recursively collects all arguments of an n-ary operation. - * - * For example, given the tree `(a + (b * c)) + (d + e)`, the method returns - * the list `[a, b * c, d, e]` when called with the `+` operator. - */ - private def getAllArguments(tree: Tree, op: Symbol)(using Context): List[Tree] = - tree match - case Apply(method @ Select(qual, _), List(arg)) if method.symbol == op => - getAllArguments(qual, op) ::: getAllArguments(arg, op) - case Block(Nil, expr) => - getAllArguments(expr, op) - case _ => - List(transform(tree)) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala index bc0be46dbff6..75e7fd6afa0d 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala @@ -4,19 +4,14 @@ import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.ast.tpd.{closureDef, singleton, Apply, Ident, Literal, Select, Tree, given} import dotty.tools.dotc.config.Printers import dotty.tools.dotc.core.Constants.Constant -import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Contexts.{ctx, Context} import dotty.tools.dotc.core.Decorators.i import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} -import dotty.tools.dotc.core.Types.{TermRef} -import dotty.tools.dotc.transform.BetaReduce - +import dotty.tools.dotc.core.Types.TermRef import dotty.tools.dotc.reporting.trace -import dotty.tools.dotc.config.Printers +import dotty.tools.dotc.transform.BetaReduce class QualifierSolver(using Context): - private val litTrue = Literal(Constant(true)) - private val litFalse = Literal(Constant(false)) - val d = defn // Need a stable path to match on `defn` members def implies(tree1: Tree, tree2: Tree) = @@ -57,37 +52,16 @@ class QualifierSolver(using Context): case _ => () case _ => () - val tree1Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree1)) - val tree2Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree2)) - - val eqs = topLevelEqualities(tree1Normalized) - if !eqs.isEmpty then - val (tree1Rewritten, tree2Rewritten) = rewriteEquivalences(tree1Normalized, tree2Normalized, eqs) - return impliesRec2(QualifierNormalizer.normalize(tree1Rewritten), QualifierNormalizer.normalize(tree2Rewritten)) - - impliesRec2(tree1Normalized, tree2Normalized) - - def impliesRec2(tree1: Tree, tree2: Tree): Boolean = - // tree1 = lhs && rhs - tree1 match - case Apply(select @ Select(lhs, name), List(rhs)) => - select.symbol match - case d.Boolean_&& => - return impliesRec2(lhs, tree2) || impliesRec2(rhs, tree2) - case _ => () - case _ => () - - tree1 match - case Literal(Constant(false)) => - return true - case _ => () - - tree2 match - case Literal(Constant(true)) => - return true - case _ => () - - QualifierAlphaComparer().iso(tree1, tree2) + val egraph = EGraph(ctx) + //println(s"tree implies $tree1 -> $tree2") + (egraph.toNode(QualifierEvaluator.evaluate(tree1)), egraph.toNode(QualifierEvaluator.evaluate(tree2))) match + case (Some(node1), Some(node2)) => + //println(s"node implies $node1 -> $node2") + egraph.merge(node1, egraph.trueNode) + egraph.repair() + egraph.equiv(node2, egraph.trueNode) + case _ => + false private def topLevelEqualities(tree: Tree): List[(Tree, Tree)] = trace(i"topLevelEqualities $tree", Printers.qualifiedTypes): @@ -99,16 +73,7 @@ class QualifierSolver(using Context): case Apply(select @ Select(lhs, name), List(rhs)) => select.symbol match case d.Int_== | d.Any_== | d.Boolean_== => List((lhs, rhs)) - case d.Boolean_&& => topLevelEqualitiesImpl(lhs) ++ topLevelEqualitiesImpl(rhs) - case _ => Nil + case d.Boolean_&& => topLevelEqualitiesImpl(lhs) ++ topLevelEqualitiesImpl(rhs) + case _ => Nil case _ => Nil - - private def rewriteEquivalences(tree1: Tree, tree2: Tree, eqs: List[(Tree, Tree)]): (Tree, Tree) = - trace(i"rewriteEquivalences $tree1, $tree2, $eqs", Printers.qualifiedTypes): - val egraph = QualifierEGraph() - for (lhs, rhs) <- eqs do - egraph.union(lhs, rhs) - egraph.repair() - (egraph.rewrite(tree1), egraph.rewrite(tree2)) - diff --git a/tests/pos-custom-args/qualified-types/subtyping_comparisons.scala b/tests/pos-custom-args/qualified-types/subtyping_comparisons.scala index 5bed71babd11..153dfb2bee64 100644 --- a/tests/pos-custom-args/qualified-types/subtyping_comparisons.scala +++ b/tests/pos-custom-args/qualified-types/subtyping_comparisons.scala @@ -2,12 +2,17 @@ def tp[T](): Boolean = ??? class Outer: class Inner: - class D + type D summon[{v: Boolean with tp[Inner.this.D]()} =:= {v: Boolean with tp[D]()}] object OuterO: object InnerO: - class D + type D summon[{v: Boolean with tp[InnerO.this.D]()} =:= {v: Boolean with tp[D]()}] + + // Before normalization: + // lhs: .this.OuterO$.this.InnerO$.this.D + // rhs: .this.OuterO$.this.InnerO.D summon[{v: Boolean with tp[InnerO.D]()} =:= {v: Boolean with tp[D]()}] + summon[{v: Boolean with tp[OuterO.InnerO.D]()} =:= {v: Boolean with tp[D]()}] diff --git a/tests/pos-custom-args/qualified-types/subtyping_equalities.scala b/tests/pos-custom-args/qualified-types/subtyping_equalities.scala index d34cad165f0a..0ea2cc578678 100644 --- a/tests/pos-custom-args/qualified-types/subtyping_equalities.scala +++ b/tests/pos-custom-args/qualified-types/subtyping_equalities.scala @@ -9,8 +9,12 @@ def test: Unit = val c: Int = ??? val d: Int = ??? + summon[{v: Int with v == 2} <:< {v: Int with v == 2}] + summon[{v: Int with v == f(a)} <:< {v: Int with v == f(a)}] + // Equality is reflexive, symmetric and transitive summon[{v: Int with v == v} <:< {v: Int with true}] + summon[{v: Int with a == b} <:< {v: Int with true}] summon[{v: Int with v == a} <:< {v: Int with v == a}] summon[{v: Int with v == a} <:< {v: Int with a == v}] summon[{v: Int with a == b} <:< {v: Int with b == a}] diff --git a/tests/pos-custom-args/qualified-types/subtyping_normalization.scala b/tests/pos-custom-args/qualified-types/subtyping_normalization.scala index 986e1f618e9b..bb961712510a 100644 --- a/tests/pos-custom-args/qualified-types/subtyping_normalization.scala +++ b/tests/pos-custom-args/qualified-types/subtyping_normalization.scala @@ -17,9 +17,9 @@ def test: Unit = summon[{v: Int with v == x + 3 * y} <:< {v: Int with v == 2 * y + x + y}] summon[{v: Int with v == x + 3 * y} <:< {v: Int with v == 2 * y + (x + y)}] summon[{v: Int with v == 0} <:< {v: Int with v == 1 - 1}] - // summon[{v: Int with v == 0} <:< {v: Int with v == x - x}] // TODO(mbovel): handle subtraction + summon[{v: Int with v == 0} <:< {v: Int with v == x - x}] summon[{v: Int with v == 0} <:< {v: Int with v == x + (x * -1)}] - // summon[{v: Int with v == x} <:< {v: Int with v == 1 + x - 1}] // TODO(mbovel): handle subtraction + summon[{v: Int with v == x} <:< {v: Int with v == 1 + x - 1}] summon[{v: Int with v == 4 * (x + 1)} <:< {v: Int with v == 2 * (x + 1) + 2 * (1 + x)}] summon[{v: Int with v == 4 * (x / 2)} <:< {v: Int with v == 2 * (x / 2) + 2 * (x / 2)}] From 57a335877003909089a878696b836ddee8134f86 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 3 Jul 2025 14:54:50 +0000 Subject: [PATCH 11/20] Implement lambdas to E-Node conversion --- .../tools/dotc/qualified_types/EGraph.scala | 64 +++++++++---------- .../subtyping_lambdas_neg.scala.scala | 9 +++ .../qualified-types/subtyping_lambdas.scala | 17 +++++ ...omparisons.scala => subtyping_paths.scala} | 0 4 files changed, 57 insertions(+), 33 deletions(-) create mode 100644 tests/neg-custom-args/qualified-types/subtyping_lambdas_neg.scala.scala create mode 100644 tests/pos-custom-args/qualified-types/subtyping_lambdas.scala rename tests/pos-custom-args/qualified-types/{subtyping_comparisons.scala => subtyping_paths.scala} (100%) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala index 7263e732589c..b4cc339ddf5a 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala @@ -47,6 +47,7 @@ import dotty.tools.dotc.qualified_types.ENode.Op import dotty.tools.dotc.reporting.trace import dotty.tools.dotc.transform.TreeExtractors.BinaryOp import dotty.tools.dotc.util.Spans.Span +import scala.collection.mutable.ListBuffer final class EGraph(rootCtx: Context): @@ -72,23 +73,23 @@ final class EGraph(rootCtx: Context): /** Map used for hash-consing nodes, keys and values are the same */ private val index = mutable.Map.empty[ENode, ENode] - val trueNode: ENode.Atom = ENode.Atom(ConstantType(Constant(true))(using rootCtx)) + final val trueNode: ENode.Atom = ENode.Atom(ConstantType(Constant(true))(using rootCtx)) index(trueNode) = trueNode - val falseNode: ENode.Atom = ENode.Atom(ConstantType(Constant(false))(using rootCtx)) + final val falseNode: ENode.Atom = ENode.Atom(ConstantType(Constant(false))(using rootCtx)) index(falseNode) = falseNode - val minusOneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(-1))(using rootCtx)) + final val minusOneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(-1))(using rootCtx)) index(minusOneIntNode) = minusOneIntNode - val zeroIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(0))(using rootCtx)) + final val zeroIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(0))(using rootCtx)) index(zeroIntNode) = zeroIntNode - val oneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(1))(using rootCtx)) + final val oneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(1))(using rootCtx)) index(oneIntNode) = oneIntNode - val d = defn(using rootCtx) // Need a stable path to match on `defn` members - val builtinOps = Map( + private val d = defn(using rootCtx) // Need a stable path to match on `defn` members + private val builtinOps = Map( d.Int_== -> Op.Equal, d.Boolean_== -> Op.Equal, d.Any_== -> Op.Equal, @@ -137,16 +138,16 @@ final class EGraph(rootCtx: Context): } ).asInstanceOf[node.type] - def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramNodes: List[ENode.ArgRefType] = Nil)(using + def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramTps: List[ENode.ArgRefType] = Nil)(using Context ): Option[ENode] = trace(i"EGraph.toNode $tree", Printers.qualifiedTypes): - computeToNode(tree, paramSyms, paramNodes).map(node => representent(unique(node))) + computeToNode(tree, paramSyms, paramTps).map(node => representent(unique(node))) private def computeToNode( tree: Tree, paramSyms: List[Symbol] = Nil, - paramNodes: List[ENode.ArgRefType] = Nil + paramTps: List[ENode.ArgRefType] = Nil )(using currentCtx: Context): Option[ENode] = trace(i"ENode.computeToNode $tree", Printers.qualifiedTypes): def normalizeType(tp: Type): Type = @@ -159,48 +160,45 @@ final class EGraph(rootCtx: Context): case tp => tp def mapType(tp: Type): Type = - normalizeType(tp.subst(paramSyms, paramNodes)) + normalizeType(tp.subst(paramSyms, paramTps)) tree match case Literal(_) | Ident(_) | This(_) if tree.tpe.isInstanceOf[SingletonType] => Some(ENode.Atom(mapType(tree.tpe).asInstanceOf[SingletonType])) case New(clazz) => - for clazzNode <- toNode(clazz, paramSyms, paramNodes) yield ENode.New(clazzNode) + for clazzNode <- toNode(clazz, paramSyms, paramTps) yield ENode.New(clazzNode) case Select(qual, name) => - for qualNode <- toNode(qual, paramSyms, paramNodes) yield ENode.Select(qualNode, tree.symbol) + for qualNode <- toNode(qual, paramSyms, paramTps) yield ENode.Select(qualNode, tree.symbol) case BinaryOp(lhs, op, rhs) if builtinOps.contains(op) => for - lhsNode <- toNode(lhs, paramSyms, paramNodes) - rhsNode <- toNode(rhs, paramSyms, paramNodes) + lhsNode <- toNode(lhs, paramSyms, paramTps) + rhsNode <- toNode(rhs, paramSyms, paramTps) yield normalizeOp(builtinOps(op), List(lhsNode, rhsNode)) case BinaryOp(lhs, d.Int_-, rhs) if lhs.tpe.isInstanceOf[ValueType] && rhs.tpe.isInstanceOf[ValueType] => for - lhsNode <- toNode(lhs, paramSyms, paramNodes) - rhsNode <- toNode(rhs, paramSyms, paramNodes) + lhsNode <- toNode(lhs, paramSyms, paramTps) + rhsNode <- toNode(rhs, paramSyms, paramTps) yield normalizeOp(Op.IntSum, List(lhsNode, normalizeOp(Op.IntProduct, List(minusOneIntNode, rhsNode)))) case Apply(fun, args) => for - funNode <- toNode(fun, paramSyms, paramNodes) - argsNodes <- args.map(toNode(_, paramSyms, paramNodes)).sequence + funNode <- toNode(fun, paramSyms, paramTps) + argsNodes <- args.map(toNode(_, paramSyms, paramTps)).sequence yield ENode.Apply(funNode, argsNodes) case TypeApply(fun, args) => - for funNode <- toNode(fun, paramSyms, paramNodes) + for funNode <- toNode(fun, paramSyms, paramTps) yield ENode.TypeApply(funNode, args.map(tp => mapType(tp.tpe))) case closureDef(defDef) => defDef.symbol.info.dealias match case mt: MethodType => assert(defDef.termParamss.size == 1, "closures have a single parameter list, right?") - val params = defDef.termParamss.head - val myParamSyms = params.map(_.symbol) - - val myParamTps: ArrayBuffer[Type] = ArrayBuffer.empty - ??? - - val myRetTp = ??? - - val myParamNodes = myParamTps.zipWithIndex.map((tp, i) => ENode.ArgRefType(i, tp)).toList - - for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamNodes ::: paramNodes) + val myParamSyms: List[Symbol] = defDef.termParamss.head.map(_.symbol) + val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty + val paramTpsSize = paramTps.size + for myParamSym <- myParamSyms do + val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList)) + myParamTps += ENode.ArgRefType(paramTpsSize + myParamTps.size, underlying) + val myRetTp = mapType(defDef.tpt.tpe.subst(myParamSyms, myParamTps.toList)) + for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamTps.toList ::: paramTps) yield ENode.Lambda(myParamTps.toList, myRetTp, body) case _ => None case _ => @@ -222,7 +220,6 @@ final class EGraph(rootCtx: Context): case ENode.TypeApply(fn, args) => ENode.TypeApply(representent(fn), args) case ENode.Lambda(paramTps, retTp, body) => - ENode.Lambda(paramTps, retTp, representent(body)) )) @@ -230,7 +227,8 @@ final class EGraph(rootCtx: Context): op match case Op.Equal => assert(args.size == 2, s"Expected 2 arguments for equality, got $args") - if args(0) eq args(1) then trueNode + if args(0) eq args(1) then + trueNode else ENode.OpApply(op, args.sortBy(_.hashCode())) case Op.And => assert(args.size == 2, s"Expected 2 arguments for conjunction, got $args") diff --git a/tests/neg-custom-args/qualified-types/subtyping_lambdas_neg.scala.scala b/tests/neg-custom-args/qualified-types/subtyping_lambdas_neg.scala.scala new file mode 100644 index 000000000000..c36d13c0941f --- /dev/null +++ b/tests/neg-custom-args/qualified-types/subtyping_lambdas_neg.scala.scala @@ -0,0 +1,9 @@ +def toBool[T](x: T): Boolean = ??? +def tp[T](): Any = ??? + +def test: Unit = + val x: {l: List[Int] with toBool((x: String, y: x.type) => x.length > 0)} = ??? // error: cannot turn method type into closure because it has internal parameter dependencies + summon[{l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())}] // error + + summon[{l: List[Int] with toBool((x: Double) => (y: Int) => x == y)} =:= {l: List[Int] with toBool((a: Double) => (b: Int) => a == a)}] // error + summon[{l: List[Int] with toBool((x: Int) => (y: Int) => x == y)} =:= {l: List[Int] with toBool((a: Int) => (b: Int) => a == a)}] // error diff --git a/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala b/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala new file mode 100644 index 000000000000..204dde4ac55e --- /dev/null +++ b/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala @@ -0,0 +1,17 @@ +def toBool[T](x: T): Boolean = ??? +def tp[T](): Any = ??? + + +def test: Unit = + summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(x => x > 0)}] + summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(y => y > 0)}] + summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(_ > 0)}] + + summon[{l: List[Int] with toBool((x: String) => x.length > 0)} =:= {l: List[Int] with toBool((y: String) => y.length > 0)}] + + summon[{l: List[Int] with toBool((x: String) => tp[x.type]())} =:= {l: List[Int] with toBool((y: String) => tp[y.type]())}] + summon[{l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[x.type]())}] + summon[{l: List[Int] with toBool((x: String) => tp[x.type]())} =:= {l: List[Int] with toBool((y: String) => tp[y.type]())}] + summon[{l: List[Int] with toBool((x: String, y: String) => tp[y.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())}] + + summon[{l: List[Int] with toBool((x: String) => (y: String) => x == y)} =:= {l: List[Int] with toBool((a: String) => (b: String) => a == b)}] diff --git a/tests/pos-custom-args/qualified-types/subtyping_comparisons.scala b/tests/pos-custom-args/qualified-types/subtyping_paths.scala similarity index 100% rename from tests/pos-custom-args/qualified-types/subtyping_comparisons.scala rename to tests/pos-custom-args/qualified-types/subtyping_paths.scala From 8483d6227d3c8500fcba9e3dbfe0d1cca3ea4ea0 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 3 Jul 2025 14:55:05 +0000 Subject: [PATCH 12/20] Add support for field accesses and constructors --- compiler/src/dotty/tools/dotc/ast/untpd.scala | 2 +- .../dotty/tools/dotc/core/Definitions.scala | 1 + .../tools/dotc/qualified_types/EGraph.scala | 97 ++++++++++++++----- .../tools/dotc/qualified_types/ENode.scala | 4 +- .../qualified_types/QualifierSolver.scala | 4 +- .../qualified-types/adapt_neg.scala | 8 +- .../subtyping_objects_neg.scala | 52 ++++++++++ .../qualified-types/adapt.scala | 17 ++-- .../qualified-types/sized_lists.scala | 3 - .../qualified-types/subtyping_objects.scala | 57 +++++++++++ 10 files changed, 202 insertions(+), 43 deletions(-) create mode 100644 tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala create mode 100644 tests/pos-custom-args/qualified-types/subtyping_objects.scala diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index a323600a49b7..1b452b252d0d 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -473,7 +473,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def New(tpt: Tree, argss: List[List[Tree]])(using Context): Tree = ensureApplied(argss.foldLeft(makeNew(tpt))(Apply(_, _))) - /** A new expression with constrictor and possibly type arguments. See + /** A new expression with constructor and possibly type arguments. See * `New(tpt, argss)` for details. */ def makeNew(tpt: Tree)(using Context): Tree = { diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 088c34fda66e..9aa1591c4f85 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -670,6 +670,7 @@ class Definitions { @tu lazy val StringClass: ClassSymbol = requiredClass("java.lang.String") def StringType: Type = StringClass.typeRef @tu lazy val StringModule: Symbol = StringClass.linkedClass + @tu lazy val String_== : TermSymbol = enterMethod(StringClass, nme.EQ, methOfAnyRef(BooleanType), Final) @tu lazy val String_+ : TermSymbol = enterMethod(StringClass, nme.raw.PLUS, methOfAny(StringType), Final) @tu lazy val String_valueOf_Object: Symbol = StringModule.info.member(nme.valueOf).suchThat(_.info.firstParamTypes match { case List(pt) => pt.isAny || pt.stripNull().isAnyRef diff --git a/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala index b4cc339ddf5a..d3c90589f0f7 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala @@ -2,6 +2,7 @@ package dotty.tools.dotc.qualified_types import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer import dotty.tools.dotc.ast.tpd.{ closureDef, @@ -25,13 +26,16 @@ import dotty.tools.dotc.core.Constants.Constant import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.Contexts.ctx import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Flags import dotty.tools.dotc.core.Hashable.Binders import dotty.tools.dotc.core.Names.Designator import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} import dotty.tools.dotc.core.Types.{ + AppliedType, CachedProxyType, ConstantType, + LambdaType, MethodType, NamedType, NoPrefix, @@ -40,6 +44,7 @@ import dotty.tools.dotc.core.Types.{ TermParamRef, TermRef, Type, + TypeRef, TypeVar, ValueType } @@ -47,7 +52,6 @@ import dotty.tools.dotc.qualified_types.ENode.Op import dotty.tools.dotc.reporting.trace import dotty.tools.dotc.transform.TreeExtractors.BinaryOp import dotty.tools.dotc.util.Spans.Span -import scala.collection.mutable.ListBuffer final class EGraph(rootCtx: Context): @@ -92,7 +96,7 @@ final class EGraph(rootCtx: Context): private val builtinOps = Map( d.Int_== -> Op.Equal, d.Boolean_== -> Op.Equal, - d.Any_== -> Op.Equal, + d.String_== -> Op.Equal, d.Boolean_&& -> Op.And, d.Boolean_|| -> Op.Or, d.Boolean_! -> Op.Not, @@ -108,9 +112,8 @@ final class EGraph(rootCtx: Context): def equiv(node1: ENode, node2: ENode)(using Context): Boolean = trace(i"EGraph.equiv", Printers.qualifiedTypes): - val margin = ctx.base.indentTab * (ctx.base.indent) + // val margin = ctx.base.indentTab * (ctx.base.indent) // println(s"$margin node1: $node1\n$margin node2: $node2") - // Check if the representents of both nodes are the same val repr1 = representent(node1) val repr2 = representent(node2) repr1 eq repr2 @@ -121,8 +124,8 @@ final class EGraph(rootCtx: Context): node match case ENode.Atom(tp) => () - case ENode.New(clazz) => - addUse(clazz, node) + case ENode.Constructor(sym) => + () case ENode.Select(qual, member) => addUse(qual, node) case ENode.Apply(fn, args) => @@ -138,6 +141,7 @@ final class EGraph(rootCtx: Context): } ).asInstanceOf[node.type] + // TODO(mbovel): Memoize this def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramTps: List[ENode.ArgRefType] = Nil)(using Context ): Option[ENode] = @@ -165,16 +169,18 @@ final class EGraph(rootCtx: Context): tree match case Literal(_) | Ident(_) | This(_) if tree.tpe.isInstanceOf[SingletonType] => Some(ENode.Atom(mapType(tree.tpe).asInstanceOf[SingletonType])) - case New(clazz) => - for clazzNode <- toNode(clazz, paramSyms, paramTps) yield ENode.New(clazzNode) + case Select(New(_), nme.CONSTRUCTOR) => + constructorNode(tree.symbol) + case tree: Select if isCaseClassApply(tree.symbol) => + constructorNode(tree.symbol.owner.linkedClass.primaryConstructor) case Select(qual, name) => - for qualNode <- toNode(qual, paramSyms, paramTps) yield ENode.Select(qualNode, tree.symbol) + for qualNode <- toNode(qual, paramSyms, paramTps) yield normalizeSelect(qualNode, tree.symbol) case BinaryOp(lhs, op, rhs) if builtinOps.contains(op) => for lhsNode <- toNode(lhs, paramSyms, paramTps) rhsNode <- toNode(rhs, paramSyms, paramTps) yield normalizeOp(builtinOps(op), List(lhsNode, rhsNode)) - case BinaryOp(lhs, d.Int_-, rhs) if lhs.tpe.isInstanceOf[ValueType] && rhs.tpe.isInstanceOf[ValueType] => + case BinaryOp(lhs, d.Int_-, rhs) => for lhsNode <- toNode(lhs, paramSyms, paramTps) rhsNode <- toNode(rhs, paramSyms, paramTps) @@ -192,7 +198,7 @@ final class EGraph(rootCtx: Context): case mt: MethodType => assert(defDef.termParamss.size == 1, "closures have a single parameter list, right?") val myParamSyms: List[Symbol] = defDef.termParamss.head.map(_.symbol) - val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty + val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty val paramTpsSize = paramTps.size for myParamSym <- myParamSyms do val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList)) @@ -204,15 +210,38 @@ final class EGraph(rootCtx: Context): case _ => None + // TODO(mbovel): Memoize this + private def constructorNode(constr: Symbol)(using Context): Option[ENode.Constructor] = + val clazz = constr.owner + if hasStructuralEquality(clazz) then + val isPrimaryConstructor = constr.denot.isPrimaryConstructor + val fieldsRaw = clazz.denot.asClass.paramAccessors.filter(isPrimaryConstructor && _.isStableMember) + val constrParams = constr.paramSymss.flatten.filter(_.isTerm) + val fields = constrParams.map(p => fieldsRaw.find(_.name == p.name).getOrElse(NoSymbol)) + Some(ENode.Constructor(constr)(fields)) + else + None + + private def hasStructuralEquality(clazz: Symbol)(using Context): Boolean = + val equalsMethod = clazz.info.decls.lookup(nme.equals_) + val equalsNotOverriden = !equalsMethod.exists || equalsMethod.is(Flags.Synthetic) + clazz.isClass && clazz.is(Flags.Case) && equalsNotOverriden + + private def isCaseClassApply(meth: Symbol)(using Context): Boolean = + meth.name == nme.apply + && meth.flags.is(Flags.Synthetic) + && meth.owner.linkedClass.is(Flags.Case) + private def canonicalize(node: ENode): ENode = + // println(s"canonicalize $node") representent(unique( node match case ENode.Atom(tp) => node - case ENode.New(clazz) => - ENode.New(representent(clazz)) + case ENode.Constructor(sym) => + node case ENode.Select(qual, member) => - ENode.Select(representent(qual), member) + normalizeSelect(representent(qual), member) case ENode.Apply(fn, args) => ENode.Apply(representent(fn), args.map(representent)) case ENode.OpApply(op, args) => @@ -223,6 +252,33 @@ final class EGraph(rootCtx: Context): ENode.Lambda(paramTps, retTp, representent(body)) )) + private def normalizeSelect(qual: ENode, member: Symbol): ENode = + getAppliedConstructor(qual) match + case Some(constr) => + val memberIndex = constr.fields.indexOf(member) + + if memberIndex >= 0 then + val args = getTermArguments(qual) + assert(args.size == constr.fields.size) + args(memberIndex) + else + ENode.Select(qual, member) + case None => + ENode.Select(qual, member) + + private def getAppliedConstructor(node: ENode): Option[ENode.Constructor] = + node match + case ENode.Apply(fn, args) => getAppliedConstructor(fn) + case ENode.TypeApply(fn, args) => getAppliedConstructor(fn) + case node: ENode.Constructor => Some(node) + case _ => None + + private def getTermArguments(node: ENode): List[ENode] = + node match + case ENode.Apply(fn, args) => getTermArguments(fn) ::: args + case ENode.TypeApply(fn, args) => getTermArguments(fn) + case _ => Nil + private def normalizeOp(op: ENode.Op, args: List[ENode]): ENode = op match case Op.Equal => @@ -316,12 +372,10 @@ final class EGraph(rootCtx: Context): (a, b) match case (ENode.Atom(_: ConstantType), _) => (a, b) case (_, ENode.Atom(_: ConstantType)) => (b, a) - case (ENode.Atom(_: SkolemType), _) => (a, b) - case (_, ENode.Atom(_: SkolemType)) => (b, a) + case (_: ENode.Constructor, _) => (a, b) + case (_, _: ENode.Constructor) => (b, a) case (_: ENode.Atom, _) => (a, b) case (_, _: ENode.Atom) => (b, a) - case (_: ENode.New, _) => (a, b) - case (_, _: ENode.New) => (b, a) case (_: ENode.Select, _) => (a, b) case (_, _: ENode.Select) => (b, a) case (_: ENode.Apply, _) => (a, b) @@ -336,8 +390,6 @@ final class EGraph(rootCtx: Context): if aRepr eq bRepr then return assert(aRepr != bRepr, s"$aRepr and $bRepr are `equals` but not `eq`") - // TODO(mbovel): if both nodes are objects, recursively merge their arguments - /// Update represententOf and usedBy maps val (newRepr, oldRepr) = order(aRepr, bRepr) represententOf(oldRepr) = newRepr @@ -371,8 +423,9 @@ final class EGraph(rootCtx: Context): node match case ENode.Atom(tp) => singleton(tp) - case ENode.New(clazz) => - New(toTree(clazz, paramRefs)) + case ENode.Constructor(sym) => + val tycon = sym.owner.info.typeConstructor + New(tycon).select(TermRef(tycon, sym)) case ENode.Select(qual, member) => toTree(qual, paramRefs).select(member) case ENode.Apply(fn, args) => diff --git a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala index acbd4b1e5ca2..059ca8066093 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala @@ -23,7 +23,7 @@ enum ENode: import ENode.* case Atom(tp: SingletonType) - case New(clazz: ENode) + case Constructor(constr: Symbol)(val fields: List[Symbol]) case Select(qual: ENode, member: Symbol) case Apply(fn: ENode, args: List[ENode]) case OpApply(fn: ENode.Op, args: List[ENode]) @@ -33,7 +33,7 @@ enum ENode: override def toString(): String = this match case Atom(tp) => typeToString(tp) - case New(clazz) => s"new $clazz" + case Constructor(constr) => s"new ${designatorToString(constr.lastKnownDenotation.owner)}" case Select(qual, member) => s"$qual.${designatorToString(member)}" case Apply(fn, args) => s"$fn(${args.mkString(", ")})" case OpApply(op, args) => s"(${args.mkString(op.operatorString())})" diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala index 75e7fd6afa0d..ae4155e7409f 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala @@ -53,10 +53,10 @@ class QualifierSolver(using Context): case _ => () val egraph = EGraph(ctx) - //println(s"tree implies $tree1 -> $tree2") + // println(s"tree implies $tree1 -> $tree2") (egraph.toNode(QualifierEvaluator.evaluate(tree1)), egraph.toNode(QualifierEvaluator.evaluate(tree2))) match case (Some(node1), Some(node2)) => - //println(s"node implies $node1 -> $node2") + // println(s"node implies $node1 -> $node2") egraph.merge(node1, egraph.trueNode) egraph.repair() egraph.equiv(node2, egraph.trueNode) diff --git a/tests/neg-custom-args/qualified-types/adapt_neg.scala b/tests/neg-custom-args/qualified-types/adapt_neg.scala index 70fe964cebbc..8722ffcbb7b8 100644 --- a/tests/neg-custom-args/qualified-types/adapt_neg.scala +++ b/tests/neg-custom-args/qualified-types/adapt_neg.scala @@ -12,11 +12,11 @@ def test: Unit = val v3: {v: Int with v == x + 1} = x + 2 // error val v4: {v: Int with v == f(x)} = g(x) // error val v5: {v: Int with v == g(x)} = f(x) // error - //val v6: {v: Int with v == IntBox(x)} = IntBox(x) // Not implemented - //val v7: {v: Int with v == Box(x)} = Box(x) // Not implemented + val v6: {v: IntBox with v == IntBox(x)} = IntBox(x + 1) // error + val v7: {v: Box[Int] with v == Box(x)} = Box(x + 1) // error val v8: {v: Int with v == x + f(x)} = x + g(x) // error val v9: {v: Int with v == x + g(x)} = x + f(x) // error val v10: {v: Int with v == f(x + 1)} = f(x + 2) // error val v11: {v: Int with v == g(x + 1)} = g(x + 2) // error - //val v12: {v: Int with v == IntBox(x + 1)} = IntBox(x + 1) // Not implemented - //val v13: {v: Int with v == Box(x + 1)} = Box(x + 1) // Not implemented + val v12: {v: IntBox with v == IntBox(x + 1)} = IntBox(x) // error + val v13: {v: Box[Int] with v == Box(x + 1)} = Box(x) // error diff --git a/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala b/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala new file mode 100644 index 000000000000..b1283f4e807b --- /dev/null +++ b/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala @@ -0,0 +1,52 @@ +class Box[T](val x: T) + +class BoxMutable[T](var x: T) + +class Foo(val id: String): + def this(x: Int) = this(x.toString) + +class Person(val name: String, val age: Int) + +class PersonCurried(val name: String)(val age: Int) + +class PersonMutable(val name: String, var age: Int) + +case class PersonCaseMutable(name: String, var age: Int) + +case class PersonCaseSecondary(name: String, age: Int): + def this(name: String) = this(name, 0) + +case class PersonCaseEqualsOverriden(name: String, age: Int): + override def equals(that: Object): Boolean = this eq that + +def test: Unit = + summon[{b: Box[Int] with b == Box(1)} =:= {b: Box[Int] with b == Box(1)}] // error + + summon[{b: BoxMutable[Int] with b == BoxMutable(1)} =:= {b: BoxMutable[Int] with b == BoxMutable(1)}] // error + // TODO(mbovel): restrict selection to stable members + //summon[{b: BoxMutable[Int] with b.x == 3} =:= {b: BoxMutable[Int] with b.x == 3}] + + summon[{f: Foo with f == Foo("hello")} =:= {f: Foo with f == Foo("hello")}] // error + summon[{f: Foo with f == Foo(1)} =:= {f: Foo with f == Foo(1)}] // error + summon[{s: String with Foo("hello").id == s} =:= {s: String with s == "hello"}] // error + + summon[{p: Person with p == Person("Alice", 30)} =:= {p: Person with p == Person("Alice", 30)}] // error + summon[{s: String with Person("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error + summon[{n: Int with Person("Alice", 30).age == n} =:= {n: Int with n == 30}] // error + + summon[{p: PersonCurried with p == PersonCurried("Alice")(30)} =:= {p: PersonCurried with p == PersonCurried("Alice")(30)}] // error + summon[{s: String with PersonCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}] // error + summon[{n: Int with PersonCurried("Alice")(30).age == n} =:= {n: Int with n == 30}] // error + + summon[{p: PersonMutable with p == PersonMutable("Alice", 30)} =:= {p: PersonMutable with p == PersonMutable("Alice", 30)}] // error + summon[{s: String with PersonMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error + summon[{n: Int with PersonMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error + + summon[{n: Int with PersonCaseMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error + + summon[{s: String with new PersonCaseSecondary("Alice").name == s} =:= {s: String with s == "Alice"}] // error + summon[{n: Int with new PersonCaseSecondary("Alice").age == n} =:= {n: Int with n == 0}] // error + + summon[{p: PersonCaseEqualsOverriden with PersonCaseEqualsOverriden("Alice", 30) == p} =:= {p: PersonCaseEqualsOverriden with p == PersonCaseEqualsOverriden("Alice", 30)}] // error + summon[{s: String with PersonCaseEqualsOverriden("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error + summon[{n: Int with PersonCaseEqualsOverriden("Alice", 30).age == n} =:= {n: Int with n == 30}] // error diff --git a/tests/pos-custom-args/qualified-types/adapt.scala b/tests/pos-custom-args/qualified-types/adapt.scala index 12fc30e9c80c..597c6906b19d 100644 --- a/tests/pos-custom-args/qualified-types/adapt.scala +++ b/tests/pos-custom-args/qualified-types/adapt.scala @@ -2,7 +2,6 @@ def f(x: Int): Int = ??? case class IntBox(x: Int) case class Box[T](x: T) - def f(x: Int, y: Int): {r: Int with r == x + y} = x + y def test: Unit = @@ -12,11 +11,11 @@ def test: Unit = val v1: {v: Int with v == x + 1} = x + 1 val v2: {v: Int with v == f(x)} = f(x) val v3: {v: Int with v == g(x)} = g(x) - //val v6: {v: Int with v == IntBox(x)} = IntBox(x) // Not implemented - //val v7: {v: Int with v == Box(x)} = Box(x) // Not implemented - val v4: {v: Int with v == x + f(x)} = x + f(x) - val v5: {v: Int with v == x + g(x)} = x + g(x) - val v6: {v: Int with v == f(x + 1)} = f(x + 1) - val v7: {v: Int with v == g(x + 1)} = g(x + 1) - //val v12: {v: Int with v == IntBox(x + 1)} = IntBox(x + 1) // Not implemented - //val v13: {v: Int with v == Box(x + 1)} = Box(x + 1) // Not implemented + val v4: {v: IntBox with v == IntBox(x)} = IntBox(x) + val v5: {v: Box[Int] with v == Box(x)} = Box(x) + val v6: {v: Int with v == x + f(x)} = x + f(x) + val v7: {v: Int with v == x + g(x)} = x + g(x) + val v8: {v: Int with v == f(x + 1)} = f(x + 1) + val v9: {v: Int with v == g(x + 1)} = g(x + 1) + val v12: {v: IntBox with v == IntBox(x + 1)} = IntBox(x + 1) + val v13: {v: Box[Int] with v == Box(x + 1)} = Box(x + 1) diff --git a/tests/pos-custom-args/qualified-types/sized_lists.scala b/tests/pos-custom-args/qualified-types/sized_lists.scala index 88b60db58563..385434c253f7 100644 --- a/tests/pos-custom-args/qualified-types/sized_lists.scala +++ b/tests/pos-custom-args/qualified-types/sized_lists.scala @@ -1,9 +1,6 @@ - - def size(v: Vec): Int = ??? type Vec - def vec(s: Int): {v: Vec with size(v) == s} = ??? def concat(v1: Vec, v2: Vec): {v: Vec with size(v) == size(v1) + size(v2)} = ??? def sum(v1: Vec, v2: Vec with size(v1) == size(v2)): {v: Vec with size(v) == size(v1)} = ??? diff --git a/tests/pos-custom-args/qualified-types/subtyping_objects.scala b/tests/pos-custom-args/qualified-types/subtyping_objects.scala new file mode 100644 index 000000000000..da3afae0ebd8 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/subtyping_objects.scala @@ -0,0 +1,57 @@ +class Box[T](val x: T) + +class Foo(val id: String): + def this(x: Int) = this(x.toString) + +case class PersonCase(name: String, age: Int) + +case class PersonCaseCurried(name: String)(val age: Int) + +case class PersonCaseMutable(name: String, var age: Int) + +case class PersonCaseSecondary(name: String, age: Int): + def this(name: String) = this(name, 0) + +def test: Unit = + summon[{b: Box[Int] with 3 == b.x} =:= {b: Box[Int] with b.x == 3}] + summon[{f: Foo with f.id == "hello"} =:= {f: Foo with "hello" == f.id}] + + // new PersonCase + summon[{p: PersonCase with p == new PersonCase("Alice", 30)} =:= {p: PersonCase with p == new PersonCase("Alice", 30)}] + summon[{s: String with new PersonCase("Alice", 30).name == s} =:= {s: String with s == "Alice"}] + summon[{n: Int with new PersonCase("Alice", 30).age == n} =:= {n: Int with n == 30}] + + // PersonCase + summon[{p: PersonCase with p == PersonCase("Alice", 30)} =:= {p: PersonCase with p == PersonCase("Alice", 30)}] + summon[{s: String with PersonCase("Alice", 30).name == s} =:= {s: String with s == "Alice"}] + summon[{n: Int with PersonCase("Alice", 30).age == n} =:= {n: Int with n == 30}] + + // new PersonCaseCurried + summon[{p: PersonCaseCurried with p == new PersonCaseCurried("Alice")(30)} =:= {p: PersonCaseCurried with p == new PersonCaseCurried("Alice")(30)}] + summon[{s: String with new PersonCaseCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}] + summon[{n: Int with new PersonCaseCurried("Alice")(30).age == n} =:= {n: Int with n == 30}] + + // PersonCaseCurried + summon[{p: PersonCaseCurried with p == PersonCaseCurried("Alice")(30)} =:= {p: PersonCaseCurried with p == PersonCaseCurried("Alice")(30)}] + summon[{s: String with PersonCaseCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}] + summon[{n: Int with PersonCaseCurried("Alice")(30).age == n} =:= {n: Int with n == 30}] + + // new PersonCaseMutable + summon[{p: PersonCaseMutable with p == new PersonCaseMutable("Alice", 30)} =:= {p: PersonCaseMutable with p == new PersonCaseMutable("Alice", 30)}] + summon[{s: String with new PersonCaseMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}] + //summon[{n: Int with new PersonCaseMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error + + // PersonCaseMutable + summon[{p: PersonCaseMutable with p == PersonCaseMutable("Alice", 30)} =:= {p: PersonCaseMutable with p == PersonCaseMutable("Alice", 30)}] + summon[{s: String with PersonCaseMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}] + //summon[{n: Int with PersonCaseMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error + + // new PersonCaseSecondary + summon[{p: PersonCaseSecondary with p == new PersonCaseSecondary("Alice")} =:= {p: PersonCaseSecondary with p == new PersonCaseSecondary("Alice")}] + //summon[{s: String with new PersonCaseSecondary("Alice").name == s} =:= {s: String with s == "Alice"}] // error + //summon[{n: Int with new PersonCaseSecondary("Alice").age == n} =:= {n: Int with n == 0}] // error + + // PersonCaseSecondary + summon[{p: PersonCaseSecondary with p == PersonCaseSecondary("Alice", 30)} =:= {p: PersonCaseSecondary with p == PersonCaseSecondary("Alice", 30)}] + //summon[{s: String with PersonCaseSecondary("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error + //summon[{n: Int with PersonCaseSecondary("Alice", 30).age == n} =:= {n: Int with n == 30}] // error From d367dd95ca6beccdb4116d732964f49febdd8d93 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 21 Aug 2025 08:29:39 +0000 Subject: [PATCH 13/20] Encode qualifier arguments as E-Nodes --- .../src/dotty/tools/dotc/ast/Desugar.scala | 3 +- .../dotty/tools/dotc/core/Definitions.scala | 2 + .../dotty/tools/dotc/core/TypeComparer.scala | 9 +- .../src/dotty/tools/dotc/core/Types.scala | 6 +- .../dotty/tools/dotc/parsing/Parsers.scala | 2 +- .../tools/dotc/printing/PlainPrinter.scala | 5 + .../tools/dotc/qualified_types/EGraph.scala | 584 ++++++++-------- .../tools/dotc/qualified_types/ENode.scala | 636 ++++++++++++++++-- .../dotc/qualified_types/ENodeParamRef.scala | 23 + .../qualified_types/QualifiedAnnotation.scala | 39 ++ .../dotc/qualified_types/QualifiedType.scala | 35 +- .../dotc/qualified_types/QualifiedTypes.scala | 82 ++- .../qualified_types/QualifierEvaluator.scala | 96 --- .../qualified_types/QualifierSolver.scala | 116 ++-- .../tools/dotc/transform/PostTyper.scala | 15 +- .../tools/dotc/transform/TreeExtractors.scala | 9 + .../tools/dotc/transform/TypeTestsCasts.scala | 16 +- .../dotty/tools/dotc/typer/TypeAssigner.scala | 6 +- .../src/dotty/tools/dotc/typer/Typer.scala | 4 +- .../dotc/qualified_types/EGraphTest.scala | 235 +++++++ .../dotc/qualified_types/ENodeTest.scala | 60 ++ .../qualified_types/QualifiedTypesTest.scala | 41 ++ .../qualified-types/list_apply_neg.scala | 4 + .../subtyping_egraph_state.scala | 7 + .../subtyping_objects_neg.scala | 16 +- .../qualified-types/avoidance.scala | 9 +- .../qualified-types/class_constraints.scala | 3 + .../qualified-types/list_map.scala | 9 + .../qualified-types/sized_lists.scala | 4 - .../qualified-types/sized_lists2.scala | 22 + ...alities.scala => subtyping_equality.scala} | 13 + .../qualified-types/subtyping_lambdas.scala | 18 +- .../subtyping_reflectivity.scala | 4 + .../subtyping_singletons.scala | 2 - .../typing_type_variables.scala | 9 + tests/printing/qualified-types.check | 16 +- tests/printing/qualified-types.scala | 9 +- .../qualified-types/pattern_matching.scala | 9 - 38 files changed, 1551 insertions(+), 627 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/ENodeParamRef.scala create mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala delete mode 100644 compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala create mode 100644 compiler/test/dotty/tools/dotc/qualified_types/EGraphTest.scala create mode 100644 compiler/test/dotty/tools/dotc/qualified_types/ENodeTest.scala create mode 100644 compiler/test/dotty/tools/dotc/qualified_types/QualifiedTypesTest.scala create mode 100644 tests/neg-custom-args/qualified-types/list_apply_neg.scala create mode 100644 tests/neg-custom-args/qualified-types/subtyping_egraph_state.scala create mode 100644 tests/pos-custom-args/qualified-types/class_constraints.scala create mode 100644 tests/pos-custom-args/qualified-types/list_map.scala create mode 100644 tests/pos-custom-args/qualified-types/sized_lists2.scala rename tests/pos-custom-args/qualified-types/{subtyping_equalities.scala => subtyping_equality.scala} (73%) create mode 100644 tests/pos-custom-args/qualified-types/subtyping_reflectivity.scala create mode 100644 tests/pos-custom-args/qualified-types/typing_type_variables.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index b34867918397..c028dd896bb7 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -2561,7 +2561,8 @@ object desugar { tree if Feature.qualifiedTypesEnabled then - transform(tpt) + trace(i"desugar qualified types in pattern: $tpt", Printers.qualifiedTypes): + transform(tpt) else tpt diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 9aa1591c4f85..9eb869af6efc 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -625,9 +625,11 @@ class Definitions { @tu lazy val Int_/ : Symbol = IntClass.requiredMethod(nme.DIV, List(IntType)) @tu lazy val Int_* : Symbol = IntClass.requiredMethod(nme.MUL, List(IntType)) @tu lazy val Int_== : Symbol = IntClass.requiredMethod(nme.EQ, List(IntType)) + @tu lazy val Int_!= : Symbol = IntClass.requiredMethod(nme.NE, List(IntType)) @tu lazy val Int_>= : Symbol = IntClass.requiredMethod(nme.GE, List(IntType)) @tu lazy val Int_<= : Symbol = IntClass.requiredMethod(nme.LE, List(IntType)) @tu lazy val Int_> : Symbol = IntClass.requiredMethod(nme.GT, List(IntType)) + @tu lazy val Int_< : Symbol = IntClass.requiredMethod(nme.LT, List(IntType)) @tu lazy val LongType: TypeRef = valueTypeRef("scala.Long", java.lang.Long.TYPE, LongEnc, nme.specializedTypeNames.Long) def LongClass(using Context): ClassSymbol = LongType.symbol.asClass @tu lazy val Long_+ : Symbol = LongClass.requiredMethod(nme.PLUS, List(LongType)) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 91d6c6f6de19..e62eedb5176c 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -889,7 +889,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling throw ex compareCapturing || fourthTry case QualifiedType(parent2, qualifier2) => - QualifiedTypes.typeImplies(tp1, qualifier2) && recur(tp1, parent2) + recur(tp1, parent2) && QualifiedTypes.typeImplies(tp1, qualifier2, qualifierSolver()) case tp2: AnnotatedType if tp2.isRefining => (tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) && recur(tp1, tp2.parent) @@ -3310,6 +3310,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling protected def explainingTypeComparer(short: Boolean) = ExplainingTypeComparer(comparerContext, short) protected def matchReducer = MatchReducer(comparerContext) + protected def qualifierSolver() = qualified_types.QualifierSolver(using comparerContext) private def inSubComparer[T, Cmp <: TypeComparer](comparer: Cmp)(op: Cmp => T): T = val saved = myInstance @@ -3965,7 +3966,7 @@ class ExplainingTypeComparer(initctx: Context, short: Boolean) extends TypeCompa lastForwardGoal = null override def traceIndented[T](str: String)(op: => T): T = - val str1 = str.replace('\n', ' ') + val str1 = str if short && str1 == lastForwardGoal then op // repeated goal, skip for clarity else @@ -4034,5 +4035,9 @@ class ExplainingTypeComparer(initctx: Context, short: Boolean) extends TypeCompa super.subCaptures(refs1, refs2, vs) } + override def qualifierSolver() = + val traceIndented0 = [T] => (message: String) => traceIndented[T](message) + qualified_types.ExplainingQualifierSolver(traceIndented0)(using comparerContext) + def lastTrace(header: String): String = header + { try b.toString finally b.clear() } } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 6636da640485..468a8b0009a0 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -43,7 +43,6 @@ import cc.* import CaptureSet.IdentityCaptRefMap import Capabilities.* import transform.Recheck.currentRechecker - import qualified_types.QualifiedType import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -6320,6 +6319,8 @@ object Types extends TypeUtils { tp.derivedAnnotatedType(underlying, annot) protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type = tp.derivedCapturingType(parent, refs) + protected def derivedENodeParamRef(tp: qualified_types.ENodeParamRef, index: Int, underlying: Type): Type = + tp.derivedENodeParamRef(index, underlying) protected def derivedWildcardType(tp: WildcardType, bounds: Type): Type = tp.derivedWildcardType(bounds) protected def derivedSkolemType(tp: SkolemType, info: Type): Type = @@ -6544,6 +6545,9 @@ object Types extends TypeUtils { case tp: JavaArrayType => derivedJavaArrayType(tp, this(tp.elemType)) + case tp: qualified_types.ENodeParamRef => + derivedENodeParamRef(tp, tp.index, this(tp.underlying)) + case _ => tp } diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 246a20f545b8..84ab0d70c545 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1702,7 +1702,7 @@ object Parsers { * | `(' [ FunArgType {`,' FunArgType } ] `)' * | '(' [ TypedFunParam {',' TypedFunParam } ')' * MatchType ::= InfixType `match` <<< TypeCaseClauses >>> - * QualifiedType2 ::= InfixType `with` PostfixExprf + * QualifiedType2 ::= InfixType `with` PostfixExpr * IntoType ::= [‘into’] IntoTargetType * | ‘( IntoType ‘)’ * IntoTargetType ::= Type diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 90dfb72ef010..9dee04158318 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -469,6 +469,11 @@ class PlainPrinter(_ctx: Context) extends Printer { "<" ~ reprStr ~ ":" ~ toText(tp.info) ~ ">" else reprStr + case qualified_types.ENodeParamRef(index, underlying) => + if ctx.settings.XprintTypes.value then + "<" ~ "eparam" ~ index.toString ~ ":" ~ toText(underlying) ~ ">" + else + "eparam" ~ index.toString } } diff --git a/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala index d3c90589f0f7..1efbac3058dc 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala @@ -33,6 +33,7 @@ import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} import dotty.tools.dotc.core.Types.{ AppliedType, + CachedConstantType, CachedProxyType, ConstantType, LambdaType, @@ -48,215 +49,300 @@ import dotty.tools.dotc.core.Types.{ TypeVar, ValueType } +import dotty.tools.dotc.core.Uniques import dotty.tools.dotc.qualified_types.ENode.Op -import dotty.tools.dotc.reporting.trace import dotty.tools.dotc.transform.TreeExtractors.BinaryOp +import dotty.tools.dotc.util.{EqHashMap, HashMap} import dotty.tools.dotc.util.Spans.Span +import dotty.tools.dotc.reporting +import dotty.tools.dotc.config.Printers -final class EGraph(rootCtx: Context): +import annotation.threadUnsafe as tu +import reflect.ClassTag + +final class EGraph(_ctx: Context, checksEnabled: Boolean = true): + + /** Cache for unique E-Nodes + * + * Invariant: Each key is `eq` to its associated value. + * + * Invariant: If a node is in this map, then its children also are. + */ + private val index: HashMap[ENode, ENode] = HashMap() + + + private val idOf: EqHashMap[ENode, Int] = EqHashMap() + + /** Map from nodes to their unique, canonical representations. + * + * Invariant: After a call to [[repair]], if a node is in the index but not + * in this map, then it is its own representant and it is canonical. + * + * Invariant: After a call to [[repair]], values of this map are canonical. + */ + private val representantOf: EqHashMap[ENode, ENode] = EqHashMap() + + /** Map from child nodes to their parent nodes + * + * Invariant: After a call to [[repair]], values of this map are canonical. + */ + private val usedBy: EqHashMap[ENode, mutable.Set[ENode]] = EqHashMap() + + /** Worklist for nodes that need to be repairedConstantType(Constant(value). + * + * This queue is filled by [[merge]] and processed by [[repair]]. + * + * Invariant: After a call to [[repair]], this queue is empty. + */ + private val worklist = mutable.Queue.empty[ENode] - private val represententOf = mutable.Map.empty[ENode, ENode] + val trueNode: ENode.Atom = constant(true) + val falseNode: ENode.Atom = constant(false) + val zeroIntNode: ENode.Atom = constant(0) + val minusOneIntNode: ENode.Atom = constant(-1) + val oneIntNode: ENode.Atom = constant(1) + + /** Returns the canonical node for the given constant value */ + def constant(value: Any): ENode.Atom = + val node = ENode.Atom(ConstantType(Constant(value))(using _ctx)) + idOf.getOrElseUpdate(node, idOf.size) + index.getOrElseUpdate(node, node).asInstanceOf[ENode.Atom] + + /** Adds the given node to the E-Graph, returning its canonical representant. + * + * Pre-condition: The node must be normalized, and its children must be + * canonical. + */ + private def unique(node: ENode): ENode = + if index.contains(node) then + representant(index(node)) + else + index.update(node, node) + idOf.update(node, idOf.size) + node match + case ENode.Atom(tp) => + () + case ENode.Constructor(sym) => + () + case ENode.Select(qual, member) => + addUse(qual, node) + case ENode.Apply(fn, args) => + addUse(fn, node) + for arg <- args do + addUse(arg, node) + case ENode.OpApply(op, args) => + for arg <- args do + addUse(arg, node) + case ENode.TypeApply(fn, args) => + addUse(fn, node) + case ENode.Lambda(paramTps, retTp, body) => + addUse(body, node) + node + node - private def representent(node: ENode): ENode = - represententOf.get(node) match + private def representant(node: ENode): ENode = + representantOf.get(node) match case None => node case Some(repr) => - assert(repr ne node, s"Node $node has itself as representent") - representent(repr) + // There must be no cycles in the `representantOf` map. + // If a node is canonical, it must have no representant. + assert(repr ne node, s"Node $node has itself as representant ($repr)") + representant(repr) + + def assertCanonical(node: ENode): Unit = + if checksEnabled then + // By the invariants, if a node is in the index (meaning it is tracked by + // this E-Graph), and has no representant, then it is itself a canonical + // node. We double-check by forcing a deep canonicalization. + assert(index.contains(node) && index(node) == node, s"Node $node is not unique in this E-Graph") + assert(!representantOf.contains(node), s"Node $node has a representant: ${representantOf(node)}") + val canonical = canonicalize(node) + assert(node eq canonical, s"Recanonicalization of $node did not return itself, but $canonical") + + private def addUse(child: ENode, parent: ENode): Unit = + usedBy.getOrElseUpdate(child, mutable.Set.empty) += parent - /** Map from child nodes to their parent nodes */ - private val usedBy = mutable.Map.empty[ENode, mutable.Set[ENode]] + override def toString(): String = + s"EGraph{\nindex = $index,\nrepresentantOf = $representantOf,\nusedBy = $usedBy,\nworklist = $worklist}\n" + + def toDot()(using Context): String = + val sb = new StringBuilder() + sb.append("digraph EGraph {\nnode [height=.1 shape=record]\n") + for node <- index.valuesIterator do + sb.append(node.toDot()) + for (node, repr) <- representantOf.iterator do + sb.append(s"${node.dotId()} -> ${repr.dotId()} [style=dotted]\n") + for (child, parents) <- usedBy.iterator do + for parent <- parents do + sb.append(s"${child.dotId()} -> ${parent.dotId()} [style=dashed]\n") + sb.append("}\n") + sb.toString() + + def debugString()(using _ctx: Context): String = + given Context = _ctx.withoutColors + index + .valuesIterator + .toList + .groupBy(representant) + .toList + .sortBy((repr, members) => repr.showNoBreak) + .map((repr, members) => repr.showNoBreak + ": " + members.filter(_ ne repr).map(_.showNoBreak).sorted.mkString("{", ", ", "}")) + .mkString("", "\n", "\n") + + + private inline def show(enode: ENode): String = + enode.showNoBreak(using _ctx) + + private inline def trace[T](inline message: String)(inline f: T): T = + reporting.trace(message, Printers.qualifiedTypes)(f)(using _ctx) + + def equiv(node1: ENode, node2: ENode): Boolean = + trace(s"equiv ${show(node1)}, ${show(node2)}"): + val repr1 = representant(node1) + val repr2 = representant(node2) + repr1 eq repr2 - private def uses(node: ENode): mutable.Set[ENode] = - usedBy.getOrElseUpdate(node, mutable.Set.empty) + def merge(a: ENode, b: ENode): Unit = + if checksEnabled then + assert(index.contains(a) && index(a) == a, s"Node $a is not unique in this E-Graph") + assert(index.contains(b) && index(b) == b, s"Node $b is not unique in this E-Graph") - private def addUse(node: ENode, parent: ENode): Unit = - require(!represententOf.contains(node), s"Reference $node is not normalized") - uses(node) += parent + val aRepr = representant(a) + val bRepr = representant(b) - /** Map used for hash-consing nodes, keys and values are the same */ - private val index = mutable.Map.empty[ENode, ENode] + if aRepr eq bRepr then return - final val trueNode: ENode.Atom = ENode.Atom(ConstantType(Constant(true))(using rootCtx)) - index(trueNode) = trueNode + if checksEnabled then + assert(aRepr != bRepr, s"$aRepr and $bRepr are `equals` but not `eq`") - final val falseNode: ENode.Atom = ENode.Atom(ConstantType(Constant(false))(using rootCtx)) - index(falseNode) = falseNode + // Update representantOf and usedBy maps + val (newRepr, oldRepr) = order(aRepr, bRepr) + representantOf(oldRepr) = newRepr + val oldusages = usedBy.getOrElse(oldRepr, mutable.Set.empty) + usedBy.getOrElseUpdate(newRepr, mutable.Set.empty) ++= oldusages + usedBy.remove(oldRepr) - final val minusOneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(-1))(using rootCtx)) - index(minusOneIntNode) = minusOneIntNode + trace(s"merge ${show(newRepr)} <-- ${show(oldRepr)}"): + // Propagate truth values over disjunctions, conjunctions and equalities + oldRepr match + case ENode.OpApply(Op.And, args) if newRepr eq trueNode => + args.foreach(merge(_, trueNode)) + case ENode.OpApply(Op.Or, args) if newRepr eq falseNode => + args.foreach(merge(_, falseNode)) + case ENode.OpApply(Op.Equal, args) if newRepr eq trueNode => + merge(args(0), args(1)) + case _ => + () - final val zeroIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(0))(using rootCtx)) - index(zeroIntNode) = zeroIntNode + // Enqueue all nodes that use the oldRepr for repair + trace(s"enqueue ${oldusages.map(show).mkString(", ")}"): + worklist.enqueueAll(oldusages) + () - final val oneIntNode: ENode.Atom = ENode.Atom(ConstantType(Constant(1))(using rootCtx)) - index(oneIntNode) = oneIntNode + private def order(a: ENode, b: ENode): (ENode, ENode) = + if a.contains(b) then + (b, a) + else if b.contains(a) then + (a, b) + else + (a, b) match + case (ENode.Atom(_: ConstantType), _) => (a, b) + case (_, ENode.Atom(_: ConstantType)) => (b, a) + case (_: ENode.OpApply, _) => (a, b) + case (_, _: ENode.OpApply) => (b, a) + case (_: ENode.Constructor, _) => (a, b) + case (_, _: ENode.Constructor) => (b, a) + case (_: ENode.Select, _) => (a, b) + case (_, _: ENode.Select) => (b, a) + case (_: ENode.Apply, _) => (a, b) + case (_, _: ENode.Apply) => (b, a) + case (_: ENode.TypeApply, _) => (a, b) + case (_, _: ENode.TypeApply) => (b, a) + case (_: ENode.Atom, _) => (a, b) + case (_, _: ENode.Atom) => (b, a) + case _ => (a, b) - private val d = defn(using rootCtx) // Need a stable path to match on `defn` members - private val builtinOps = Map( - d.Int_== -> Op.Equal, - d.Boolean_== -> Op.Equal, - d.String_== -> Op.Equal, - d.Boolean_&& -> Op.And, - d.Boolean_|| -> Op.Or, - d.Boolean_! -> Op.Not, - d.Int_+ -> Op.IntSum, - d.Int_* -> Op.IntProduct - ) + def repair(): Unit = + var i = 0 + trace(s"repair (queue: ${worklist.map(show).mkString(", ")})"): + while !worklist.isEmpty do + val head = worklist.dequeue() + val headRepr = representant(head) + val headCanonical = canonicalize(head, deep = false) + if headRepr ne headCanonical then + trace(s"repair ${show(headCanonical)}, ${show(headRepr)}"): + merge(headCanonical, headRepr) + i += 1 + if i > 100 then + throw new RuntimeException("EGraph.repair: too many iterations, possible infinite loop") + + assertInvariants() + + def assertInvariants(): Unit = + if checksEnabled then + assert(worklist.isEmpty, "Worklist is not empty") + + // Check that all nodes in the index are canonical + for (node, node2) <- index.iterator do + assert(node eq node2, s"Key and value in index are not equal: $node ne $node2") + + val repr = representant(node) + assertCanonical(repr) + + def usages(node: ENode): mutable.Set[ENode] = + usedBy.getOrElse(node, mutable.Set.empty) - private val worklist = mutable.Queue.empty[ENode] + node match + case ENode.Atom(tp) => () + case ENode.Constructor(sym) => () + case ENode.Select(qual, member) => + index.contains(qual) && usages(qual).contains(node) + case ENode.Apply(fn, args) => + index.contains(fn) && usages(fn).contains(node) + args.forall(arg => index.contains(arg) && usages(arg).contains(node)) + case ENode.OpApply(op, args) => + args.forall(arg => index.contains(arg) && usages(arg).contains(node)) + case ENode.TypeApply(fn, args) => + index.contains(fn) && usages(fn).contains(node) + case ENode.Lambda(paramTps, retTp, body) => + index.contains(body) && usages(body).contains(node) - override def toString(): String = - val represententsString = represententOf.map((node, repr) => s" $node -> $repr").mkString("\n") - s"EGraph:\n$represententsString\n" - - def equiv(node1: ENode, node2: ENode)(using Context): Boolean = - trace(i"EGraph.equiv", Printers.qualifiedTypes): - // val margin = ctx.base.indentTab * (ctx.base.indent) - // println(s"$margin node1: $node1\n$margin node2: $node2") - val repr1 = representent(node1) - val repr2 = representent(node2) - repr1 eq repr2 + for (node, repr) <- representantOf.iterator do + assert(index.contains(node), s"Node $node is not in the index") + + for (child, parents) <- usedBy.iterator do + assertCanonical(child) - private def unique(node: ENode): node.type = - index.getOrElseUpdate( - node, { + // ----------------------------------- + // Canonicalization + // ----------------------------------- + + def canonicalize(node: ENode, deep: Boolean = true): ENode = + def recur(node: ENode): ENode = + if deep then canonicalize(node, deep) else representant(node) + trace(s"canonicalize ${show(node)}"): + representant(unique( node match case ENode.Atom(tp) => - () + node case ENode.Constructor(sym) => - () + node case ENode.Select(qual, member) => - addUse(qual, node) + normalizeSelect(recur(qual), member) case ENode.Apply(fn, args) => - addUse(fn, node) - args.foreach(addUse(_, node)) + ENode.Apply(recur(fn), args.map(recur)) case ENode.OpApply(op, args) => - args.foreach(addUse(_, node)) + normalizeOp(op, args.map(recur)) case ENode.TypeApply(fn, args) => - addUse(fn, node) + ENode.TypeApply(recur(fn), args) case ENode.Lambda(paramTps, retTp, body) => - addUse(body, node) - node - } - ).asInstanceOf[node.type] - - // TODO(mbovel): Memoize this - def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramTps: List[ENode.ArgRefType] = Nil)(using - Context - ): Option[ENode] = - trace(i"EGraph.toNode $tree", Printers.qualifiedTypes): - computeToNode(tree, paramSyms, paramTps).map(node => representent(unique(node))) - - private def computeToNode( - tree: Tree, - paramSyms: List[Symbol] = Nil, - paramTps: List[ENode.ArgRefType] = Nil - )(using currentCtx: Context): Option[ENode] = - trace(i"ENode.computeToNode $tree", Printers.qualifiedTypes): - def normalizeType(tp: Type): Type = - tp match - case tp: TypeVar if tp.isPermanentlyInstantiated => - tp.permanentInst - case tp: NamedType => - if tp.symbol.isStatic then tp.symbol.termRef - else normalizeType(tp.prefix).select(tp.symbol) - case tp => tp - - def mapType(tp: Type): Type = - normalizeType(tp.subst(paramSyms, paramTps)) - - tree match - case Literal(_) | Ident(_) | This(_) if tree.tpe.isInstanceOf[SingletonType] => - Some(ENode.Atom(mapType(tree.tpe).asInstanceOf[SingletonType])) - case Select(New(_), nme.CONSTRUCTOR) => - constructorNode(tree.symbol) - case tree: Select if isCaseClassApply(tree.symbol) => - constructorNode(tree.symbol.owner.linkedClass.primaryConstructor) - case Select(qual, name) => - for qualNode <- toNode(qual, paramSyms, paramTps) yield normalizeSelect(qualNode, tree.symbol) - case BinaryOp(lhs, op, rhs) if builtinOps.contains(op) => - for - lhsNode <- toNode(lhs, paramSyms, paramTps) - rhsNode <- toNode(rhs, paramSyms, paramTps) - yield normalizeOp(builtinOps(op), List(lhsNode, rhsNode)) - case BinaryOp(lhs, d.Int_-, rhs) => - for - lhsNode <- toNode(lhs, paramSyms, paramTps) - rhsNode <- toNode(rhs, paramSyms, paramTps) - yield normalizeOp(Op.IntSum, List(lhsNode, normalizeOp(Op.IntProduct, List(minusOneIntNode, rhsNode)))) - case Apply(fun, args) => - for - funNode <- toNode(fun, paramSyms, paramTps) - argsNodes <- args.map(toNode(_, paramSyms, paramTps)).sequence - yield ENode.Apply(funNode, argsNodes) - case TypeApply(fun, args) => - for funNode <- toNode(fun, paramSyms, paramTps) - yield ENode.TypeApply(funNode, args.map(tp => mapType(tp.tpe))) - case closureDef(defDef) => - defDef.symbol.info.dealias match - case mt: MethodType => - assert(defDef.termParamss.size == 1, "closures have a single parameter list, right?") - val myParamSyms: List[Symbol] = defDef.termParamss.head.map(_.symbol) - val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty - val paramTpsSize = paramTps.size - for myParamSym <- myParamSyms do - val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList)) - myParamTps += ENode.ArgRefType(paramTpsSize + myParamTps.size, underlying) - val myRetTp = mapType(defDef.tpt.tpe.subst(myParamSyms, myParamTps.toList)) - for body <- toNode(defDef.rhs, myParamSyms ::: paramSyms, myParamTps.toList ::: paramTps) - yield ENode.Lambda(myParamTps.toList, myRetTp, body) - case _ => None - case _ => - None - - // TODO(mbovel): Memoize this - private def constructorNode(constr: Symbol)(using Context): Option[ENode.Constructor] = - val clazz = constr.owner - if hasStructuralEquality(clazz) then - val isPrimaryConstructor = constr.denot.isPrimaryConstructor - val fieldsRaw = clazz.denot.asClass.paramAccessors.filter(isPrimaryConstructor && _.isStableMember) - val constrParams = constr.paramSymss.flatten.filter(_.isTerm) - val fields = constrParams.map(p => fieldsRaw.find(_.name == p.name).getOrElse(NoSymbol)) - Some(ENode.Constructor(constr)(fields)) - else - None - - private def hasStructuralEquality(clazz: Symbol)(using Context): Boolean = - val equalsMethod = clazz.info.decls.lookup(nme.equals_) - val equalsNotOverriden = !equalsMethod.exists || equalsMethod.is(Flags.Synthetic) - clazz.isClass && clazz.is(Flags.Case) && equalsNotOverriden - - private def isCaseClassApply(meth: Symbol)(using Context): Boolean = - meth.name == nme.apply - && meth.flags.is(Flags.Synthetic) - && meth.owner.linkedClass.is(Flags.Case) - - private def canonicalize(node: ENode): ENode = - // println(s"canonicalize $node") - representent(unique( - node match - case ENode.Atom(tp) => - node - case ENode.Constructor(sym) => - node - case ENode.Select(qual, member) => - normalizeSelect(representent(qual), member) - case ENode.Apply(fn, args) => - ENode.Apply(representent(fn), args.map(representent)) - case ENode.OpApply(op, args) => - normalizeOp(op, args.map(representent)) - case ENode.TypeApply(fn, args) => - ENode.TypeApply(representent(fn), args) - case ENode.Lambda(paramTps, retTp, body) => - ENode.Lambda(paramTps, retTp, representent(body)) - )) + ENode.Lambda(paramTps, retTp, recur(body)) + )) private def normalizeSelect(qual: ENode, member: Symbol): ENode = getAppliedConstructor(qual) match case Some(constr) => val memberIndex = constr.fields.indexOf(member) - if memberIndex >= 0 then val args = getTermArguments(qual) assert(args.size == constr.fields.size) @@ -268,24 +354,24 @@ final class EGraph(rootCtx: Context): private def getAppliedConstructor(node: ENode): Option[ENode.Constructor] = node match - case ENode.Apply(fn, args) => getAppliedConstructor(fn) + case ENode.Apply(fn, args) => getAppliedConstructor(fn) case ENode.TypeApply(fn, args) => getAppliedConstructor(fn) - case node: ENode.Constructor => Some(node) - case _ => None + case node: ENode.Constructor => Some(node) + case _ => None private def getTermArguments(node: ENode): List[ENode] = node match - case ENode.Apply(fn, args) => getTermArguments(fn) ::: args + case ENode.Apply(fn, args) => getTermArguments(fn) ::: args case ENode.TypeApply(fn, args) => getTermArguments(fn) - case _ => Nil + case _ => Nil private def normalizeOp(op: ENode.Op, args: List[ENode]): ENode = - op match + val res = op match case Op.Equal => assert(args.size == 2, s"Expected 2 arguments for equality, got $args") if args(0) eq args(1) then trueNode - else ENode.OpApply(op, args.sortBy(_.hashCode())) + else ENode.OpApply(op, args.sortBy(idOf.apply)) case Op.And => assert(args.size == 2, s"Expected 2 arguments for conjunction, got $args") if (args(0) eq falseNode) || (args(1) eq falseNode) then falseNode @@ -298,12 +384,31 @@ final class EGraph(rootCtx: Context): else if args(0) eq falseNode then args(1) else if args(1) eq falseNode then args(0) else ENode.OpApply(op, args) - case Op.IntProduct => - val (consts, nonConsts) = decomposeIntProduct(args) - makeIntProduct(consts, nonConsts) case Op.IntSum => val (const, nonConsts) = decomposeIntSum(args) makeIntSum(const, nonConsts) + case Op.IntMinus => + assert(args.size == 2, s"Expected 2 arguments for subtraction, got $args") + // Rewrite a - b as a + (-1) * b + val lhs = args(0) + val rhs = args(1) + val negativeRhs = unique(normalizeOp(Op.IntProduct, List(minusOneIntNode, rhs))) + normalizeOp(Op.IntSum, List(lhs, negativeRhs)) + case Op.IntProduct => + val (consts, nonConsts) = decomposeIntProduct(args) + makeIntProduct(consts, nonConsts) + case Op.IntLessThan => constFoldBinaryOp[Int, Boolean](op, args, _ < _) + case Op.IntLessEqual => constFoldBinaryOp[Int, Boolean](op, args, _ <= _) + case Op.IntGreaterThan => constFoldBinaryOp[Int, Boolean](op, args, _ > _) + case Op.IntGreaterEqual => constFoldBinaryOp[Int, Boolean](op, args, _ >= _) + case _ => + ENode.OpApply(op, args) + res + + private def constFoldBinaryOp[T: ClassTag, S](op: ENode.Op, args: List[ENode], fn: (T, T) => S): ENode = + args match + case List(ENode.Atom(ConstantType(Constant(c1: T))), ENode.Atom(ConstantType(Constant(c2: T)))) => + constant(fn(c1, c2)) case _ => ENode.OpApply(op, args) @@ -311,12 +416,12 @@ final class EGraph(rootCtx: Context): val factors = args.flatMap: case ENode.OpApply(Op.IntProduct, innerFactors) => innerFactors - case arg => List(arg) + case arg => List(arg) val (consts, nonConsts) = factors.partitionMap: case ENode.Atom(ConstantType(Constant(c: Int))) => Left(c) - case factor => Right(factor) - (consts.product, nonConsts.sortBy(_.hashCode())) + case factor => Right(factor) + (consts.product, nonConsts.sortBy(idOf.apply)) private def makeIntProduct(const: Int, nonConsts: List[ENode]): ENode = if const == 0 then @@ -326,12 +431,15 @@ final class EGraph(rootCtx: Context): else if nonConsts.size == 1 then nonConsts.head else ENode.OpApply(Op.IntProduct, nonConsts) else - val constNode = unique(ENode.Atom(ConstantType(Constant(const))(using rootCtx))) + val constNode = constant(const) nonConsts match case Nil => constNode - case List(ENode.OpApply(Op.IntSum, summands)) => - ENode.OpApply(Op.IntSum, summands.map(summand => normalizeOp(Op.IntProduct, List(constNode, summand)))) + //case List(ENode.OpApply(Op.IntSum, summands)) => + // ENode.OpApply( + // Op.IntSum, + // summands.map(summand => unique(makeIntProduct(const, List(summand)))) + // ) case _ => ENode.OpApply(Op.IntProduct, constNode :: nonConsts) @@ -339,23 +447,23 @@ final class EGraph(rootCtx: Context): val summands: List[ENode] = args.flatMap: case ENode.OpApply(Op.IntSum, innerSummands) => innerSummands - case arg => List(arg) + case arg => List(arg) val decomposed: List[(Int, List[ENode])] = summands.map: case ENode.OpApply(Op.IntProduct, args) => args match case ENode.Atom(ConstantType(Constant(const: Int))) :: nonConsts => (const, nonConsts) - case nonConsts => (1, nonConsts) + case nonConsts => (1, nonConsts) case ENode.Atom(ConstantType(Constant(const: Int))) => (const, Nil) - case other => (1, List(other)) + case other => (1, List(other)) val grouped = decomposed.groupMapReduce(_._2)(_._1)(_ + _) val const = grouped.getOrElse(Nil, 0) val nonConsts = grouped .toList .filter((nonConsts, const) => const != 0 && !nonConsts.isEmpty) - .sortBy((nonConsts, const) => nonConsts.hashCode()) - .map((nonConsts, const) => makeIntProduct(const, nonConsts)) + .sortBy((nonConsts, const) => idOf(nonConsts.head)) + .map((nonConsts, const) => unique(makeIntProduct(const, nonConsts))) (const, nonConsts) private def makeIntSum(const: Int, nonConsts: List[ENode]): ENode = @@ -364,88 +472,6 @@ final class EGraph(rootCtx: Context): else if nonConsts.size == 1 then nonConsts.head else ENode.OpApply(Op.IntSum, nonConsts) else - val constNode = unique(ENode.Atom(ConstantType(Constant(const))(using rootCtx))) + val constNode = constant(const) if nonConsts.isEmpty then constNode else ENode.OpApply(Op.IntSum, constNode :: nonConsts) - - private def order(a: ENode, b: ENode): (ENode, ENode) = - (a, b) match - case (ENode.Atom(_: ConstantType), _) => (a, b) - case (_, ENode.Atom(_: ConstantType)) => (b, a) - case (_: ENode.Constructor, _) => (a, b) - case (_, _: ENode.Constructor) => (b, a) - case (_: ENode.Atom, _) => (a, b) - case (_, _: ENode.Atom) => (b, a) - case (_: ENode.Select, _) => (a, b) - case (_, _: ENode.Select) => (b, a) - case (_: ENode.Apply, _) => (a, b) - case (_, _: ENode.Apply) => (b, a) - case (_: ENode.TypeApply, _) => (a, b) - case (_, _: ENode.TypeApply) => (b, a) - case _ => (a, b) - - def merge(a: ENode, b: ENode): Unit = - val aRepr = representent(a) - val bRepr = representent(b) - if aRepr eq bRepr then return - assert(aRepr != bRepr, s"$aRepr and $bRepr are `equals` but not `eq`") - - /// Update represententOf and usedBy maps - val (newRepr, oldRepr) = order(aRepr, bRepr) - represententOf(oldRepr) = newRepr - uses(newRepr) ++= uses(oldRepr) - val oldUses = uses(oldRepr) - usedBy.remove(oldRepr) - - // Propagate truth values over disjunctions, conjunctions and equalities - oldRepr match - case ENode.OpApply(Op.And, args) if newRepr eq trueNode => - args.foreach(merge(_, trueNode)) - case ENode.OpApply(Op.Or, args) if newRepr eq falseNode => - args.foreach(merge(_, falseNode)) - case ENode.OpApply(Op.Equal, args) if newRepr eq trueNode => - args.foreach(arg => merge(args(0), args(1))) - case _ => - () - - // Enqueue all nodes that use the oldRepr for repair - worklist.enqueueAll(oldUses) - - def repair(): Unit = - while !worklist.isEmpty do - val head = worklist.dequeue() - val headRepr = representent(head) - val headCanonical = canonicalize(head) - if headRepr ne headCanonical then - merge(headRepr, headCanonical) - - private def toTree(node: ENode, paramRefs: List[Tree])(using Context): Tree = - node match - case ENode.Atom(tp) => - singleton(tp) - case ENode.Constructor(sym) => - val tycon = sym.owner.info.typeConstructor - New(tycon).select(TermRef(tycon, sym)) - case ENode.Select(qual, member) => - toTree(qual, paramRefs).select(member) - case ENode.Apply(fn, args) => - Apply(toTree(fn, paramRefs), args.map(toTree(_, paramRefs))) - case ENode.OpApply(op, args) => - ??? - case ENode.TypeApply(fn, args) => - TypeApply(toTree(fn, paramRefs), args.map(TypeTree(_, false))) - case ENode.Lambda(paramTps, retTp, body) => - ??? - - extension [T](xs: List[Option[T]]) - private def sequence: Option[List[T]] = - var result = List.newBuilder[T] - var current = xs - while current.nonEmpty do - current.head match - case Some(x) => - result += x - current = current.tail - case None => - return None - Some(result.result()) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala index 059ca8066093..7b88b6e2b9e6 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala @@ -1,25 +1,56 @@ package dotty.tools.dotc.qualified_types +import scala.collection.mutable.ListBuffer + +import dotty.tools.dotc.ast.{tpd, untpd} +import dotty.tools.dotc.ast.tpd.TreeOps +import dotty.tools.dotc.config.Printers +import dotty.tools.dotc.config.Settings.Setting.value +import dotty.tools.dotc.core.Constants.Constant import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Contexts.ctx +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Flags.EmptyFlags import dotty.tools.dotc.core.Hashable.Binders +import dotty.tools.dotc.core.Names.{termName, Name} import dotty.tools.dotc.core.Names.Designator -import dotty.tools.dotc.core.Names.Name import dotty.tools.dotc.core.StdNames.nme -import dotty.tools.dotc.core.StdNames.tpnme +import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} import dotty.tools.dotc.core.Symbols.Symbol import dotty.tools.dotc.core.Types.{ + AndType, + AppliedType, CachedProxyType, + ClassInfo, ConstantType, + LambdaType, MethodType, NamedType, NoPrefix, + ParamRef, SingletonType, + SkolemType, + TermParamRef, TermRef, ThisType, - Type + Type, + TypeMap, + TypeProxy, + TypeRef, + TypeVar, + ValueType } +import dotty.tools.dotc.parsing +import dotty.tools.dotc.printing.{Printer, Showable} +import dotty.tools.dotc.printing.GlobalPrec +import dotty.tools.dotc.printing.Texts.{stringToText, Text} +import dotty.tools.dotc.qualified_types.ENode.Op +import dotty.tools.dotc.reporting.trace +import dotty.tools.dotc.transform.TreeExtractors.{BinaryOp, UnaryOp} +import dotty.tools.dotc.util.Spans.Span -enum ENode: +enum ENode extends Showable: import ENode.* case Atom(tp: SingletonType) @@ -30,30 +61,327 @@ enum ENode: case TypeApply(fn: ENode, args: List[Type]) case Lambda(paramTps: List[Type], retTp: Type, body: ENode) - override def toString(): String = + require( this match - case Atom(tp) => typeToString(tp) - case Constructor(constr) => s"new ${designatorToString(constr.lastKnownDenotation.owner)}" - case Select(qual, member) => s"$qual.${designatorToString(member)}" - case Apply(fn, args) => s"$fn(${args.mkString(", ")})" - case OpApply(op, args) => s"(${args.mkString(op.operatorString())})" - case TypeApply(fn, args) => s"$fn[${args.map(typeToString).mkString(", ")}]" + case Constructor(constr) => + constr.lastKnownDenotation.isConstructor case Lambda(paramTps, retTp, body) => - s"(${paramTps.map(typeToString).mkString(", ")}): ${{ typeToString(retTp) }} => $body" + paramTps.zipWithIndex.forall: (tp, index) => + tp match + case ENodeParamRef(i, _) => i < index + case _ => true + case _ => true + ) -object ENode: - private def typeToString(tp: Type): String = - tp match - case tp: NamedType => - val prefixString = if isEmptyPrefix(tp.prefix) then "" else typeToString(tp.prefix) + "." - prefixString + designatorToString(tp.designator) - case tp: ConstantType => - tp.value.value.toString - case tp: ThisType => - typeToString(tp.tref) + ".this" - case _ => - tp.toString + def prettyString(printFullPaths: Boolean = false): String = + + def rec(n: ENode): String = + n.prettyString(printFullPaths) + + def printTp(tp: Type): String = + tp match + case tp: NamedType => + val prefixString = if isEmptyPrefix(tp.prefix) then "" else printTp(tp.prefix) + "." + prefixString + printDesignator(tp.designator) // + s"#${System.identityHashCode(tp).toHexString}" + case tp: ConstantType => + tp.value.value.toString // + s"#${System.identityHashCode(tp).toHexString}" + case tp: SkolemType => + "(?" + tp.hashCode + ": " + printTp(tp.info) + ")" + case tp: ThisType => + printTp(tp.tref) + ".this" + case tp: TypeVar => + tp.origin.paramName.toString() + case tp: ENodeParamRef => + s"arg${tp.index}" + case tp: AppliedType => + val argsString = tp.args.map(printTp).mkString(", ") + s"${printTp(tp.tycon)}[$argsString]" + case _ => + tp.toString + + def printDesignator(d: Designator): String = + d match + case d: Symbol => d.lastKnownDenotation.name.toString + case _ => d.toString + + this match + case Atom(tp) => + printTp(tp) + case Constructor(constr) => + s"new ${printDesignator(constr.lastKnownDenotation.owner)}" + case Select(qual, member) => + s"${rec(qual)}.${printDesignator(member)}" + case Apply(fn, args) => + s"${rec(fn)}(${args.map(rec).mkString(", ")})" + case OpApply(op, args) => + s"(${args.map(rec).mkString(" " + op.operatorName().toString() + " ")})" + case TypeApply(fn, args) => + s"${rec(fn)}[${args.map(printTp).mkString(", ")}]" + case Lambda(paramTps, retTp, body) => + val paramsString = paramTps.map(p => "_: " + printTp(p)).mkString(", ") + s"($paramsString): ${printTp(retTp)} => ${rec(body)}" + + + override def toText(p: Printer): Text = toText(p, false) + + def toText(p: Printer, printAddresses: Boolean): Text = + given Context = p.printerContext + + def withAddress(obj: Any, text: Text): Text = + if printAddresses then + "<" ~ text ~ s"#${System.identityHashCode(obj).toHexString}" ~ ">" + else + text + + def listToText[T](xs: List[T], fn: T => Text, sep: Text): Text = + xs.map(fn).reduceLeftOption(_ ~ sep ~ _).getOrElse("") + + withAddress( + this, + this match + case Atom(tp) => + p.toTextRef(tp) + case Constructor(constr) => + "new" ~ p.toText(constr.lastKnownDenotation.owner) + case Select(qual, member) => + qual.toText(p) ~ "." ~ p.toText(member.name) + case Apply(fn, args) => + fn.toText(p) ~ "(" ~ listToText(args, arg => p.atPrec(GlobalPrec)(arg.toText(p)), ", ") ~ ")" + case OpApply(op, args) => + assert(args.nonEmpty) + op match + // All operators with arity >= 2 + case Op.IntSum | Op.IntMinus | Op.IntProduct | + Op.IntLessThan | Op.IntLessEqual | Op.IntGreaterThan | Op.IntGreaterEqual | + Op.LongSum | Op.LongMinus | Op.LongProduct | + Op.And | Op.Or | Op.Equal | Op.NotEqual => + val opPrec = parsing.precedence(op.operatorName()) + val isRightAssoc = false + val leftPrec = if isRightAssoc then opPrec + 1 else opPrec + val rightPrec = if !isRightAssoc then opPrec + 1 else opPrec + p.changePrec(opPrec): + args.map(_.toText(p)).reduceLeft: (l, r) => + p.atPrec(leftPrec)(l) ~ " " ~ p.toText(op.operatorName()) ~ " " ~ p.atPrec(rightPrec)(r) + // Unary operators + case _ => + assert(args.length == 1) + val opPrec = parsing.precedence(op.operatorName()) + p.changePrec(opPrec): + p.toText(op.operatorName()) ~ p.atPrec(opPrec + 1)(args.head.toText(p)) + case TypeApply(fn, args) => + fn.toText(p) ~ "[" ~ listToText(args, p.toText, ",") ~ "]" + case Lambda(paramTps, retTp, body) => + val paramsText = listToText(paramTps, "_: " ~ p.toText(_), ", ") + "(" ~ paramsText ~ ")" ~ " => " ~ p.atPrec(GlobalPrec)(body.toText(p)) + ) + + def showNoBreak(using Context): String = + toText(ctx.printer).mkString() + + def dotId() = + "n" + System.identityHashCode(this).toHexString.substring(1) + + def toDot()(using _ctx: Context): String = + given Context = _ctx.withoutColors + val id = dotId() + val fields: List[ENode | String] = + this match + case Atom(tp) => this.showNoBreak :: Nil + case Constructor(constr) => this.showNoBreak :: Nil + case Select(qual, member) => qual :: member.name.show :: Nil + case Apply(fn, args) => fn :: args + case OpApply(op, args) => op.operatorName().toString() :: args + case TypeApply(fn, args) => fn :: args.map(_.show) + case Lambda(paramTps, retTp, body) => + val paramsString = paramTps.map(p => "_:" + p.show).mkString(", ") + "(" + paramsString + ") => " + retTp.show :: body :: Nil + val fieldStrings = + fields.zipWithIndex.map: (field, i) => + field match + case child: ENode => s"" + case str: String => str.replace("<", "\\<").replace(">", "\\>") + val nodeString = s"$id [label=\"${fieldStrings.mkString("|")}\"];\n" + val edgesString = + fields.zipWithIndex.map: (field, i) => + field match + case child: ENode => s"$id:p$i -> ${child.dotId()};\n" + case _ => "" + nodeString + edgesString.mkString + + def mapTypes(f: Type => Type)(using Context): ENode = + this match + case Atom(tp) => + val mappedTp = f(tp) + if mappedTp eq tp then + this + else + mappedTp match + case mappedTp: SingletonType => Atom(mappedTp) + case _ => Atom(SkolemType(mappedTp)) + case Constructor(constr) => + this + case node @ Select(qual, member) => + node.derived(qual.mapTypes(f), member) + case node @ Apply(fn, args) => + node.derived(fn.mapTypes(f), args.mapConserve(_.mapTypes(f))) + case node @ OpApply(op, args) => + node.derived(op, args.mapConserve(_.mapTypes(f))) + case node @ TypeApply(fn, args) => + node.derived(fn.mapTypes(f), args.mapConserve(f)) + case node @ Lambda(paramTps, retTp, body) => + node.derived(paramTps.mapConserve(f), f(retTp), body.mapTypes(f)) + + def foreachType(f: Type => Unit)(using Context): Unit = + this match + case Atom(tp) => f(tp) + case Constructor(_) => () + case Select(qual, _) => qual.foreachType(f) + case Apply(fn, args) => + fn.foreachType(f) + args.foreach(_.foreachType(f)) + case OpApply(_, args) => args.foreach(_.foreachType(f)) + case TypeApply(fn, args) => + fn.foreachType(f) + args.foreach(f) + case Lambda(paramTps, retTp, body) => + paramTps.foreach(f) + f(retTp) + body.foreachType(f) + + def normalizeTypes()(using Context): ENode = + mapTypes(NormalizeMap()) + + private class NormalizeMap(using Context) extends TypeMap: + def apply(tp: Type): Type = + tp match + case tp: TypeVar if tp.isPermanentlyInstantiated => + apply(tp.permanentInst) + case tp: NamedType => + val dealiased = tp.dealias + if dealiased ne tp then + apply(dealiased) + else if tp.symbol.isStatic then + if tp.isInstanceOf[TermRef] then tp.symbol.termRef + else tp.symbol.typeRef + else + derivedSelect(tp, apply(tp.prefix)) + case _ => + mapOver(tp) + + def substEParamRefs(from: Int, to: List[Type])(using Context): ENode = + this match + case Atom(tp) => + mapTypes(SubstEParamsMap(from, to)) + case Constructor(_) => + this + case node @ Select(qual, member) => + node.derived(qual.substEParamRefs(from, to), member) + case node @ Apply(fn, args) => + node.derived(fn.substEParamRefs(from, to), args.mapConserve(_.substEParamRefs(from, to))) + case node @ OpApply(op, args) => + node.derived(op, args.mapConserve(_.substEParamRefs(from, to))) + case node @ TypeApply(fn, args) => + node.derived(fn.substEParamRefs(from, to), args.mapConserve(SubstEParamsMap(from, to))) + case node @ Lambda(paramTps, retTp, body) => + node.derived(paramTps.mapConserve(SubstEParamsMap(from, to)), SubstEParamsMap(from, to)(retTp), body.substEParamRefs(from + paramTps.length, to)) + + private class SubstEParamsMap(from: Int, to: List[Type])(using Context) extends TypeMap: + override def apply(tp: Type): Type = + tp match + case ENodeParamRef(i, _) if i >= from && i < from + to.length => to(i - from) + case _ => mapOver(tp) + + def foreach(f: ENode => Unit): Unit = + f(this) + this match + case Atom(_) => () + case Constructor(_) => () + case Select(qual, _) => + qual.foreach(f) + case Apply(fn, args) => + fn.foreach(f) + args.foreach(_.foreach(f)) + case OpApply(_, args) => + args.foreach(_.foreach(f)) + case TypeApply(fn, args) => + fn.foreach(f) + case Lambda(_, _, body) => + body.foreach(f) + def contains(that: ENode): Boolean = + var found = false + foreach: node => + if node eq that then + found = true + found + + // ----------------------------------- + // Conversion from E-Nodes to Trees + // ----------------------------------- + + def toTree(paramRefs: List[Type] = Nil)(using Context): tpd.Tree = + def mapType(tp: Type): Type = SubstEParamsMap(0, paramRefs)(tp) + + trace(i"ENode.toTree $this, paramRefs: $paramRefs", Printers.qualifiedTypes): + this match + case Atom(tp) => + mapType(tp) match + case tp1: TermParamRef => untpd.Ident(tp1.paramName).withType(tp1) + case tp1 => tpd.singleton(tp1) + case Constructor(sym) => + val tycon = sym.owner.asClass.classDenot.classInfo.selfType + tpd.New(tycon).select(TermRef(tycon, sym)) + case Select(qual, member) => + qual.toTree(paramRefs).select(member) + case Apply(fn, args) => + tpd.Apply(fn.toTree(paramRefs), args.map(_.toTree(paramRefs))) + case OpApply(op, args) => + def unaryOp(symbol: Symbol): tpd.Tree = + require(args.length == 1) + args(0).toTree(paramRefs).select(symbol).appliedToNone + def binaryOp(symbol: Symbol): tpd.Tree = + require(args.length == 2) + args(0).toTree(paramRefs).select(symbol).appliedTo(args(1).toTree(paramRefs)) + op match + case Op.IntSum => + args.map(_.toTree(paramRefs)).reduceLeft(_.select(defn.Int_+).appliedTo(_)) + case Op.IntMinus => + binaryOp(defn.Int_-) + case Op.IntProduct => + args.map(_.toTree(paramRefs)).reduceLeft(_.select(defn.Int_*).appliedTo(_)) + case Op.LongSum => + ??? + case Op.LongMinus => + ??? + case Op.LongProduct => + ??? + case Op.Equal => + args(0).toTree(paramRefs).equal(args(1).toTree(paramRefs)) + case Op.NotEqual => + val lhs = args(0).toTree(paramRefs) + val rhs = args(1).toTree(paramRefs) + tpd.applyOverloaded(lhs, nme.NE, rhs :: Nil, Nil, defn.BooleanType) + case Op.Not => unaryOp(defn.Boolean_!) + case Op.And => binaryOp(defn.Boolean_&&) + case Op.Or => binaryOp(defn.Boolean_||) + case Op.IntLessThan => binaryOp(defn.Int_<) + case Op.IntLessEqual => binaryOp(defn.Int_<=) + case Op.IntGreaterThan => binaryOp(defn.Int_>) + case Op.IntGreaterEqual => binaryOp(defn.Int_>=) + case TypeApply(fn, args) => + tpd.TypeApply(fn.toTree(paramRefs), args.map(tp => tpd.TypeTree(mapType(tp), false))) + case Lambda(paramTps, retTp, body) => + val myParamNames = paramTps.zipWithIndex.map((tp, i) => termName("param" + (paramRefs.size + i))) + def computeParamTypes(mt: MethodType) = + val reversedParamRefs = mt.paramRefs.reverse + paramTps.zipWithIndex.map((tp, i) => SubstEParamsMap(0, reversedParamRefs.take(i) ::: paramRefs)(tp)) + val mt = MethodType(myParamNames)(computeParamTypes, _ => retTp) + tpd.Lambda(mt, myParamRefTrees => + val myParamRefs = myParamRefTrees.map(_.tpe).reverse + body.toTree(myParamRefs ::: paramRefs) + ) + +object ENode: private def isEmptyPrefix(tp: Type): Boolean = tp match case tp: NoPrefix.type => @@ -61,44 +389,250 @@ object ENode: case tp: ThisType => tp.tref.designator match case d: Symbol => d.lastKnownDenotation.name.toTermName == nme.EMPTY_PACKAGE - case _ => false + case _ => false case _ => false - private def designatorToString(d: Designator): String = - d match - case d: Symbol => d.lastKnownDenotation.name.toString - case _ => d.toString enum Op: case IntSum + case IntMinus case IntProduct case LongSum + case LongMinus case LongProduct case Equal + case NotEqual case Not case And case Or - case LessThan + case IntLessThan + case IntLessEqual + case IntGreaterThan + case IntGreaterEqual - def operatorString(): String = + def operatorName(): Name = this match - case IntSum => "+" - case IntProduct => "*" - case LongSum => "+" - case LongProduct => "*" - case Equal => "==" - case Not => "!" - case And => "&&" - case Or => "||" - case LessThan => "<" - - /** Reference to the argument of an [[ENode.Lambda]]. - * - * @param index  - * Debruijn index of the argument, starting from 0 - * @param underyling - * Underlying type of the argument - */ - final case class ArgRefType(index: Int, underlying: Type) extends CachedProxyType, SingletonType: - override def underlying(using Context): Type = underlying - override def computeHash(bs: Binders): Int = doHash(bs, index, underlying) + case IntSum => nme.Plus + case IntMinus => nme.Minus + case IntProduct => nme.Times + case LongSum => nme.Plus + case LongMinus => nme.Minus + case LongProduct => nme.Times + case Equal => nme.Equals + case NotEqual => nme.NotEquals + case Not => nme.Not + case And => nme.And + case Or => nme.Or + case IntLessThan => nme.Le + case IntLessEqual => nme.Lt + case IntGreaterThan => nme.Gt + case IntGreaterEqual => nme.Ge + + // ----------------------------------- + // Conversion from Trees to E-Nodes + // ----------------------------------- + + def fromTree( + tree: tpd.Tree, + paramSyms: List[Symbol] = Nil, + paramTps: List[Type] = Nil + )(using Context): Option[ENode] = + val d = defn // Need a stable path to match on `defn` members + + def binaryOpNode(op: ENode.Op, lhs: tpd.Tree, rhs: tpd.Tree): Option[ENode] = + for + lhsNode <- fromTree(lhs, paramSyms, paramTps) + rhsNode <- fromTree(rhs, paramSyms, paramTps) + yield OpApply(op, List(lhsNode, rhsNode)) + + def unaryOpNode(op: ENode.Op, arg: tpd.Tree): Option[ENode] = + for argNode <- fromTree(arg, paramSyms, paramTps) yield OpApply(op, List(argNode)) + + def isValidEqual(sym: Symbol, lhs: tpd.Tree, rhs: tpd.Tree): Boolean = + def lhsClass = lhs.tpe.classSymbol + sym == defn.Int_== + || sym == defn.Boolean_== + || sym == defn.Any_== && lhsClass == defn.StringClass + || sym.name == nme.EQ && lhsClass.exists && hasCaseClassEquals(lhsClass) + + trace(s"ENode.fromTree $tree", Printers.qualifiedTypes): + tree match + case tpd.Literal(_) | tpd.Ident(_) | tpd.This(_) + if tree.tpe.isInstanceOf[SingletonType] && tpd.isIdempotentExpr(tree) => + Some(Atom(substParamRefs(tree.tpe, paramSyms, paramTps).asInstanceOf[SingletonType])) + case tpd.Select(tpd.New(_), nme.CONSTRUCTOR) => + constructorNode(tree.symbol) + case tree: tpd.Select if isCaseClassApply(tree.symbol) => + constructorNode(tree.symbol.owner.linkedClass.primaryConstructor) + case tpd.Select(qual, name) => + for qualNode <- fromTree(qual, paramSyms, paramTps) yield Select(qualNode, tree.symbol) + case BinaryOp(lhs, sym, rhs) if isValidEqual(sym, lhs, rhs) => binaryOpNode(ENode.Op.Equal, lhs, rhs) + case BinaryOp(lhs, d.Int_!= | d.Boolean_!=, rhs) => binaryOpNode(ENode.Op.NotEqual, lhs, rhs) + case UnaryOp(d.Boolean_!, arg) => unaryOpNode(ENode.Op.Not, arg) + case BinaryOp(lhs, d.Boolean_&&, rhs) => binaryOpNode(ENode.Op.And, lhs, rhs) + case BinaryOp(lhs, d.Boolean_||, rhs) => binaryOpNode(ENode.Op.Or, lhs, rhs) + case BinaryOp(lhs, d.Int_+, rhs) => binaryOpNode(ENode.Op.IntSum, lhs, rhs) + case BinaryOp(lhs, d.Int_-, rhs) => binaryOpNode(ENode.Op.IntMinus, lhs, rhs) + case BinaryOp(lhs, d.Int_*, rhs) => binaryOpNode(ENode.Op.IntProduct, lhs, rhs) + case BinaryOp(lhs, d.Int_<, rhs) => binaryOpNode(ENode.Op.IntLessThan, lhs, rhs) + case BinaryOp(lhs, d.Int_<=, rhs) => binaryOpNode(ENode.Op.IntLessEqual, lhs, rhs) + case BinaryOp(lhs, d.Int_>, rhs) => binaryOpNode(ENode.Op.IntGreaterThan, lhs, rhs) + case BinaryOp(lhs, d.Int_>=, rhs) => binaryOpNode(ENode.Op.IntGreaterEqual, lhs, rhs) + case tpd.Apply(fun, args) => + for + funNode <- fromTree(fun, paramSyms, paramTps) + argsNodes <- args.map(fromTree(_, paramSyms, paramTps)).sequence + yield ENode.Apply(funNode, argsNodes) + case tpd.TypeApply(fun, args) => + for funNode <- fromTree(fun, paramSyms, paramTps) + yield ENode.TypeApply(funNode, args.map(tp => substParamRefs(tp.tpe, paramSyms, paramTps))) + case tpd.closureDef(defDef) => + defDef.symbol.info.dealias match + case mt: MethodType => + assert(defDef.termParamss.size == 1, "closure is expected to have a single parameter list") + var newParamSyms: List[Symbol] = paramSyms + var newParamTps: List[Type] = paramTps + val myParamSyms: List[Symbol] = defDef.termParamss.head.map(_.symbol) + val myParamTps: List[Type] = mt.paramInfos + for (myParamSym, myParamTp) <- myParamSyms.zip(myParamTps) do + newParamTps = substParamRefs(myParamTp, newParamSyms, newParamTps) :: newParamTps + newParamSyms = myParamSym :: newParamSyms + val myRetTp = substParamRefs(mt.resType, newParamSyms, newParamTps) + for body <- fromTree(defDef.rhs, newParamSyms, newParamTps) + yield ENode.Lambda(newParamTps.take(myParamTps.size), myRetTp, body) + case _ => None + case _ => + None + + private def constructorNode(constr: Symbol)(using Context): Option[ENode.Constructor] = + val clazz = constr.owner + if hasCaseClassEquals(clazz) then + val isPrimaryConstructor = constr.denot.isPrimaryConstructor + val fieldsRaw = clazz.denot.asClass.paramAccessors.filter(isPrimaryConstructor && _.isStableMember) + val constrParams = constr.paramSymss.flatten.filter(_.isTerm) + val fields = constrParams.map(p => fieldsRaw.find(_.name == p.name).getOrElse(NoSymbol)) + Some(ENode.Constructor(constr)(fields)) + else + None + + private def hasCaseClassEquals(clazz: Symbol)(using Context): Boolean = + val equalsMethod = clazz.info.decls.lookup(nme.equals_) + val equalsNotOverriden = !equalsMethod.exists || equalsMethod.is(Flags.Synthetic) + clazz.isClass && clazz.is(Flags.Case) && equalsNotOverriden + + private def isCaseClassApply(meth: Symbol)(using Context): Boolean = + meth.name == nme.apply + && meth.flags.is(Flags.Synthetic) + && meth.owner.linkedClass.is(Flags.Case) + + def substParamRefs(tp: Type, paramSyms: List[Symbol], paramTps: List[Type])(using Context): Type = + trace(i"substParamRefs($tp, $paramSyms, $paramTps)", Printers.qualifiedTypes): + tp.subst(paramSyms, paramTps.zipWithIndex.map((tp, i) => ENodeParamRef(i, tp)).toList) + + def selfify(tree: tpd.Tree)(using Context): Option[ENode.Lambda] = + trace(i"ENode.selfify $tree", Printers.qualifiedTypes): + fromTree(tree) match + case Some(treeNode) => + Some(ENode.Lambda( + List(tree.tpe), + defn.BooleanType, + OpApply(ENode.Op.Equal, List(treeNode, ENode.Atom(ENodeParamRef(0, tree.tpe)))) + )) + case None => None + + // ----------------------------------- + // Assumptions retrieval + // ----------------------------------- + + def assumptions(node: ENode)(using Context): List[ENode] = + trace(i"assumptions($node)", Printers.qualifiedTypes): + node match + case Atom(tp: SingletonType) => termAssumptions(tp) ++ typeAssumptions(tp) + case n: Constructor => Nil + case n: Select => assumptions(n.qual) + case n: Apply => assumptions(n.fn) ++ n.args.flatMap(assumptions) + case n: OpApply => n.args.flatMap(assumptions) + case n: TypeApply => assumptions(n.fn) + case n: Lambda => Nil + + private def termAssumptions(tp: SingletonType)(using Context): List[ENode] = + trace(i"termAssumptions($tp)", Printers.qualifiedTypes): + tp match + case tp: TermRef => + tp.symbol.info match + case QualifiedType(_, _) => Nil + case _ => + tp.symbol.defTree match + case valDef: tpd.ValDef if !valDef.rhs.isEmpty && !valDef.symbol.is(Flags.Lazy) => + fromTree(valDef.rhs) match + case Some(treeNode) => OpApply(ENode.Op.Equal, List(treeNode, Atom(tp))) :: assumptions(treeNode) + case None => Nil + case _ => Nil + case _ => Nil + + private def typeAssumptions(rootTp: SingletonType)(using Context): List[ENode] = + def rec(tp: Type): List[ENode] = + tp match + case QualifiedType(parent, qualifier) => qualifier.body.substEParamRefs(0, List(rootTp)) :: assumptions(qualifier.body) ::: rec(parent) + case tp: SingletonType if tp ne rootTp => List(OpApply(ENode.Op.Equal, List(Atom(tp), Atom(rootTp)))) + case tp: TypeProxy => rec(tp.underlying) + case AndType(tp1, tp2) => rec(tp1) ++ rec(tp2) + case _ => Nil + trace(i"typeAssumptions($rootTp)", Printers.qualifiedTypes): + rec(rootTp) + + // ----------------------------------- + // Utils + // ----------------------------------- + + extension (n: Atom) + def derived(tp: SingletonType): ENode.Atom = + if n.tp eq tp then n + else ENode.Atom(tp) + + extension (n: Constructor) + def derived(constr: Symbol): ENode.Constructor = + if n.constr eq constr then n + else ENode.Constructor(constr)(n.fields) + + extension (n: Select) + def derived(qual: ENode, member: Symbol): ENode.Select = + if (n.qual eq qual) && (n.member eq member) then n + else ENode.Select(qual, member) + + extension (n: Apply) + def derived(fn: ENode, args: List[ENode]): ENode.Apply = + if (n.fn eq fn) && (n.args eq args) then n + else ENode.Apply(fn, args) + + extension (n: OpApply) + def derived(op: ENode.Op, args: List[ENode]): ENode.OpApply = + if (n.fn eq op) && (n.args eq args) then n + else ENode.OpApply(op, args) + + extension (n: TypeApply) + def derived(fn: ENode, args: List[Type]): ENode.TypeApply = + if (n.fn eq fn) && (n.args eq args) then n + else ENode.TypeApply(fn, args) + + extension (n: Lambda) + def derived(paramTps: List[Type], retTp: Type, body: ENode): ENode.Lambda = + if (n.paramTps eq paramTps) && (n.retTp eq retTp) && (n.body eq body) then n + else ENode.Lambda(paramTps, retTp, body) + + // ----------------------------------- + // Utils + // ----------------------------------- + + extension [T](xs: List[Option[T]]) + private def sequence: Option[List[T]] = + var result = List.newBuilder[T] + var current = xs + while current.nonEmpty do + current.head match + case Some(x) => + result += x + current = current.tail + case None => + return None + Some(result.result()) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/ENodeParamRef.scala b/compiler/src/dotty/tools/dotc/qualified_types/ENodeParamRef.scala new file mode 100644 index 000000000000..fdf4513d7d97 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/ENodeParamRef.scala @@ -0,0 +1,23 @@ +package dotty.tools.dotc.qualified_types +import dotty.tools.dotc.core.Types.{ + SingletonType, + CachedProxyType, + Type +} +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Hashable.Binders + + +/** Reference to the argument of an [[ENode.Lambda]]. + * + * @param index  + * Debruijn index of the argument, starting from 0 + * @param underyling + * Underlying type of the argument + */ +final case class ENodeParamRef(index: Int, underlying: Type) extends CachedProxyType, SingletonType: + override def underlying(using Context): Type = underlying + override def computeHash(bs: Binders): Int = doHash(bs, index, underlying) + def derivedENodeParamRef(index: Int, underlying: Type): ENodeParamRef = + if index == this.index && (underlying eq this.underlying) then this + else ENodeParamRef(index, underlying) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala new file mode 100644 index 000000000000..83738a188ced --- /dev/null +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala @@ -0,0 +1,39 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.dotc.ast.tpd.Tree +import dotty.tools.dotc.core.Annotations.Annotation +import dotty.tools.dotc.core.Contexts.{ctx, Context} +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.core.Symbols.defn +import dotty.tools.dotc.core.Types.{TermLambda, TermParamRef, Type, ConstantType, TypeMap} +import dotty.tools.dotc.printing.Printer +import dotty.tools.dotc.printing.Texts.Text +import dotty.tools.dotc.printing.Texts.stringToText +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.report + +case class QualifiedAnnotation(qualifier: ENode.Lambda) extends Annotation: + + override def tree(using Context): Tree = qualifier.toTree() + + override def symbol(using Context) = defn.QualifiedAnnot + + override def derivedAnnotation(tree: Tree)(using Context): Annotation = ??? + + private def derivedAnnotation(qualifier: ENode.Lambda)(using Context): Annotation = + if qualifier eq this.qualifier then this + else QualifiedAnnotation(qualifier) + + override def toText(printer: Printer): Text = + "with " ~ qualifier.body.toText(printer) + + override def mapWith(tm: TypeMap)(using Context): Annotation = + derivedAnnotation(qualifier.mapTypes(tm).asInstanceOf[ENode.Lambda]) + + override def refersToParamOf(tl: TermLambda)(using Context): Boolean = + var res = false + qualifier.foreachType: tp => + tp.stripped match + case TermParamRef(tl1, _) if tl eq tl1 => res = true + case _ => () + res diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala index 71eec80753fb..22394db100f2 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedType.scala @@ -3,7 +3,9 @@ package dotty.tools.dotc.qualified_types import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.core.Annotations.Annotation import dotty.tools.dotc.core.Contexts.{ctx, Context} -import dotty.tools.dotc.core.Types.{AnnotatedType, Type} +import dotty.tools.dotc.core.Types.{AnnotatedType, Type, ErrorType} +import dotty.tools.dotc.core.Decorators.em +import dotty.tools.dotc.typer.ErrorReporting.errorType /** A qualified type is internally represented as a type annotated with a * `@qualified` annotation. @@ -17,22 +19,23 @@ object QualifiedType: * a pair containing the parent type and the qualifier tree (a lambda) on * success, [[None]] otherwise */ - def unapply(tp: Type)(using Context): Option[(Type, tpd.Tree)] = + def unapply(tp: Type)(using Context): Option[(Type, ENode.Lambda)] = tp match - case AnnotatedType(parent, annot) if annot.symbol == ctx.definitions.QualifiedAnnot => - Some((parent, annot.argument(0).get)) + case AnnotatedType(parent, QualifiedAnnotation(qualifier)) => + Some((parent, qualifier)) case _ => None - /** Factory method to create a qualified type. - * - * @param parent - * the parent type - * @param qualifier - * the qualifier tree (a lambda) - * @return - * a qualified type - */ - def apply(parent: Type, qualifier: tpd.Tree)(using Context): Type = - val annotTp = ctx.definitions.QualifiedAnnot.typeRef.appliedTo(parent) - AnnotatedType(parent, Annotation(tpd.New(annotTp, List(qualifier)))) + def apply(parent: Type, qualifier: ENode.Lambda)(using Context): Type = + AnnotatedType(parent, QualifiedAnnotation(qualifier)) + + def apply(parent: Type, annot: Annotation)(using Context): Type = + annot match + case annot: QualifiedAnnotation => AnnotatedType(parent, annot) + case _ => apply(parent, annot.arguments(0)) + + def apply(parent: Type, annotTree: tpd.Tree)(using Context): Type = + val arg = tpd.allTermArguments(annotTree)(0) + ENode.fromTree(arg) match + case Some(qualifier: ENode.Lambda) => apply(parent, qualifier) + case _ => errorType(em"Invalid qualifier: $arg", annotTree.srcPos) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala index ceddec843393..dec4e3f02af0 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedTypes.scala @@ -37,6 +37,7 @@ import dotty.tools.dotc.core.Types.{ Type, TypeProxy } +import dotty.tools.dotc.util.SrcPos import dotty.tools.dotc.report import dotty.tools.dotc.reporting.trace @@ -48,23 +49,37 @@ object QualifiedTypes: * Note: the logic here is similar to [[Type#derivesAnnotWith]] but * additionally handle comparisons with [[SingletonType]]s. */ - def typeImplies(tp1: Type, qualifier2: Tree)(using Context): Boolean = - trace(i"typeImplies $tp1 --> $qualifier2", Printers.qualifiedTypes): + def typeImplies(tp1: Type, qualifier2: ENode.Lambda, solver: QualifierSolver)(using Context): Boolean = + def trySelfifyType() = + val ENode.Lambda(List(paramTp), _, _) = qualifier2: @unchecked + ENode.selfify(tpd.singleton(tp1)) match + case Some(qualifier1) => solver.implies(qualifier1, qualifier2) + case None => false + trace(i"typeImplies $tp1 --> ${qualifier2.body}", Printers.qualifiedTypes): tp1 match case QualifiedType(parent1, qualifier1) => - QualifierSolver().implies(qualifier1, qualifier2) - case tp1: (ConstantType | TermRef) => - QualifierSolver().implies(equalToPredicate(tpd.singleton(tp1)), qualifier2) - || typeImplies(tp1.underlying, qualifier2) + solver.implies(qualifier1, qualifier2) + case tp1: TermRef => + def trySelfifyRef() = + tp1.underlying match + case QualifiedType(_, _) => false + case _ => trySelfifyType() + typeImplies(tp1.underlying, qualifier2, solver) || trySelfifyRef() + case tp1: ConstantType => + trySelfifyType() case tp1: TypeProxy => - typeImplies(tp1.underlying, qualifier2) + typeImplies(tp1.underlying, qualifier2, solver) case AndType(tp11, tp12) => - typeImplies(tp11, qualifier2) || typeImplies(tp12, qualifier2) + typeImplies(tp11, qualifier2, solver) || typeImplies(tp12, qualifier2, solver) case OrType(tp11, tp12) => - typeImplies(tp11, qualifier2) && typeImplies(tp12, qualifier2) + typeImplies(tp11, qualifier2, solver) && typeImplies(tp12, qualifier2, solver) case _ => - false - // QualifierSolver().implies(truePredicate(), qualifier2) + val trueQualifier: ENode.Lambda = ENode.Lambda( + List(defn.AnyType), + defn.BooleanType, + ENode.Atom(ConstantType(Constant(true))) + ) + solver.implies(trueQualifier, qualifier2) /** Try to adapt the tree to the given type `pt` * @@ -74,10 +89,10 @@ object QualifiedTypes: * Used by [[dotty.tools.dotc.core.Typer]]. */ def adapt(tree: Tree, pt: Type)(using Context): Tree = - trace(i"adapt $tree to $pt", Printers.qualifiedTypes): - if containsQualifier(pt) then + if containsQualifier(pt) then + trace(i"adapt $tree to qualified type $pt", Printers.qualifiedTypes): if tree.tpe.hasAnnotation(defn.RuntimeCheckedAnnot) then - if checkContainsSkolem(pt) then + if checkContainsSkolem(pt, tree.srcPos) then tpd.evalOnce(tree): e => If( e.isInstance(pt), @@ -86,23 +101,16 @@ object QualifiedTypes: ) else tree.withType(ErrorType(em"")) - else if isSimple(tree) then - val selfifiedTp = QualifiedType(tree.tpe, equalToPredicate(tree)) - if selfifiedTp <:< pt then tree.cast(selfifiedTp) else EmptyTree else - EmptyTree + ENode.selfify(tree) match + case Some(qualifier) => + val selfifiedTp = QualifiedType(tree.tpe, qualifier) + if selfifiedTp <:< pt then tree.cast(selfifiedTp) else EmptyTree + case None => + EmptyTree else EmptyTree - def isSimple(tree: Tree)(using Context): Boolean = - tree match - case Apply(fn, args) => isSimple(fn) && args.forall(isSimple) - case TypeApply(fn, args) => isSimple(fn) - case SeqLiteral(elems, _) => elems.forall(isSimple) - case Typed(expr, _) => isSimple(expr) - case Block(Nil, expr) => isSimple(expr) - case _ => tpd.isIdempotentExpr(tree) - def containsQualifier(tp: Type)(using Context): Boolean = tp match case QualifiedType(_, _) => true @@ -111,27 +119,15 @@ object QualifiedTypes: case OrType(tp1, tp2) => containsQualifier(tp1) || containsQualifier(tp2) case _ => false - def checkContainsSkolem(tp: Type)(using Context): Boolean = + def checkContainsSkolem(tp: Type, pos: SrcPos)(using Context): Boolean = var res = true tp.foreachPart: case QualifiedType(_, qualifier) => - qualifier.foreachSubTree: subTree => - subTree.tpe.foreachPart: + qualifier.foreachType: rootTp => + rootTp.foreachPart: case tp: SkolemType => - report.error(em"The qualified type $qualifier cannot be checked at runtime", qualifier.srcPos) + report.error(em"The qualified type $qualifier cannot be checked at runtime", pos) res = false case _ => () case _ => () res - - private def equalToPredicate(tree: Tree)(using Context): Tree = - Lambda( - MethodType(List("v".toTermName))(_ => List(tree.tpe), _ => defn.BooleanType), - (args) => Ident(args(0).symbol.termRef).equal(tree) - ) - - private def truePredicate()(using Context): Tree = - Lambda( - MethodType(List("v".toTermName))(_ => List(defn.AnyType), _ => defn.BooleanType), - (args) => Literal(Constant(true)) - ) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala deleted file mode 100644 index ce20a38aee4a..000000000000 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierEvaluator.scala +++ /dev/null @@ -1,96 +0,0 @@ -package dotty.tools.dotc.qualified_types - -import scala.annotation.tailrec - -import dotty.tools.dotc.ast.tpd.{ - Apply, - Block, - ConstantTree, - isIdempotentExpr, - EmptyTree, - Literal, - Ident, - Match, - Select, - This, - Tree, - TreeMap, - ValDef, - given -} -import dotty.tools.dotc.core.Constants.Constant -import dotty.tools.dotc.core.Contexts.Context -import dotty.tools.dotc.core.Decorators.i -import dotty.tools.dotc.core.Flags -import dotty.tools.dotc.core.Mode.Type -import dotty.tools.dotc.core.StdNames.nme -import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} -import dotty.tools.dotc.core.SymDenotations.given -import dotty.tools.dotc.core.Types.{ConstantType, NoPrefix, TermRef} -import dotty.tools.dotc.inlines.InlineReducer -import dotty.tools.dotc.transform.TreeExtractors.BinaryOp -import dotty.tools.dotc.transform.patmat.{Empty as EmptySpace, SpaceEngine} -import dotty.tools.dotc.typer.Typer -import scala.util.boundary -import scala.util.boundary.break - - -import dotty.tools.dotc.reporting.trace -import dotty.tools.dotc.config.Printers - -private[qualified_types] object QualifierEvaluator: - /** Reduces a tree by constant folding, simplification and unfolding of simple - * references. - * - * This is more aggressive than [[dotty.tools.dotc.transform.BetaReduce]] and - * [[dotty.tools.dotc.typer.ConstFold]] (which is used under the hood by - * `BetaReduce` through [[dotty.tools.dotc.ast.tpd.cpy]]), as it also unfolds - * non-constant expressions. - */ - def evaluate(tree: Tree, args: Map[Symbol, Tree] = Map.empty)(using Context): Tree = - trace(i"evaluate $tree", Printers.qualifiedTypes): - QualifierEvaluator(args).transform(tree) - -private class QualifierEvaluator(args: Map[Symbol, Tree]) extends TreeMap: - import QualifierEvaluator.* - - override def transform(tree: Tree)(using Context): Tree = - unfold(reduce(tree)) - - private def reduce(tree: Tree)(using Context): Tree = - tree match - case tree: Apply => - val treeTransformed = super.transform(tree) - constFold(treeTransformed).orElse(treeTransformed) - case tree: Select => - val treeTransformed = super.transform(tree) - constFold(treeTransformed).orElse(treeTransformed) - case Block(Nil, expr) => - transform(expr) - case tree => - super.transform(tree) - - private def constFold(tree: Tree)(using Context): Tree = - tree match - case ConstantTree(c: Constant) => Literal(c) - case _ => EmptyTree - - private def unfold(tree: Tree)(using Context): Tree = - args.get(tree.symbol) match - case Some(tree2) => - return transform(tree2) - case None => () - - tree match - case tree: Ident => - trace(s"unfold $tree", Printers.qualifiedTypes): - tree.symbol.defTree match - case valDef: ValDef - if !valDef.rhs.isEmpty - && !valDef.symbol.is(Flags.Lazy) - && QualifiedTypes.isSimple(valDef.rhs) => - transform(valDef.rhs) - case _ => - tree - case _ => - tree diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala index ae4155e7409f..10792bb5cf3a 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala @@ -1,79 +1,57 @@ package dotty.tools.dotc.qualified_types -import dotty.tools.dotc.ast.tpd -import dotty.tools.dotc.ast.tpd.{closureDef, singleton, Apply, Ident, Literal, Select, Tree, given} +import ENode.{Lambda, OpApply, Op} + import dotty.tools.dotc.config.Printers -import dotty.tools.dotc.core.Constants.Constant import dotty.tools.dotc.core.Contexts.{ctx, Context} +import dotty.tools.dotc.core.Symbols.defn +import dotty.tools.dotc.core.Types.{Type, TypeVar, TypeMap} import dotty.tools.dotc.core.Decorators.i -import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol} -import dotty.tools.dotc.core.Types.TermRef -import dotty.tools.dotc.reporting.trace -import dotty.tools.dotc.transform.BetaReduce +import dotty.tools.dotc.printing.Showable class QualifierSolver(using Context): - val d = defn // Need a stable path to match on `defn` members - - def implies(tree1: Tree, tree2: Tree) = - trace(i"implies $tree1 -> $tree2", Printers.qualifiedTypes): - (tree1, tree2) match - case (closureDef(defDef1), closureDef(defDef2)) => - val tree1ArgSym = defDef1.symbol.paramSymss.head.head - val tree2ArgSym = defDef2.symbol.paramSymss.head.head - val rhs = defDef1.rhs - val lhs = defDef2.rhs - if tree1ArgSym.info frozen_<:< tree2ArgSym.info then - impliesRec1(rhs, lhs.subst(List(tree2ArgSym), List(tree1ArgSym))) - else if tree2ArgSym.info frozen_<:< tree1ArgSym.info then - impliesRec1(rhs.subst(List(tree1ArgSym), List(tree2ArgSym)), lhs) - else - false - case _ => - throw IllegalArgumentException("Qualifiers must be closures") - - private def impliesRec1(tree1: Tree, tree2: Tree): Boolean = - // tree1 = lhs || rhs - tree1 match - case Apply(select @ Select(lhs, name), List(rhs)) => - select.symbol match - case d.Boolean_|| => - return impliesRec1(lhs, tree2) && impliesRec1(rhs, tree2) - case _ => () - case _ => () - // tree2 = lhs && rhs, or tree2 = lhs || rhs - tree2 match - case Apply(select @ Select(lhs, name), List(rhs)) => - select.symbol match - case d.Boolean_&& => - return impliesRec1(tree1, lhs) && impliesRec1(tree1, rhs) - case d.Boolean_|| => - return impliesRec1(tree1, lhs) || impliesRec1(tree1, rhs) - case _ => () + def implies(node1: ENode.Lambda, node2: ENode.Lambda) = + require(node1.paramTps.length == 1) + require(node2.paramTps.length == 1) + val node1Inst = node1.normalizeTypes().asInstanceOf[ENode.Lambda] + val node2Inst = node2.normalizeTypes().asInstanceOf[ENode.Lambda] + val paramTp1 = node1Inst.paramTps.head + val paramTp2 = node2Inst.paramTps.head + if paramTp1 frozen_<:< paramTp2 then + impliesRec(subsParamRefTps(node1Inst.body, node2Inst), node2Inst.body) + else if paramTp2 frozen_<:< paramTp1 then + impliesRec(node1Inst.body, subsParamRefTps(node2Inst.body, node1Inst)) + else + false + + private def subsParamRefTps(node1Body: ENode, node2: ENode.Lambda): ENode = + val paramRefs = node2.paramTps.zipWithIndex.map((tp, i) => ENodeParamRef(i, tp)) + node1Body.substEParamRefs(0, paramRefs) + + private def impliesRec(node1: ENode, node2: ENode): Boolean = + node1 match + case OpApply(Op.Or, List(lhs, rhs)) => + return impliesRec(lhs, node2) && impliesRec(rhs, node2) case _ => () - val egraph = EGraph(ctx) - // println(s"tree implies $tree1 -> $tree2") - (egraph.toNode(QualifierEvaluator.evaluate(tree1)), egraph.toNode(QualifierEvaluator.evaluate(tree2))) match - case (Some(node1), Some(node2)) => - // println(s"node implies $node1 -> $node2") - egraph.merge(node1, egraph.trueNode) - egraph.repair() - egraph.equiv(node2, egraph.trueNode) - case _ => - false - - private def topLevelEqualities(tree: Tree): List[(Tree, Tree)] = - trace(i"topLevelEqualities $tree", Printers.qualifiedTypes): - topLevelEqualitiesImpl(tree) - - private def topLevelEqualitiesImpl(tree: Tree): List[(Tree, Tree)] = - val d = defn - tree match - case Apply(select @ Select(lhs, name), List(rhs)) => - select.symbol match - case d.Int_== | d.Any_== | d.Boolean_== => List((lhs, rhs)) - case d.Boolean_&& => topLevelEqualitiesImpl(lhs) ++ topLevelEqualitiesImpl(rhs) - case _ => Nil - case _ => - Nil + val assumptions = ENode.assumptions(node1) ++ ENode.assumptions(node2) + val node1WithAssumptions = assumptions.foldLeft(node1)((acc, a) => OpApply(Op.And, List(acc, a.normalizeTypes()))) + impliesLeaf(EGraph(ctx), node1WithAssumptions, node2) + + protected def impliesLeaf(egraph: EGraph, enode1: ENode, enode2: ENode): Boolean = + val node1Canonical = egraph.canonicalize(enode1) + val node2Canonical = egraph.canonicalize(enode2) + egraph.assertInvariants() + egraph.merge(node1Canonical, egraph.trueNode) + egraph.repair() + egraph.equiv(node2Canonical, egraph.trueNode) + +final class ExplainingQualifierSolver( + traceIndented: [T] => (String) => (=> T) => T)(using Context) extends QualifierSolver: + + override protected def impliesLeaf(egraph: EGraph, enode1: ENode, enode2: ENode): Boolean = + traceIndented(s"${enode1.showNoBreak} --> ${enode2.showNoBreak}"): + val res = super.impliesLeaf(egraph, enode1, enode2) + if !res then println(egraph.debugString()) + res diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 29baf816da5e..0da6f34795ec 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -26,6 +26,7 @@ import cc.* import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation import dotty.tools.dotc.core.NameKinds.DefaultGetterName import ast.TreeInfo +import dotty.tools.dotc.qualified_types.QualifiedAnnotation object PostTyper { val name: String = "posttyper" @@ -208,11 +209,15 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => } private def transformAnnot(annot: Annotation)(using Context): Annotation = - val tree1 = - annot match - case _: BodyAnnotation => annot.tree - case _ => copySymbols(annot.tree) - annot.derivedAnnotation(transformAnnotTree(tree1)) + annot match + case _: QualifiedAnnotation => + annot + case _ => + val tree1 = + annot match + case _: BodyAnnotation => annot.tree + case _ => copySymbols(annot.tree) + annot.derivedAnnotation(transformAnnotTree(tree1)) /** Transforms all annotations in the given type. */ private def transformAnnotsIn(using Context) = diff --git a/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala b/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala index 8d5b7c28bbbc..9ec91c79abad 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala @@ -19,6 +19,15 @@ object TreeExtractors { } } + /** Match arg.op() and extract (arg, op.symbol) */ + object UnaryOp: + def unapply(t: Tree)(using Context): Option[(Symbol, Tree)] = + t match + case Apply(sel @ Select(arg, _), Nil) => + Some((sel.symbol, arg)) + case _ => + None + /** Match new C(args) and extract (C, args). * Also admit new C(args): T and {new C(args)}. */ diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala index 8fd3ed4373f4..26364298e41f 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala @@ -354,12 +354,16 @@ object TypeTestsCasts { ref(defn.RuntimeTuples_isInstanceOfNonEmptyTuple).appliedTo(expr) case AppliedType(tref: TypeRef, _) if tref.symbol == defn.PairClass => ref(defn.RuntimeTuples_isInstanceOfNonEmptyTuple).appliedTo(expr) - case QualifiedType(parent, closureDef(qualifierDef)) => - evalOnce(expr): e => - // e.isInstanceOf[baseType] && qualifier(e.asInstanceOf[baseType]) - val arg = e.asInstance(parent) - val qualifierTest = BetaReduce.reduceApplication(qualifierDef, List(List(arg))).get - transformTypeTest(e, parent, flagUnrelated).and(qualifierTest) + case QualifiedType(parent, qualifier) => + qualifier.toTree() match + case closureDef(qualifierDef) => + evalOnce(expr): e => + // e.isInstanceOf[baseType] && qualifier(e.asInstanceOf[baseType]) + val arg = e.asInstance(parent) + val qualifierTest = BetaReduce.reduceApplication(qualifierDef, List(List(arg))).get + transformTypeTest(e, parent, flagUnrelated).and(qualifierTest) + case tree => + throw new IllegalStateException("Malformed qualifier tree: $tree, expected a closure definition") case _ => val testWidened = testType.widen defn.untestableClasses.find(testWidened.isRef(_)) match diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 55b6384c9e89..47e343b9ddc6 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -14,6 +14,7 @@ import Checking.{checkNoPrivateLeaks, checkNoWildcard} import cc.CaptureSet import util.Property import transform.Splicer +import qualified_types.QualifiedType trait TypeAssigner { import tpd.* @@ -572,7 +573,10 @@ trait TypeAssigner { def assignType(tree: untpd.Annotated, arg: Tree, annot: Tree)(using Context): Annotated = { assert(tree.isType) // annotating a term is done via a Typed node, can't use Annotate directly - tree.withType(AnnotatedType(arg.tpe, Annotation(annot))) + if Annotations.annotClass(annot) == defn.QualifiedAnnot then + tree.withType(QualifiedType(arg.tpe, annot)) + else + tree.withType(AnnotatedType(arg.tpe, Annotation(annot))) } def assignType(tree: untpd.PackageDef, pid: Tree)(using Context): PackageDef = diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index f94aedfa3aa8..0df1dfbf04c2 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -48,7 +48,7 @@ import reporting.* import Nullables.* import NullOpsDecorator.* import cc.{CheckCaptures, isRetainsLike} -import qualified_types.QualifiedTypes +import qualified_types.{QualifiedTypes, QualifiedType} import config.Config import config.MigrationVersion import transform.CheckUnused.OriginalName @@ -2581,7 +2581,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // untyped tree is no longer accessed after all // accesses with typedTypeTree are done. case None => - errorTree(tree, em"Something's wrong: missing original symbol for type tree") + errorTree(tree, em"Something's wrong: missing original symbol for type tree ${tree}") } case _ => completeTypeTree(InferredTypeTree(), pt, tree) diff --git a/compiler/test/dotty/tools/dotc/qualified_types/EGraphTest.scala b/compiler/test/dotty/tools/dotc/qualified_types/EGraphTest.scala new file mode 100644 index 000000000000..0162da177cd1 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/qualified_types/EGraphTest.scala @@ -0,0 +1,235 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.DottyTest +import dotty.tools.dotc.core.Decorators.i +import dotty.tools.dotc.ast.tpd + +import org.junit.Assert.assertEquals +import org.junit.Test + +class EGraphTest extends QualifiedTypesTest: + + def checkImplies(fromString: String, toString: String, egraphString: String, expected: Boolean = true): Unit = + val src = s""" + |def test = { + | val b1: Boolean = ??? + | val b2: Boolean = ??? + | val b3: Boolean = ??? + | val w: Int = ??? + | val x: Int = ??? + | val y: Int = ??? + | val z: Int = ??? + | def f(a: Int): Boolean = ??? + | def g(a: Int): Int = ??? + | def h(a: Int, b: Int): Int = ??? + | def id[T](a: T): T = a + | type Vec[T] + | type Pos = {v: Int with v > 0} + | def len[T](v: Vec[T]): Pos = ??? + | val v1: Vec[Int] = ??? + | val v2: Vec[Int] = ??? + | val v3: Vec[Int] = ??? + | val from: Boolean = $fromString + | val to: Boolean = $toString + |}""".stripMargin + checkCompileExpr(src): stats => + val testTree = getDefDef(stats, "test") + val body = testTree.rhs.asInstanceOf[tpd.Block] + val fromTree = getValDef(body.stats, "from").rhs + val toTree = getValDef(body.stats, "to").rhs + val egraph = EGraph(ctx, checksEnabled = true) + val from = ENode.fromTree(fromTree).get.normalizeTypes() + val to = ENode.fromTree(toTree).get.normalizeTypes() + val fromCanonical = egraph.canonicalize(from) + val toCanonical = egraph.canonicalize(to) + egraph.merge(fromCanonical, egraph.trueNode) + egraph.repair() + assertStringEquals(egraphString, egraph.debugString()) + val res = egraph.equiv(toCanonical, egraph.trueNode) + assertEquals(s"Expected $fromString --> $toString to be $expected", expected, res) + + def checkNotImplies(fromString: String, toString: String, egraphString: String): Unit = + checkImplies(fromString, toString, egraphString, expected = false) + + @Test def test1() = + checkImplies( + "true", + "true", + """-1: {} + |0: {} + |1: {} + |false: {} + |true: {} + |""".stripMargin + ) + + @Test def test2() = + checkImplies( + "b1", + "b1", + """-1: {} + |0: {} + |1: {} + |false: {} + |true: {b1} + |""".stripMargin + ) + + @Test def test3() = + checkNotImplies( + "b1", + "b2", + """-1: {} + |0: {} + |1: {} + |b2: {} + |false: {} + |true: {b1} + |""".stripMargin + ) + + @Test def test4() = + checkImplies( + "b1 && b2", + "b2", + """-1: {} + |0: {} + |1: {} + |false: {} + |true: {b1, b1 && b2, b2} + |""".stripMargin + ) + + @Test def test5() = + checkNotImplies( + "b1 || b2", + "b2", + """-1: {} + |0: {} + |1: {} + |b1: {} + |b2: {} + |false: {} + |true: {b1 || b2} + |""".stripMargin + ) + + @Test def test6() = + checkImplies( + "b1 && b2 && b3", + "b3", + """-1: {} + |0: {} + |1: {} + |false: {} + |true: {b1, b1 && b2, b1 && b2 && b3, b2, b3} + |""".stripMargin + ) + + @Test def test7() = + checkImplies( + "b1 && b1 == b2", + "b2", + """-1: {} + |0: {} + |1: {} + |false: {} + |true: {b1, b1 && b1 == b2, b1 == b2, b2} + |""".stripMargin + ) + + @Test def test8() = + checkImplies( + "b1 && b1 == b2 && b2 == b3", + "b3", + """-1: {} + |0: {} + |1: {} + |false: {} + |true: {b1, b1 && b1 == b2, b1 && b1 == b2 && b2 == b3, b1 == b2, b2, b2 == b3, b3} + |""".stripMargin + ) + + @Test def test9() = + checkImplies( + "f(x) && x == y", + "f(y)", + """-1: {} + |0: {} + |1: {} + |f: {} + |false: {} + |true: {f(x), f(x) && x == y, f(y), x == y} + |x: {y} + |""".stripMargin + ) + + @Test def nestedFunctions() = + checkImplies( + "f(g(x)) && g(x) == g(y)", + "f(g(y))", + """-1: {} + |0: {} + |1: {} + |f: {} + |false: {} + |g: {} + |g(x): {g(y)} + |true: {f(g(x)), f(g(x)) && g(x) == g(y), f(g(y)), g(x) == g(y)} + |x: {} + |y: {} + |""".stripMargin + ) + + @Test def multipleArgs() = + checkImplies( + "y == z", + "h(x, y) == h(x, z)", + """-1: {} + |0: {} + |1: {} + |false: {} + |h: {} + |h(x, y): {h(x, z)} + |true: {h(x, y) == h(x, z), y == z} + |x: {} + |y: {z} + |""".stripMargin + ) + + @Test def multipleArgsDeep() = + checkImplies( + "f(h(x, y)) && y == z", + "f(h(x, z))", + """-1: {} + |0: {} + |1: {} + |f: {} + |false: {} + |h: {} + |h(x, y): {h(x, z)} + |true: {f(h(x, y)), f(h(x, y)) && y == z, f(h(x, z)), y == z} + |x: {} + |y: {z} + |""".stripMargin + ) + + @Test def sizeSum() = + checkImplies( + "len(v1) == len(v2) + len(v3) && len(v2) == 3 && len(v3) == 4", + "len(v1) == 7", + """-1: {} + |0: {} + |1: {} + |3: {len[Int](v2)} + |4: {len[Int](v3)} + |7: {len[Int](v1), len[Int](v2) + len[Int](v3)} + |false: {} + |len: {} + |len[Int]: {} + |true: {len[Int](v1) == 7, len[Int](v1) == len[Int](v2) + len[Int](v3), len[Int](v1) == len[Int](v2) + len[Int](v3) && len[Int](v2) == 3, len[Int](v1) == len[Int](v2) + len[Int](v3) && len[Int](v2) == 3 && len[Int](v3) == 4, len[Int](v2) + len[Int](v3) == 7, len[Int](v2) == 3, len[Int](v3) == 4} + |v1: {} + |v2: {} + |v3: {} + |""".stripMargin + ) diff --git a/compiler/test/dotty/tools/dotc/qualified_types/ENodeTest.scala b/compiler/test/dotty/tools/dotc/qualified_types/ENodeTest.scala new file mode 100644 index 000000000000..24051d09a718 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/qualified_types/ENodeTest.scala @@ -0,0 +1,60 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.DottyTest +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.Context + +import org.junit.Assert.assertEquals +import org.junit.Test + +class ENodeTest extends QualifiedTypesTest: + + def checkFromToTree(exprString: String, resultString: String): Unit = + checkCompileExpr(s"val v = $exprString"): stats => + val tree1: tpd.Tree = getValDef(stats, "v").rhs + val enode: ENode = ENode.fromTree(tree1).get + assertStringEquals(resultString, enode.show) + val tree2: tpd.Tree = enode.toTree() + assertStringEquals(tree1.show, tree2.show) + + @Test def testFromToTree1() = + checkFromToTree( + "(param0: Int) => param0", + "(_: Int) => eparam0" + ) + + @Test def testFromToTree2() = + checkFromToTree( + "(param0: Int) => param0 + 1", + "(_: Int) => eparam0 + 1" + ) + + @Test def testFromToTree3() = + // ENode.fromTree and ENode#toTree do not perform constant folding or + // normalization. This is only done when adding E-Nodes to an E-Graph. + checkFromToTree( + "(param0: Int) => param0 + 1 + 1", + "(_: Int) => eparam0 + 1 + 1" + ) + + @Test def testFromToTree4() = + checkFromToTree( + "(param0: Int) => (param1: Int) => param0 + param1", + // In De Bruijn notation the outermost parameter is param1 and the + // innermost param0 + "(_: Int) => (_: Int) => eparam1 + eparam0" + ) + + @Test def testFromToTree5() = + checkFromToTree( + "(param0: Int, param1: Int) => param0 + param1", + // Same for paramter lists with multiple parameters: the outermost + // parameter is param1 and the innermost param0 + "(_: Int, _: Int) => eparam1 + eparam0" + ) + + @Test def testFromToTree6() = + checkFromToTree( + "(param0: Int, param1: Int) => (param2: Int, param3: Int) => param0 + param1 + param2 + param3", + "(_: Int, _: Int) => (_: Int, _: Int) => eparam3 + eparam2 + eparam1 + eparam0" + ) diff --git a/compiler/test/dotty/tools/dotc/qualified_types/QualifiedTypesTest.scala b/compiler/test/dotty/tools/dotc/qualified_types/QualifiedTypesTest.scala new file mode 100644 index 000000000000..5d0fbeb6849d --- /dev/null +++ b/compiler/test/dotty/tools/dotc/qualified_types/QualifiedTypesTest.scala @@ -0,0 +1,41 @@ +package dotty.tools.dotc.qualified_types + +import dotty.tools.DottyTest +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.{Context, FreshContext} +import dotty.tools.dotc.core.Decorators.i + +import org.junit.Assert.assertEquals +import org.junit.runners.MethodSorters +import org.junit.FixMethodOrder + +@FixMethodOrder(MethodSorters.JVM) +abstract class QualifiedTypesTest extends DottyTest: + + override protected def initializeCtx(fc: FreshContext): Unit = + super.initializeCtx(fc) + fc.setSetting(fc.settings.XnoEnrichErrorMessages, true) + fc.setSetting(fc.settings.color, "never") + fc.setSetting(fc.settings.language, List("experimental.qualifiedTypes").asInstanceOf) + + def checkCompileExpr(statsString: String)(assertion: List[tpd.Tree] => Context ?=> Unit): Unit = + checkCompile("typer", s"object Test { $statsString }"): (pkg, context) => + given Context = context + val packageStats = pkg.asInstanceOf[tpd.PackageDef].stats + val clazz = getTypeDef(packageStats, "Test$") + val clazzStats = clazz.rhs.asInstanceOf[tpd.Template].body + assertion(clazzStats)(using context) + + def getTypeDef(trees: List[tpd.Tree], name: String)(using Context): tpd.TypeDef = + trees.collectFirst { case td: tpd.TypeDef if td.name.toString() == name => td }.get + + def getValDef(trees: List[tpd.Tree], name: String)(using Context): tpd.ValDef = + trees.collectFirst { case vd: tpd.ValDef if vd.name.toString() == name => vd }.get + + def getDefDef(trees: List[tpd.Tree], name: String)(using Context): tpd.DefDef = + trees.collectFirst { case vd: tpd.DefDef if vd.name.toString() == name => vd }.get + + def assertStringEquals(expected: String, found: String)(using Context): Unit = + val formattedExpected = if expected.contains('\n') then "\n" + expected.linesIterator.map(" " + _).mkString("\n") else expected + val formattedFound = if found.contains('\n') then "\n" + found.linesIterator.map(" " + _).mkString("\n") else found + assertEquals(s"\n Expected: $formattedExpected\n Found: $formattedFound\n", expected, found) diff --git a/tests/neg-custom-args/qualified-types/list_apply_neg.scala b/tests/neg-custom-args/qualified-types/list_apply_neg.scala new file mode 100644 index 000000000000..1c3eedda0291 --- /dev/null +++ b/tests/neg-custom-args/qualified-types/list_apply_neg.scala @@ -0,0 +1,4 @@ +type PosInt = {v: Int with v > 0} + +@main def Test = + val l: List[PosInt] = List(1,-2,3) // error // error diff --git a/tests/neg-custom-args/qualified-types/subtyping_egraph_state.scala b/tests/neg-custom-args/qualified-types/subtyping_egraph_state.scala new file mode 100644 index 000000000000..a40f6db5331c --- /dev/null +++ b/tests/neg-custom-args/qualified-types/subtyping_egraph_state.scala @@ -0,0 +1,7 @@ +def test: Unit = + val b: Boolean = ??? + val b2: Boolean = ??? + summon[{u: Unit with b && b2} <:< {u: Unit with b}] + // Checks that E-Graph state is reset after the implication check: b is no + // longer true + summon[{u: Unit with true} <:< {u: Unit with b}] // error diff --git a/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala b/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala index b1283f4e807b..9a0e444232b2 100644 --- a/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala +++ b/tests/neg-custom-args/qualified-types/subtyping_objects_neg.scala @@ -20,25 +20,25 @@ case class PersonCaseEqualsOverriden(name: String, age: Int): override def equals(that: Object): Boolean = this eq that def test: Unit = - summon[{b: Box[Int] with b == Box(1)} =:= {b: Box[Int] with b == Box(1)}] // error + summon[{b: Box[Int] with b == Box(1)} =:= {b: Box[Int] with b == Box(1)}] // error // error - summon[{b: BoxMutable[Int] with b == BoxMutable(1)} =:= {b: BoxMutable[Int] with b == BoxMutable(1)}] // error + summon[{b: BoxMutable[Int] with b == BoxMutable(1)} =:= {b: BoxMutable[Int] with b == BoxMutable(1)}] // error // error // TODO(mbovel): restrict selection to stable members //summon[{b: BoxMutable[Int] with b.x == 3} =:= {b: BoxMutable[Int] with b.x == 3}] - summon[{f: Foo with f == Foo("hello")} =:= {f: Foo with f == Foo("hello")}] // error - summon[{f: Foo with f == Foo(1)} =:= {f: Foo with f == Foo(1)}] // error + summon[{f: Foo with f == Foo("hello")} =:= {f: Foo with f == Foo("hello")}] // error // error + summon[{f: Foo with f == Foo(1)} =:= {f: Foo with f == Foo(1)}] // error // error summon[{s: String with Foo("hello").id == s} =:= {s: String with s == "hello"}] // error - summon[{p: Person with p == Person("Alice", 30)} =:= {p: Person with p == Person("Alice", 30)}] // error + summon[{p: Person with p == Person("Alice", 30)} =:= {p: Person with p == Person("Alice", 30)}] // error // error summon[{s: String with Person("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error summon[{n: Int with Person("Alice", 30).age == n} =:= {n: Int with n == 30}] // error - summon[{p: PersonCurried with p == PersonCurried("Alice")(30)} =:= {p: PersonCurried with p == PersonCurried("Alice")(30)}] // error + summon[{p: PersonCurried with p == PersonCurried("Alice")(30)} =:= {p: PersonCurried with p == PersonCurried("Alice")(30)}] // error // error summon[{s: String with PersonCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}] // error summon[{n: Int with PersonCurried("Alice")(30).age == n} =:= {n: Int with n == 30}] // error - summon[{p: PersonMutable with p == PersonMutable("Alice", 30)} =:= {p: PersonMutable with p == PersonMutable("Alice", 30)}] // error + summon[{p: PersonMutable with p == PersonMutable("Alice", 30)} =:= {p: PersonMutable with p == PersonMutable("Alice", 30)}] // error // error summon[{s: String with PersonMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error summon[{n: Int with PersonMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error @@ -47,6 +47,6 @@ def test: Unit = summon[{s: String with new PersonCaseSecondary("Alice").name == s} =:= {s: String with s == "Alice"}] // error summon[{n: Int with new PersonCaseSecondary("Alice").age == n} =:= {n: Int with n == 0}] // error - summon[{p: PersonCaseEqualsOverriden with PersonCaseEqualsOverriden("Alice", 30) == p} =:= {p: PersonCaseEqualsOverriden with p == PersonCaseEqualsOverriden("Alice", 30)}] // error + summon[{p: PersonCaseEqualsOverriden with PersonCaseEqualsOverriden("Alice", 30) == p} =:= {p: PersonCaseEqualsOverriden with p == PersonCaseEqualsOverriden("Alice", 30)}] // error // error summon[{s: String with PersonCaseEqualsOverriden("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error summon[{n: Int with PersonCaseEqualsOverriden("Alice", 30).age == n} =:= {n: Int with n == 30}] // error diff --git a/tests/pos-custom-args/qualified-types/avoidance.scala b/tests/pos-custom-args/qualified-types/avoidance.scala index b595557924cf..bff6decb61a2 100644 --- a/tests/pos-custom-args/qualified-types/avoidance.scala +++ b/tests/pos-custom-args/qualified-types/avoidance.scala @@ -1,7 +1,6 @@ def Test = () - /* - val x = - val y = 1 - y: {v: Int with v == y} - */ + + //val x = + // val y: Int = ??? + // y: {v: Int with v == y} // TODO(mbovel): proper avoidance for qualified types diff --git a/tests/pos-custom-args/qualified-types/class_constraints.scala b/tests/pos-custom-args/qualified-types/class_constraints.scala new file mode 100644 index 000000000000..71085de63d44 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/class_constraints.scala @@ -0,0 +1,3 @@ +/*class foo(elem: Int with elem > 0)*/ + +@main def Test = () diff --git a/tests/pos-custom-args/qualified-types/list_map.scala b/tests/pos-custom-args/qualified-types/list_map.scala new file mode 100644 index 000000000000..38b4a73af647 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/list_map.scala @@ -0,0 +1,9 @@ +type PosInt = {v: Int with v > 0} + + +def inc(x: PosInt): PosInt = (x + 1).runtimeChecked + +@main def Test = + val l: List[PosInt] = List(1,2,3) + val l2: List[PosInt] = l.map(inc) + () diff --git a/tests/pos-custom-args/qualified-types/sized_lists.scala b/tests/pos-custom-args/qualified-types/sized_lists.scala index 385434c253f7..529079debdde 100644 --- a/tests/pos-custom-args/qualified-types/sized_lists.scala +++ b/tests/pos-custom-args/qualified-types/sized_lists.scala @@ -6,10 +6,6 @@ def concat(v1: Vec, v2: Vec): {v: Vec with size(v) == size(v1) + size(v2)} = ??? def sum(v1: Vec, v2: Vec with size(v1) == size(v2)): {v: Vec with size(v) == size(v1)} = ??? @main def Test = - val v3: {v: Vec with size(v) == 3} = vec(3) val v4: {v: Vec with size(v) == 4} = vec(4) - /* val v7: {v: Vec with size(v) == 7} = concat(v3, v4) - */ - // TODO(mbovel): need constraints of referred term refs diff --git a/tests/pos-custom-args/qualified-types/sized_lists2.scala b/tests/pos-custom-args/qualified-types/sized_lists2.scala new file mode 100644 index 000000000000..5b4b1576bada --- /dev/null +++ b/tests/pos-custom-args/qualified-types/sized_lists2.scala @@ -0,0 +1,22 @@ +type Vec[T] +object Vec: + def fill[T](n: Int, v: T): + {r: Vec[T] with r.len == n} + = ??? +extension [T](a: Vec[T]) + def len: {r: Int with r >= 0} = ??? + def concat(b: Vec[T]): + {r: Vec[T] with r.len == a.len + b.len} + = ??? + def zip[S](b: Vec[S] with b.len == a.len): + {r: Vec[(T, S)] with r.len == a.len} + = ??? + +@main def Test = + val n: Int with n >= 0 = ??? + val m: Int with m >= 0 = ??? + val v1 = Vec.fill(n, 0) + val v2 = Vec.fill(m, 1) + val v3 = v1.concat(v2) + val mPlusN = m + n + val v4: {r: Vec[(String, Int)] with r.len == mPlusN} = Vec.fill(mPlusN, "").zip(v3) diff --git a/tests/pos-custom-args/qualified-types/subtyping_equalities.scala b/tests/pos-custom-args/qualified-types/subtyping_equality.scala similarity index 73% rename from tests/pos-custom-args/qualified-types/subtyping_equalities.scala rename to tests/pos-custom-args/qualified-types/subtyping_equality.scala index 0ea2cc578678..4b4c8d5e8989 100644 --- a/tests/pos-custom-args/qualified-types/subtyping_equalities.scala +++ b/tests/pos-custom-args/qualified-types/subtyping_equality.scala @@ -3,6 +3,9 @@ def g(x: Int): Int = ??? def f2(x: Int, y: Int): Int = ??? def g2(x: Int, y: Int): Int = ??? +case class IntBox(x: Int) +case class Box[T](x: T) + def test: Unit = val a: Int = ??? val b: Int = ??? @@ -33,3 +36,13 @@ def test: Unit = // of the e-graph when `f(a)` and `f(b)` are inserted before `a == b`. summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with f(a) == f(b)}] summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with f(f(a)) == f(f(b))}] + + // Equality is supported on Strings + summon[{v: String with v == "hello"} <:< {v: String with v == "hello"}] + summon[{v: String with v == "hello"} <:< {v: String with "hello" == v}] + + // Equality is supported on case classes + summon[{v: IntBox with v == IntBox(3)} <:< {v: IntBox with v == IntBox(3)}] + summon[{v: IntBox with v == IntBox(3)} <:< {v: IntBox with IntBox(3) == v}] + summon[{v: Box[Int] with v == Box(3)} <:< {v: Box[Int] with v == Box(3)}] + summon[{v: Box[Int] with v == Box(3)} <:< {v: Box[Int] with Box(3) == v}] diff --git a/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala b/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala index 204dde4ac55e..39807c28a840 100644 --- a/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala +++ b/tests/pos-custom-args/qualified-types/subtyping_lambdas.scala @@ -3,15 +3,15 @@ def tp[T](): Any = ??? def test: Unit = - summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(x => x > 0)}] - summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(y => y > 0)}] - summon[{l: List[Int] with l.forall(x => x > 0)} =:= {l: List[Int] with l.forall(_ > 0)}] + val v1: {l: List[Int] with l.forall(x => x > 0)} = ??? : {l: List[Int] with l.forall(x => x > 0)} + val v2: {l: List[Int] with l.forall(x => x > 0)} = ??? : {l: List[Int] with l.forall(y => y > 0)} + val v3: {l: List[Int] with l.forall(x => x > 0)} = ??? : {l: List[Int] with l.forall(_ > 0)} - summon[{l: List[Int] with toBool((x: String) => x.length > 0)} =:= {l: List[Int] with toBool((y: String) => y.length > 0)}] + val v4: {l: List[Int] with toBool((x: String) => x.length > 0)} = ??? : {l: List[Int] with toBool((y: String) => y.length > 0)} - summon[{l: List[Int] with toBool((x: String) => tp[x.type]())} =:= {l: List[Int] with toBool((y: String) => tp[y.type]())}] - summon[{l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[x.type]())}] - summon[{l: List[Int] with toBool((x: String) => tp[x.type]())} =:= {l: List[Int] with toBool((y: String) => tp[y.type]())}] - summon[{l: List[Int] with toBool((x: String, y: String) => tp[y.type]())} =:= {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())}] + val v5: {l: List[Int] with toBool((x: String) => tp[x.type]())} = ??? : {l: List[Int] with toBool((y: String) => tp[y.type]())} + val v6: {l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} = ??? : {l: List[Int] with toBool((x: String, y: String) => tp[x.type]())} + val v7: {l: List[Int] with toBool((x: String) => tp[x.type]())} = ??? : {l: List[Int] with toBool((y: String) => tp[y.type]())} + val v8: {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())} = ??? : {l: List[Int] with toBool((x: String, y: String) => tp[y.type]())} - summon[{l: List[Int] with toBool((x: String) => (y: String) => x == y)} =:= {l: List[Int] with toBool((a: String) => (b: String) => a == b)}] + val v9: {l: List[Int] with toBool((x: String) => (y: String) => x == y)} = ??? : {l: List[Int] with toBool((a: String) => (b: String) => a == b)} diff --git a/tests/pos-custom-args/qualified-types/subtyping_reflectivity.scala b/tests/pos-custom-args/qualified-types/subtyping_reflectivity.scala new file mode 100644 index 000000000000..8cace392c1f6 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/subtyping_reflectivity.scala @@ -0,0 +1,4 @@ +@main def Test = + val n: Int = ??? + summon[{v: Int with v == 2} <:< {v: Int with v == 2}] + summon[{v: Int with v == n} <:< {v: Int with v == n}] diff --git a/tests/pos-custom-args/qualified-types/subtyping_singletons.scala b/tests/pos-custom-args/qualified-types/subtyping_singletons.scala index 984674da225c..f2f83e8dd8c4 100644 --- a/tests/pos-custom-args/qualified-types/subtyping_singletons.scala +++ b/tests/pos-custom-args/qualified-types/subtyping_singletons.scala @@ -2,9 +2,7 @@ type Pos = {v: Int with v > 0} def test: Unit = val x: Int = ??? - val one: Int = 1 summon[1 <:< {v: Int with v == 1}] summon[1 <:< {v: Int with v > 0}] summon[1 <:< Pos] summon[x.type <:< {v: Int with v == x}] - summon[one.type <:< {v: Int with v == 1}] diff --git a/tests/pos-custom-args/qualified-types/typing_type_variables.scala b/tests/pos-custom-args/qualified-types/typing_type_variables.scala new file mode 100644 index 000000000000..647ec1203af7 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/typing_type_variables.scala @@ -0,0 +1,9 @@ +class Box[T](val value: T) +def makeBox[T](x: T): Box[T] with true = Box(x) + +def test: Unit = + val box1 = Box(3) + val box1b: Box[Int] with true = Box(3) + + val box2 = makeBox(3) + val box2b: Box[Int] with true = makeBox(3) diff --git a/tests/printing/qualified-types.check b/tests/printing/qualified-types.check index 29ac7b28ea21..2b5d798060d3 100644 --- a/tests/printing/qualified-types.check +++ b/tests/printing/qualified-types.check @@ -13,20 +13,13 @@ package example { type Pos2 = Int @qualified[Int]((x: Int) => x > 0) type Pos3 = Int @qualified[Int]((x: Int) => x > 0) type Pos4 = Int @qualified[Int]((x: Int) => x > 0) - type Pos5 = - Int @qualified[Int]((x: Int) => - { - val res: Boolean = x > 0 - res:Boolean - } - ) + type Pos5 = Int @qualified[Int]((x: Int) => (_$1: Int) => (x > 0).apply(42)) type UninhabitedInt = Int @qualified[Int]((_: Int) => false) + def id[T >: Nothing <: Any](x: T): T = x type Nested = Int @qualified[Int]((x: Int) => - { - val y: Int @qualified[Int]((z: Int) => z > 0) = ??? - x > y - } + example.id[Boolean @qualified[Boolean]((b: Boolean) => b == x > 42)]( + (x > 42).$asInstanceOf[Boolean with x > 42 == eparam0]) ) type Intersection = Int & Int @qualified[Int]((x: Int) => x > 0) type ValRefinement = @@ -34,7 +27,6 @@ package example { { val x: Int @qualified[Int]((x: Int) => x > 0) } - def id[T >: Nothing <: Any](x: T): T = x def test(): Unit = { val x: example.Pos = 1 diff --git a/tests/printing/qualified-types.scala b/tests/printing/qualified-types.scala index bfa4dad9d0b2..97314ac8af77 100644 --- a/tests/printing/qualified-types.scala +++ b/tests/printing/qualified-types.scala @@ -11,18 +11,17 @@ type Pos3 = {x: Int with type Pos4 = {x: Int with x > 0} type Pos5 = {x: Int with - val res = x > 0 - res + ((_: Int) => x > 0)(42) } type UninhabitedInt = Int with false -type Nested = {x: Int with { val y: {z: Int with z > 0} = ??? ; x > y }} +def id[T](x: T): T = x + +type Nested = {x: Int with id[{b: Boolean with b == x > 42}](x > 42) } type Intersection = Int & {x: Int with x > 0} type ValRefinement = {val x: Int with x > 0} -def id[T](x: T): T = x - def test() = val x: Pos = 1 val x2: {x: Int with x > 0} = 1 diff --git a/tests/run-custom-args/qualified-types/pattern_matching.scala b/tests/run-custom-args/qualified-types/pattern_matching.scala index 43dc87b9dbe9..2e1229c75296 100644 --- a/tests/run-custom-args/qualified-types/pattern_matching.scala +++ b/tests/run-custom-args/qualified-types/pattern_matching.scala @@ -1,5 +1,4 @@ type Pos = {x: Int with x > 0} - type NonEmptyString = {s: String with !s.isEmpty} type PoliteString = {s: NonEmptyString with s.head.isUpper && s.takeRight(6) == "please"} @@ -7,14 +6,6 @@ def id[T](x: T): T = println("call id") x -def rec(x: NonEmptyString): List[Char] = - val rest = - x.tail match - case xs: NonEmptyString => rec(xs) - case _ => Nil - - x.head :: rest - @main def Test = for v <- List[Any](-1, 1, 2, "", "Do it please", "do it already", false, null) do val vStr = From b46413f952e55c988fb437441ca93dbfdbfb279a Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Tue, 30 Sep 2025 11:35:04 +0000 Subject: [PATCH 14/20] Re-add `qualifiedTypes` in `object language` --- library/src/scala/language.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/src/scala/language.scala b/library/src/scala/language.scala index 5f39ce2017e4..3fc006149999 100644 --- a/library/src/scala/language.scala +++ b/library/src/scala/language.scala @@ -293,6 +293,10 @@ object language { @compileTimeOnly("`separationChecking` can only be used at compile time in import statements") object separationChecking + /** Experimental support for qualified types */ + @compileTimeOnly("`qualifiedTypes` is only be used at compile time") + object qualifiedTypes + /** Experimental support for automatic conversions of arguments, without requiring * a language import `import scala.language.implicitConversions`. * From 24eeb587d141ae4085f64ca870b5889e74101bd7 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Sun, 12 Oct 2025 14:03:12 +0000 Subject: [PATCH 15/20] Fix Param references equality --- .../qualified_types/QualifierSolver.scala | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala index 10792bb5cf3a..dc988c0f2463 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala @@ -8,26 +8,25 @@ import dotty.tools.dotc.core.Symbols.defn import dotty.tools.dotc.core.Types.{Type, TypeVar, TypeMap} import dotty.tools.dotc.core.Decorators.i import dotty.tools.dotc.printing.Showable +import dotty.tools.dotc.reporting.trace class QualifierSolver(using Context): - def implies(node1: ENode.Lambda, node2: ENode.Lambda) = - require(node1.paramTps.length == 1) - require(node2.paramTps.length == 1) - val node1Inst = node1.normalizeTypes().asInstanceOf[ENode.Lambda] - val node2Inst = node2.normalizeTypes().asInstanceOf[ENode.Lambda] - val paramTp1 = node1Inst.paramTps.head - val paramTp2 = node2Inst.paramTps.head - if paramTp1 frozen_<:< paramTp2 then - impliesRec(subsParamRefTps(node1Inst.body, node2Inst), node2Inst.body) - else if paramTp2 frozen_<:< paramTp1 then - impliesRec(node1Inst.body, subsParamRefTps(node2Inst.body, node1Inst)) - else - false - - private def subsParamRefTps(node1Body: ENode, node2: ENode.Lambda): ENode = - val paramRefs = node2.paramTps.zipWithIndex.map((tp, i) => ENodeParamRef(i, tp)) - node1Body.substEParamRefs(0, paramRefs) + def implies(node1: ENode.Lambda, node2: ENode.Lambda): Boolean = + trace(i"implie ${node1.showNoBreak} --> ${node2.showNoBreak}", Printers.qualifiedTypes): + require(node1.paramTps.length == 1) + require(node2.paramTps.length == 1) + val node1Inst = node1.normalizeTypes().asInstanceOf[ENode.Lambda] + val node2Inst = node2.normalizeTypes().asInstanceOf[ENode.Lambda] + val paramTp1 = node1Inst.paramTps.head + val paramTp2 = node2Inst.paramTps.head + if paramTp1 frozen_<:< paramTp2 then impliesCommonParams(node1Inst, node2Inst, node1Inst) + else if paramTp2 frozen_<:< paramTp1 then impliesCommonParams(node1Inst, node2Inst, node2Inst) + else false + + private def impliesCommonParams(node1: ENode.Lambda, node2: ENode.Lambda, mostPreciseNode: ENode.Lambda): Boolean = + val paramRefs = mostPreciseNode.paramTps.zipWithIndex.map((tp, i) => ENodeParamRef(i, tp)) + impliesRec(node1.body.substEParamRefs(0, paramRefs), node2.body.substEParamRefs(0, paramRefs)) private def impliesRec(node1: ENode, node2: ENode): Boolean = node1 match @@ -42,10 +41,11 @@ class QualifierSolver(using Context): protected def impliesLeaf(egraph: EGraph, enode1: ENode, enode2: ENode): Boolean = val node1Canonical = egraph.canonicalize(enode1) val node2Canonical = egraph.canonicalize(enode2) - egraph.assertInvariants() - egraph.merge(node1Canonical, egraph.trueNode) - egraph.repair() - egraph.equiv(node2Canonical, egraph.trueNode) + trace(i"impliesLeaf ${node1Canonical.showNoBreak} --> ${node2Canonical.showNoBreak}", Printers.qualifiedTypes): + egraph.assertInvariants() + egraph.merge(node1Canonical, egraph.trueNode) + egraph.repair() + egraph.equiv(node2Canonical, egraph.trueNode) final class ExplainingQualifierSolver( traceIndented: [T] => (String) => (=> T) => T)(using Context) extends QualifierSolver: @@ -53,5 +53,5 @@ final class ExplainingQualifierSolver( override protected def impliesLeaf(egraph: EGraph, enode1: ENode, enode2: ENode): Boolean = traceIndented(s"${enode1.showNoBreak} --> ${enode2.showNoBreak}"): val res = super.impliesLeaf(egraph, enode1, enode2) - if !res then println(egraph.debugString()) + //if !res then println(egraph.debugString()) res From cf00bffa3acfd748c8c864459f0e4f24dc6dda8f Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Sun, 12 Oct 2025 14:10:03 +0000 Subject: [PATCH 16/20] Handle `QualifiedAnnotation` in `TypeAccumulator` --- compiler/src/dotty/tools/dotc/core/Types.scala | 8 ++++++-- compiler/src/dotty/tools/dotc/qualified_types/ENode.scala | 8 ++++---- .../tools/dotc/qualified_types/QualifiedAnnotation.scala | 6 ++++++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 468a8b0009a0..ef3ae67670ed 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -43,7 +43,7 @@ import cc.* import CaptureSet.IdentityCaptRefMap import Capabilities.* import transform.Recheck.currentRechecker -import qualified_types.QualifiedType +import qualified_types.{QualifiedType, QualifiedAnnotation} import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -5018,6 +5018,7 @@ object Types extends TypeUtils { private var myRepr: Name | Null = null def repr(using Context): Name = { + //if (myRepr == null) myRepr = s"?$id".toString.toTermName if (myRepr == null) myRepr = SkolemName.fresh() myRepr.nn } @@ -6944,7 +6945,10 @@ object Types extends TypeUtils { def apply(x: T, tp: Type): T - protected def applyToAnnot(x: T, annot: Annotation): T = x // don't go into annotations + protected def applyToAnnot(x: T, annot: Annotation): T = + annot match + case annot: QualifiedAnnotation => annot.foldOverTypes(x, this) + case _ => x // don't go into other annotations /** A prefix is never contravariant. Even if say `p.A` is used in a contravariant * context, we cannot assume contravariance for `p` because `p`'s lower diff --git a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala index 7b88b6e2b9e6..058c5f9988bf 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala @@ -249,7 +249,8 @@ enum ENode extends Showable: body.foreachType(f) def normalizeTypes()(using Context): ENode = - mapTypes(NormalizeMap()) + trace(i"normalizeTypes($this)", Printers.qualifiedTypes): + mapTypes(NormalizeMap()) private class NormalizeMap(using Context) extends TypeMap: def apply(tp: Type): Type = @@ -257,7 +258,7 @@ enum ENode extends Showable: case tp: TypeVar if tp.isPermanentlyInstantiated => apply(tp.permanentInst) case tp: NamedType => - val dealiased = tp.dealias + val dealiased = tp.dealiasKeepAnnotsAndOpaques if dealiased ne tp then apply(dealiased) else if tp.symbol.isStatic then @@ -547,7 +548,7 @@ object ENode: def assumptions(node: ENode)(using Context): List[ENode] = trace(i"assumptions($node)", Printers.qualifiedTypes): node match - case Atom(tp: SingletonType) => termAssumptions(tp) ++ typeAssumptions(tp) + case n: Atom => termAssumptions(n.tp) ++ typeAssumptions(n.tp) case n: Constructor => Nil case n: Select => assumptions(n.qual) case n: Apply => assumptions(n.fn) ++ n.args.flatMap(assumptions) @@ -560,7 +561,6 @@ object ENode: tp match case tp: TermRef => tp.symbol.info match - case QualifiedType(_, _) => Nil case _ => tp.symbol.defTree match case valDef: tpd.ValDef if !valDef.rhs.isEmpty && !valDef.symbol.is(Flags.Lazy) => diff --git a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala index 83738a188ced..b651e764a10d 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/QualifiedAnnotation.scala @@ -37,3 +37,9 @@ case class QualifiedAnnotation(qualifier: ENode.Lambda) extends Annotation: case TermParamRef(tl1, _) if tl eq tl1 => res = true case _ => () res + + def foldOverTypes[A](z: A, f: (A, Type) => A)(using Context): A = + var acc = z + qualifier.foreachType: tp => + acc = f(acc, tp) + acc From a76b30296ead7699db79131a711fecf5f8b0f777 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Tue, 14 Oct 2025 03:20:10 +0000 Subject: [PATCH 17/20] Add Scala Workshop 25 examples --- .../qualified-types/list_collect_neg.scala | 5 + .../qualified-types/scala-workshop-25.scala | 185 ++++++++++++++++++ .../qualified-types/implicits.scala | 17 ++ .../qualified-types/list_collect.scala | 5 + .../qualified-types/list_map.scala | 10 +- .../qualified-types/matrices.scala | 21 ++ 6 files changed, 238 insertions(+), 5 deletions(-) create mode 100644 tests/neg-custom-args/qualified-types/list_collect_neg.scala create mode 100644 tests/neg-custom-args/qualified-types/scala-workshop-25.scala create mode 100644 tests/pos-custom-args/qualified-types/implicits.scala create mode 100644 tests/pos-custom-args/qualified-types/list_collect.scala create mode 100644 tests/pos-custom-args/qualified-types/matrices.scala diff --git a/tests/neg-custom-args/qualified-types/list_collect_neg.scala b/tests/neg-custom-args/qualified-types/list_collect_neg.scala new file mode 100644 index 000000000000..3d5afabfc22a --- /dev/null +++ b/tests/neg-custom-args/qualified-types/list_collect_neg.scala @@ -0,0 +1,5 @@ +type Pos = { v: Int with v >= 0 } + +@main def main = + val xs = List(-1,2,-2,1) + xs.collect { case x: Int => x } : List[Pos] // error diff --git a/tests/neg-custom-args/qualified-types/scala-workshop-25.scala b/tests/neg-custom-args/qualified-types/scala-workshop-25.scala new file mode 100644 index 000000000000..864e8db6aa5a --- /dev/null +++ b/tests/neg-custom-args/qualified-types/scala-workshop-25.scala @@ -0,0 +1,185 @@ +@main def main = + + // Specification using qualified types + { + def zip[A, B](as: List[A], bs: List[B] with bs.size == as.size): + {l: List[(A, B)] with l.size == as.size} + = ??? + + def concat[T](as: List[T], bs: List[T]): + {rs: List[T] with rs.size == as.size + bs.size} + = ??? + + val xs: List[Int] = ??? + val ys: List[Int] = ??? + zip(concat(xs, ys), concat(ys, xs)) + zip(concat(xs, ys), concat(xs, xs)) // error + } + + // Syntax + { + { + type NonEmptyList[A] = { l: List[A] with l.nonEmpty } + + // Not to be confused with structural types + case class Box(value: Any) + type IntBox = Box { val value: Int } + } + + // Shortand + { + def zip[A, B](as: List[A], bs: List[B] with bs.size == as.size) = ??? + } + } + + // Valid/invalid predicates + { + { + var x = 3 + val y: Int with y == 3 = x // error: ⛔️ x is mutable + } + + { + val x = 3 + val y: Int with y == 3 = x // okay + } + + { + class Box(val value: Int) + val b: Box with b == Box(3) = Box(3) // error: ⛔️ Box has equality by reference + } + + { + case class Box(value: Int) + val b: Box with b == Box(3) = Box(3) // okay + } + } + + // Selfification + { + val x: Int = ??? + val y: Int with (y == x + 1) = x + 1 + + def f(x: Int): Int = ??? + val z: Int with (z == x + f(x)) = x + f(x) + } + + // Runtime checks + { + val idRegex = "^[a-zA-Z_][a-zA-Z0-9_]*$" + type ID = {s: String with s.matches(idRegex)} + + { + "a2e7-e89b" match + case _: ID => // matched + case _ => // didn't match + } + + { + val id: ID = "a2e7-e89b".runtimeChecked + } + + { + val id: ID = + if ("a2e7-e89b".matches(idRegex)) "a2e7-e89b".asInstanceOf[ID] + else throw new IllegalArgumentException() + } + + { + type Pos = { v: Int with v >= 0 } + + val xs = List(-1,2,-2,1) + xs.collect { case x: Pos => x } : List[Pos] + } + } + + // Subtyping + { + { + val x: Int = ??? + val y: Int = ??? + summon[{v: Int with v == 1 + 1} =:= {v: Int with v == 2}] + summon[{v: Int with v == x + 1} =:= {v: Int with v == 1 + x}] + summon[{v: Int with v == y + x} =:= {v: Int with v == x + y}] + summon[{v: Int with v == x + 3 * y} =:= {v: Int with v == 2 * y + (x + y)}] + } + + { + val x: Int = ??? + val y: Int = x + 1 + summon[{v: Int with v == y} =:= {v: Int with v == x + 1}] + } + + { + val a: Int = ??? + val b: Int = ??? + summon[{v: Int with v == a && a == b} <:< {v: Int with v == b}] + def f(x: Int): Int = ??? + summon[{v: Int with a == b} <:< {v: Int with f(a) == f(b)}] + } + + { + summon[3 <:< {v: Int with v == 3}] + } + } + + // Backup slides about type-level programming with existing Scala features + { + def checkSame(dimA: Int, dimB: dimA.type): Unit = () + checkSame(3, 3) // ok + checkSame(3, 4) // error + + { + val x = 3 + val y = 3 + checkSame(x, y) // error + } + + { + val x2: 3 = 3 + val y3: 3 = 3 + checkSame(x2, y3) // ok + } + + def readInt(): Int = ??? + + { + val x: Int = readInt() + val y = x + val z = y + checkSame(y, z) // error + } + + { + val x: Int = readInt() + val y: x.type = x + val z: x.type = x + checkSame(y, z) // okay + } + + + { + val x: Int = readInt() + val y: Int = readInt() + val z = x + y + val a = y + x + checkSame(z, a) // error + } + + { + import scala.compiletime.ops.int.+ + val x: 3 = 3 + val y: 5 = 5 + val z: x.type + y.type = x + y + val a: y.type + x.type = y + x + } + + { + import scala.compiletime.ops.int.+ + val x: Int = readInt() + val y: Int = readInt() + val z: x.type + y.type = x + y // error + val a: y.type + x.type = y + x // error + checkSame(z, a) // error + } + } diff --git a/tests/pos-custom-args/qualified-types/implicits.scala b/tests/pos-custom-args/qualified-types/implicits.scala new file mode 100644 index 000000000000..abadb5fb8129 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/implicits.scala @@ -0,0 +1,17 @@ +type Pos = { v: Int with v >= 0 } +type Neg = { v: Int with v < 0 } + +trait Show[-A]: + def apply(a: A): String + +given show1: Show[Pos] with + def apply(a: Pos): String = "I am a positive integer!" + +given show2: Show[Neg] with + def apply(a: Neg): String = "I am a negative integer!" + +def show[A](a: A)(using s: Show[A]): String = s.apply(a) + +def f(x: Int with x == 42, y: Int with y == -42): Unit = + println(show(x)) // I am a positive integer! + println(show(y)) // I am a negative integer! diff --git a/tests/pos-custom-args/qualified-types/list_collect.scala b/tests/pos-custom-args/qualified-types/list_collect.scala new file mode 100644 index 000000000000..e1291ec349d7 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/list_collect.scala @@ -0,0 +1,5 @@ +type Pos = { v: Int with v >= 0 } + +@main def main = + val xs = List(-1,2,-2,1) + xs.collect { case x: Pos => x } : List[Pos] diff --git a/tests/pos-custom-args/qualified-types/list_map.scala b/tests/pos-custom-args/qualified-types/list_map.scala index 38b4a73af647..2d2478b4fb31 100644 --- a/tests/pos-custom-args/qualified-types/list_map.scala +++ b/tests/pos-custom-args/qualified-types/list_map.scala @@ -1,9 +1,9 @@ -type PosInt = {v: Int with v > 0} +type Pos = {v: Int with v > 0} - -def inc(x: PosInt): PosInt = (x + 1).runtimeChecked +def inc(x: Pos): Pos = (x + 1).runtimeChecked @main def Test = - val l: List[PosInt] = List(1,2,3) - val l2: List[PosInt] = l.map(inc) + val l: List[Pos] = List(1,2,3) + val l2 = l.map(inc) + l2: List[Pos] () diff --git a/tests/pos-custom-args/qualified-types/matrices.scala b/tests/pos-custom-args/qualified-types/matrices.scala new file mode 100644 index 000000000000..51aa52ecfa46 --- /dev/null +++ b/tests/pos-custom-args/qualified-types/matrices.scala @@ -0,0 +1,21 @@ +type Pos = {v: Int with v >= 0} +type Matrix[T] = List[List[T]] + +def length[T](v: List[T]): Pos = + v.length.runtimeChecked + +def width[T](m: Matrix[T]): Pos = + if m.isEmpty then 0 else length(m.head) + +def height[T](m: Matrix[T]): Pos = + length(m) + +def tabulate[T](rows: Pos, cols: Pos, f: (r: Pos with r <= rows, c: Pos with c <= cols) => T) + : {r: Matrix[T] with width(r) == cols && height(r) == rows} = + List.tabulate(rows, cols)((r, c) => f(r.runtimeChecked, c.runtimeChecked)).runtimeChecked + +def transpose[T](m: Matrix[T]): {r: Matrix[T] with width(r) == height(m) && height(r) == width(m)} = + val newWidth = height(m) + val newHeight = width(m) + tabulate(newHeight, newWidth, (r, c) => m(r)(c)) + From c38b6b0c088f4a0b41618bdfe5822549586fb05b Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 30 Oct 2025 12:16:01 +0000 Subject: [PATCH 18/20] Fix desugaring of fields with qualified types --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 12 ++++++++++-- .../qualified-types/class_constraints.scala | 6 +++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index c028dd896bb7..83e28edccea0 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -213,7 +213,11 @@ object desugar { def valDef(vdef0: ValDef)(using Context): Tree = val vdef @ ValDef(_, tpt, rhs) = vdef0 val valName = normalizeName(vdef, tpt).asTermName - val tpt1 = desugarQualifiedTypes(tpt, valName) + val tpt1 = + if Feature.qualifiedTypesEnabled then + desugarQualifiedTypes(tpt, valName) + else + tpt var mods1 = vdef.mods val vdef1 = cpy.ValDef(vdef)(name = valName, tpt = tpt1).withMods(mods1) @@ -747,7 +751,11 @@ object desugar { report.error(CaseClassMissingNonImplicitParamList(cdef), namePos) ListOfNil } - else originalVparamss.nestedMap(toMethParam(_, KeepAnnotations.All, keepDefault = true)) + else + originalVparamss.nestedMap: param => + val methParam = toMethParam(param, KeepAnnotations.All, keepDefault = true) + valDef(methParam).asInstanceOf[ValDef] // desugar early to handle qualified types + val derivedTparams = constrTparams.zipWithConserve(impliedTparams)((tparam, impliedParam) => derivedTypeParam(tparam).withAnnotations(impliedParam.mods.annotations)) diff --git a/tests/pos-custom-args/qualified-types/class_constraints.scala b/tests/pos-custom-args/qualified-types/class_constraints.scala index 71085de63d44..f8dcfeee6124 100644 --- a/tests/pos-custom-args/qualified-types/class_constraints.scala +++ b/tests/pos-custom-args/qualified-types/class_constraints.scala @@ -1,3 +1,7 @@ -/*class foo(elem: Int with elem > 0)*/ +class foo(elem: Int with elem > 0) + +class Multi(x: Int with x > 0)(y: String with y.length > 2, z: Double with z >= 0.0) + +case class MultiCase(x: Int with x > 0)(val y: String with y.length > 2, val z: Double with z >= 0.0) @main def Test = () From e40a37111e576a86e72e030c37efb91bd6f0055c Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 30 Oct 2025 14:29:47 +0000 Subject: [PATCH 19/20] Fix handling of `null` in qualifiers --- compiler/src/dotty/tools/dotc/qualified_types/ENode.scala | 2 ++ tests/pos-custom-args/qualified-types/class_constraints.scala | 2 ++ 2 files changed, 4 insertions(+) diff --git a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala index 058c5f9988bf..ea144122ec1d 100644 --- a/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala +++ b/compiler/src/dotty/tools/dotc/qualified_types/ENode.scala @@ -461,6 +461,8 @@ object ENode: case tpd.Literal(_) | tpd.Ident(_) | tpd.This(_) if tree.tpe.isInstanceOf[SingletonType] && tpd.isIdempotentExpr(tree) => Some(Atom(substParamRefs(tree.tpe, paramSyms, paramTps).asInstanceOf[SingletonType])) + case tpd.Literal(Constant(null)) => // null does not have a SingletonType + Some(Atom(ConstantType(Constant(null)))) case tpd.Select(tpd.New(_), nme.CONSTRUCTOR) => constructorNode(tree.symbol) case tree: tpd.Select if isCaseClassApply(tree.symbol) => diff --git a/tests/pos-custom-args/qualified-types/class_constraints.scala b/tests/pos-custom-args/qualified-types/class_constraints.scala index f8dcfeee6124..53d466b8bfe6 100644 --- a/tests/pos-custom-args/qualified-types/class_constraints.scala +++ b/tests/pos-custom-args/qualified-types/class_constraints.scala @@ -4,4 +4,6 @@ class Multi(x: Int with x > 0)(y: String with y.length > 2, z: Double with z >= case class MultiCase(x: Int with x > 0)(val y: String with y.length > 2, val z: Double with z >= 0.0) +class TParam[T](x: T with x != null) + @main def Test = () From 9d6121eed261bacf4627cae1ae5c97d6f3ac40af Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Tue, 11 Nov 2025 16:15:53 +0000 Subject: [PATCH 20/20] Add mergeSort example --- .../qualified-types/mergeSort.check | 4 ++ .../qualified-types/mergeSort.scala | 45 +++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 tests/run-custom-args/qualified-types/mergeSort.check create mode 100644 tests/run-custom-args/qualified-types/mergeSort.scala diff --git a/tests/run-custom-args/qualified-types/mergeSort.check b/tests/run-custom-args/qualified-types/mergeSort.check new file mode 100644 index 000000000000..2a0f91aee9ed --- /dev/null +++ b/tests/run-custom-args/qualified-types/mergeSort.check @@ -0,0 +1,4 @@ +Unsorted: ArraySeq(5, 3, 8, 1, 2, 7, 4, 6) +Sorted: ArraySeq(1, 2, 3, 4, 5, 6, 7, 8) +Unsorted: ArraySeq(7, 4, 5, 3, 2, 6, 1) +Sorted: ArraySeq(1, 2, 3, 4, 5, 6, 7) diff --git a/tests/run-custom-args/qualified-types/mergeSort.scala b/tests/run-custom-args/qualified-types/mergeSort.scala new file mode 100644 index 000000000000..147e8e82ca40 --- /dev/null +++ b/tests/run-custom-args/qualified-types/mergeSort.scala @@ -0,0 +1,45 @@ +type Pos = {x: Int with x >= 0} + +def safeDiv(x: Pos, y: Pos with y > 1): {res: Pos with res < x} = + (x / y).asInstanceOf[{res: Pos with res < x}] + +object SafeSeqs: + opaque type SafeSeq[T] = Seq[T] + object SafeSeq: + def fromSeq[T](seq: Seq[T]): SafeSeq[T] = seq + def apply[T](elems: T*): SafeSeq[T] = fromSeq(elems) + extension [T](a: SafeSeq[T]) + def len: Pos = a.length.runtimeChecked + def apply(i: Pos with i < a.len): T = a(i) + def splitAt(i: Pos with i < a.len): (SafeSeq[T], SafeSeq[T]) = a.splitAt(i) + def head: T = a(0) + def tail: SafeSeq[T] = a.tail + def ++(that: SafeSeq[T]): SafeSeq[T] = a ++ that + +import SafeSeqs.* + +def merge[T: Ordering as ord](left: SafeSeq[T], right: SafeSeq[T]): SafeSeq[T] = + if left.len == 0 then right + else if right.len == 0 then left + else + if ord.lt(left.head, right.head) then SafeSeq(left.head) ++ merge(left.tail, right) + else SafeSeq(right.head) ++ merge(left, right.tail) + +def mergeSort[T: Ordering](list: SafeSeq[T]): SafeSeq[T] = + val len = list.len + val middle = safeDiv(len, 2) + if middle == 0 then list + else + val (left, right) = list.splitAt(middle) + merge(mergeSort(left), mergeSort(right)) + +@main def Test = + val nums = SafeSeq(5, 3, 8, 1, 2, 7, 4, 6) + val sortedNums = mergeSort(nums) + println(s"Unsorted: $nums") + println(s"Sorted: $sortedNums") + + val nums2 = SafeSeq(7, 4, 5, 3, 2, 6, 1) + val sortedNums2 = mergeSort(nums2) + println(s"Unsorted: $nums2") + println(s"Sorted: $sortedNums2")