From c11403d3d1241f52cc111209c9bc9ab72354e29f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Quenaudon?= Date: Thu, 16 Apr 2026 10:29:02 +0100 Subject: [PATCH] Avoid race in duplex pipe for streaming calls --- .../kotlin/com/squareup/wire/GrpcClient.kt | 4 +- .../wire/internal/BlockingMessageSource.kt | 4 +- .../internal/RealGrpcServerStreamingCall.kt | 110 +++++++++++++++++- .../wire/internal/RealGrpcStreamingCall.kt | 13 ++- .../kotlin/com/squareup/wire/internal/grpc.kt | 11 +- .../squareup/wire/GrpcOnMockWebServerTest.kt | 40 +++++++ 6 files changed, 172 insertions(+), 10 deletions(-) diff --git a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/GrpcClient.kt b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/GrpcClient.kt index 31c28a6d1b..3b2dcc02f3 100644 --- a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/GrpcClient.kt +++ b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/GrpcClient.kt @@ -16,9 +16,9 @@ package com.squareup.wire import com.squareup.wire.internal.RealGrpcCall +import com.squareup.wire.internal.RealGrpcServerStreamingCall import com.squareup.wire.internal.RealGrpcStreamingCall import com.squareup.wire.internal.asGrpcClientStreamingCall -import com.squareup.wire.internal.asGrpcServerStreamingCall import java.util.concurrent.TimeUnit import kotlin.reflect.KClass import okhttp3.Call @@ -194,6 +194,6 @@ internal class WireGrpcClient internal constructor( } override fun newServerStreamingCall(method: GrpcMethod): GrpcServerStreamingCall { - return RealGrpcStreamingCall(this, method).asGrpcServerStreamingCall() + return RealGrpcServerStreamingCall(this, method) } } diff --git a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/BlockingMessageSource.kt b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/BlockingMessageSource.kt index 6c9e81606a..67df3332d7 100644 --- a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/BlockingMessageSource.kt +++ b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/BlockingMessageSource.kt @@ -33,7 +33,7 @@ import okio.IOException * * Complete: enqueued when the stream completes normally. */ internal class BlockingMessageSource( - val grpcCall: RealGrpcStreamingCall<*, R>, + val onResponseMetadata: (Map) -> Unit, val responseAdapter: ProtoAdapter, val call: Call, ) : MessageSource { @@ -67,7 +67,7 @@ internal class BlockingMessageSource( override fun onResponse(call: Call, response: Response) { try { - grpcCall.responseMetadata = response.headers.toMap() + onResponseMetadata(response.headers.toMap()) response.use { response.messageSource(responseAdapter).use { reader -> while (true) { diff --git a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcServerStreamingCall.kt b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcServerStreamingCall.kt index b74aace69c..eb12aa6632 100644 --- a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcServerStreamingCall.kt +++ b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcServerStreamingCall.kt @@ -19,11 +19,115 @@ import com.squareup.wire.GrpcMethod import com.squareup.wire.GrpcServerStreamingCall import com.squareup.wire.GrpcStreamingCall import com.squareup.wire.MessageSource +import com.squareup.wire.WireGrpcClient +import java.util.concurrent.TimeUnit import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel +import okio.ForwardingTimeout import okio.Timeout +/** + * A [GrpcServerStreamingCall] that sends a single non-duplex request and reads a streaming + * response. Using a non-duplex request body ensures the complete request (including END_STREAM) is + * sent to the server before responses are read, avoiding delays on servers that wait for the + * client's half-close before starting to stream responses. + */ internal class RealGrpcServerStreamingCall( + private val grpcClient: WireGrpcClient, + override val method: GrpcMethod, +) : GrpcServerStreamingCall { + + private var call: okhttp3.Call? = null + private var canceled = false + + override val timeout: Timeout = ForwardingTimeout(Timeout()) + + init { + timeout.clearTimeout() + timeout.clearDeadline() + } + + override var requestMetadata: Map = mapOf() + + override var responseMetadata: Map? = null + internal set + + override fun cancel() { + canceled = true + call?.cancel() + } + + override fun isCanceled(): Boolean = canceled || call?.isCanceled() == true + + override fun isExecuted(): Boolean = call?.isExecuted() ?: false + + override fun clone(): GrpcServerStreamingCall { + val result = RealGrpcServerStreamingCall(grpcClient, method) + val oldTimeout = this.timeout + result.timeout.also { newTimeout -> + newTimeout.timeout(oldTimeout.timeoutNanos(), TimeUnit.NANOSECONDS) + if (oldTimeout.hasDeadline()) { + newTimeout.deadlineNanoTime(oldTimeout.deadlineNanoTime()) + } else { + newTimeout.clearDeadline() + } + } + result.requestMetadata += this.requestMetadata + return result + } + + override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel { + val responseChannel = Channel(1) + val call = initCall(request) + + responseChannel.invokeOnClose { + if (responseChannel.isClosedForReceive) { + call.cancel() + } + } + + call.enqueue( + responseChannel.readFromResponseBodyCallback( + onResponseMetadata = { this.responseMetadata = it }, + responseAdapter = method.responseAdapter, + ), + ) + + return responseChannel + } + + override fun executeBlocking(request: S): MessageSource { + val call = initCall(request) + val messageSource = BlockingMessageSource( + onResponseMetadata = { this.responseMetadata = it }, + responseAdapter = method.responseAdapter, + call = call, + ) + call.enqueue(messageSource.readFromResponseBodyCallback()) + return messageSource + } + + private fun initCall(request: S): okhttp3.Call { + check(this.call == null) { "already executed" } + val requestBody = newRequestBody( + minMessageToCompress = grpcClient.minMessageToCompress, + requestAdapter = method.requestAdapter, + onlyMessage = request, + ) + val result = grpcClient.newCall(method, requestMetadata, requestBody, timeout) + this.call = result + if (canceled) result.cancel() + (timeout as ForwardingTimeout).setDelegate(result.timeout()) + return result + } +} + +/** + * Wraps a [GrpcStreamingCall] as a [GrpcServerStreamingCall]. Used for test doubles created via + * [com.squareup.wire.GrpcServerStreamingCall] factory functions in GrpcCalls. + */ +internal class GrpcStreamingCallServerStreamingAdapter( private val callDelegate: GrpcStreamingCall, override val method: GrpcMethod, ) : GrpcServerStreamingCall { @@ -46,7 +150,7 @@ internal class RealGrpcServerStreamingCall( override fun isExecuted() = callDelegate.isExecuted() - override fun clone() = RealGrpcServerStreamingCall(callDelegate.clone(), method) + override fun clone() = GrpcStreamingCallServerStreamingAdapter(callDelegate.clone(), method) override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel { val (sendChannel, receiveChannel) = callDelegate.executeIn(scope) @@ -65,5 +169,5 @@ internal class RealGrpcServerStreamingCall( } } -internal fun GrpcStreamingCall.asGrpcServerStreamingCall() = - RealGrpcServerStreamingCall(this, method) +internal fun GrpcStreamingCall.asGrpcServerStreamingCall() = + GrpcStreamingCallServerStreamingAdapter(this, method) diff --git a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcStreamingCall.kt b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcStreamingCall.kt index ea7ca3b504..108a00175f 100644 --- a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcStreamingCall.kt +++ b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcStreamingCall.kt @@ -90,14 +90,23 @@ internal class RealGrpcStreamingCall( callForCancel = call, ) } - call.enqueue(responseChannel.readFromResponseBodyCallback(this, method.responseAdapter)) + call.enqueue( + responseChannel.readFromResponseBodyCallback( + onResponseMetadata = { this.responseMetadata = it }, + responseAdapter = method.responseAdapter, + ), + ) return requestChannel to responseChannel } override fun executeBlocking(): Pair, MessageSource> { val call = initCall() - val messageSource = BlockingMessageSource(this, method.responseAdapter, call) + val messageSource = BlockingMessageSource( + onResponseMetadata = { this.responseMetadata = it }, + responseAdapter = method.responseAdapter, + call = call, + ) val messageSink = requestBody.messageSink( minMessageToCompress = grpcClient.minMessageToCompress, requestAdapter = method.requestAdapter, diff --git a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/grpc.kt b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/grpc.kt index 94f69f4c62..342b5b6864 100644 --- a/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/grpc.kt +++ b/wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/grpc.kt @@ -87,6 +87,15 @@ internal fun PipeDuplexRequestBody.messageSink( internal fun SendChannel.readFromResponseBodyCallback( grpcCall: RealGrpcStreamingCall<*, R>, responseAdapter: ProtoAdapter, +): Callback = readFromResponseBodyCallback( + onResponseMetadata = { grpcCall.responseMetadata = it }, + responseAdapter = responseAdapter, +) + +/** Sends the response messages to the channel. */ +internal fun SendChannel.readFromResponseBodyCallback( + onResponseMetadata: (Map) -> Unit, + responseAdapter: ProtoAdapter, ): Callback { return object : Callback { override fun onFailure(call: Call, e: IOException) { @@ -95,7 +104,7 @@ internal fun SendChannel.readFromResponseBodyCallback( } override fun onResponse(call: Call, response: Response) { - grpcCall.responseMetadata = response.headers.toMap() + onResponseMetadata(response.headers.toMap()) runBlocking { response.use { val messageSource = try { diff --git a/wire-grpc-tests/src/test/java/com/squareup/wire/GrpcOnMockWebServerTest.kt b/wire-grpc-tests/src/test/java/com/squareup/wire/GrpcOnMockWebServerTest.kt index 3578ff85c0..eae31e4ef5 100644 --- a/wire-grpc-tests/src/test/java/com/squareup/wire/GrpcOnMockWebServerTest.kt +++ b/wire-grpc-tests/src/test/java/com/squareup/wire/GrpcOnMockWebServerTest.kt @@ -18,6 +18,7 @@ package com.squareup.wire import assertk.assertThat import assertk.assertions.containsExactly import assertk.assertions.isEqualTo +import assertk.assertions.isNull import com.squareup.wire.mockwebserver.GrpcDispatcher import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference @@ -25,10 +26,13 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.ObsoleteCoroutinesApi import kotlinx.coroutines.runBlocking import okhttp3.Call +import okhttp3.Headers.Companion.headersOf import okhttp3.Interceptor import okhttp3.OkHttpClient import okhttp3.Protocol +import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockWebServer +import okio.Buffer import org.junit.Before import org.junit.Rule import org.junit.Test @@ -82,6 +86,42 @@ class GrpcOnMockWebServerTest { routeGuideService = grpcClient.create(RouteGuideClient::class) } + @Test + fun serverStreamingListFeatures() { + // MockWebServer only dispatches after receiving the complete request body including END_STREAM, + // mimicking some server behaviors that would cause hanging until timeout when + // GrpcServerStreamingCall used a duplex request body. + val responseBody = Buffer() + for (feature in listOf(Feature(name = "peak"), Feature(name = "valley"))) { + val encoded = Feature.ADAPTER.encodeByteString(feature) + responseBody.writeByte(0) // not compressed + responseBody.writeInt(encoded.size) + responseBody.write(encoded) + } + val grpcDispatcher = mockWebServer.dispatcher + mockWebServer.dispatcher = object : okhttp3.mockwebserver.Dispatcher() { + override fun dispatch(request: okhttp3.mockwebserver.RecordedRequest): MockResponse { + if (request.path == "/routeguide.RouteGuide/ListFeatures") { + return MockResponse() + .setHeader("Content-Type", "application/grpc") + .setTrailers(headersOf("grpc-status", "0")) + .setBody(responseBody) + } + return grpcDispatcher.dispatch(request) + } + } + + runBlocking { + val responses = routeGuideService.ListFeatures().executeIn( + this, + Rectangle(lo = Point(latitude = 1, longitude = 2), hi = Point(latitude = 3, longitude = 4)), + ) + assertThat(responses.receive()).isEqualTo(Feature(name = "peak")) + assertThat(responses.receive()).isEqualTo(Feature(name = "valley")) + assertThat(responses.receiveCatching().getOrNull()).isNull() + } + } + @Test fun requestResponseSuspend() { runBlocking {