Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
package com.squareup.wire

import com.squareup.wire.internal.RealGrpcCall
import com.squareup.wire.internal.RealGrpcServerStreamingCall
import com.squareup.wire.internal.RealGrpcStreamingCall
import com.squareup.wire.internal.asGrpcClientStreamingCall
import com.squareup.wire.internal.asGrpcServerStreamingCall
import java.util.concurrent.TimeUnit
import kotlin.reflect.KClass
import okhttp3.Call
Expand Down Expand Up @@ -194,6 +194,6 @@ internal class WireGrpcClient internal constructor(
}

override fun <S : Any, R : Any> newServerStreamingCall(method: GrpcMethod<S, R>): GrpcServerStreamingCall<S, R> {
return RealGrpcStreamingCall(this, method).asGrpcServerStreamingCall()
return RealGrpcServerStreamingCall(this, method)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import okio.IOException
* * Complete: enqueued when the stream completes normally.
*/
internal class BlockingMessageSource<R : Any>(
val grpcCall: RealGrpcStreamingCall<*, R>,
val onResponseMetadata: (Map<String, String>) -> Unit,
val responseAdapter: ProtoAdapter<R>,
val call: Call,
) : MessageSource<R> {
Expand Down Expand Up @@ -67,7 +67,7 @@ internal class BlockingMessageSource<R : Any>(

override fun onResponse(call: Call, response: Response) {
try {
grpcCall.responseMetadata = response.headers.toMap()
onResponseMetadata(response.headers.toMap())
response.use {
response.messageSource(responseAdapter).use { reader ->
while (true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,115 @@ import com.squareup.wire.GrpcMethod
import com.squareup.wire.GrpcServerStreamingCall
import com.squareup.wire.GrpcStreamingCall
import com.squareup.wire.MessageSource
import com.squareup.wire.WireGrpcClient
import java.util.concurrent.TimeUnit
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import okio.ForwardingTimeout
import okio.Timeout

/**
* A [GrpcServerStreamingCall] that sends a single non-duplex request and reads a streaming
* response. Using a non-duplex request body ensures the complete request (including END_STREAM) is
* sent to the server before responses are read, avoiding delays on servers that wait for the
* client's half-close before starting to stream responses.
*/
internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
private val grpcClient: WireGrpcClient,
override val method: GrpcMethod<S, R>,
) : GrpcServerStreamingCall<S, R> {

private var call: okhttp3.Call? = null
private var canceled = false

override val timeout: Timeout = ForwardingTimeout(Timeout())

init {
timeout.clearTimeout()
timeout.clearDeadline()
}

override var requestMetadata: Map<String, String> = mapOf()

override var responseMetadata: Map<String, String>? = null
internal set

override fun cancel() {
canceled = true
call?.cancel()
}

override fun isCanceled(): Boolean = canceled || call?.isCanceled() == true

override fun isExecuted(): Boolean = call?.isExecuted() ?: false

override fun clone(): GrpcServerStreamingCall<S, R> {
val result = RealGrpcServerStreamingCall(grpcClient, method)
val oldTimeout = this.timeout
result.timeout.also { newTimeout ->
newTimeout.timeout(oldTimeout.timeoutNanos(), TimeUnit.NANOSECONDS)
if (oldTimeout.hasDeadline()) {
newTimeout.deadlineNanoTime(oldTimeout.deadlineNanoTime())
} else {
newTimeout.clearDeadline()
}
}
result.requestMetadata += this.requestMetadata
return result
}

override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel<R> {
val responseChannel = Channel<R>(1)
val call = initCall(request)

responseChannel.invokeOnClose {
if (responseChannel.isClosedForReceive) {
call.cancel()
}
}

call.enqueue(
responseChannel.readFromResponseBodyCallback(
onResponseMetadata = { this.responseMetadata = it },
responseAdapter = method.responseAdapter,
),
)

return responseChannel
}

override fun executeBlocking(request: S): MessageSource<R> {
val call = initCall(request)
val messageSource = BlockingMessageSource(
onResponseMetadata = { this.responseMetadata = it },
responseAdapter = method.responseAdapter,
call = call,
)
call.enqueue(messageSource.readFromResponseBodyCallback())
return messageSource
}

private fun initCall(request: S): okhttp3.Call {
check(this.call == null) { "already executed" }
val requestBody = newRequestBody(
minMessageToCompress = grpcClient.minMessageToCompress,
requestAdapter = method.requestAdapter,
onlyMessage = request,
)
val result = grpcClient.newCall(method, requestMetadata, requestBody, timeout)
this.call = result
if (canceled) result.cancel()
(timeout as ForwardingTimeout).setDelegate(result.timeout())
return result
}
}

/**
* Wraps a [GrpcStreamingCall] as a [GrpcServerStreamingCall]. Used for test doubles created via
* [com.squareup.wire.GrpcServerStreamingCall] factory functions in GrpcCalls.
*/
internal class GrpcStreamingCallServerStreamingAdapter<S : Any, R : Any>(
private val callDelegate: GrpcStreamingCall<S, R>,
override val method: GrpcMethod<S, R>,
) : GrpcServerStreamingCall<S, R> {
Expand All @@ -46,7 +150,7 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(

override fun isExecuted() = callDelegate.isExecuted()

override fun clone() = RealGrpcServerStreamingCall(callDelegate.clone(), method)
override fun clone() = GrpcStreamingCallServerStreamingAdapter(callDelegate.clone(), method)

override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel<R> {
val (sendChannel, receiveChannel) = callDelegate.executeIn(scope)
Expand All @@ -65,5 +169,5 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
}
}

internal fun <S : Any, R : Any>GrpcStreamingCall<S, R>.asGrpcServerStreamingCall() =
RealGrpcServerStreamingCall(this, method)
internal fun <S : Any, R : Any> GrpcStreamingCall<S, R>.asGrpcServerStreamingCall() =
GrpcStreamingCallServerStreamingAdapter(this, method)
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,23 @@ internal class RealGrpcStreamingCall<S : Any, R : Any>(
callForCancel = call,
)
}
call.enqueue(responseChannel.readFromResponseBodyCallback(this, method.responseAdapter))
call.enqueue(
responseChannel.readFromResponseBodyCallback(
onResponseMetadata = { this.responseMetadata = it },
responseAdapter = method.responseAdapter,
),
)

return requestChannel to responseChannel
}

override fun executeBlocking(): Pair<MessageSink<S>, MessageSource<R>> {
val call = initCall()
val messageSource = BlockingMessageSource(this, method.responseAdapter, call)
val messageSource = BlockingMessageSource(
onResponseMetadata = { this.responseMetadata = it },
responseAdapter = method.responseAdapter,
call = call,
)
val messageSink = requestBody.messageSink(
minMessageToCompress = grpcClient.minMessageToCompress,
requestAdapter = method.requestAdapter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ internal fun <S : Any> PipeDuplexRequestBody.messageSink(
internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
grpcCall: RealGrpcStreamingCall<*, R>,
responseAdapter: ProtoAdapter<R>,
): Callback = readFromResponseBodyCallback(
onResponseMetadata = { grpcCall.responseMetadata = it },
responseAdapter = responseAdapter,
)

/** Sends the response messages to the channel. */
internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
onResponseMetadata: (Map<String, String>) -> Unit,
responseAdapter: ProtoAdapter<R>,
): Callback {
return object : Callback {
override fun onFailure(call: Call, e: IOException) {
Expand All @@ -95,7 +104,7 @@ internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
}

override fun onResponse(call: Call, response: Response) {
grpcCall.responseMetadata = response.headers.toMap()
onResponseMetadata(response.headers.toMap())
runBlocking {
response.use {
val messageSource = try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@ package com.squareup.wire
import assertk.assertThat
import assertk.assertions.containsExactly
import assertk.assertions.isEqualTo
import assertk.assertions.isNull
import com.squareup.wire.mockwebserver.GrpcDispatcher
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.ObsoleteCoroutinesApi
import kotlinx.coroutines.runBlocking
import okhttp3.Call
import okhttp3.Headers.Companion.headersOf
import okhttp3.Interceptor
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
import okio.Buffer
import org.junit.Before
import org.junit.Rule
import org.junit.Test
Expand Down Expand Up @@ -82,6 +86,42 @@ class GrpcOnMockWebServerTest {
routeGuideService = grpcClient.create(RouteGuideClient::class)
}

@Test
fun serverStreamingListFeatures() {
// MockWebServer only dispatches after receiving the complete request body including END_STREAM,
// mimicking some server behaviors that would cause hanging until timeout when
// GrpcServerStreamingCall used a duplex request body.
val responseBody = Buffer()
for (feature in listOf(Feature(name = "peak"), Feature(name = "valley"))) {
val encoded = Feature.ADAPTER.encodeByteString(feature)
responseBody.writeByte(0) // not compressed
responseBody.writeInt(encoded.size)
responseBody.write(encoded)
}
val grpcDispatcher = mockWebServer.dispatcher
mockWebServer.dispatcher = object : okhttp3.mockwebserver.Dispatcher() {
override fun dispatch(request: okhttp3.mockwebserver.RecordedRequest): MockResponse {
if (request.path == "/routeguide.RouteGuide/ListFeatures") {
return MockResponse()
.setHeader("Content-Type", "application/grpc")
.setTrailers(headersOf("grpc-status", "0"))
.setBody(responseBody)
}
return grpcDispatcher.dispatch(request)
}
}

runBlocking {
val responses = routeGuideService.ListFeatures().executeIn(
this,
Rectangle(lo = Point(latitude = 1, longitude = 2), hi = Point(latitude = 3, longitude = 4)),
)
assertThat(responses.receive()).isEqualTo(Feature(name = "peak"))
assertThat(responses.receive()).isEqualTo(Feature(name = "valley"))
assertThat(responses.receiveCatching().getOrNull()).isNull()
}
}

@Test
fun requestResponseSuspend() {
runBlocking {
Expand Down
Loading