@@ -18,72 +18,99 @@ package org.mongodb.scala.internal
1818
1919import org .mongodb .scala ._
2020
21+ import java .util .concurrent .atomic .AtomicReference
22+
23+ sealed trait State
24+ case object Init extends State
25+ case class WaitingOnChild (s : Subscription ) extends State
26+ case object LastChildNotified extends State
27+ case object LastChildResponded extends State
28+ case object Done extends State
29+ case object Error extends State
30+
2131private [scala] case class FlatMapObservable [T , S ](observable : Observable [T ], f : T => Observable [S ])
2232 extends Observable [S ] {
23-
2433 // scalastyle:off cyclomatic.complexity method.length
2534 override def subscribe (observer : Observer [_ >: S ]): Unit = {
2635 observable.subscribe(
2736 SubscriptionCheckingObserver (
2837 new Observer [T ] {
29-
30- @ volatile
31- private var outerSubscription : Option [Subscription ] = None
32- @ volatile
33- private var nestedSubscription : Option [Subscription ] = None
34- @ volatile
35- private var demand : Long = 0
36- @ volatile
37- private var onCompleteCalled : Boolean = false
38+ @ volatile private var outerSubscription : Option [Subscription ] = None
39+ @ volatile private var demand : Long = 0
40+ private val state = new AtomicReference [State ](Init )
3841
3942 override def onSubscribe (subscription : Subscription ): Unit = {
4043 val masterSub = new Subscription () {
4144 override def isUnsubscribed : Boolean = subscription.isUnsubscribed
42-
43- def request (n : Long ): Unit = {
45+ override def unsubscribe () : Unit = subscription.unsubscribe()
46+ override def request (n : Long ): Unit = {
4447 require(n > 0L , s " Number requested must be greater than zero: $n" )
4548 val localDemand = addDemand(n)
46- val (sub, num) = nestedSubscription.map((_, localDemand)).getOrElse((subscription, 1L ))
47- sub.request(num)
49+ state.get() match {
50+ case Init => subscription.request(1L )
51+ case WaitingOnChild (s) => s.request(localDemand)
52+ case _ => // noop
53+ }
4854 }
49-
50- override def unsubscribe (): Unit = subscription.unsubscribe()
5155 }
52-
5356 outerSubscription = Some (masterSub)
57+ state.set(Init )
5458 observer.onSubscribe(masterSub)
5559 }
5660
5761 override def onComplete (): Unit = {
58- if (! onCompleteCalled) {
59- onCompleteCalled = true
60- if (nestedSubscription.isEmpty) observer.onComplete()
62+ state.get() match {
63+ case Done => // ok
64+ case Error => // ok
65+ case Init if state.compareAndSet(Init , Done ) =>
66+ observer.onComplete()
67+ case w @ WaitingOnChild (_) if state.compareAndSet(w, LastChildNotified ) =>
68+ // letting the child know that we delegate onComplete call to it
69+ case LastChildNotified =>
70+ // wait for the child to do the delegated onCompleteCall
71+ case LastChildResponded if state.compareAndSet(LastChildResponded , Done ) =>
72+ observer.onComplete()
73+ case other =>
74+ // state machine is broken, let's fail
75+ // normally this won't happen
76+ throw new IllegalStateException (s " Unexpected state in FlatMapObservable `onComplete` handler: ${other}" )
6177 }
6278 }
6379
64- override def onError (throwable : Throwable ): Unit = observer.onError(throwable)
80+ override def onError (throwable : Throwable ): Unit = {
81+ observer.onError(throwable)
82+ }
6583
6684 override def onNext (tResult : T ): Unit = {
6785 f(tResult).subscribe(
6886 new Observer [S ]() {
6987 override def onError (throwable : Throwable ): Unit = {
70- nestedSubscription = None
88+ state.set( Error )
7189 observer.onError(throwable)
7290 }
7391
7492 override def onSubscribe (subscription : Subscription ): Unit = {
75- nestedSubscription = Some ( subscription)
93+ state.set( WaitingOnChild ( subscription) )
7694 if (demand > 0 ) subscription.request(demand)
7795 }
7896
7997 override def onComplete (): Unit = {
80- nestedSubscription = None
81- onCompleteCalled match {
82- case true => observer.onComplete()
83- case false if demand > 0 =>
98+ state.get() match {
99+ case Done => // no need to call parent's onComplete
100+ case Error => // no need to call parent's onComplete
101+ case LastChildNotified if state.compareAndSet(LastChildNotified , LastChildResponded ) =>
102+ // parent told us to call onComplete
103+ observer.onComplete()
104+ case _ if demand > 0 =>
105+ // otherwise we are not the last child, let's tell the parent
106+ // it's not dealing with us anymore.
107+ // Init -> * will be handled by possible later items in the stream
108+ state.set(Init )
84109 addDemand(- 1 ) // reduce demand by 1 as it will be incremented by the outerSubscription
85110 outerSubscription.foreach(_.request(1 ))
86- case false => // No more demand
111+ case _ =>
112+ // no demand
113+ state.set(Init )
87114 }
88115 }
89116
0 commit comments