Skip to content

Commit 05f4d59

Browse files
committed
Mark trailing map receives body to inspect
1 parent 52b6eb2 commit 05f4d59

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,7 +2234,7 @@ object desugar {
22342234
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
22352235
case _ => false
22362236

2237-
def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName): Unit =
2237+
def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName, body: Tree): Unit =
22382238
if sourceVersion.enablesBetterFors
22392239
&& selectName == mapName
22402240
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
@@ -2245,9 +2245,8 @@ object desugar {
22452245
enums match {
22462246
case Nil if sourceVersion.enablesBetterFors => body
22472247
case (gen: GenFrom) :: Nil =>
2248-
val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2249-
markTrailingMap(aply, gen, mapName)
2250-
aply
2248+
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2249+
.tap(markTrailingMap(_, gen, mapName, body))
22512250
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
22522251
val cont = makeFor(mapName, flatMapName, rest, body)
22532252
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
@@ -2264,7 +2263,7 @@ object desugar {
22642263
if suffix.exists(_.isInstanceOf[GenFrom]) then flatMapName
22652264
else mapName
22662265
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2267-
.tap(markTrailingMap(_, gen, selectName))
2266+
.tap(markTrailingMap(_, gen, selectName, cont))
22682267
else
22692268
val (pats, rhss) = valeqs.map { case GenAlias(pat, rhs) => (pat, rhs) }.unzip
22702269
val (defpat0, id0) = makeIdPat(gen.pat)

tests/run/i24673.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
@main def Test = {
2+
def result = for {
3+
a <- Option(2)
4+
_ = if (true) {
5+
sys.error("err")
6+
}
7+
} yield a
8+
9+
try
10+
result
11+
???
12+
catch case e: RuntimeException => assert(e.getMessage == "err")
13+
}

0 commit comments

Comments
 (0)