Skip to content

Commit ffa31ca

Browse files
mostroverkhovrobertroeser
authored andcommitted
limit streams for request-stream and request-channel (#15)
tests for stream limiting
1 parent 897064a commit ffa31ca

File tree

6 files changed

+296
-94
lines changed

6 files changed

+296
-94
lines changed

rsocket-core/src/main/java/io/rsocket/android/RSocketClient.kt

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
package io.rsocket.android
1818

19-
import io.rsocket.android.util.ExceptionUtil.noStacktrace
20-
2119
import io.netty.buffer.Unpooled
2220
import io.netty.util.collection.IntObjectHashMap
2321
import io.reactivex.Completable
@@ -32,24 +30,34 @@ import io.reactivex.processors.UnicastProcessor
3230
import io.rsocket.android.exceptions.ConnectionException
3331
import io.rsocket.android.exceptions.Exceptions
3432
import io.rsocket.android.internal.LimitableRequestPublisher
33+
import io.rsocket.android.util.ExceptionUtil.noStacktrace
3534
import io.rsocket.android.util.PayloadImpl
36-
37-
import java.nio.channels.ClosedChannelException
38-
import java.util.concurrent.atomic.AtomicBoolean
39-
import java.util.concurrent.atomic.AtomicInteger
4035
import org.reactivestreams.Publisher
4136
import org.reactivestreams.Subscriber
37+
import java.nio.channels.ClosedChannelException
4238
import java.util.concurrent.CancellationException
4339
import java.util.concurrent.TimeUnit
40+
import java.util.concurrent.atomic.AtomicBoolean
41+
import java.util.concurrent.atomic.AtomicInteger
4442

4543
/** Client Side of a RSocket socket. Sends [Frame]s to a [RSocketServer] */
4644
internal class RSocketClient @JvmOverloads constructor(
4745
private val connection: DuplexConnection,
4846
private val errorConsumer: (Throwable) -> Unit,
4947
private val streamIdSupplier: StreamIdSupplier,
48+
private val streamDemandLimit: Int,
5049
tickPeriod: Duration = Duration.ZERO,
5150
ackTimeout: Duration = Duration.ZERO,
5251
missedAcks: Int = 0) : RSocket {
52+
53+
internal constructor(connection: DuplexConnection,
54+
errorConsumer: (Throwable) -> Unit,
55+
streamIdSupplier: StreamIdSupplier,
56+
tickPeriod: Duration = Duration.ZERO,
57+
ackTimeout: Duration = Duration.ZERO,
58+
missedAcks: Int = 0)
59+
: this(connection, errorConsumer, streamIdSupplier, DEFAULT_STREAM_WINDOW, tickPeriod, ackTimeout, missedAcks)
60+
5361
private val started: PublishProcessor<Void> = PublishProcessor.create()
5462
private val completeOnStart = started.ignoreElements()
5563
private val senders: IntObjectHashMap<LimitableRequestPublisher<*>> = IntObjectHashMap(256, 0.9f)
@@ -146,10 +154,13 @@ internal class RSocketClient @JvmOverloads constructor(
146154
handleRequestResponse(payload)
147155

148156
override fun requestStream(payload: Payload): Flowable<Payload> =
149-
handleRequestStream(payload)
157+
handleRequestStream(payload).rebatchRequests(streamDemandLimit)
150158

151159
override fun requestChannel(payloads: Publisher<Payload>): Flowable<Payload> =
152-
handleChannel(Flowable.fromPublisher(payloads), FrameType.REQUEST_CHANNEL)
160+
handleChannel(
161+
Flowable.fromPublisher(payloads).rebatchRequests(streamDemandLimit),
162+
FrameType.REQUEST_CHANNEL
163+
).rebatchRequests(streamDemandLimit)
153164

154165
override fun metadataPush(payload: Payload): Completable {
155166
val requestFrame = Frame.Request.from(
@@ -472,6 +483,7 @@ internal class RSocketClient @JvmOverloads constructor(
472483

473484
companion object {
474485
private val CLOSED_CHANNEL_EXCEPTION = noStacktrace(ClosedChannelException())
486+
private val DEFAULT_STREAM_WINDOW = 128
475487
}
476488
private fun <T> UnicastProcessor<T>.isTerminated(): Boolean = hasComplete() || hasThrowable()
477489
}

rsocket-core/src/main/java/io/rsocket/android/RSocketFactory.kt

Lines changed: 44 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -38,39 +38,23 @@ object RSocketFactory {
3838
*
3939
* @return a client factory
4040
*/
41-
fun connect(): ClientRSocketFactory {
42-
return ClientRSocketFactory()
43-
}
41+
fun connect(): ClientRSocketFactory = ClientRSocketFactory()
4442

4543
/**
4644
* Creates a factory that receives server connections from client RSockets.
4745
*
4846
* @return a server factory.
4947
*/
50-
fun receive(): ServerRSocketFactory {
51-
return ServerRSocketFactory()
52-
}
48+
fun receive(): ServerRSocketFactory = ServerRSocketFactory()
5349

5450
interface Start<T : Closeable> {
5551
fun start(): Single<T>
5652
}
5753

58-
interface SetupPayload<T> {
59-
fun setupPayload(payload: Payload): T
60-
}
61-
62-
interface Acceptor<T, A> {
63-
fun acceptor(acceptor: () -> A): T
64-
65-
fun acceptor(acceptor: A): T {
66-
return acceptor({ acceptor })
67-
}
68-
}
69-
7054
interface ClientTransportAcceptor {
7155
fun transport(transport: () -> ClientTransport): Start<RSocket>
7256

73-
fun transport(transport: ClientTransport): Start<RSocket> = transport({ transport })
57+
fun transport(transport: ClientTransport): Start<RSocket> = transport { transport }
7458

7559
}
7660

@@ -81,39 +65,7 @@ object RSocketFactory {
8165

8266
}
8367

84-
interface Fragmentation<T> {
85-
fun fragment(mtu: Int): T
86-
}
87-
88-
interface ErrorConsumer<T> {
89-
fun errorConsumer(errorConsumer: (Throwable) -> Unit): T
90-
}
91-
92-
interface KeepAlive<T> {
93-
fun keepAlive(): T
94-
95-
fun keepAlive(tickPeriod: Duration, ackTimeout: Duration, missedAcks: Int): T
96-
97-
fun keepAliveTickPeriod(tickPeriod: Duration): T
98-
99-
fun keepAliveAckTimeout(ackTimeout: Duration): T
100-
101-
fun keepAliveMissedAcks(missedAcks: Int): T
102-
}
103-
104-
interface MimeType<T> {
105-
fun mimeType(metadataMimeType: String, dataMimeType: String): T
106-
107-
fun dataMimeType(dataMimeType: String): T
108-
109-
fun metadataMimeType(metadataMimeType: String): T
110-
}
111-
112-
class ClientRSocketFactory : Acceptor<ClientTransportAcceptor, (RSocket) -> RSocket>,
113-
ClientTransportAcceptor, KeepAlive<ClientRSocketFactory>,
114-
MimeType<ClientRSocketFactory>, Fragmentation<ClientRSocketFactory>,
115-
ErrorConsumer<ClientRSocketFactory>,
116-
SetupPayload<ClientRSocketFactory> {
68+
class ClientRSocketFactory {
11769

11870
private var acceptor: () -> (RSocket) -> RSocket = { { rs -> rs } }
11971

@@ -131,6 +83,8 @@ object RSocketFactory {
13183
private var metadataMimeType = "application/binary"
13284
private var dataMimeType = "application/binary"
13385

86+
private var streamDemandLimit = 128
87+
13488
fun addConnectionPlugin(interceptor: DuplexConnectionInterceptor): ClientRSocketFactory {
13589
plugins.addConnectionPlugin(interceptor)
13690
return this
@@ -146,78 +100,82 @@ object RSocketFactory {
146100
return this
147101
}
148102

149-
override fun keepAlive(): ClientRSocketFactory {
103+
fun keepAlive(): ClientRSocketFactory {
150104
tickPeriod = Duration.ofSeconds(20)
151105
return this
152106
}
153107

154-
override fun keepAlive(
108+
fun keepAlive(
155109
tickPeriod: Duration, ackTimeout: Duration, missedAcks: Int): ClientRSocketFactory {
156110
this.tickPeriod = tickPeriod
157111
this.ackTimeout = ackTimeout
158112
this.missedAcks = missedAcks
159113
return this
160114
}
161115

162-
override fun keepAliveTickPeriod(tickPeriod: Duration): ClientRSocketFactory {
116+
fun keepAliveTickPeriod(tickPeriod: Duration): ClientRSocketFactory {
163117
this.tickPeriod = tickPeriod
164118
return this
165119
}
166120

167-
override fun keepAliveAckTimeout(ackTimeout: Duration): ClientRSocketFactory {
121+
fun keepAliveAckTimeout(ackTimeout: Duration): ClientRSocketFactory {
168122
this.ackTimeout = ackTimeout
169123
return this
170124
}
171125

172-
override fun keepAliveMissedAcks(missedAcks: Int): ClientRSocketFactory {
126+
fun keepAliveMissedAcks(missedAcks: Int): ClientRSocketFactory {
173127
this.missedAcks = missedAcks
174128
return this
175129
}
176130

177-
override fun mimeType(metadataMimeType: String, dataMimeType: String): ClientRSocketFactory {
131+
fun mimeType(metadataMimeType: String, dataMimeType: String): ClientRSocketFactory {
178132
this.dataMimeType = dataMimeType
179133
this.metadataMimeType = metadataMimeType
180134
return this
181135
}
182136

183-
override fun dataMimeType(dataMimeType: String): ClientRSocketFactory {
137+
fun dataMimeType(dataMimeType: String): ClientRSocketFactory {
184138
this.dataMimeType = dataMimeType
185139
return this
186140
}
187141

188-
override fun metadataMimeType(metadataMimeType: String): ClientRSocketFactory {
142+
fun metadataMimeType(metadataMimeType: String): ClientRSocketFactory {
189143
this.metadataMimeType = metadataMimeType
190144
return this
191145
}
192146

193-
override fun transport(transport: () -> ClientTransport): Start<RSocket> {
194-
return StartClient(transport)
195-
}
147+
fun transport(transport: () -> ClientTransport): Start<RSocket> = StartClient(transport)
196148

197-
override fun acceptor(acceptor: () -> (RSocket) -> RSocket): ClientTransportAcceptor {
149+
fun acceptor(acceptor: () -> (RSocket) -> RSocket): ClientTransportAcceptor {
198150
this.acceptor = acceptor
199151
return object : ClientTransportAcceptor {
200152
override fun transport(transport: () -> ClientTransport): Start<RSocket> = StartClient(transport)
201153

202154
}
203155
}
204156

205-
override fun fragment(mtu: Int): ClientRSocketFactory {
157+
fun fragment(mtu: Int): ClientRSocketFactory {
206158
this.mtu = mtu
207159
return this
208160
}
209161

210-
override fun errorConsumer(errorConsumer: (Throwable) -> Unit): ClientRSocketFactory {
162+
fun errorConsumer(errorConsumer: (Throwable) -> Unit): ClientRSocketFactory {
211163
this.errorConsumer = errorConsumer
212164
return this
213165
}
214166

215-
override fun setupPayload(payload: Payload): ClientRSocketFactory {
167+
fun setupPayload(payload: Payload): ClientRSocketFactory {
216168
this.setupPayload = payload
217169
return this
218170
}
219171

220-
protected inner class StartClient internal constructor(private val transportClient: () -> ClientTransport) : Start<RSocket> {
172+
fun streamDemandLimit(streamDemandLimit: Int): ClientRSocketFactory {
173+
this.streamDemandLimit = streamDemandLimit
174+
return this
175+
}
176+
177+
private inner class StartClient internal constructor(private val transportClient: () -> ClientTransport)
178+
: Start<RSocket> {
221179

222180
override fun start(): Single<RSocket> {
223181
return transportClient()
@@ -242,6 +200,7 @@ object RSocketFactory {
242200
multiplexer.asClientConnection(),
243201
errorConsumer,
244202
StreamIdSupplier.clientSupplier(),
203+
streamDemandLimit,
245204
tickPeriod,
246205
ackTimeout,
247206
missedAcks)
@@ -256,7 +215,7 @@ object RSocketFactory {
256215
wrappedRSocketServer
257216
.doAfterSuccess { rSocket ->
258217
RSocketServer(
259-
multiplexer.asServerConnection(), rSocket, errorConsumer)
218+
multiplexer.asServerConnection(), rSocket, errorConsumer, streamDemandLimit)
260219
}.flatMapCompletable { conn.sendOne(setupFrame) }
261220
.andThen (wrappedRSocketClient)
262221
}
@@ -265,14 +224,13 @@ object RSocketFactory {
265224
}
266225
}
267226

268-
class ServerRSocketFactory internal constructor() : Acceptor<ServerTransportAcceptor, SocketAcceptor>,
269-
Fragmentation<ServerRSocketFactory>,
270-
ErrorConsumer<ServerRSocketFactory> {
227+
class ServerRSocketFactory internal constructor() {
271228

272229
private var acceptor: (() -> SocketAcceptor)? = null
273230
private var errorConsumer: (Throwable) -> Unit = { it.printStackTrace() }
274231
private var mtu = 0
275232
private val plugins = PluginRegistry(Plugins.defaultPlugins())
233+
private var streamWindow = 20
276234

277235
fun addConnectionPlugin(interceptor: DuplexConnectionInterceptor): ServerRSocketFactory {
278236
plugins.addConnectionPlugin(interceptor)
@@ -289,26 +247,31 @@ object RSocketFactory {
289247
return this
290248
}
291249

292-
override fun acceptor(acceptor: () -> SocketAcceptor): ServerTransportAcceptor {
250+
fun acceptor(acceptor: () -> SocketAcceptor): ServerTransportAcceptor {
293251
this.acceptor = acceptor
294252
return object : ServerTransportAcceptor {
295-
override fun <T : Closeable> transport(transport: () -> ServerTransport<T>): Start<T> {
296-
return ServerStart(transport)
297-
}
253+
override fun <T : Closeable> transport(transport: () -> ServerTransport<T>): Start<T> =
254+
ServerStart(transport)
298255
}
299256
}
300257

301-
override fun fragment(mtu: Int): ServerRSocketFactory {
258+
fun fragment(mtu: Int): ServerRSocketFactory {
302259
this.mtu = mtu
303260
return this
304261
}
305262

306-
override fun errorConsumer(errorConsumer: (Throwable) -> Unit): ServerRSocketFactory {
263+
fun errorConsumer(errorConsumer: (Throwable) -> Unit): ServerRSocketFactory {
307264
this.errorConsumer = errorConsumer
308265
return this
309266
}
310267

311-
private inner class ServerStart<T : Closeable> internal constructor(private val transportServer: () -> ServerTransport<T>) : Start<T> {
268+
fun streamWindow(streamWindow: Int): ServerRSocketFactory {
269+
this.streamWindow = streamWindow
270+
return this
271+
}
272+
273+
private inner class ServerStart<T : Closeable> internal constructor(private val transportServer: () -> ServerTransport<T>)
274+
: Start<T> {
312275

313276
override fun start(): Single<T> {
314277
return transportServer()
@@ -345,13 +308,13 @@ object RSocketFactory {
345308
val setupPayload = ConnectionSetupPayload.create(setupFrame)
346309

347310
val rSocketClient = RSocketClient(
348-
multiplexer.asServerConnection(), errorConsumer, StreamIdSupplier.serverSupplier())
311+
multiplexer.asServerConnection(), errorConsumer, StreamIdSupplier.serverSupplier(), streamWindow)
349312

350313
val wrappedRSocketClient = Single.just(rSocketClient).map { plugins.applyClient(it) }
351314

352315
return wrappedRSocketClient
353316
.flatMap { sender -> acceptor?.let { it() }?.accept(setupPayload, sender)?.map { plugins.applyServer(it) } }
354-
.map { handler -> RSocketServer(multiplexer.asClientConnection(), handler, errorConsumer) }
317+
.map { handler -> RSocketServer(multiplexer.asClientConnection(), handler, errorConsumer,streamWindow) }
355318
.toCompletable()
356319
}
357320
}

0 commit comments

Comments
 (0)