From 5625389adcab9d89812ab7f0cac9155d3c065d33 Mon Sep 17 00:00:00 2001 From: Simon Schaefer Date: Mon, 22 Sep 2014 21:08:04 +0200 Subject: [PATCH 1/6] Add override keywords to ReusingPrinter where necessary --- .../tools/refactoring/sourcegen/ReusingPrinter.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala index 90f57ec9..acdf1324 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala @@ -202,22 +202,22 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { override def CaseDef(tree: CaseDef, pat: Tree, guard: Tree, body: Tree)(implicit ctx: PrintingContext) = { val arrowReq = new Requisite { - def isRequired(l: Layout, r: Layout) = { + override def isRequired(l: Layout, r: Layout) = { !(l.contains("=>") || r.contains("=>") || p(body).asText.startsWith("=>")) } // It's just nice to have a whitespace before and after the arrow - def getLayout = Layout(" => ") + override def getLayout = Layout(" => ") } val ifReq = new Requisite { - def isRequired(l: Layout, r: Layout) = { + override def isRequired(l: Layout, r: Layout) = { !(l.contains("if") || r.contains("if")) } // Leading and trailing whitespace is required in some cases! // e.g. `case i if i > 0 => ???` becomes `case iifi > 0 => ???` otherwise - def getLayout = Layout(" if ") + override def getLayout = Layout(" if ") } body match { From 97f1b14dba9d59544d8ce966a1293a0bddba55ef Mon Sep 17 00:00:00 2001 From: Simon Schaefer Date: Mon, 22 Sep 2014 21:10:47 +0200 Subject: [PATCH 2/6] Generate return type for defs with single expressions in braces In case def foo = { 0 } should be transformed to def foo: Int = { 0 } by adding its return type, the result was def foo: Int = 0 } Fixes #1002268 --- .../sourcegen/ReusingPrinter.scala | 19 +++- .../tests/sourcegen/ReusingPrinterTest.scala | 88 +++++++++++++++++++ 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala index acdf1324..bb1966ac 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala @@ -921,10 +921,12 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { val body = p(rhs) val noEqualNeeded = body == EmptyFragment || rhs.tpe == null || (rhs.tpe != null && rhs.tpe.toString == "Unit") + def openingBrace = keepOpeningBrace(tree, tpt, rhs) + if (noEqualNeeded) l ++ mods_ ++ resultType ++ body ++ r else - l ++ mods_ ++ resultType ++ Requisite.anywhere("=", " = ") ++ body ++ r + l ++ mods_ ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r } } @@ -1004,12 +1006,25 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { body == EmptyFragment || rhs.tpe == null || (rhs.tpe != null && rhs.tpe.toString == "Unit") } + def openingBrace = keepOpeningBrace(tree, tpt, rhs) + if (noEqualNeeded && !hasEqualInSource) { l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ body ++ r } else { - l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ body ++ r + l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r } } + + private def keepOpeningBrace(tree: Tree, tpt: Tree, rhs: Tree): String = tpt match { + case tpt: TypeTree if tpt.original != null && tree.pos != NoPosition && rhs.pos != NoPosition => + val OpeningBrace = "(?s).*(\\{.*)".r + Layout(tree.pos.source, tree.pos.point, rhs.pos.start).asText match { + case OpeningBrace(brace) => brace + case _ => "" + } + case _ => + "" + } } trait SuperPrinters { diff --git a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala index 021b47ee..9a1be52f 100644 --- a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala +++ b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala @@ -91,6 +91,94 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { d.copy(tpt = newTpt) replaces d }}} + @Test + def add_return_type_to_val_with_single_expression_in_braces() = """ + package add_return_type_to_val_with_single_expression_in_braces + object X { + val foo = { + 0 + } + } + """ becomes """ + package add_return_type_to_val_with_single_expression_in_braces + object X { + val foo: Int = { + 0 + } + } + """ after topdown { matchingChildren { transform { + case d @ ValDef(_, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_return_type_to_val_with_multiple_expressions_in_braces() = """ + package add_return_type_to_val_with_multiple_expressions_in_braces + object X { + val foo = { + val a = 0 + a + } + } + """ becomes """ + package add_return_type_to_val_with_multiple_expressions_in_braces + object X { + val foo: Int = { + val a: Int = 0 + a + } + } + """ after topdown { matchingChildren { transform { + case d @ ValDef(_, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_return_type_to_def_with_single_expression_in_braces() = """ + package add_return_type_to_def_with_single_expression_in_braces + object X { + def foo = { + 0 + } + } + """ becomes """ + package add_return_type_to_def_with_single_expression_in_braces + object X { + def foo: Int = { + 0 + } + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_return_type_to_def_with_multiple_expressions_in_braces() = """ + package add_return_type_to_def_with_multiple_expressions_in_braces + object X { + def foo = { + def a = 0 + a + } + } + """ becomes """ + package add_return_type_to_def_with_multiple_expressions_in_braces + object X { + def foo: Int = { + def a: Int = 0 + a + } + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + @Test def add_override_flag() = """ package add_override_flag From 5827cdf1489ec1ca960daae25bb3e68806d008fb Mon Sep 17 00:00:00 2001 From: Simon Schaefer Date: Fri, 26 Sep 2014 12:14:36 +0200 Subject: [PATCH 3/6] Fix return type printing for methods that return Unit --- .../sourcegen/ReusingPrinter.scala | 26 ++++++++++++++----- .../tests/sourcegen/ReusingPrinterTest.scala | 25 ++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala index bb1966ac..736f4a73 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala @@ -987,8 +987,9 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { case _ => false } + val isAbstract = body == EmptyFragment val resultType = - if (body == EmptyFragment && !existsTptInFile) + if (isAbstract && !existsTptInFile) EmptyFragment else p(tpt, before = Requisite.allowSurroundingWhitespace(":", ": ")) @@ -1002,19 +1003,30 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { } } - val noEqualNeeded = { - body == EmptyFragment || rhs.tpe == null || (rhs.tpe != null && rhs.tpe.toString == "Unit") - } - - def openingBrace = keepOpeningBrace(tree, tpt, rhs) + val noEqualNeeded = resultType == EmptyFragment || isAbstract if (noEqualNeeded && !hasEqualInSource) { l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ body ++ r } else { - l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r + val openingBrace = keepOpeningBrace(tree, tpt, rhs) + // In case a Unit return type is added to a method like `def f {}`, we + // need to remove the whitespace between name and rhs, otherwise the + // result would be `def f : Unit = {}`. + val modsAndName2 = + if (modsAndName.trailing.asText.trim.isEmpty) + Fragment(modsAndName.leading, modsAndName.center, NoLayout) + else + modsAndName + + l ++ modsAndName2 ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r } } + /** + * In case a definition like `def f = {0}` contains a single expression in + * braces, we need to find the braces manually because they are no part of + * the tree. + */ private def keepOpeningBrace(tree: Tree, tpt: Tree, rhs: Tree): String = tpt match { case tpt: TypeTree if tpt.original != null && tree.pos != NoPosition && rhs.pos != NoPosition => val OpeningBrace = "(?s).*(\\{.*)".r diff --git a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala index 9a1be52f..3b3fe2dd 100644 --- a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala +++ b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala @@ -179,6 +179,31 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { d.copy(tpt = newTpt) replaces d }}} + @Test + def add_Unit_return_type_to_def_with_single_expression_in_braces() = """ + package add_Unit_return_type_to_def_with_single_expression_in_braces + object X { + def foo { + println + } + def bar {} + def baz = () + } + """ becomes """ + package add_Unit_return_type_to_def_with_single_expression_in_braces + object X { + def foo: Unit = { + println + } + def bar: Unit = {} + def baz: Unit = () + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + @Test def add_override_flag() = """ package add_override_flag From afafa18c82473468c902ad063291190ab0133b84 Mon Sep 17 00:00:00 2001 From: Simon Schaefer Date: Fri, 26 Sep 2014 12:45:26 +0200 Subject: [PATCH 4/6] Add space before return type of method that ends with special sign Fixes #1002267 --- .../sourcegen/ReusingPrinter.scala | 9 +++++++-- .../tests/sourcegen/ReusingPrinterTest.scala | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala index 736f4a73..553cc5c2 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala @@ -1005,8 +1005,13 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { val noEqualNeeded = resultType == EmptyFragment || isAbstract + val resultType2 = { + def addLeadingSpace = name.isOperatorName || name.endsWith('_') + if (resultType != EmptyFragment && addLeadingSpace) Layout(" ") ++ resultType else resultType + } + if (noEqualNeeded && !hasEqualInSource) { - l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ body ++ r + l ++ modsAndName ++ typeParameters ++ parameters ++ resultType2 ++ body ++ r } else { val openingBrace = keepOpeningBrace(tree, tpt, rhs) // In case a Unit return type is added to a method like `def f {}`, we @@ -1018,7 +1023,7 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { else modsAndName - l ++ modsAndName2 ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r + l ++ modsAndName2 ++ typeParameters ++ parameters ++ resultType2 ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r } } diff --git a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala index 3b3fe2dd..2cb40569 100644 --- a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala +++ b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala @@ -204,6 +204,25 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { d.copy(tpt = newTpt) replaces d }}} + @Test + def add_space_before_return_type_of_def_when_it_ends_with_special_sign() = """ + package add_space_before_return_type_of_def_when_it_ends_with_special_sign + object X { + def foo_ = 0 + def ++ = 0 + } + """ becomes """ + package add_space_before_return_type_of_def_when_it_ends_with_special_sign + object X { + def foo_ : Int = 0 + def ++ : Int = 0 + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + @Test def add_override_flag() = """ package add_override_flag From 4f5dbd893cd0b1d49c075affd6ffc8e695232b16 Mon Sep 17 00:00:00 2001 From: Simon Schaefer Date: Fri, 26 Sep 2014 16:36:51 +0200 Subject: [PATCH 5/6] Print type keyword when the type of an object should be printed I would prefer to fix this somewhere in the printing logic and not in the construction logic of the tree but I didn't find a way to differentiate the Ident trees correctly in the printer. For object X { def x = 0 } the Ident trees of `X.type` and `X.x` seem to have the same shape at the beginning. Fixes #1002233 --- .../transformation/TreeFactory.scala | 6 ++++-- .../tests/sourcegen/ReusingPrinterTest.scala | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/transformation/TreeFactory.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/transformation/TreeFactory.scala index a39de45a..244b7e68 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/transformation/TreeFactory.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/transformation/TreeFactory.scala @@ -64,7 +64,9 @@ trait TreeFactory { def mkReturn(s: List[Symbol]): Tree = s match { case Nil => EmptyTree - case x :: Nil => Ident(x) setType x.tpe + case x :: Nil => + val ident = if (x.isModuleClass) Ident(newTermName(s"${x.name}.type")) else Ident(x) + ident setType x.tpe case xs => typer.typed(gen.mkTuple(xs map (s => Ident(s) setType s.tpe))) match { case t: Apply => @@ -114,7 +116,7 @@ trait TreeFactory { } if (mods != NoMods) valOrVarDef setSymbol NoSymbol.newValue(name, newFlags = mods.flags) else valOrVarDef - } + } def mkParam(name: String, tpe: Type, defaultVal: Tree = EmptyTree): ValDef = { ValDef(Modifiers(Flags.PARAM), newTermName(name), TypeTree(tpe), defaultVal) diff --git a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala index 2cb40569..3756a14e 100644 --- a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala +++ b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala @@ -223,6 +223,23 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { d.copy(tpt = newTpt) replaces d }}} + @Test + def add_type_keyword_to_return_type_when_it_represents_an_object() = """ + package add_type_keyword_to_return_type_when_it_represents_an_object + object X { + def o = X + } + """ becomes """ + package add_type_keyword_to_return_type_when_it_represents_an_object + object X { + def o: X.type = X + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + @Test def add_override_flag() = """ package add_override_flag From 98c0a5549d51bf6bc9f9110ad768835bd50ad3e7 Mon Sep 17 00:00:00 2001 From: Simon Schaefer Date: Mon, 20 Oct 2014 21:09:03 +0200 Subject: [PATCH 6/6] Add return type to ValOrDefDef with single expressions in braces Fixing this was complicated because the compiler removes braces from the tree when the code would compile without them. Therefore we manually have to restore them. To get the braces (and whitespace) back, we search for an equal sign in the area of the DefDef (which indicates the start of the rhs). Once it is found, all contents between the equal sign and the start and the expression of the rhs are put back into the tree. Because an equal sign can also occur in a comment, we have to parse the region instead of simply looking for the equal sign. This adds some overhead, but hopefully not too much. Furthermore, as reviewers suggested, some variable names are renamed. This also fixes a bug in the test suite where the source file is passed to the refactoring logic. --- .../common/CompilerApiExtensions.scala | 61 ++++++++++++++++- .../sourcegen/AbstractPrinter.scala | 9 ++- .../refactoring/sourcegen/PrettyPrinter.scala | 4 +- .../sourcegen/ReusingPrinter.scala | 39 ++++++----- .../sourcegen/SourceGenerator.scala | 3 +- .../tests/sourcegen/ReusingPrinterTest.scala | 66 +++++++++++++++++-- 6 files changed, 151 insertions(+), 31 deletions(-) diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala index 3beadc3a..2a5a6aea 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala @@ -1,9 +1,12 @@ package scala.tools.refactoring.common -import scala.tools.nsc.Global +import scala.collection.immutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.internal.util.SourceFile +import scala.tools.nsc.ast.parser.Tokens /* - * FIXME: This class duplicates functionality from org.scalaide.core.compiler.CompilerApiExtensions. + * FIXME: This class duplicates functionality from [[org.scalaide.core.compiler.CompilerApiExtensions]]. */ trait CompilerApiExtensions { this: CompilerAccess => @@ -55,4 +58,58 @@ trait CompilerApiExtensions { } } } + + /** A helper class to access the lexical tokens of `source`. + * + * Once constructed, instances of this class are thread-safe. + */ + class LexicalStructure(source: SourceFile) { + private val token = new ArrayBuffer[Int] + private val startOffset = new ArrayBuffer[Int] + private val endOffset = new ArrayBuffer[Int] + private val scanner = new syntaxAnalyzer.UnitScanner(new CompilationUnit(source)) + scanner.init() + + while (scanner.token != Tokens.EOF) { + startOffset += scanner.offset + token += scanner.token + scanner.nextToken + endOffset += scanner.lastOffset + } + + /** Return the index of the token that covers `offset`. + */ + private def locateIndex(offset: Int): Int = { + var lo = 0 + var hi = token.length - 1 + while (lo < hi) { + val mid = (lo + hi + 1) / 2 + if (startOffset(mid) <= offset) lo = mid + else hi = mid - 1 + } + lo + } + + /** Return all tokens between start and end offsets. + * + * The first token may start before `start` and the last token may span after `end`. + */ + def tokensBetween(start: Int, end: Int): immutable.Seq[Token] = { + val startIndex = locateIndex(start) + val endIndex = locateIndex(end) + + val tmp = for (i <- startIndex to endIndex) + yield Token(token(i), startOffset(i), endOffset(i)) + + tmp.toSeq + } + } + + /** A Scala token covering [start, end) + * + * @param tokenId one of scala.tools.nsc.ast.parser.Tokens identifiers + * @param start the offset of the first character in this token + * @param end the offset of the first character after this token + */ + case class Token(tokenId: Int, start: Int, end: Int) } diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala index 11052248..32e5258a 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala @@ -9,7 +9,7 @@ import scala.reflect.internal.util.SourceFile trait AbstractPrinter extends CommonPrintUtils { - this: common.Tracing with common.PimpedTrees with Indentations with common.CompilerAccess with Formatting => + this: common.Tracing with common.PimpedTrees with Indentations with common.CompilerAccess with common.CompilerApiExtensions with Formatting => import global._ @@ -18,6 +18,11 @@ trait AbstractPrinter extends CommonPrintUtils { * the context or environment for the current printing. */ case class PrintingContext(ind: Indentation, changeSet: ChangeSet, parent: Tree, file: Option[SourceFile]) { + private lazy val lexical = file map (new LexicalStructure(_)) + + def tokensBetween(start: Int, end: Int): Seq[Token] = + lexical.map(_.tokensBetween(start, end)).getOrElse(Seq()) + lazy val newline: String = { if(file.exists(_.content.containsSlice("\r\n"))) "\r\n" @@ -32,4 +37,4 @@ trait AbstractPrinter extends CommonPrintUtils { def print(t: Tree, ctx: PrintingContext): Fragment -} \ No newline at end of file +} diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala index ae82c698..7ca80d92 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala @@ -13,7 +13,7 @@ import language.implicitConversions trait PrettyPrinter extends TreePrintingTraversals with AbstractPrinter { - outer: common.PimpedTrees with common.CompilerAccess with common.Tracing with Indentations with LayoutHelper with Formatting => + outer: common.PimpedTrees with common.CompilerAccess with common.Tracing with common.CompilerApiExtensions with Indentations with LayoutHelper with Formatting => import global._ @@ -130,7 +130,7 @@ trait PrettyPrinter extends TreePrintingTraversals with AbstractPrinter { case Some(patP(patStr)) if guard == EmptyTree => Fragment(patStr) case _ => p(pat) } - + val arrowReq = new Requisite { def isRequired(l: Layout, r: Layout) = { !(l.contains("=>") || r.contains("=>")) diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala index 553cc5c2..ff7d75e7 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala @@ -11,7 +11,7 @@ import language.implicitConversions trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { - outer: LayoutHelper with common.Tracing with common.PimpedTrees with common.CompilerAccess with Formatting with Indentations => + outer: LayoutHelper with common.Tracing with common.PimpedTrees with common.CompilerAccess with common.CompilerApiExtensions with Formatting with Indentations => import global._ @@ -988,7 +988,7 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { } val isAbstract = body == EmptyFragment - val resultType = + val rawResultType = if (isAbstract && !existsTptInFile) EmptyFragment else @@ -1003,42 +1003,47 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { } } - val noEqualNeeded = resultType == EmptyFragment || isAbstract + val noEqualNeeded = rawResultType == EmptyFragment || isAbstract - val resultType2 = { + val resultType = { def addLeadingSpace = name.isOperatorName || name.endsWith('_') - if (resultType != EmptyFragment && addLeadingSpace) Layout(" ") ++ resultType else resultType + if (rawResultType != EmptyFragment && addLeadingSpace) Layout(" ") ++ rawResultType else rawResultType } if (noEqualNeeded && !hasEqualInSource) { - l ++ modsAndName ++ typeParameters ++ parameters ++ resultType2 ++ body ++ r + l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ body ++ r } else { val openingBrace = keepOpeningBrace(tree, tpt, rhs) // In case a Unit return type is added to a method like `def f {}`, we // need to remove the whitespace between name and rhs, otherwise the // result would be `def f : Unit = {}`. - val modsAndName2 = + val modsAndNameTrimmed = if (modsAndName.trailing.asText.trim.isEmpty) Fragment(modsAndName.leading, modsAndName.center, NoLayout) else modsAndName - l ++ modsAndName2 ++ typeParameters ++ parameters ++ resultType2 ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r + l ++ modsAndNameTrimmed ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r } } /** - * In case a definition like `def f = {0}` contains a single expression in - * braces, we need to find the braces manually because they are no part of - * the tree. + * In case a `ValOrDefDef` like `def f = {0}` contains a single expression + * in braces, we are screwed. + * In such cases the compiler removes the opening and closing braces from + * the tree. It also removes all the whitespace between the equal sign and + * the opening brace. If a refactoring edits such a definition, we need to + * get the whitespace+opening brace back. */ - private def keepOpeningBrace(tree: Tree, tpt: Tree, rhs: Tree): String = tpt match { + private def keepOpeningBrace(tree: Tree, tpt: Tree, rhs: Tree)(implicit ctx: PrintingContext): String = tpt match { case tpt: TypeTree if tpt.original != null && tree.pos != NoPosition && rhs.pos != NoPosition => - val OpeningBrace = "(?s).*(\\{.*)".r - Layout(tree.pos.source, tree.pos.point, rhs.pos.start).asText match { - case OpeningBrace(brace) => brace - case _ => "" - } + val tokens = ctx.tokensBetween(tree.pos.point, rhs.pos.start) + tokens find (_.tokenId == tools.nsc.ast.parser.Tokens.EQUALS) map { + case Token(_, start, _) => + val c = tree.pos.source.content + val skippedWsStart = if (c(start+1).isWhitespace) start+2 else start+1 + c.slice(skippedWsStart, rhs.pos.start).mkString + } getOrElse "" case _ => "" } diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala index 675b798f..689eeead 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala @@ -8,10 +8,11 @@ package sourcegen import common.Tracing import common.Change import common.PimpedTrees +import common.CompilerApiExtensions import scala.tools.refactoring.common.TextChange import scala.reflect.internal.util.SourceFile -trait SourceGenerator extends PrettyPrinter with Indentations with ReusingPrinter with PimpedTrees with LayoutHelper with Formatting with TreeChangesDiscoverer { +trait SourceGenerator extends PrettyPrinter with Indentations with ReusingPrinter with PimpedTrees with CompilerApiExtensions with LayoutHelper with Formatting with TreeChangesDiscoverer { self: Tracing with common.CompilerAccess => diff --git a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala index 3756a14e..5ccbb361 100644 --- a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala +++ b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala @@ -25,8 +25,8 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { final implicit class ImplicitTreeHelper(original: Tree) { /** Needs to be executed on the PC thread. */ - def printsTo(expectedOutput: String): Unit = { - val sourceFile = new BatchSourceFile("noname", expectedOutput) + def printsTo(input: String, expectedOutput: String): Unit = { + val sourceFile = new BatchSourceFile("textInput", input) val expected = stripWhitespacePreservers(expectedOutput).trim() val actual = generate(original, sourceFile = Some(sourceFile)).asText.trim() if (actual != expected) @@ -40,7 +40,7 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { def after(trans: Transformation[Tree, Tree]): Unit = ask { () => val t = trans(treeFrom(input._1)) require(t.isDefined, "transformation was not successful") - t foreach (_.printsTo(input._2)) + t foreach (_.printsTo(input._1, input._2)) } } @@ -95,16 +95,42 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { def add_return_type_to_val_with_single_expression_in_braces() = """ package add_return_type_to_val_with_single_expression_in_braces object X { - val foo = { + val a = { + 0 + } + val b = /* {str */ { + 0 + } + val c = { 0 match { + case i => i + }} + val d = { // {str} + 0 + } + val e = /* {str */ { // {str 0 } + val f={0} } """ becomes """ package add_return_type_to_val_with_single_expression_in_braces object X { - val foo: Int = { + val a: Int = { + 0 + } + val b: Int = /* {str */ { + 0 + } + val c: Int = { 0 match { + case i => i + }} + val d: Int = { // {str} + 0 + } + val e: Int = /* {str */ { // {str 0 } + val f: Int = {0} } """ after topdown { matchingChildren { transform { case d @ ValDef(_, _, tpt: TypeTree, _) => @@ -139,16 +165,42 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { def add_return_type_to_def_with_single_expression_in_braces() = """ package add_return_type_to_def_with_single_expression_in_braces object X { - def foo = { + def a = { + 0 + } + def b = /* {str */ { + 0 + } + def c = { 0 match { + case i => i + }} + def d = { // {str} + 0 + } + def e = /* {str */ { // {str 0 } + def f={0} } """ becomes """ package add_return_type_to_def_with_single_expression_in_braces object X { - def foo: Int = { + def a: Int = { + 0 + } + def b: Int = /* {str */ { + 0 + } + def c: Int = { 0 match { + case i => i + }} + def d: Int = { // {str} + 0 + } + def e: Int = /* {str */ { // {str 0 } + def f: Int = {0} } """ after topdown { matchingChildren { transform { case d @ DefDef(_, _, _, _, tpt: TypeTree, _) =>