From 60ff1427cdae689f0abee480d211b892989aaf90 Mon Sep 17 00:00:00 2001 From: MiletaA Date: Mon, 3 Nov 2025 11:47:03 +0100 Subject: [PATCH 1/3] feat(isFlow): keep Flow<...> as returnType and add returnsFlow flag to RpcCallable --- .../rpc/codegen/extension/RpcStubGenerator.kt | 15 +++++---------- .../rpc/descriptor/RpcServiceDescriptor.kt | 6 ++++++ .../rpc/descriptor/RpcServiceDescriptorDefault.kt | 1 + 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcStubGenerator.kt b/compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcStubGenerator.kt index 12b363b3c..753eb8c52 100644 --- a/compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcStubGenerator.kt +++ b/compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcStubGenerator.kt @@ -797,7 +797,7 @@ internal class RpcStubGenerator( type = ctx.rpcCallable.typeWith(declaration.serviceType), symbol = ctx.rpcCallableDefault.constructors.single(), typeArgumentsCount = 1, - valueArgumentsCount = 5, + valueArgumentsCount = 6, constructorTypeArgumentsCount = 1, ) }.apply { @@ -805,15 +805,7 @@ internal class RpcStubGenerator( callable as ServiceDeclaration.Method - val returnType = when { - callable.function.isNonSuspendingWithFlowReturn() -> { - (callable.function.returnType as IrSimpleType).arguments.single().typeOrFail - } - - else -> { - callable.function.returnType - } - } + val returnType = callable.function.returnType val invokator = invokators[callable.name] ?: error("Expected invokator for ${callable.name} in ${declaration.service.name}") @@ -889,6 +881,8 @@ internal class RpcStubGenerator( } } + val returnsFlowFlag = (callable.function.returnType.classOrNull == ctx.flow) + arguments { values { +stringConst(callable.name) @@ -900,6 +894,7 @@ internal class RpcStubGenerator( +arrayOfCall +booleanConst(!callable.function.isSuspend) + +booleanConst(returnsFlowFlag) } } } diff --git a/core/src/commonMain/kotlin/kotlinx/rpc/descriptor/RpcServiceDescriptor.kt b/core/src/commonMain/kotlin/kotlinx/rpc/descriptor/RpcServiceDescriptor.kt index 77dbe83f7..9473a52d3 100644 --- a/core/src/commonMain/kotlin/kotlinx/rpc/descriptor/RpcServiceDescriptor.kt +++ b/core/src/commonMain/kotlin/kotlinx/rpc/descriptor/RpcServiceDescriptor.kt @@ -56,6 +56,12 @@ public interface RpcCallable<@Rpc T : Any> { public val invokator: RpcInvokator public val parameters: Array public val isNonSuspendFunction: Boolean + /** + * True if the method returns Flow<...> and should be treated as a streaming return. + * The [returnType] remains the original declared KType (including Flow<...>), + * consumers can use this flag to branch logic without relying on type unwrapping. + */ + public val returnsFlow: Boolean } @ExperimentalRpcApi diff --git a/core/src/commonMain/kotlin/kotlinx/rpc/descriptor/RpcServiceDescriptorDefault.kt b/core/src/commonMain/kotlin/kotlinx/rpc/descriptor/RpcServiceDescriptorDefault.kt index f3209ba48..8348da67c 100644 --- a/core/src/commonMain/kotlin/kotlinx/rpc/descriptor/RpcServiceDescriptorDefault.kt +++ b/core/src/commonMain/kotlin/kotlinx/rpc/descriptor/RpcServiceDescriptorDefault.kt @@ -15,6 +15,7 @@ public class RpcCallableDefault<@Rpc T : Any>( override val invokator: RpcInvokator, override val parameters: Array, override val isNonSuspendFunction: Boolean, + override val returnsFlow: Boolean, ) : RpcCallable @InternalRpcApi From 8a437fba95e0c96d156d1041863048deed99ee67 Mon Sep 17 00:00:00 2001 From: MiletaA Date: Mon, 3 Nov 2025 12:05:38 +0100 Subject: [PATCH 2/3] Test for isFlow --- .../rpc/descriptor/IsFlowIntegrationTest.kt | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/descriptor/IsFlowIntegrationTest.kt diff --git a/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/descriptor/IsFlowIntegrationTest.kt b/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/descriptor/IsFlowIntegrationTest.kt new file mode 100644 index 000000000..6ef060e55 --- /dev/null +++ b/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/descriptor/IsFlowIntegrationTest.kt @@ -0,0 +1,38 @@ +package kotlinx.rpc.descriptor + +import kotlinx.coroutines.flow.Flow +import kotlinx.rpc.annotations.Rpc +import kotlin.reflect.typeOf +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +@Rpc +interface NewsServiceForFlowTest { + fun stream(): Flow + suspend fun greet(name: String): String +} + +class IsFlowIntegrationTest { + @Test + fun stream_isFlow_and_returnType_is_Flow() { + val descriptor = serviceDescriptorOf() + val stream = descriptor.callables["stream"] ?: error("stream not found") + + assertEquals(typeOf>().toString(), stream.returnType.kType.toString()) + assertTrue(stream.returnsFlow) + assertTrue(stream.isNonSuspendFunction) + } + + @Test + fun greet_notFlow_and_returnType_is_String() { + val descriptor = serviceDescriptorOf() + val greet = descriptor.callables["greet"] ?: error("greet not found") + + assertEquals(typeOf().toString(), greet.returnType.kType.toString()) + assertFalse(greet.returnsFlow) + // greet is suspend now + kotlin.test.assertFalse(greet.isNonSuspendFunction) + } +} From 183d0e91b67d243af8db9840eb8f92b2cf979879 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Tue, 4 Nov 2025 20:18:56 +0100 Subject: [PATCH 3/3] grpc: Add gzip Compression Support (#527) --- .../rpc/grpc/client/ClientInterceptor.kt | 136 +++++ .../rpc/grpc/client/GrpcCallOptions.kt | 52 ++ .../kotlinx/rpc/grpc/client/GrpcClient.kt | 297 +++++++++++ .../rpc/grpc/client/internal/GrpcChannel.kt | 18 + .../client/internal/suspendClientCalls.kt | 300 +++++++++++ .../rpc/grpc/client/GrpcCallOptions.jvm.kt | 22 + .../grpc/client/internal/GrpcChannel.jvm.kt | 22 + .../rpc/grpc/client/GrpcCallOptions.native.kt | 28 ++ .../client/internal/GrpcChannel.native.kt | 25 + .../grpc/client/internal/NativeClientCall.kt | 476 ++++++++++++++++++ .../client/internal/NativeManagedChannel.kt | 172 +++++++ .../kotlinx/rpc/grpc/GrpcCompression.kt | 55 ++ .../kotlinx/rpc/grpc/test/CoreClientTest.kt | 310 ++++++++++++ .../grpc/test/proto/GrpcCompressionTest.kt | 239 +++++++++ .../kotlin/kotlinx/rpc/grpc/test/utils.kt | 47 ++ .../kotlin/kotlinx/rpc/grpc/test/utils.jvm.kt | 32 ++ .../kotlinx/rpc/grpc/GrpcMetadata.native.kt | 328 ++++++++++++ .../kotlinx/rpc/grpc/test/utils.native.kt | 77 +++ 18 files changed, 2636 insertions(+) create mode 100644 grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/ClientInterceptor.kt create mode 100644 grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.kt create mode 100644 grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/GrpcClient.kt create mode 100644 grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.kt create mode 100644 grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/internal/suspendClientCalls.kt create mode 100644 grpc/grpc-client/src/jvmMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.jvm.kt create mode 100644 grpc/grpc-client/src/jvmMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.jvm.kt create mode 100644 grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.native.kt create mode 100644 grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.native.kt create mode 100644 grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/NativeClientCall.kt create mode 100644 grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/NativeManagedChannel.kt create mode 100644 grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcCompression.kt create mode 100644 grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CoreClientTest.kt create mode 100644 grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcCompressionTest.kt create mode 100644 grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/utils.kt create mode 100644 grpc/grpc-core/src/jvmTest/kotlin/kotlinx/rpc/grpc/test/utils.jvm.kt create mode 100644 grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.native.kt create mode 100644 grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/test/utils.native.kt diff --git a/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/ClientInterceptor.kt b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/ClientInterceptor.kt new file mode 100644 index 000000000..f8c4e318e --- /dev/null +++ b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/ClientInterceptor.kt @@ -0,0 +1,136 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.client + +import kotlinx.coroutines.flow.Flow +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.Status +import kotlinx.rpc.grpc.client.GrpcCallOptions +import kotlinx.rpc.grpc.descriptor.MethodDescriptor + +/** + * The scope of a single outgoing gRPC client call observed by a [ClientInterceptor]. + * + * An interceptor receives this scope instance for every call and can: + * - Inspect the RPC [method] being invoked. + * - Read or populate [requestHeaders] before the request is sent. + * - Read [callOptions] that affect transport-level behavior. + * - Register callbacks with [onHeaders] and [onClose] to observe response metadata and final status. + * - Cancel the call early via [cancel]. + * - Continue the call by calling [proceed] with a (possibly transformed) request [Flow]. + * - Transform the response by modifying the returned [Flow]. + * + * ```kt + * val interceptor = object : ClientInterceptor { + * override fun ClientCallScope.intercept( + * request: Flow + * ): Flow { + * // Example: add a header before proceeding + * requestHeaders[MyKeys.Authorization] = token + * + * // Example: modify call options + * callOptions.timeout = 5.seconds + * + * // Example: observe response metadata + * onHeaders { headers -> /* inspect headers */ } + * onClose { status, trailers -> /* log status/trailers */ } + * + * // IMPORTANT: proceed forwards the call to the next interceptor/transport. + * // If you do not call proceed, no request will be sent and the call is short-circuited. + * return proceed(request) + * } + * } + * ``` + * + * @param Request the request message type of the RPC. + * @param Response the response message type of the RPC. + */ +public interface ClientCallScope { + /** Descriptor of the RPC method (name, marshalling, type) being invoked. */ + public val method: MethodDescriptor + + /** + * Outgoing request headers for this call. + * + * Interceptors may read and mutate this metadata + * before calling [proceed] so the headers are sent to the server. Headers added after + * the call has already been proceeded may not be reflected on the wire. + */ + public val requestHeaders: GrpcMetadata + + /** + * Transport/engine options used for this call (deadlines, compression, etc.). + * Modifying this object is only possible before the call is proceeded. + */ + public val callOptions: GrpcCallOptions + + /** + * Register a callback invoked when the initial response headers are received. + * Typical gRPC semantics guarantee headers are delivered at most once per call + * and before the first message is received. + */ + public fun onHeaders(block: (responseHeaders: GrpcMetadata) -> Unit) + + /** + * Register a callback invoked when the call completes, successfully or not. + * The final `status` and trailing `responseTrailers` are provided. + */ + public fun onClose(block: (status: Status, responseTrailers: GrpcMetadata) -> Unit) + + /** + * Cancel the call locally, providing a human-readable [message] and an optional [cause]. + * This method won't return and abort all further processing. + * + * We made cancel throw a [kotlinx.rpc.grpc.StatusException] instead of returning, so control flow is explicit and + * race conditions between interceptors and the transport layer are avoided. + */ + public fun cancel(message: String, cause: Throwable? = null): Nothing + + /** + * Continue the invocation by forwarding it to the next interceptor or to the underlying transport. + * + * This function is the heart of an interceptor: + * - It must be called to actually perform the RPC. If you never call [proceed], the request is not sent + * and the call is effectively short-circuited by the interceptor. + * - You may transform the [request] flow before passing it to [proceed] (e.g., logging, retry orchestration, + * compression, metrics). The returned [Flow] yields response messages and can also be transformed + * before being returned to the caller. + * - Call [proceed] at most once per intercepted call. Calling it multiple times or after cancellation + * is not supported. + */ + public fun proceed(request: Flow): Flow +} + +/** + * Client-side interceptor for gRPC calls. + * + * Implementations can observe and modify client calls in a structured way. The primary entry point is the + * [intercept] extension function on [ClientCallScope], which receives the inbound request [Flow] and must + * call [ClientCallScope.proceed] to forward the call. + * + * Common use-cases include: + * - Adding authentication or custom headers. + * - Implementing logging/metrics. + * - Observing headers/trailers and final status. + * - Transforming request/response flows (e.g., mapping, buffering, throttling). + */ +public interface ClientInterceptor { + /** + * Intercept a client call. + * + * You can: + * - Inspect [ClientCallScope.method] and [ClientCallScope.callOptions]. + * - Read or populate [ClientCallScope.requestHeaders]. + * - Register [ClientCallScope.onHeaders] and [ClientCallScope.onClose] callbacks. + * - Transform the [request] flow or wrap the resulting response flow. + * + * IMPORTANT: [ClientCallScope.proceed] must eventually be called to actually execute the RPC and obtain + * the response [Flow]. If [ClientCallScope.proceed] is omitted, the call will not reach the server. + */ + public fun ClientCallScope.intercept( + request: Flow, + ): Flow + +} \ No newline at end of file diff --git a/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.kt b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.kt new file mode 100644 index 000000000..8bf380bad --- /dev/null +++ b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.kt @@ -0,0 +1,52 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.client + +import kotlinx.rpc.grpc.GrpcCompression +import kotlin.time.Duration + +/** + * The collection of runtime options for a new gRPC call. + * + * This class allows configuring per-call behavior such as timeouts. + */ +public class GrpcCallOptions { + /** + * The maximum duration to wait for the RPC to complete. + * + * If set, the RPC will be canceled (with `DEADLINE_EXCEEDED`) + * if it does not complete within the specified duration. + * The timeout is measured from the moment the call is initiated. + * If `null`, no timeout is applied, and the call may run indefinitely. + * + * The default value is `null`. + * + * @see kotlin.time.Duration + */ + public var timeout: Duration? = null + + /** + * The compression algorithm to use for encoding outgoing messages in this call. + * + * When set to a value other than [GrpcCompression.None], the client will compress request messages + * using the specified algorithm before sending them to the server. The chosen compression algorithm + * is communicated to the server via the `grpc-encoding` header. + * + * ## Default Behavior + * Defaults to [GrpcCompression.None], meaning no compression is applied to messages. + * + * ## Server Compatibility + * **Important**: It is the caller's responsibility to ensure the server supports the chosen + * compression algorithm. There is no automatic negotiation performed. If the server does not + * support the requested compression, the call will fail. + * + * ## Available Algorithms + * - [GrpcCompression.None]: No compression (identity encoding) - **default** + * - [GrpcCompression.Gzip]: GZIP compression, widely supported + * + * @see GrpcCompression + */ + public var compression: GrpcCompression = GrpcCompression.None +} \ No newline at end of file diff --git a/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/GrpcClient.kt b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/GrpcClient.kt new file mode 100644 index 000000000..0e0426d6c --- /dev/null +++ b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/GrpcClient.kt @@ -0,0 +1,297 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.client + +import kotlinx.coroutines.flow.Flow +import kotlinx.rpc.RpcCall +import kotlinx.rpc.RpcClient +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.client.GrpcCallOptions +import kotlinx.rpc.grpc.client.internal.ManagedChannel +import kotlinx.rpc.grpc.client.internal.ManagedChannelBuilder +import kotlinx.rpc.grpc.client.internal.bidirectionalStreamingRpc +import kotlinx.rpc.grpc.client.internal.buildChannel +import kotlinx.rpc.grpc.client.internal.clientStreamingRpc +import kotlinx.rpc.grpc.client.internal.serverStreamingRpc +import kotlinx.rpc.grpc.client.internal.unaryRpc +import kotlinx.rpc.grpc.codec.EmptyMessageCodecResolver +import kotlinx.rpc.grpc.codec.MessageCodecResolver +import kotlinx.rpc.grpc.codec.ThrowingMessageCodecResolver +import kotlinx.rpc.grpc.codec.plus +import kotlinx.rpc.grpc.descriptor.GrpcServiceDelegate +import kotlinx.rpc.grpc.descriptor.GrpcServiceDescriptor +import kotlinx.rpc.grpc.descriptor.MethodDescriptor +import kotlinx.rpc.grpc.descriptor.MethodType +import kotlinx.rpc.grpc.descriptor.methodType +import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap +import kotlin.time.Duration + +private typealias RequestClient = Any + +/** + * GrpcClient manages gRPC communication by providing implementation for making asynchronous RPC calls. + * + * @field channel The [kotlinx.rpc.grpc.client.internal.ManagedChannel] used to communicate with remote gRPC services. + */ +public class GrpcClient internal constructor( + internal val channel: ManagedChannel, + messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver, + internal val interceptors: List, +) : RpcClient { + private val delegates = RpcInternalConcurrentHashMap() + private val messageCodecResolver = messageCodecResolver + ThrowingMessageCodecResolver + + public fun shutdown() { + delegates.clear() + channel.shutdown() + } + + public fun shutdownNow() { + delegates.clear() + channel.shutdownNow() + } + + public suspend fun awaitTermination(duration: Duration = Duration.INFINITE) { + channel.awaitTermination(duration) + } + + override suspend fun call(call: RpcCall): T = withGrpcCall(call) { methodDescriptor, request -> + val callOptions = GrpcCallOptions() + val trailers = GrpcMetadata() + + return when (methodDescriptor.methodType) { + MethodType.UNARY -> unaryRpc( + descriptor = methodDescriptor, + request = request, + callOptions = callOptions, + headers = trailers, + ) + + MethodType.CLIENT_STREAMING -> @Suppress("UNCHECKED_CAST") clientStreamingRpc( + descriptor = methodDescriptor, + requests = request as Flow, + callOptions = callOptions, + headers = trailers, + ) + + else -> error("Wrong method type ${methodDescriptor.methodType}") + } + } + + override fun callServerStreaming(call: RpcCall): Flow = withGrpcCall(call) { methodDescriptor, request -> + val callOptions = GrpcCallOptions() + val headers = GrpcMetadata() + + when (methodDescriptor.methodType) { + MethodType.SERVER_STREAMING -> serverStreamingRpc( + descriptor = methodDescriptor, + request = request, + callOptions = callOptions, + headers = headers, + ) + + MethodType.BIDI_STREAMING -> @Suppress("UNCHECKED_CAST") bidirectionalStreamingRpc( + descriptor = methodDescriptor, + requests = request as Flow, + callOptions = callOptions, + headers = headers, + ) + + else -> error("Wrong method type ${methodDescriptor.methodType}") + } + } + + private inline fun withGrpcCall(call: RpcCall, body: (MethodDescriptor, Any) -> R): R { + require(call.arguments.size <= 1) { + "Call parameter size must be 0 or 1, but ${call.arguments.size}" + } + + val delegate = delegates.computeIfAbsent(call.descriptor.fqName) { + val grpc = call.descriptor as? GrpcServiceDescriptor<*> + ?: error("Expected a gRPC service") + + grpc.delegate(messageCodecResolver) + } + + @Suppress("UNCHECKED_CAST") + val methodDescriptor = delegate.getMethodDescriptor(call.callableName) + as? MethodDescriptor + ?: error("Expected a gRPC method descriptor") + + val request = call.arguments.getOrNull(0) ?: Unit + + return body(methodDescriptor, request) + } +} + +/** + * Creates and configures a gRPC client instance. + * + * This function initializes a new gRPC client with the specified target server + * and allows optional customization of the client's configuration through a configuration block. + * + * @param hostname The gRPC server hostname to connect to. + * @param port The gRPC server port to connect to. + * @param configure An optional configuration block to customize the [GrpcClientConfiguration]. + * This can include setting up interceptors, specifying credentials, customizing message codec + * resolution, and overriding default authority. + * + * @return A new instance of [GrpcClient] configured with the specified target and options. + * + * @see [GrpcClientConfiguration] + */ +public fun GrpcClient( + hostname: String, + port: Int, + configure: GrpcClientConfiguration.() -> Unit = {}, +): GrpcClient { + val config = GrpcClientConfiguration().apply(configure) + return GrpcClient(ManagedChannelBuilder(hostname, port, config.credentials), config) +} + + +/** + * Creates and configures a gRPC client instance. + * + * This function initializes a new gRPC client with the specified target server + * and allows optional customization of the client's configuration through a configuration block. + * + * @param target The gRPC server endpoint to connect to, typically specified in + * the format `hostname:port`. + * @param configure An optional configuration block to customize the [GrpcClientConfiguration]. + * This can include setting up interceptors, specifying credentials, customizing message codec + * resolution, and overriding default authority. + * + * @return A new instance of [GrpcClient] configured with the specified target and options. + * + * @see [GrpcClientConfiguration] + */ +public fun GrpcClient( + target: String, + configure: GrpcClientConfiguration.() -> Unit = {}, +): GrpcClient { + val config = GrpcClientConfiguration().apply(configure) + return GrpcClient(ManagedChannelBuilder(target, config.credentials), config) +} + +private fun GrpcClient( + builder: ManagedChannelBuilder<*>, + config: GrpcClientConfiguration, +): GrpcClient { + val channel = builder.apply { + config.overrideAuthority?.let { overrideAuthority(it) } + }.buildChannel() + return GrpcClient(channel, config.messageCodecResolver, config.interceptors) +} + + +/** + * Configuration class for a gRPC client, providing customization options + * for client behavior, including interceptors, credentials, codec resolution, + * and authority overrides. + * + * @see credentials + * @see overrideAuthority + * @see intercept + */ +public class GrpcClientConfiguration internal constructor() { + internal val interceptors: MutableList = mutableListOf() + + /** + * Configurable resolver used to determine the appropriate codec for a given Kotlin type + * during message serialization and deserialization in gRPC calls. + * + * Custom implementations of [MessageCodecResolver] can be provided to handle specific serialization + * for arbitrary types. + * For custom types prefer using the [kotlinx.rpc.grpc.codec.WithCodec] annotation. + * + * @see MessageCodecResolver + * @see kotlinx.rpc.grpc.codec.SourcedMessageCodec + * @see kotlinx.rpc.grpc.codec.WithCodec + */ + public var messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver + + + /** + * Configures the client credentials used for secure gRPC requests made by the client. + * + * By default, the client uses default TLS credentials. + * To use custom TLS credentials, use the [tls] constructor function which returns a + * [TlsClientCredentials] instance. + * + * To use plaintext communication, use the [plaintext] constructor function. + * Should only be used for testing or for APIs where the use of such API or + * the data exchanged is not sensitive. + * + * ``` + * GrpcClient("localhost", 50051) { + * credentials = plaintext() // for testing purposes only! + * } + * ``` + * + * @see tls + * @see plaintext + */ + public var credentials: ClientCredentials? = null + + /** + * Overrides the authority used with TLS and HTTP virtual hosting. + * It does not change what the host is actually connected to. + * Is commonly in the form `host:port`. + */ + public var overrideAuthority: String? = null + + + /** + * Adds one or more client-side interceptors to the current gRPC client configuration. + * Interceptors enable extended customization of gRPC calls + * by observing or altering the behaviors of requests and responses. + * + * The order of interceptors added via this method is significant. + * Interceptors are executed in the order they are added, + * while one interceptor has to invoke the next interceptor to proceed with the call. + * + * @param interceptors Interceptors to be added to the current configuration. + * Each provided instance of [ClientInterceptor] may perform operations such as modifying headers, + * observing call metadata, logging, or transforming data flows. + * + * @see ClientInterceptor + * @see ClientCallScope + */ + public fun intercept(vararg interceptors: ClientInterceptor) { + this.interceptors.addAll(interceptors) + } + + /** + * Provides insecure client credentials for the gRPC client configuration. + * + * Typically, this would be used for local development, testing, or other + * environments where security is not a concern. + * + * @return An insecure [ClientCredentials] instance that must be passed to [credentials]. + */ + public fun plaintext(): ClientCredentials = createInsecureClientCredentials() + + /** + * Configures and creates secure client credentials for the gRPC client. + * + * This method takes a configuration block in which TLS-related parameters, + * such as trust managers and key managers, can be defined. The resulting + * credentials are used to establish secure communication between the gRPC client + * and server, ensuring encrypted transmission of data and mutual authentication + * if configured. + * + * Alternatively, you can use the [TlsClientCredentials] constructor. + * + * @param configure A configuration block that allows setting up the TLS parameters + * using the [TlsClientCredentialsBuilder]. + * @return A secure [ClientCredentials] instance that must be passed to [credentials]. + * + * @see credentials + */ + public fun tls(configure: TlsClientCredentialsBuilder.() -> Unit): ClientCredentials = + TlsClientCredentials(configure) + +} \ No newline at end of file diff --git a/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.kt b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.kt new file mode 100644 index 000000000..e86a3b1f8 --- /dev/null +++ b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.kt @@ -0,0 +1,18 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.client.internal + +import kotlinx.rpc.grpc.client.GrpcCallOptions +import kotlinx.rpc.grpc.descriptor.MethodDescriptor +import kotlinx.rpc.internal.utils.InternalRpcApi + +@InternalRpcApi +public expect abstract class GrpcChannel + +@InternalRpcApi +public expect fun GrpcChannel.createCall( + methodDescriptor: MethodDescriptor, + callOptions: GrpcCallOptions, +): ClientCall \ No newline at end of file diff --git a/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/internal/suspendClientCalls.kt b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/internal/suspendClientCalls.kt new file mode 100644 index 000000000..47fb7e60f --- /dev/null +++ b/grpc/grpc-client/src/commonMain/kotlin/kotlinx/rpc/grpc/client/internal/suspendClientCalls.kt @@ -0,0 +1,300 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.client.internal + +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.cancel +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.onFailure +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.single +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.Status +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.StatusException +import kotlinx.rpc.grpc.client.ClientCallScope +import kotlinx.rpc.grpc.client.GrpcCallOptions +import kotlinx.rpc.grpc.client.GrpcClient +import kotlinx.rpc.grpc.descriptor.MethodDescriptor +import kotlinx.rpc.grpc.descriptor.MethodType +import kotlinx.rpc.grpc.descriptor.methodType +import kotlinx.rpc.grpc.internal.CallbackFuture +import kotlinx.rpc.grpc.internal.Ready +import kotlinx.rpc.grpc.internal.singleOrStatus +import kotlinx.rpc.grpc.statusCode +import kotlinx.rpc.internal.utils.InternalRpcApi + +// heavily inspired by +// https://github.com/grpc/grpc-kotlin/blob/master/stub/src/main/java/io/grpc/kotlin/ClientCalls.kt + +@InternalRpcApi +public suspend fun GrpcClient.unaryRpc( + descriptor: MethodDescriptor, + request: Request, + callOptions: GrpcCallOptions = GrpcCallOptions(), + headers: GrpcMetadata = GrpcMetadata(), +): Response { + val type = descriptor.methodType + require(type == MethodType.UNARY) { + "Expected a unary RPC method, but got $descriptor" + } + + return rpcImpl( + descriptor = descriptor, + callOptions = callOptions, + headers = headers, + request = flowOf(request) + ).singleOrStatus("request", descriptor) +} + +@InternalRpcApi +public fun GrpcClient.serverStreamingRpc( + descriptor: MethodDescriptor, + request: Request, + callOptions: GrpcCallOptions = GrpcCallOptions(), + headers: GrpcMetadata = GrpcMetadata(), +): Flow { + val type = descriptor.methodType + require(type == MethodType.SERVER_STREAMING) { + "Expected a server streaming RPC method, but got $type" + } + + return rpcImpl( + descriptor = descriptor, + callOptions = callOptions, + headers = headers, + request = flowOf(request) + ) +} + +@InternalRpcApi +public suspend fun GrpcClient.clientStreamingRpc( + descriptor: MethodDescriptor, + requests: Flow, + callOptions: GrpcCallOptions = GrpcCallOptions(), + headers: GrpcMetadata = GrpcMetadata(), +): Response { + val type = descriptor.methodType + require(type == MethodType.CLIENT_STREAMING) { + "Expected a client streaming RPC method, but got $type" + } + + return rpcImpl( + descriptor = descriptor, + callOptions = callOptions, + headers = headers, + request = requests + ).singleOrStatus("response", descriptor) +} + +@InternalRpcApi +public fun GrpcClient.bidirectionalStreamingRpc( + descriptor: MethodDescriptor, + requests: Flow, + callOptions: GrpcCallOptions = GrpcCallOptions(), + headers: GrpcMetadata = GrpcMetadata(), +): Flow { + val type = descriptor.methodType + check(type == MethodType.BIDI_STREAMING) { + "Expected a bidirectional streaming method, but got $type" + } + + return rpcImpl( + descriptor = descriptor, + callOptions = callOptions, + headers = headers, + request = requests + ) +} + +private sealed interface ClientRequest { + suspend fun sendTo( + clientCall: ClientCall, + ready: Ready, + ) + + class Unary(private val request: Request) : ClientRequest { + override suspend fun sendTo( + clientCall: ClientCall, + ready: Ready, + ) { + clientCall.sendMessage(request) + } + } + + class Flowing(private val requestFlow: Flow) : ClientRequest { + override suspend fun sendTo( + clientCall: ClientCall, + ready: Ready, + ) { + ready.suspendUntilReady() + requestFlow.collect { request -> + clientCall.sendMessage(request) + ready.suspendUntilReady() + } + } + } +} + +private fun GrpcClient.rpcImpl( + descriptor: MethodDescriptor, + callOptions: GrpcCallOptions, + headers: GrpcMetadata, + request: Flow, +): Flow { + val clientCallScope = ClientCallScopeImpl( + client = this, + method = descriptor, + requestHeaders = headers, + callOptions = callOptions, + ) + return clientCallScope.proceed(request) +} + +private class ClientCallScopeImpl( + val client: GrpcClient, + override val method: MethodDescriptor, + override val requestHeaders: GrpcMetadata, + override val callOptions: GrpcCallOptions, +) : ClientCallScope { + val interceptors = client.interceptors + val onHeadersFuture = CallbackFuture() + val onCloseFuture = CallbackFuture>() + + var interceptorIndex = 0 + + override fun onHeaders(block: (GrpcMetadata) -> Unit) { + onHeadersFuture.onComplete { block(it) } + } + + override fun onClose(block: (Status, GrpcMetadata) -> Unit) { + onCloseFuture.onComplete { block(it.first, it.second) } + } + + override fun cancel(message: String, cause: Throwable?): Nothing { + throw StatusException(Status(StatusCode.CANCELLED, message, cause)) + } + + override fun proceed(request: Flow): Flow { + return if (interceptorIndex < interceptors.size) { + with(interceptors[interceptorIndex++]) { + intercept(request) + } + } else { + // if the interceptor chain is exhausted, we start the actual call + doCall(request) + } + } + + private fun doCall(request: Flow): Flow = flow { + coroutineScope { + val call = client.channel.platformApi.createCall(method, callOptions) + + /* + * We maintain a buffer of size 1 so onMessage never has to block: it only gets called after + * we request a response from the server, which only happens when responses is empty and + * there is room in the buffer. + */ + val responses = Channel(1) + val ready = Ready { call.isReady() } + + call.start(channelResponseListener(call, responses, ready), requestHeaders) + + suspend fun Flow.send() { + if (method.methodType == MethodType.UNARY || method.methodType == MethodType.SERVER_STREAMING) { + call.sendMessage(single()) + } else { + ready.suspendUntilReady() + this.collect { request -> + call.sendMessage(request) + ready.suspendUntilReady() + } + } + } + + val fullMethodName = method.getFullMethodName() + val sender = launch(CoroutineName("grpc-send-message-$fullMethodName")) { + try { + request.send() + call.halfClose() + } catch (ex: Exception) { + call.cancel("Collection of requests completed exceptionally", ex) + throw ex // propagate failure upward + } + } + + try { + call.request(1) + for (response in responses) { + emit(response) + call.request(1) + } + } catch (e: Exception) { + withContext(NonCancellable) { + sender.cancel("Collection of responses completed exceptionally", e) + sender.join() + // we want the sender to be done cancelling before we cancel the handler, or it might try + // sending to a dead call, which results in ugly exception messages + call.cancel("Collection of responses completed exceptionally", e) + } + throw e + } + + if (!sender.isCompleted) { + sender.cancel("Collection of responses completed before collection of requests") + } + } + } + + private fun channelResponseListener( + call: ClientCall<*, Response>, + responses: Channel, + ready: Ready, + ) = clientCallListener( + onHeaders = { + try { + onHeadersFuture.complete(it) + } catch (e: StatusException) { + // if a client interceptor called cancel, we throw a StatusException. + // as the JVM implementation treats them differently, we need to catch them here. + call.cancel(e.message, e.cause) + } + }, + onMessage = { message: Response -> + responses.trySend(message).onFailure { e -> + throw e ?: AssertionError("onMessage should never be called until responses is ready") + } + }, + onClose = { status: Status, trailers: GrpcMetadata -> + var cause = when { + status.statusCode == StatusCode.OK -> null + status.getCause() is CancellationException -> status.getCause() + else -> StatusException(status, trailers) + } + + try { + onCloseFuture.complete(status to trailers) + } catch (exception: Throwable) { + cause = exception + if (exception !is StatusException) { + val status = Status(StatusCode.CANCELLED, "Interceptor threw an error", exception) + cause = StatusException(status) + } + } + + responses.close(cause = cause) + }, + onReady = { + ready.onReady() + }, + ) +} diff --git a/grpc/grpc-client/src/jvmMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.jvm.kt b/grpc/grpc-client/src/jvmMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.jvm.kt new file mode 100644 index 000000000..ba337ef58 --- /dev/null +++ b/grpc/grpc-client/src/jvmMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.jvm.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.client + +import io.grpc.CallOptions +import kotlinx.rpc.grpc.GrpcCompression +import kotlinx.rpc.internal.utils.InternalRpcApi +import java.util.concurrent.TimeUnit + +@InternalRpcApi +public fun GrpcCallOptions.toJvm(): CallOptions { + var default = CallOptions.DEFAULT + if (timeout != null) { + default = default.withDeadlineAfter(timeout!!.inWholeMilliseconds, TimeUnit.MILLISECONDS) + } + if (compression !is GrpcCompression.None) { + default = default.withCompression(compression.name) + } + return default +} \ No newline at end of file diff --git a/grpc/grpc-client/src/jvmMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.jvm.kt b/grpc/grpc-client/src/jvmMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.jvm.kt new file mode 100644 index 000000000..b5e1b016d --- /dev/null +++ b/grpc/grpc-client/src/jvmMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.jvm.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.client.internal + +import io.grpc.Channel +import kotlinx.rpc.grpc.client.GrpcCallOptions +import kotlinx.rpc.grpc.client.toJvm +import kotlinx.rpc.grpc.descriptor.MethodDescriptor +import kotlinx.rpc.internal.utils.InternalRpcApi + +@InternalRpcApi +public actual typealias GrpcChannel = Channel + +@InternalRpcApi +public actual fun GrpcChannel.createCall( + methodDescriptor: MethodDescriptor, + callOptions: GrpcCallOptions, +): ClientCall { + return this.newCall(methodDescriptor, callOptions.toJvm()) +} \ No newline at end of file diff --git a/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.native.kt b/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.native.kt new file mode 100644 index 000000000..3d0f05372 --- /dev/null +++ b/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/GrpcCallOptions.native.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +@file:OptIn(ExperimentalForeignApi::class) + +package kotlinx.rpc.grpc.client + +import kotlinx.cinterop.CValue +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.rpc.internal.utils.InternalRpcApi +import libkgrpc.GPR_CLOCK_REALTIME +import libkgrpc.GPR_TIMESPAN +import libkgrpc.gpr_inf_future +import libkgrpc.gpr_now +import libkgrpc.gpr_time_add +import libkgrpc.gpr_time_from_millis +import libkgrpc.gpr_timespec + +@InternalRpcApi +public fun GrpcCallOptions.rawDeadline(): CValue { + return timeout?.let { + gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_millis(it.inWholeMilliseconds, GPR_TIMESPAN) + ) + } ?: gpr_inf_future(GPR_CLOCK_REALTIME) +} \ No newline at end of file diff --git a/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.native.kt b/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.native.kt new file mode 100644 index 000000000..d14356296 --- /dev/null +++ b/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/GrpcChannel.native.kt @@ -0,0 +1,25 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.client.internal + +import kotlinx.rpc.grpc.client.GrpcCallOptions +import kotlinx.rpc.grpc.descriptor.MethodDescriptor +import kotlinx.rpc.internal.utils.InternalRpcApi + +@InternalRpcApi +public actual abstract class GrpcChannel { + public abstract fun newCall( + methodDescriptor: MethodDescriptor, + callOptions: GrpcCallOptions, + ): ClientCall +} + +@InternalRpcApi +public actual fun GrpcChannel.createCall( + methodDescriptor: MethodDescriptor, + callOptions: GrpcCallOptions, +): ClientCall { + return this.newCall(methodDescriptor, callOptions) +} \ No newline at end of file diff --git a/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/NativeClientCall.kt b/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/NativeClientCall.kt new file mode 100644 index 000000000..cfc95970d --- /dev/null +++ b/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/NativeClientCall.kt @@ -0,0 +1,476 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +@file:OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) + +package kotlinx.rpc.grpc.client.internal + +import cnames.structs.grpc_call +import kotlinx.atomicfu.atomic +import kotlinx.cinterop.Arena +import kotlinx.cinterop.ByteVar +import kotlinx.cinterop.CPointer +import kotlinx.cinterop.CPointerVar +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.alloc +import kotlinx.cinterop.allocArray +import kotlinx.cinterop.convert +import kotlinx.cinterop.get +import kotlinx.cinterop.ptr +import kotlinx.cinterop.readValue +import kotlinx.cinterop.toKString +import kotlinx.cinterop.value +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableJob +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.Status +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.append +import kotlinx.rpc.grpc.descriptor.MethodDescriptor +import kotlinx.rpc.grpc.internal.BatchResult +import kotlinx.rpc.grpc.internal.CompletionQueue +import kotlinx.rpc.grpc.internal.destroyEntries +import kotlinx.rpc.grpc.internal.internalError +import kotlinx.rpc.grpc.internal.toByteArray +import kotlinx.rpc.grpc.internal.toGrpcByteBuffer +import kotlinx.rpc.grpc.internal.toKotlin +import kotlinx.rpc.protobuf.input.stream.asInputStream +import kotlinx.rpc.protobuf.input.stream.asSource +import kotlinx.rpc.grpc.GrpcCompression +import kotlinx.rpc.grpc.client.GrpcCallOptions +import libkgrpc.GRPC_OP_RECV_INITIAL_METADATA +import libkgrpc.GRPC_OP_RECV_MESSAGE +import libkgrpc.GRPC_OP_RECV_STATUS_ON_CLIENT +import libkgrpc.GRPC_OP_SEND_CLOSE_FROM_CLIENT +import libkgrpc.GRPC_OP_SEND_INITIAL_METADATA +import libkgrpc.GRPC_OP_SEND_MESSAGE +import libkgrpc.gpr_free +import libkgrpc.grpc_byte_buffer +import libkgrpc.grpc_byte_buffer_destroy +import libkgrpc.grpc_call_cancel_with_status +import libkgrpc.grpc_call_error +import libkgrpc.grpc_call_unref +import libkgrpc.grpc_metadata_array +import libkgrpc.grpc_metadata_array_destroy +import libkgrpc.grpc_metadata_array_init +import libkgrpc.grpc_op +import libkgrpc.grpc_slice +import libkgrpc.grpc_slice_unref +import libkgrpc.grpc_status_code +import kotlin.experimental.ExperimentalNativeApi +import kotlin.native.ref.createCleaner + + +internal class NativeClientCall( + private val cq: CompletionQueue, + internal val raw: CPointer, + private val methodDescriptor: MethodDescriptor, + private val callOptions: GrpcCallOptions, + private val callJob: CompletableJob, +) : ClientCall() { + + @Suppress("unused") + private val rawCleaner = createCleaner(raw) { + grpc_call_unref(it) + } + + init { + // cancel the call if the job is canceled. + callJob.invokeOnCompletion { + when (it) { + is CancellationException -> { + cancelInternal(grpc_status_code.GRPC_STATUS_UNAVAILABLE, "Channel shutdownNow invoked") + } + + is Throwable -> { + cancelInternal(grpc_status_code.GRPC_STATUS_INTERNAL, "Call failed: ${it.message}") + } + } + } + } + + private var listener: Listener? = null + private var halfClosed = false + private var cancelled = false + private val closed = atomic(false) + + // tracks how many operations are in flight (not yet completed by the listener). + // if 0 and we got a closeInfo (containing the status), there are no more ongoing operations. + // in this case, we can safely call onClose on the listener. + // we need this mechanism to ensure that onClose is not called while any other callback is still running + // on the listener. + private val inFlight = atomic(0) + + // holds the received status information returned by the RECV_STATUS_ON_CLIENT batch. + // if null, the call is still in progress. otherwise, the call can be closed as soon as inFlight is 0. + private val closeInfo = atomic?>(null) + + // we currently don't buffer messages, so after one `sendMessage` call, ready turns false. (KRPC-192) + private val ready = atomic(true) + + /** + * Increments the [inFlight] counter by one. + * This should be called before starting a batch. + */ + private fun beginOp() { + inFlight.incrementAndGet() + } + + /** + * Decrements the [inFlight] counter by one. + * This should be called after a batch has finished (in case of success AND error) + * AND the corresponding listener callback returned. + * + * If the counter reaches 0, no more listener callbacks are executed, and the call can be closed by + * calling [tryToCloseCall]. + */ + private fun endOp() { + if (inFlight.decrementAndGet() == 0) { + tryToCloseCall() + } + } + + /** + * Tries to close the call by invoking the listener's onClose callback. + * + * - If the call is already closed, this does nothing. + * - If the RECV_STATUS_ON_CLIENT batch is still in progress, this does nothing. + * - If the [inFlight] counter is not 0, this does nothing. + * - Otherwise, the listener's onClose callback is invoked and the call is closed. + */ + private fun tryToCloseCall() { + val info = closeInfo.value ?: return + if (inFlight.value == 0 && closed.compareAndSet(expect = false, update = true)) { + val lst = checkNotNull(listener) { internalError("Not yet started") } + // allows the managed channel to join for the call to finish. + callJob.complete() + safeUserCode("Failed to call onClose.") { + lst.onClose(info.first, info.second) + } + } + } + + /** + * Sets the [closeInfo] and calls [tryToCloseCall]. + * This is called as soon as the RECV_STATUS_ON_CLIENT batch (started with [startRecvStatus]) finished. + */ + private fun markClosePending(status: Status, trailers: GrpcMetadata) { + closeInfo.compareAndSet(null, Pair(status, trailers)) + tryToCloseCall() + } + + /** + * Sets the [ready] flag to true and calls the listener's onReady callback. + * This is called as soon as the RECV_MESSAGE batch is finished (or failed). + */ + private fun turnReady() { + if (ready.compareAndSet(expect = false, update = true)) { + safeUserCode("Failed to call onReady.") { + listener?.onReady() + } + } + } + + + override fun start( + responseListener: Listener, + headers: GrpcMetadata, + ) { + check(listener == null) { internalError("Already started") } + + listener = responseListener + + // start receiving the status from the completion queue, + // which is bound to the lifetime of the call. + val success = startRecvStatus() + if (!success) return + + // send and receive initial headers to/from the server + sendAndReceiveInitialMetadata(headers) + } + + /** + * Submits a batch operation to the [CompletionQueue] and handle the returned [kotlinx.rpc.grpc.internal.BatchResult]. + * If the batch was successfully submitted, [onSuccess] is called. + * In any case, [cleanup] is called. + */ + private fun runBatch( + ops: CPointer, + nOps: ULong, + cleanup: () -> Unit = {}, + onSuccess: () -> Unit = {}, + ) { + // we must not try to run a batch after the call is closed. + if (closed.value) return cleanup() + + // pre-book the batch, so onClose cannot be called before the batch finished. + beginOp() + + when (val callResult = cq.runBatch(this@NativeClientCall.raw, ops, nOps)) { + is BatchResult.Submitted -> { + callResult.future.onComplete { success -> + try { + if (success) { + // if the batch doesn't succeed, this is reflected in the recv status op batch. + onSuccess() + } + } finally { + // ignore failure, as it is reflected in the client status op + cleanup() + endOp() + } + } + } + + BatchResult.CQShutdown -> { + cleanup() + endOp() + cancelInternal(grpc_status_code.GRPC_STATUS_UNAVAILABLE, "Channel shutdown") + } + + is BatchResult.SubmitError -> { + cleanup() + endOp() + cancelInternal( + grpc_status_code.GRPC_STATUS_INTERNAL, + "Batch could not be submitted: ${callResult.error}" + ) + } + } + } + + /** + * Starts a batch operation to receive the status from the completion queue (RECV_STATUS_ON_CLIENT). + * This operation is bound to the lifetime of the call, so it will finish once all other operations are done. + * If this operation fails, it will call [markClosePending] with the corresponding error, as the entire call + * si considered failed. + * + * @return true if the batch was successfully submitted, false otherwise. + * In this case, the call is considered failed. + */ + @OptIn(ExperimentalStdlibApi::class) + private fun startRecvStatus(): Boolean { + checkNotNull(listener) { internalError("Not yet started") } + val arena = Arena() + val statusCode = arena.alloc() + val statusDetails = arena.alloc() + val errorStr = arena.alloc>() + + val trailingMetadata = arena.alloc() + grpc_metadata_array_init(trailingMetadata.ptr) + + val op = arena.alloc { + op = GRPC_OP_RECV_STATUS_ON_CLIENT + data.recv_status_on_client.status = statusCode.ptr + data.recv_status_on_client.status_details = statusDetails.ptr + data.recv_status_on_client.error_string = errorStr.ptr + data.recv_status_on_client.trailing_metadata = trailingMetadata.ptr + } + + when (val callResult = cq.runBatch(this@NativeClientCall.raw, op.ptr, 1u)) { + is BatchResult.Submitted -> { + callResult.future.onComplete { + val details = statusDetails.toByteArray().toKString() + val kStatusCode = statusCode.value.toKotlin() + val status = Status(kStatusCode, details, null) + val trailers = GrpcMetadata(trailingMetadata) + + // cleanup + grpc_slice_unref(statusDetails.readValue()) + if (errorStr.value != null) gpr_free(errorStr.value) + // the entries are owned by the call object, so we must only destroy the array + grpc_metadata_array_destroy(trailingMetadata.readValue()) + arena.clear() + + // set close info and try to close the call. + markClosePending(status, trailers) + } + return true + } + + BatchResult.CQShutdown -> { + arena.clear() + markClosePending(Status(StatusCode.UNAVAILABLE, "Channel shutdown"), GrpcMetadata()) + return false + } + + is BatchResult.SubmitError -> { + arena.clear() + markClosePending( + Status(StatusCode.INTERNAL, "Failed to start call: ${callResult.error}"), + GrpcMetadata() + ) + return false + } + } + } + + private fun sendAndReceiveInitialMetadata(headers: GrpcMetadata) { + // sending and receiving initial metadata + val arena = Arena() + val opsNum = 2uL + val ops = arena.allocArray(opsNum.convert()) + + // add compression algorithm to the call metadata. + // the gRPC core will read the header and perform the compression (compression_filter.cc). + if (callOptions.compression !is GrpcCompression.None) { + if (callOptions.compression !is GrpcCompression.Gzip) { + // to match the behavior of grpc-java, we throw an error if the compression algorithm is not supported. + cancelInternal(grpc_status_code.GRPC_STATUS_INTERNAL, "Unable to find compressor by name ${callOptions.compression.name}") + } + headers.append("grpc-internal-encoding-request", callOptions.compression.name) + } + + // turn given headers into a grpc_metadata_array. + val sendInitialMetadata: grpc_metadata_array = with(headers) { + arena.allocRawGrpcMetadata() + } + + // send initial meta data to server + ops[0].op = GRPC_OP_SEND_INITIAL_METADATA + ops[0].data.send_initial_metadata.count = sendInitialMetadata.count + ops[0].data.send_initial_metadata.metadata = sendInitialMetadata.metadata + + val recvInitialMetadata = arena.alloc() + grpc_metadata_array_init(recvInitialMetadata.ptr) + ops[1].op = GRPC_OP_RECV_INITIAL_METADATA + ops[1].data.recv_initial_metadata.recv_initial_metadata = recvInitialMetadata.ptr + + runBatch(ops, opsNum, cleanup = { + // we must not destroy the array itself, as it is cleared when clearing the arena. + sendInitialMetadata.destroyEntries() + // the entries are owned by the call object, so we must only destroy the array + grpc_metadata_array_destroy(recvInitialMetadata.readValue()) + arena.clear() + }) { + val headers = GrpcMetadata(recvInitialMetadata) + safeUserCode("Failed to call onHeaders.") { + listener?.onHeaders(headers) + } + } + } + + /** + * Requests [numMessages] messages from the server. + * This must only be called again after [numMessages] were received in the [Listener.onMessage] callback. + */ + override fun request(numMessages: Int) { + check(numMessages > 0) { internalError("numMessages must be > 0") } + // limit numMessages to prevent potential stack overflows + check(numMessages <= 16) { internalError("numMessages must be <= 16") } + val listener = checkNotNull(listener) { internalError("Not yet started") } + if (cancelled) { + // no need to send message if the call got already cancelled. + return + } + + var remainingMessages = numMessages + + // we need to request only one message at a time, so we use a recursive function that + // requests one message and then calls itself again. + fun post() { + if (remainingMessages-- <= 0) return + + val arena = Arena() + val recvPtr = arena.alloc>() + val op = arena.alloc { + op = GRPC_OP_RECV_MESSAGE + data.recv_message.recv_message = recvPtr.ptr + } + runBatch(op.ptr, 1u, cleanup = { + if (recvPtr.value != null) grpc_byte_buffer_destroy(recvPtr.value) + arena.clear() + }) { + // if the call was successful, but no message was received, we reached the end-of-stream. + val buf = recvPtr.value ?: return@runBatch + val msg = methodDescriptor.getResponseMarshaller() + .parse(buf.toKotlin().asInputStream()) + safeUserCode("Failed to call onClose.") { + listener.onMessage(msg) + } + post() + } + } + + // start requesting messages + post() + } + + override fun cancel(message: String?, cause: Throwable?) { + cancelled = true + val status = Status(StatusCode.CANCELLED, message ?: "Call cancelled", cause) + // user side cancellation must always win over any other status (even if the call is already completed). + // this will also preserve the cancellation cause, which cannot be passed to the grpc-core. + closeInfo.value = Pair(status, GrpcMetadata()) + cancelInternal( + grpc_status_code.GRPC_STATUS_CANCELLED, + message ?: "Call cancelled with cause: ${cause?.message}" + ) + } + + private fun cancelInternal(statusCode: grpc_status_code, message: String) { + val cancelResult = grpc_call_cancel_with_status(raw, statusCode, message, null) + if (cancelResult != grpc_call_error.GRPC_CALL_OK) { + markClosePending(Status(StatusCode.INTERNAL, "Failed to cancel call: $cancelResult"), GrpcMetadata()) + } + } + + override fun halfClose() { + check(!halfClosed) { internalError("Already half closed.") } + if (cancelled) return + halfClosed = true + + val arena = Arena() + val op = arena.alloc { + op = GRPC_OP_SEND_CLOSE_FROM_CLIENT + } + + runBatch(op.ptr, 1u, cleanup = { arena.clear() }) { + // nothing to do here + } + } + + override fun isReady(): Boolean = ready.value + + override fun sendMessage(message: Request) { + checkNotNull(listener) { internalError("Not yet started") } + check(!halfClosed) { internalError("Already half closed.") } + check(isReady()) { internalError("Not yet ready.") } + + if (cancelled) return + + // set ready false, as only one message can be sent at a time. + ready.value = false + + val arena = Arena() + val inputStream = methodDescriptor.getRequestMarshaller().stream(message) + val byteBuffer = inputStream.asSource().toGrpcByteBuffer() + + val op = arena.alloc { + op = GRPC_OP_SEND_MESSAGE + data.send_message.send_message = byteBuffer + } + + runBatch(op.ptr, 1u, cleanup = { + // actual cleanup + grpc_byte_buffer_destroy(byteBuffer) + arena.clear() + }) { + // set ready true, as we can now send another message. + turnReady() + } + } + + /** + * Safely executes the provided block of user code, catching any thrown exceptions or errors. + * If an exception is caught, it cancels the operation with the specified message and cause. + */ + private fun safeUserCode(cancelMsg: String, block: () -> Unit) { + try { + block() + } catch (e: Throwable) { + cancel(cancelMsg, e) + } + } +} diff --git a/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/NativeManagedChannel.kt b/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/NativeManagedChannel.kt new file mode 100644 index 000000000..5cf6563b9 --- /dev/null +++ b/grpc/grpc-client/src/nativeMain/kotlin/kotlinx/rpc/grpc/client/internal/NativeManagedChannel.kt @@ -0,0 +1,172 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +@file:OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) + +package kotlinx.rpc.grpc.client.internal + +import cnames.structs.grpc_channel +import kotlinx.atomicfu.atomic +import kotlinx.cinterop.CPointer +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.alloc +import kotlinx.cinterop.cstr +import kotlinx.cinterop.memScoped +import kotlinx.cinterop.ptr +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancelChildren +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.rpc.grpc.client.ClientCredentials +import kotlinx.rpc.grpc.client.GrpcCallOptions +import kotlinx.rpc.grpc.client.rawDeadline +import kotlinx.rpc.grpc.descriptor.MethodDescriptor +import kotlinx.rpc.grpc.internal.CompletionQueue +import kotlinx.rpc.grpc.internal.GrpcRuntime +import kotlinx.rpc.grpc.internal.internalError +import kotlinx.rpc.grpc.internal.toGrpcSlice +import libkgrpc.GRPC_PROPAGATE_DEFAULTS +import libkgrpc.grpc_arg +import libkgrpc.grpc_arg_type +import libkgrpc.grpc_channel_args +import libkgrpc.grpc_channel_create +import libkgrpc.grpc_channel_create_call +import libkgrpc.grpc_channel_destroy +import libkgrpc.grpc_slice_unref +import kotlin.coroutines.cancellation.CancellationException +import kotlin.experimental.ExperimentalNativeApi +import kotlin.native.ref.createCleaner +import kotlin.time.Duration + + +/** + * Native implementation of [ManagedChannel]. + * + * @param target The target address to connect to. + * @param credentials The credentials to use for the connection. + */ +internal class NativeManagedChannel( + target: String, + val authority: String?, + // we must store them, otherwise the credentials are getting released + credentials: ClientCredentials, +) : ManagedChannel, ManagedChannelPlatform() { + + // a reference to make sure the grpc_init() was called. (it is released after shutdown) + @Suppress("unused") + private val rt = GrpcRuntime.acquire() + + // job bundling all the call jobs created by this channel. + // this allows easy cancellation of ongoing calls. + private val callJobSupervisor = SupervisorJob() + + // the channel's completion queue, handling all request operations + private val cq = CompletionQueue() + + internal val raw: CPointer = memScoped { + val args = authority?.let { + // the C Core API doesn't have a way to override the authority (used for TLS SNI) as it + // is available in the Java gRPC implementation. + // instead, it can be done by setting the "grpc.ssl_target_name_override" argument. + val authorityOverride = alloc { + type = grpc_arg_type.GRPC_ARG_STRING + key = "grpc.ssl_target_name_override".cstr.ptr + value.string = authority.cstr.ptr + } + + alloc { + num_args = 1u + args = authorityOverride.ptr + } + } + grpc_channel_create(target, credentials.raw, args?.ptr) + ?: error("Failed to create channel") + } + + @Suppress("unused") + private val rawCleaner = createCleaner(raw) { + grpc_channel_destroy(it) + } + + override val platformApi: ManagedChannelPlatform = this + + private var isShutdownInternal = atomic(false) + override val isShutdown: Boolean + get() = isShutdownInternal.value + private val isTerminatedInternal = CompletableDeferred() + override val isTerminated: Boolean + get() = isTerminatedInternal.isCompleted + + override suspend fun awaitTermination(duration: Duration): Boolean { + withTimeoutOrNull(duration) { + isTerminatedInternal.await() + } ?: return false + return true + } + + override fun shutdown(): ManagedChannel { + shutdownInternal(false) + return this + } + + override fun shutdownNow(): ManagedChannel { + shutdownInternal(true) + return this + } + + private fun shutdownInternal(force: Boolean) { + isShutdownInternal.value = true + if (isTerminatedInternal.isCompleted) { + return + } + if (force) { + // cancel all jobs, such that the shutdown is completing faster (not immediate). + // TODO: replace jobs by custom pendingCallClass. + callJobSupervisor.cancelChildren(CancellationException("Channel is shutting down")) + } + + // wait for the completion queue to shut down. + // the completion queue will be shut down after all requests are completed. + // therefore, we don't have to wait for the callJobs to be completed. + cq.shutdown(force).onComplete { + if (isTerminatedInternal.complete(Unit)) { + // release the grpc runtime, so it might call grpc_shutdown() + rt.close() + } + } + } + + override fun newCall( + methodDescriptor: MethodDescriptor, + callOptions: GrpcCallOptions, + ): ClientCall { + check(!isShutdown) { internalError("Channel is shutdown") } + + val callJob = Job(callJobSupervisor) + + val methodFullName = methodDescriptor.getFullMethodName() + // to construct a valid HTTP/2 path, we must prepend the name with a slash. + // the user does not do this to align it with the java implementation. + val methodNameSlice = "/$methodFullName".toGrpcSlice() + + val rawCall = grpc_channel_create_call( + channel = raw, + parent_call = null, + propagation_mask = GRPC_PROPAGATE_DEFAULTS, + completion_queue = cq.raw, + method = methodNameSlice, + host = null, + deadline = callOptions.rawDeadline(), + reserved = null + ) ?: error("Failed to create call") + + grpc_slice_unref(methodNameSlice) + + return NativeClientCall( + cq, rawCall, methodDescriptor, callOptions, callJob + ) + } + +} diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcCompression.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcCompression.kt new file mode 100644 index 000000000..01a7a0ca3 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcCompression.kt @@ -0,0 +1,55 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc + +/** + * Represents a compression algorithm for gRPC message encoding. + * + * Compression can be applied to gRPC messages to reduce bandwidth usage during transmission. + * + * ## Supported Algorithms + * - [None] (identity): No compression is applied. + * - [Gzip]: GZIP compression algorithm, widely supported and provides good compression ratios. + * + * This interface is not meant to be implemented by users. + * + * @property name The compression algorithm identifier sent in the `grpc-encoding` header. + * + * @see kotlinx.rpc.grpc.client.GrpcCallOptions.compression + * @see GrpcCompression.None + * @see GrpcCompression.Gzip + */ +@OptIn(ExperimentalSubclassOptIn::class) +@SubclassOptInRequired +public interface GrpcCompression { + + /** + * The name of the compression algorithm as it appears in the `grpc-encoding` header. + */ + public val name: String + + /** + * Represents no compression (identity encoding). + * + * This is the default compression setting. When used, messages are transmitted without + * any compression applied. + */ + public object None : GrpcCompression { + override val name: String = "identity" + } + + /** + * Represents GZIP compression. + * + * GZIP is a widely supported compression algorithm that provides good compression ratios + * for most data types. + * + * **Note**: Ensure the server supports GZIP compression before using this option, + * as the call will fail if the server cannot handle the requested compression algorithm. + */ + public object Gzip : GrpcCompression { + override val name: String = "gzip" + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CoreClientTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CoreClientTest.kt new file mode 100644 index 000000000..d24d197f9 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CoreClientTest.kt @@ -0,0 +1,310 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ +package kotlinx.rpc.grpc.test + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.delay +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeout +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.Status +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.client.createInsecureClientCredentials +import kotlinx.rpc.grpc.client.internal.ClientCall +import kotlinx.rpc.grpc.client.GrpcCallOptions +import kotlinx.rpc.grpc.client.internal.ManagedChannel +import kotlinx.rpc.grpc.client.internal.ManagedChannelBuilder +import kotlinx.rpc.grpc.client.internal.buildChannel +import kotlinx.rpc.grpc.client.internal.clientCallListener +import kotlinx.rpc.grpc.client.internal.createCall +import kotlinx.rpc.grpc.descriptor.MethodDescriptor +import kotlinx.rpc.grpc.descriptor.MethodType +import kotlinx.rpc.grpc.descriptor.methodDescriptor +import kotlinx.rpc.grpc.server.GrpcServer +import kotlinx.rpc.grpc.statusCode +import kotlinx.rpc.registerService +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFails +import kotlin.test.assertFailsWith + +private const val PORT = 50051 + +/** + * Client tests that use lower level API directly to test that it behaves correctly. + * Before executing the tests run [GreeterServiceImpl.runServer] on JVM. + */ +// TODO: Start external service server automatically (KRPC-208) +class GrpcCoreClientTest { + + private fun descriptorFor(fullName: String = "kotlinx.rpc.grpc.test.GreeterService/SayHello"): MethodDescriptor = + methodDescriptor( + fullMethodName = fullName, + requestCodec = HelloRequestInternal.CODEC, + responseCodec = HelloReplyInternal.CODEC, + type = MethodType.UNARY, + schemaDescriptor = Unit, + idempotent = true, + safe = true, + sampledToLocalTracing = true, + ) + + private fun ManagedChannel.newHelloCall(fullName: String = "kotlinx.rpc.grpc.test.GreeterService/SayHello"): ClientCall = + platformApi.createCall(descriptorFor(fullName), GrpcCallOptions()) + + private fun createChannel(): ManagedChannel = ManagedChannelBuilder( + target = "localhost:$PORT", + credentials = createInsecureClientCredentials() + ).buildChannel() + + + private fun helloReq(timeout: UInt = 0u): HelloRequest = HelloRequest { + name = "world" + this.timeout = timeout + } + + private fun shutdownAndWait(channel: ManagedChannel, now: Boolean = false) { + if (now) { + channel.shutdownNow() + } else { + channel.shutdown() + } + runBlocking { channel.awaitTermination() } + } + + @Test + fun normalUnaryCall_ok() = repeat(1000) { + val channel = createChannel() + val call = channel.newHelloCall() + val req = helloReq() + + val statusDeferred = CompletableDeferred() + val replyDeferred = CompletableDeferred() + val listener = createClientCallListener( + onMessage = { replyDeferred.complete(it) }, + onClose = { status, _ -> statusDeferred.complete(status) } + ) + + call.start(listener, GrpcMetadata()) + call.sendMessage(req) + call.halfClose() + call.request(1) + + runBlocking { + withTimeout(10000) { + val status = statusDeferred.await() + assertEquals(StatusCode.OK, status.statusCode) + val reply = replyDeferred.await() + assertEquals("Hello world", reply.message) + } + } + shutdownAndWait(channel) + } + + @Test + fun start_twice_throws() { + val channel = createChannel() + val call = channel.newHelloCall() + val statusDeferred = CompletableDeferred() + val listener = createClientCallListener( + onClose = { status, _ -> statusDeferred.complete(status) } + ) + call.start(listener, GrpcMetadata()) + assertFailsWith { call.start(listener, GrpcMetadata()) } + // cancel to finish the call quickly + call.cancel("Double start test", null) + runBlocking { withTimeout(5000) { statusDeferred.await() } } + shutdownAndWait(channel) + } + + @Test + fun send_afterHalfClose_throws() { + val channel = createChannel() + val call = channel.newHelloCall() + val req = helloReq() + val statusDeferred = CompletableDeferred() + val listener = createClientCallListener( + onClose = { status, _ -> statusDeferred.complete(status) } + ) + call.start(listener, GrpcMetadata()) + call.halfClose() + assertFailsWith { call.sendMessage(req) } + // Ensure call completes + call.cancel("cleanup", null) + runBlocking { withTimeout(5000) { statusDeferred.await() } } + shutdownAndWait(channel) + } + + @Test + fun request_negative_throws() { + val channel = createChannel() + val call = channel.newHelloCall() + val statusDeferred = CompletableDeferred() + val listener = createClientCallListener( + onClose = { status, _ -> statusDeferred.complete(status) } + ) + call.start(listener, GrpcMetadata()) + assertFails { call.request(-1) } + call.cancel("cleanup", null) + runBlocking { withTimeout(5000) { statusDeferred.await() } } + shutdownAndWait(channel) + } + + @Test + fun cancel_afterStart_resultsInCancelledStatus() { + val channel = createChannel() + val call = channel.newHelloCall() + val statusDeferred = CompletableDeferred() + val listener = createClientCallListener( + onClose = { status, _ -> statusDeferred.complete(status) } + ) + call.start(listener, GrpcMetadata()) + call.cancel("user cancel", null) + runBlocking { + withTimeout(10000) { + val status = statusDeferred.await() + assertEquals(StatusCode.CANCELLED, status.statusCode) + } + } + shutdownAndWait(channel) + } + + @Test + fun invalid_method_returnsNonOkStatus() { + val channel = createChannel() + val call = channel.newHelloCall("kotlinx.rpc.grpc.test.Greeter/NoSuchMethod") + val statusDeferred = CompletableDeferred() + val listener = createClientCallListener( + onClose = { status, _ -> statusDeferred.complete(status) } + ) + + call.start(listener, GrpcMetadata()) + call.sendMessage(helloReq()) + call.halfClose() + call.request(1) + runBlocking { + withTimeout(10000) { + val status = statusDeferred.await() + assertEquals(StatusCode.UNIMPLEMENTED, status.statusCode) + } + } + shutdownAndWait(channel) + } + + + @Test + fun halfCloseBeforeSendingMessage_errorWithoutCrashing() { + val channel = createChannel() + val call = channel.newHelloCall() + val listener = createClientCallListener() + assertFailsWith { + try { + call.start(listener, GrpcMetadata()) + call.halfClose() + call.sendMessage(helloReq()) + } finally { + shutdownAndWait(channel) + } + } + } + + @Test + fun invokeStartAfterShutdown() { + val channel = createChannel() + val call = channel.newHelloCall() + val statusDeferred = CompletableDeferred() + val listener = createClientCallListener( + onClose = { status, _ -> statusDeferred.complete(status) } + ) + + channel.shutdown() + runBlocking { channel.awaitTermination() } + call.start(listener, GrpcMetadata()) + call.sendMessage(helloReq()) + call.halfClose() + call.request(1) + + runBlocking { + withTimeout(10000) { + val status = statusDeferred.await() + assertEquals(StatusCode.UNAVAILABLE, status.statusCode) + } + } + } + + @Test + fun shutdownNowInMiddleOfCall() { + val channel = createChannel() + val call = channel.newHelloCall() + val statusDeferred = CompletableDeferred() + val listener = createClientCallListener( + onClose = { status, _ -> statusDeferred.complete(status) } + ) + + call.start(listener, GrpcMetadata()) + // set timeout on the server to 1000 ms, to simulate a long-running call + call.sendMessage(helloReq(1000u)) + call.halfClose() + call.request(1) + + runBlocking { + delay(100) + channel.shutdownNow() + withTimeout(10000) { + val status = statusDeferred.await() + assertEquals(StatusCode.UNAVAILABLE, status.statusCode) + } + } + } +} + +class GreeterServiceImpl : GreeterService { + + override suspend fun SayHello(message: HelloRequest): HelloReply { + delay(message.timeout?.toLong() ?: 0) + return HelloReply { + this.message = "Hello ${message.name}" + } + } + + + /** + * Run this on JVM before executing tests. + */ + @Test + fun runServer() = runTest { + val server = GrpcServer( + port = PORT, + ) { + services { + registerService { GreeterServiceImpl() } + } + } + + try { + server.start() + println("Server started") + server.awaitTermination() + } finally { + server.shutdown() + server.awaitTermination() + + } + } + +} + + +private fun createClientCallListener( + onHeaders: (headers: GrpcMetadata) -> Unit = {}, + onMessage: (message: T) -> Unit = {}, + onClose: (status: Status, trailers: GrpcMetadata) -> Unit = { _, _ -> }, + onReady: () -> Unit = {}, +) = clientCallListener( + onHeaders = onHeaders, + onMessage = onMessage, + onClose = onClose, + onReady = onReady, +) \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcCompressionTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcCompressionTest.kt new file mode 100644 index 000000000..e14276919 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcCompressionTest.kt @@ -0,0 +1,239 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.test.proto + +import kotlinx.coroutines.test.runTest +import kotlinx.rpc.RpcServer +import kotlinx.rpc.grpc.GrpcCompression +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.Status +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.get +import kotlinx.rpc.grpc.keys +import kotlinx.rpc.grpc.test.EchoRequest +import kotlinx.rpc.grpc.test.EchoService +import kotlinx.rpc.grpc.test.EchoServiceImpl +import kotlinx.rpc.grpc.test.Runtime +import kotlinx.rpc.grpc.test.assertContainsAll +import kotlinx.rpc.grpc.test.assertGrpcFailure +import kotlinx.rpc.grpc.test.captureStdErr +import kotlinx.rpc.grpc.test.clearNativeEnv +import kotlinx.rpc.grpc.test.invoke +import kotlinx.rpc.grpc.test.runtime +import kotlinx.rpc.grpc.test.setNativeEnv +import kotlinx.rpc.registerService +import kotlinx.rpc.withService +import kotlin.collections.emptyList +import kotlin.test.Test +import kotlin.test.assertEquals + +/** + * Tests that the client can configure the compression of requests. + * + * This test is hard to realize on native, as the gRPC-Core doesn't expose internal headers like + * `grpc-encoding` to the user application. This means we cannot verify that the client or sever + * actually sent those headers on native. Instead, we capture the grpc trace output (written to stderr) + * and verify that the client and server actually used the compression algorithm. + */ +class GrpcCompressionTest : GrpcProtoTest() { + override fun RpcServer.registerServices() { + return registerService { EchoServiceImpl() } + } + + @Test + fun `test gzip client compression - should succeed`() = runTest { + testCompression( + clientCompression = GrpcCompression.Gzip, + expectedEncoding = "gzip", + expectedRequestCompressionAlg = 2, + expectedRequestDecompressionAlg = 2 + ) + } + + @Test + fun `test identity compression - should not compress`() = runTest { + testCompression( + clientCompression = GrpcCompression.None, + expectedEncoding = null, + expectedRequestCompressionAlg = 0, + expectedRequestDecompressionAlg = 0 + ) + } + + @Test + fun `test no compression set - should not compress`() = runTest { + testCompression( + clientCompression = null, + expectedEncoding = null, + expectedRequestCompressionAlg = 0, + expectedRequestDecompressionAlg = 0 + ) + } + + @Test + fun `test unknown compression - should fail`() = assertGrpcFailure( + StatusCode.INTERNAL, + "Unable to find compressor by name unknownCompressionName" + ) { + runGrpcTest( + clientInterceptors = clientInterceptor { + callOptions.compression = object : GrpcCompression { + override val name: String + get() = "unknownCompressionName" + + } + proceed(it) + } + ) { client -> + client.withService().UnaryEcho(EchoRequest.invoke { message = "Unknown compression" }) + } + } + + private suspend fun testCompression( + clientCompression: GrpcCompression?, + expectedEncoding: String?, + expectedRequestCompressionAlg: Int, + expectedRequestDecompressionAlg: Int, + expectedResponseCompressionAlg: Int = 0, + expectedResponseDecompressionAlg: Int = 0 + ) { + var reqHeaders = emptyMap() + var respHeaders = emptyMap() + val logs = captureNativeGrpcLogs { + runGrpcTest( + clientInterceptors = clientInterceptor { + clientCompression?.let { compression -> + callOptions.compression = compression + } + onHeaders { headers -> respHeaders = headers.toMap() } + proceed(it) + }, + serverInterceptors = serverInterceptor { + reqHeaders = requestHeaders.toMap() + proceed(it) + } + ) { + val message = "Echo with ${clientCompression?.name}" + val response = it.withService().UnaryEcho(EchoRequest.invoke { this.message = message }) + + // Verify the call succeeded and data is correct + assertEquals(message, response.message) + } + } + + if (runtime == Runtime.NATIVE) { + // if we are on native, we need to parse the logs manually to get the `grpc-` prefixed headers + val traceHeaders = HeadersTrace.fromTrace(logs) + reqHeaders = traceHeaders.requestHeaders + respHeaders = traceHeaders.responseHeaders + + // verify that the client and server actually used the expected compression algorithm + val compression = CompressionTrace.fromTrace(logs) + assertEquals(expectedRequestCompressionAlg, compression.requestCompressionAlg) + assertEquals(expectedRequestDecompressionAlg, compression.requestDecompressionAlg) + assertEquals(expectedResponseCompressionAlg, compression.responseCompressionAlg) + assertEquals(expectedResponseDecompressionAlg, compression.responseDecompressionAlg) + } + + fun Map.grpcAcceptEncoding() = + this["grpc-accept-encoding"]?.split(",")?.map { it.trim() } ?: emptyList() + + // check request headers + if (expectedEncoding != null) { + assertEquals(expectedEncoding, reqHeaders["grpc-encoding"]) + } + assertContainsAll(listOf("gzip"), reqHeaders.grpcAcceptEncoding()) + + assertContainsAll(listOf("gzip"), respHeaders.grpcAcceptEncoding()) + } + + private suspend fun captureNativeGrpcLogs(block: suspend () -> Unit): String { + try { + return captureStdErr { + setNativeEnv("GRPC_TRACE", "compression,http") + block() + } + } finally { + clearNativeEnv("GRPC_GRACE") + } + } + + private fun GrpcMetadata.toMap(): Map { + return keys().mapNotNull { key -> + if (!key.endsWith("-bin")) { + key to this@toMap[key]!! + } else null + }.toMap() + } + + data class CompressionTrace( + val requestCompressionAlg: Int, + val requestDecompressionAlg: Int, + val responseCompressionAlg: Int, + val responseDecompressionAlg: Int + ) { + companion object { + fun fromTrace(logs: String): CompressionTrace { + val compressMessageRegex = Regex("""CompressMessage: len=\d+ alg=(\d+)""") + val decompressMessageRegex = Regex("""DecompressMessage: len=\d+ max=\d+ alg=(\d+)""") + + val compressions = compressMessageRegex.findAll(logs).map { it.groupValues[1].toInt() }.toList() + val decompressions = decompressMessageRegex.findAll(logs).map { it.groupValues[1].toInt() }.toList() + + require(compressions.size == 2) { + "Expected exactly 2 CompressMessage entries, but found ${compressions.size}" + } + require(decompressions.size == 2) { + "Expected exactly 2 DecompressMessage entries, but found ${decompressions.size}" + } + + return CompressionTrace( + requestCompressionAlg = compressions[0], + requestDecompressionAlg = decompressions[0], + responseCompressionAlg = compressions[1], + responseDecompressionAlg = decompressions[1] + ) + } + } + } + + data class HeadersTrace( + val requestHeaders: Map, + val responseHeaders: Map + ) { + companion object { + fun fromTrace(logs: String): HeadersTrace { + val metadataRegex = Regex( + """perform_stream_op\[.*SEND_INITIAL_METADATA\{([^}]+)\}""", + RegexOption.MULTILINE + ) + + val metadataBlocks = metadataRegex.findAll(logs).map { it.groupValues[1] }.toList() + + require(metadataBlocks.size == 2) { + "Expected exactly 2 SEND_INITIAL_METADATA entries, but found ${metadataBlocks.size}" + } + + return HeadersTrace( + requestHeaders = parseHeaders(metadataBlocks[0]), + responseHeaders = parseHeaders(metadataBlocks[1]) + ) + } + + private fun parseHeaders(metadataBlock: String): Map { + val headers = mutableMapOf() + val headerRegex = Regex("""([^:,]+):\s*([^,]+(?:,\s*[^:,]+)*)(?=,\s+[^:,]+:|${'$'})""") + + for (match in headerRegex.findAll(metadataBlock)) { + val key = match.groupValues[1].trim() + val value = match.groupValues[2].trim() + headers[key] = value + } + + return headers + } + } + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/utils.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/utils.kt new file mode 100644 index 000000000..f4b0ea9cf --- /dev/null +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/utils.kt @@ -0,0 +1,47 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.test + +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.StatusException +import kotlinx.rpc.grpc.statusCode +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +fun assertGrpcFailure(statusCode: StatusCode, message: String? = null, block: () -> Unit) { + val exc = assertFailsWith(message) { block() } + assertEquals(statusCode, exc.getStatus().statusCode) + if (message != null) { + assertContains(message, exc.getStatus().getDescription() ?: "") + } +} + +fun assertContainsAll(actual: Iterable, expected: Iterable) { + val expectedSet = expected.toSet() + for (element in actual) { + require(element in expectedSet) { + "Actual element '$element' not found in expected collection" + } + } +} + +enum class Runtime { + JVM, + NATIVE +} +expect val runtime: Runtime + +expect fun setNativeEnv(key: String, value: String) +expect fun clearNativeEnv(key: String) + +/** + * Captures the standard error output written during the execution of the provided suspending block. + * + * @param block A suspending lambda function whose standard error output will be captured. + * @return A string containing the captured standard error output. + */ +expect suspend fun captureStdErr(block: suspend () -> Unit): String + diff --git a/grpc/grpc-core/src/jvmTest/kotlin/kotlinx/rpc/grpc/test/utils.jvm.kt b/grpc/grpc-core/src/jvmTest/kotlin/kotlinx/rpc/grpc/test/utils.jvm.kt new file mode 100644 index 000000000..ef56db05d --- /dev/null +++ b/grpc/grpc-core/src/jvmTest/kotlin/kotlinx/rpc/grpc/test/utils.jvm.kt @@ -0,0 +1,32 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.test + +import java.io.ByteArrayOutputStream +import java.io.PrintStream + +actual val runtime: Runtime + get() = Runtime.JVM + +actual fun setNativeEnv(key: String, value: String) { + // Nothing to do on JVM +} + +actual fun clearNativeEnv(key: String) { + // Nothing to do on JVM +} + +actual suspend fun captureStdErr(block: suspend () -> Unit): String { + val orig = System.out + val baos = ByteArrayOutputStream() + System.setOut(PrintStream(baos)) + try { + block() + return baos.toString() + } finally { + System.setOut(orig) + } +} + diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.native.kt new file mode 100644 index 000000000..6f9437776 --- /dev/null +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.native.kt @@ -0,0 +1,328 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +@file:OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class, ExperimentalEncodingApi::class) + +package kotlinx.rpc.grpc + +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.NativePlacement +import kotlinx.cinterop.addressOf +import kotlinx.cinterop.alloc +import kotlinx.cinterop.allocArray +import kotlinx.cinterop.convert +import kotlinx.cinterop.get +import kotlinx.cinterop.ptr +import kotlinx.cinterop.usePinned +import kotlinx.io.Buffer +import kotlinx.io.Source +import kotlinx.io.readByteArray +import kotlinx.rpc.grpc.codec.MessageCodec +import kotlinx.rpc.grpc.codec.SourcedMessageCodec +import kotlinx.rpc.grpc.internal.toByteArray +import kotlinx.rpc.internal.utils.InternalRpcApi +import kotlinx.rpc.protobuf.input.stream.asInputStream +import libkgrpc.grpc_metadata +import libkgrpc.grpc_metadata_array +import libkgrpc.grpc_metadata_array_init +import libkgrpc.grpc_slice_from_copied_buffer +import libkgrpc.grpc_slice_from_copied_string +import libkgrpc.grpc_slice_ref +import libkgrpc.grpc_slice_unref +import libkgrpc.kgrpc_metadata_array_append +import kotlin.experimental.ExperimentalNativeApi +import kotlin.io.encoding.Base64 +import kotlin.io.encoding.ExperimentalEncodingApi + +public actual class GrpcMetadataKey actual constructor(name: String, public val codec: MessageCodec) { + public val name: String = name.lowercase() + internal val isBinary get() = name.endsWith("-bin") + + internal fun encode(value: T): ByteArray = codec.encode(value).buffer.readByteArray() + internal fun decode(value: ByteArray): T = Buffer().let { buffer -> + buffer.write(value) + codec.decode(buffer.asInputStream()) + } + + internal fun validateForString() { + validateName() + require(!isBinary) { "String header is named ${name}. It must not end with '-bin'" } + } + + internal fun validateForBinary() { + validateName() + require(isBinary) { "Binary header is named ${name}. It must end with '-bin'" } + } + + internal companion object +} + +@Suppress(names = ["RedundantConstructorKeyword"]) +public actual class GrpcMetadata actual constructor() { + internal val map: LinkedHashMap> = linkedMapOf() + + public constructor(raw: grpc_metadata_array) : this() { + for (i in 0 until raw.count.toInt()) { + val metadata = raw.metadata?.get(i) + if (metadata != null) { + val key = metadata.key.toByteArray().toAsciiString() + val value = metadata.value.toByteArray() + map.getOrPut(key) { mutableListOf() }.add(value) + } + } + } + + @InternalRpcApi + public fun NativePlacement.allocRawGrpcMetadata(): grpc_metadata_array { + val raw = alloc() + grpc_metadata_array_init(raw.ptr) + + // the sum of all values + val entryCount = map.entries.sumOf { it.value.size } + + raw.count = 0u + raw.capacity = entryCount.convert() + raw.metadata = allocArray(entryCount) + + map.entries.forEach { (key, values) -> + val keySlice = grpc_slice_from_copied_string(key) + + for (entry in values) { + val size = entry.size.toULong() + val valSlice = entry.usePinned { pinned -> + grpc_slice_from_copied_buffer(pinned.addressOf(0), size) + } + // we create a fresh reference for each entry + val keySliceRef = grpc_slice_ref(keySlice) + + check(kgrpc_metadata_array_append(raw.ptr, keySliceRef, valSlice)) { + "Failed to append metadata to array" + } + } + + // we unref/drop the original keySlice, as it isn't used anymore + grpc_slice_unref(keySlice) + } + + return raw + } + + @OptIn(ExperimentalEncodingApi::class) + override fun toString(): String { + val sb = StringBuilder("Metadata(") + var first = true + for ((key, values) in map) { + for (value in values) { + if (!first) { + sb.append(',') + } + first = false + sb.append(key).append('=') + if (key.endsWith("-bin")) { + sb.append(Base64.encode(value)) + } else { + sb.append(value.toAsciiString()) + } + } + } + return sb.append(')').toString() + } + +} + +public actual operator fun GrpcMetadata.get(key: String): String? { + return get(key.toAsciiKey()) +} + +public actual operator fun GrpcMetadata.get(key: GrpcMetadataKey): T? { + key.validateForString() + return map[key.name]?.lastOrNull()?.let { + key.decode(it) + } +} + +public actual fun GrpcMetadata.getBinary(key: String): ByteArray? { + val key = key.toBinaryKey() + key.validateForBinary() + return map[key.name]?.lastOrNull() +} + +public actual fun GrpcMetadata.getBinary(key: GrpcMetadataKey): T? { + key.validateForBinary() + return map[key.name]?.lastOrNull()?.let { + key.decode(it) + } +} + +public actual fun GrpcMetadata.getAll(key: String): List { + return getAll(key.toAsciiKey()) +} + +public actual fun GrpcMetadata.getAll(key: GrpcMetadataKey): List { + key.validateForString() + return map[key.name]?.map { key.decode(it) } ?: emptyList() +} + +public actual fun GrpcMetadata.getAllBinary(key: String): List { + val key = key.toBinaryKey() + key.validateForBinary() + return map[key.name] ?: emptyList() +} + +public actual fun GrpcMetadata.getAllBinary(key: GrpcMetadataKey): List { + key.validateForBinary() + return map[key.name]?.map { key.decode(it) } ?: emptyList() +} + +public actual operator fun GrpcMetadata.contains(key: String): Boolean { + return map.containsKey(key.lowercase()) +} + +public actual fun GrpcMetadata.keys(): Set { + return map.entries.filter { it.value.isNotEmpty() }.mapTo(mutableSetOf()) { it.key } +} + +public actual fun GrpcMetadata.append(key: String, value: String) { + append(key.toAsciiKey(), value) +} + +public actual fun GrpcMetadata.append(key: GrpcMetadataKey, value: T) { + key.validateForString() + map.getOrPut(key.name) { mutableListOf() }.add(key.encode(value)) +} + +public actual fun GrpcMetadata.appendBinary(key: String, value: ByteArray) { + val key = key.toBinaryKey() + key.validateForBinary() + map.getOrPut(key.name) { mutableListOf() }.add(value) +} + +public actual fun GrpcMetadata.appendBinary(key: GrpcMetadataKey, value: T) { + key.validateForBinary() + map.getOrPut(key.name) { mutableListOf() }.add(key.encode(value)) +} + +public actual fun GrpcMetadata.remove(key: String, value: String): Boolean { + return remove(key.toAsciiKey(), value) +} + +public actual fun GrpcMetadata.remove(key: GrpcMetadataKey, value: T): Boolean { + key.validateForString() + val index = getAll(key).indexOf(value) + if (index == -1) return false + map[key.name]!!.removeAt(index) + return true +} + +public actual fun GrpcMetadata.removeBinary(key: String, value: ByteArray): Boolean { + val keyObj = key.toBinaryKey() + keyObj.validateForBinary() + val index = getAllBinary(key).indexOf(value) + if (index == -1) return false + map[keyObj.name]!!.removeAt(index) + return true +} + +public actual fun GrpcMetadata.removeBinary(key: GrpcMetadataKey, value: T): Boolean { + key.validateForBinary() + val index = getAllBinary(key).indexOf(value) + if (index == -1) return false + map[key.name]!!.removeAt(index) + return true +} + +public actual fun GrpcMetadata.removeAll(key: String): List { + return removeAll(key.toAsciiKey()) +} + +public actual fun GrpcMetadata.removeAll(key: GrpcMetadataKey): List { + key.validateForString() + return map.remove(key.name)?.map { key.decode(it) } ?: emptyList() +} + +public actual fun GrpcMetadata.removeAllBinary(key: String): List { + return removeAllBinary(key.toBinaryKey()) +} + +public actual fun GrpcMetadata.removeAllBinary(key: GrpcMetadataKey): List { + key.validateForBinary() + return map.remove(key.name)?.map { key.decode(it) } ?: emptyList() +} + +public actual fun GrpcMetadata.merge(other: GrpcMetadata) { + for ((key, values) in other.map) { + map.getOrPut(key) { mutableListOf() }.addAll(values) + } +} + +/** + * Converts the ByteArray to a string containing only ASCII characters. + * For bytes within the ASCII range (0x00 to 0x7F), the corresponding character is used. + * For bytes outside this range, the replacement character '�' (`\uFFFD`) is used. + * + * @return A string representation of the ByteArray, + * where non-ASCII bytes are replaced with '�' (`\uFFFD`). + */ +private fun ByteArray.toAsciiString(): String { + return buildString(size) { + for (b in this@toAsciiString) { + val ub = b.toInt() and 0xFF + append(if (ub in 0..0x7F) ub.toChar() else '\uFFFD') + } + } +} + +/** + * Converts the string to a byte array encoded in US-ASCII. + * Characters outside the ASCII range are replaced with the '?' character. + * + * @return a byte array representing the ASCII-encoded version of the string + */ +private fun String.toAsciiBytes(): ByteArray { + // encode as US_ASCII bytes, replacing non-ASCII chars with '?' + return ByteArray(length) { idx -> + val c = this[idx] + if (c.code in 0..0x7F) c.code.toByte() else '?'.code.toByte() + } +} + +@OptIn(ObsoleteNativeApi::class) +private val VALID_KEY_CHARS by lazy { + BitSet(0x7f).apply { + set('-'.code) + set('_'.code) + set('.'.code) + set('0'.code..'9'.code) + set('a'.code..'z'.code) + } +} + +@OptIn(ObsoleteNativeApi::class) +private fun GrpcMetadataKey.validateName() { + for (char in name) { + require(VALID_KEY_CHARS[char.code]) { "Header is named $name. It contains illegal character $char." } + } +} + +private val AsciiCodec = object : SourcedMessageCodec { + override fun encodeToSource(value: String): Source = Buffer().apply { + write(value.toAsciiBytes()) + } + + override fun decodeFromSource(stream: Source): String = stream.use { buffer -> + buffer.readByteArray().toAsciiString() + } +} + +private val BinaryCodec = object : SourcedMessageCodec { + override fun encodeToSource(value: ByteArray): Source = Buffer().apply { + write(value) + } + + override fun decodeFromSource(stream: Source): ByteArray = stream.readByteArray() +} + +private fun String.toAsciiKey() = GrpcMetadataKey(this, AsciiCodec) +private fun String.toBinaryKey() = GrpcMetadataKey(this, BinaryCodec) + diff --git a/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/test/utils.native.kt b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/test/utils.native.kt new file mode 100644 index 000000000..17565693a --- /dev/null +++ b/grpc/grpc-core/src/nativeTest/kotlin/kotlinx/rpc/grpc/test/utils.native.kt @@ -0,0 +1,77 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +@file:OptIn(ExperimentalForeignApi::class) + +package kotlinx.rpc.grpc.test + +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.IntVar +import kotlinx.cinterop.allocArray +import kotlinx.cinterop.convert +import kotlinx.cinterop.get +import kotlinx.cinterop.memScoped +import kotlinx.cinterop.refTo +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.IO +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import platform.posix.STDERR_FILENO +import platform.posix.close +import platform.posix.dup +import platform.posix.dup2 +import platform.posix.fflush +import platform.posix.fprintf +import platform.posix.pipe +import platform.posix.read +import platform.posix.stderr + +actual val runtime: Runtime + get() = Runtime.NATIVE + +actual fun setNativeEnv(key: String, value: String) { + platform.posix.setenv(key, value, 1) +} + +actual fun clearNativeEnv(key: String) { + platform.posix.unsetenv(key) +} + +actual suspend fun captureStdErr(block: suspend () -> Unit): String = coroutineScope { + memScoped { + val pipeErr = allocArray(2) + check(pipe(pipeErr) == 0) { "pipe stderr failed" } + + val savedStderr = dup(STDERR_FILENO) + + // redirect stderr write end + check(dup2(pipeErr[1], STDERR_FILENO) != -1) { "dup2 stderr failed" } + close(pipeErr[1]) + + val outputBuf = StringBuilder() + val readJob = launch(Dispatchers.IO) { + val buf = ByteArray(4096) + var r: Long + do { + r = read(pipeErr[0], buf.refTo(0), buf.size.convert()) + if (r > 0) outputBuf.append(buf.decodeToString(0, r.convert())) + } while (r > 0) + close(pipeErr[0]) + } + + try { + block() + } finally { + fflush(stderr) + // restore stderr + dup2(savedStderr, STDERR_FILENO) + close(savedStderr) + } + + // wait reading to finish + readJob.join() + outputBuf.toString() + } +} +