Skip to content
This repository was archived by the owner on Sep 3, 2020. It is now read-only.

Commit 233cbdc

Browse files
committed
Merge pull request #74 from mlangc/various-simple-bugfixes
Treat locally defined types properly
2 parents 833712b + 10cc938 commit 233cbdc

File tree

3 files changed

+96
-13
lines changed

3 files changed

+96
-13
lines changed

org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/analysis/CompilationUnitDependencies.scala

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
package scala.tools.refactoring
66
package analysis
77

8-
trait CompilationUnitDependencies {
8+
import scala.tools.refactoring.common.CompilerApiExtensions
9+
10+
trait CompilationUnitDependencies extends CompilerApiExtensions {
911
// we need to interactive compiler because we work with RangePositions
1012
this: common.InteractiveScalaCompiler with common.TreeTraverser with common.TreeExtractors with common.PimpedTrees =>
1113

@@ -72,19 +74,9 @@ trait CompilationUnitDependencies {
7274
*/
7375
def neededImports(t: Tree): List[Select] = {
7476

75-
/**
76-
* Check if the definition is also a child of the outer `t`. In that case, we don't need
77-
* to add an import because the dependency is to a local definition.
78-
*/
79-
def isLocalDefinition(dependency: Tree) = t.exists {
80-
case t: DefTree => dependency.symbol == t.symbol
81-
case _ => false
82-
}
83-
8477
val deps = dependencies(t)
8578

8679
val neededDependencies = deps.flatMap {
87-
case t if isLocalDefinition(t) => None
8880
case t: Select if !t.pos.isRange => Some(t)
8981
case t => findDeepestNeededSelect(t)
9082
}.filter(isImportReallyNeeded).distinct
@@ -100,6 +92,20 @@ trait CompilationUnitDependencies {
10092
* because they are defined in the same compilation unit.
10193
*/
10294
def dependencies(t: Tree): List[Select] = {
95+
val wholeTree = t
96+
97+
def qualifierIsEnclosingPackage(t: Select) = {
98+
enclosingPackage(wholeTree, t.pos) match {
99+
case pkgDef: PackageDef =>
100+
t.qualifier.nameString == pkgDef.nameString
101+
case _ => false
102+
}
103+
}
104+
105+
def isDefinedLocally(t: Tree) = wholeTree.exists {
106+
case defTree: DefTree if t.symbol == defTree.symbol => true
107+
case _ => false
108+
}
103109

104110
val result = new collection.mutable.HashMap[String, Select]
105111

@@ -173,7 +179,6 @@ trait CompilationUnitDependencies {
173179
val language = newTermName("language")
174180

175181
override def traverse(tree: Tree) = tree match {
176-
177182
// Always add the SIP 18 language imports as required until we can handle them properly
178183
case Import(select @ Select(Ident(nme.scala_), `language`), feature) =>
179184
feature foreach (selector => addToResult(Select(select, selector.name)))
@@ -240,7 +245,8 @@ trait CompilationUnitDependencies {
240245
if (!isMethodCallFromExplicitReceiver(t)
241246
&& !isSelectFromInvisibleThis(qual)
242247
&& t.name != nme.WILDCARD
243-
&& hasStableQualifier(t)) {
248+
&& hasStableQualifier(t)
249+
&& !(isDefinedLocally(t) && qualifierIsEnclosingPackage(t))) {
244250
addToResult(t)
245251
}
246252

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package scala.tools.refactoring.common
2+
3+
import scala.tools.nsc.Global
4+
5+
/*
6+
* FIXME: This class duplicates functionality from org.scalaide.core.compiler.CompilerApiExtensions.
7+
*/
8+
trait CompilerApiExtensions {
9+
this: CompilerAccess =>
10+
import global._
11+
12+
/** Locate the smallest tree that encloses position.
13+
*
14+
* @param tree The tree in which to search `pos`
15+
* @param pos The position to look for
16+
* @param p An additional condition to be satisfied by the resulting tree
17+
* @return The innermost enclosing tree for which p is true, or `EmptyTree`
18+
* if the position could not be found.
19+
*/
20+
def locateIn(tree: Tree, pos: Position, p: Tree => Boolean = t => true): Tree =
21+
new FilteringLocator(pos, p) locateIn tree
22+
23+
def enclosingPackage(tree: Tree, pos: Position): Tree = {
24+
locateIn(tree, pos, _.isInstanceOf[PackageDef])
25+
}
26+
27+
private class FilteringLocator(pos: Position, p: Tree => Boolean) extends Locator(pos) {
28+
override def isEligible(t: Tree) = super.isEligible(t) && p(t)
29+
}
30+
31+
/*
32+
* For Scala-2.10 (see scala.reflect.internal.Positions.Locator in Scala-2.11).
33+
*/
34+
private class Locator(pos: Position) extends Traverser {
35+
var last: Tree = _
36+
def locateIn(root: Tree): Tree = {
37+
this.last = EmptyTree
38+
traverse(root)
39+
this.last
40+
}
41+
protected def isEligible(t: Tree) = !t.pos.isTransparent
42+
override def traverse(t: Tree) {
43+
t match {
44+
case tt : TypeTree if tt.original != null && (tt.pos includes tt.original.pos) =>
45+
traverse(tt.original)
46+
case _ =>
47+
if (t.pos includes pos) {
48+
if (isEligible(t)) last = t
49+
super.traverse(t)
50+
} else t match {
51+
case mdef: MemberDef =>
52+
traverseTrees(mdef.mods.annotations)
53+
case _ =>
54+
}
55+
}
56+
}
57+
}
58+
}

org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/analysis/CompilationUnitDependenciesTest.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,5 +834,24 @@ class CompilationUnitDependenciesTest extends TestHelper with CompilationUnitDep
834834
def xx(i: Int) = i
835835
}
836836
""")
837+
838+
@Test
839+
def importLocallyDefinedClass() = assertNeededImports(
840+
"""test.MyType""",
841+
"""import test.MyType
842+
package test {
843+
class MyType
844+
}
845+
class Test(myType: MyType)
846+
}""")
847+
848+
@Test
849+
def importLocallyDefniedClassNotNeeded = assertNeededImports(
850+
"""""",
851+
"""package test {
852+
class MyType
853+
class Test(myType: MyType)
854+
}
855+
}""")
837856
}
838857

0 commit comments

Comments
 (0)