Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.rpc.RpcCall
import kotlinx.rpc.RpcClient
import kotlinx.rpc.grpc.GrpcMetadata
import kotlinx.rpc.grpc.client.GrpcCallOptions
import kotlinx.rpc.grpc.client.internal.ManagedChannel
import kotlinx.rpc.grpc.client.internal.ManagedChannelBuilder
import kotlinx.rpc.grpc.client.internal.applyConfig
import kotlinx.rpc.grpc.client.internal.bidirectionalStreamingRpc
import kotlinx.rpc.grpc.client.internal.buildChannel
import kotlinx.rpc.grpc.client.internal.clientStreamingRpc
Expand All @@ -27,6 +27,7 @@ import kotlinx.rpc.grpc.descriptor.MethodType
import kotlinx.rpc.grpc.descriptor.methodType
import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

private typealias RequestClient = Any

Expand Down Expand Up @@ -180,9 +181,7 @@ private fun GrpcClient(
builder: ManagedChannelBuilder<*>,
config: GrpcClientConfiguration,
): GrpcClient {
val channel = builder.apply {
config.overrideAuthority?.let { overrideAuthority(it) }
}.buildChannel()
val channel = builder.applyConfig(config).buildChannel()
return GrpcClient(channel, config.messageCodecResolver, config.interceptors)
}

Expand All @@ -198,6 +197,7 @@ private fun GrpcClient(
*/
public class GrpcClientConfiguration internal constructor() {
internal val interceptors: MutableList<ClientInterceptor> = mutableListOf()
internal var keepAlive: KeepAlive? = null

/**
* Configurable resolver used to determine the appropriate codec for a given Kotlin type
Expand Down Expand Up @@ -294,4 +294,65 @@ public class GrpcClientConfiguration internal constructor() {
public fun tls(configure: TlsClientCredentialsBuilder.() -> Unit): ClientCredentials =
TlsClientCredentials(configure)

}
/**
* Configures keep-alive settings for the gRPC client.
*
* Keep-alive allows you to fine-tune the behavior of the client to ensure the connection
* between the client and server remains active according to specific parameters.
*
* By default, keep-alive is disabled.
*
* ```
* GrpcClient("localhost", 50051) {
* keepAlive {
* time = 10.seconds
* timeout = 20.seconds
* withoutCalls = false
* }
* }
* ```
*
* @param configure A lambda to apply custom configurations to the [KeepAlive] instance.
* The [KeepAlive] settings include:
* - `time`: The maximum amount of time that the channel can be idle before a keep-alive
* ping is sent.
* - `timeout`: The time allowed for a keep-alive ping to complete.
* - `withoutCalls`: Whether to send keep-alive pings even when there are no outstanding
* RPCs on the connection.
*
* @see KeepAlive
*/
public fun keepAlive(configure: KeepAlive.() -> Unit) {
keepAlive = KeepAlive().apply(configure)
}

/**
* Represents keep-alive settings for a gRPC client connection.
*
* Keep-alive ensures that the connection between the client and the server remains active.
* It helps detect connection issues proactively before a request is made and facilitates
* maintaining long-lived idle connections.
*
* Client authors must coordinate with service owners for whether a particular client-side
* setting is acceptable.
*
* @property time Specifies the maximum amount of time the channel can remain idle before a
* keep-alive ping is sent to the server to check the connection state.
* The default value is `Duration.INFINITE`, which disables keep-alive pings when idle.
*
* @property timeout Sets the amount of time to wait for a keep-alive ping response.
* If the server does not respond within this timeout, the connection will be considered broken.
* The default value is 20 seconds.
*
* @property withoutCalls Defines whether keep-alive pings will be sent even when there
* are no active RPCs on the connection. If set to `true`, pings will be sent regardless
* of ongoing calls; otherwise, pings are only sent during active RPCs.
* The default value is `false`.
*/
public class KeepAlive internal constructor() {
public var time: Duration = Duration.INFINITE
public var timeout: Duration = 20.seconds
public var withoutCalls: Boolean = false
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package kotlinx.rpc.grpc.client.internal

import kotlinx.rpc.grpc.client.ClientCredentials
import kotlinx.rpc.grpc.client.GrpcClientConfiguration
import kotlinx.rpc.internal.utils.InternalRpcApi
import kotlin.time.Duration

Expand Down Expand Up @@ -71,9 +72,7 @@ public interface ManagedChannel {
* Builder class for [ManagedChannel].
*/
@InternalRpcApi
public expect abstract class ManagedChannelBuilder<T : ManagedChannelBuilder<T>> {
public abstract fun overrideAuthority(authority: String): T
}
public expect abstract class ManagedChannelBuilder<T : ManagedChannelBuilder<T>>

@InternalRpcApi
public expect fun ManagedChannelBuilder(
Expand All @@ -88,5 +87,7 @@ public expect fun ManagedChannelBuilder(
credentials: ClientCredentials? = null,
): ManagedChannelBuilder<*>

internal expect fun ManagedChannelBuilder<*>.applyConfig(config: GrpcClientConfiguration): ManagedChannelBuilder<*>

@InternalRpcApi
public expect fun ManagedChannelBuilder<*>.buildChannel(): ManagedChannel
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.grpc.Grpc
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.rpc.grpc.client.ClientCredentials
import kotlinx.rpc.grpc.client.GrpcClientConfiguration
import kotlinx.rpc.internal.utils.InternalRpcApi
import java.util.concurrent.TimeUnit
import kotlin.time.Duration
Expand Down Expand Up @@ -80,3 +81,14 @@ private class JvmManagedChannel(private val channel: io.grpc.ManagedChannel) : M
override val platformApi: ManagedChannelPlatform
get() = channel
}

internal actual fun ManagedChannelBuilder<*>.applyConfig(config: GrpcClientConfiguration): ManagedChannelBuilder<*> {
config.keepAlive?.let {
keepAliveTime(it.time.inWholeMilliseconds, TimeUnit.MILLISECONDS)
keepAliveTimeout(it.timeout.inWholeMilliseconds, TimeUnit.MILLISECONDS)
keepAliveWithoutCalls(it.withoutCalls)
}

config.overrideAuthority?.let { overrideAuthority(it) }
return this
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package kotlinx.rpc.grpc.client.internal

import kotlinx.rpc.grpc.client.ClientCredentials
import kotlinx.rpc.grpc.client.GrpcClientConfiguration
import kotlinx.rpc.grpc.client.TlsClientCredentials
import kotlinx.rpc.grpc.internal.internalError
import kotlinx.rpc.internal.utils.InternalRpcApi
Expand All @@ -22,25 +23,23 @@ public actual abstract class ManagedChannelPlatform : GrpcChannel()
*/
@InternalRpcApi
public actual abstract class ManagedChannelBuilder<T : ManagedChannelBuilder<T>> {
public actual abstract fun overrideAuthority(authority: String): T
internal var config: GrpcClientConfiguration? = null
}

internal class NativeManagedChannelBuilder(
private val target: String,
private var credentials: Lazy<ClientCredentials>,
) : ManagedChannelBuilder<NativeManagedChannelBuilder>() {

private var authority: String? = null

override fun overrideAuthority(authority: String): NativeManagedChannelBuilder {
this.authority = authority
return this
}

fun buildChannel(): NativeManagedChannel {
val keepAlive = config?.keepAlive
keepAlive?.run {
require(time.isPositive()) { "keepalive time must be positive" }
require(timeout.isPositive()) { "keepalive timeout must be positive" }
}
return NativeManagedChannel(
target,
authority = authority,
authority = config?.overrideAuthority,
keepAlive = config?.keepAlive,
credentials = credentials.value,
)
}
Expand Down Expand Up @@ -69,4 +68,7 @@ public actual fun ManagedChannelBuilder(target: String, credentials: ClientCrede
return NativeManagedChannelBuilder(target, credentials)
}


internal actual fun ManagedChannelBuilder<*>.applyConfig(config: GrpcClientConfiguration): ManagedChannelBuilder<*> {
this.config = config
return this
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import cnames.structs.grpc_channel
import kotlinx.atomicfu.atomic
import kotlinx.cinterop.CPointer
import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.MemScope
import kotlinx.cinterop.alloc
import kotlinx.cinterop.allocArray
import kotlinx.cinterop.convert
import kotlinx.cinterop.cstr
import kotlinx.cinterop.memScoped
import kotlinx.cinterop.ptr
Expand All @@ -21,6 +24,7 @@ import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.rpc.grpc.client.ClientCredentials
import kotlinx.rpc.grpc.client.GrpcCallOptions
import kotlinx.rpc.grpc.client.GrpcClientConfiguration
import kotlinx.rpc.grpc.client.rawDeadline
import kotlinx.rpc.grpc.descriptor.MethodDescriptor
import kotlinx.rpc.grpc.internal.CompletionQueue
Expand Down Expand Up @@ -50,6 +54,7 @@ import kotlin.time.Duration
internal class NativeManagedChannel(
target: String,
val authority: String?,
val keepAlive: GrpcClientConfiguration.KeepAlive?,
// we must store them, otherwise the credentials are getting released
credentials: ClientCredentials,
) : ManagedChannel, ManagedChannelPlatform() {
Expand All @@ -66,22 +71,36 @@ internal class NativeManagedChannel(
private val cq = CompletionQueue()

internal val raw: CPointer<grpc_channel> = memScoped {
val args = authority?.let {
val args = mutableListOf<GrpcArg>()

authority?.let {
// the C Core API doesn't have a way to override the authority (used for TLS SNI) as it
// is available in the Java gRPC implementation.
// instead, it can be done by setting the "grpc.ssl_target_name_override" argument.
val authorityOverride = alloc<grpc_arg> {
type = grpc_arg_type.GRPC_ARG_STRING
key = "grpc.ssl_target_name_override".cstr.ptr
value.string = authority.cstr.ptr
}
args.add(GrpcArg.Str(
key = "grpc.ssl_target_name_override",
value = it
))
}

alloc<grpc_channel_args> {
num_args = 1u
args = authorityOverride.ptr
}
keepAlive?.let {
args.add(GrpcArg.Integer(
key = "grpc.keepalive_time_ms",
value = it.time.inWholeMilliseconds.convert()
))
args.add(GrpcArg.Integer(
key = "grpc.keepalive_timeout_ms",
value = it.timeout.inWholeMilliseconds.convert()
))
args.add(GrpcArg.Integer(
key = "grpc.keepalive_permit_without_calls",
value = if (it.withoutCalls) 1 else 0
))
}
grpc_channel_create(target, credentials.raw, args?.ptr)

var rawArgs = if (args.isNotEmpty()) args.toRaw(this) else null

grpc_channel_create(target, credentials.raw, rawArgs?.ptr)
?: error("Failed to create channel")
}

Expand Down Expand Up @@ -170,3 +189,33 @@ internal class NativeManagedChannel(
}

}

internal sealed class GrpcArg(val key: String) {
internal class Str(key: String, val value: String) : GrpcArg(key)
internal class Integer(key: String, val value: Int) : GrpcArg(key)

internal val rawType: grpc_arg_type
get() = when (this) {
is Str -> grpc_arg_type.GRPC_ARG_STRING
is Integer -> grpc_arg_type.GRPC_ARG_INTEGER
}
}

private fun List<GrpcArg>.toRaw(memScope: MemScope): grpc_channel_args {
with(memScope) {
val arr = allocArray<grpc_arg>(size) {
val arg = get(it)
type = arg.rawType
key = arg.key.cstr.ptr
when (arg) {
is GrpcArg.Str -> value.string = arg.value.cstr.ptr
is GrpcArg.Integer -> value.integer = arg.value.convert()
}
}

return alloc<grpc_channel_args> {
num_args = size.convert()
args = arr
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import kotlinx.coroutines.test.runTest
import kotlinx.rpc.RpcServer
import kotlinx.rpc.grpc.GrpcCompression
import kotlinx.rpc.grpc.GrpcMetadata
import kotlinx.rpc.grpc.Status
import kotlinx.rpc.grpc.StatusCode
import kotlinx.rpc.grpc.get
import kotlinx.rpc.grpc.keys
Expand All @@ -18,11 +17,9 @@ import kotlinx.rpc.grpc.test.EchoServiceImpl
import kotlinx.rpc.grpc.test.Runtime
import kotlinx.rpc.grpc.test.assertContainsAll
import kotlinx.rpc.grpc.test.assertGrpcFailure
import kotlinx.rpc.grpc.test.captureStdErr
import kotlinx.rpc.grpc.test.clearNativeEnv
import kotlinx.rpc.grpc.test.captureGrpcLogs
import kotlinx.rpc.grpc.test.invoke
import kotlinx.rpc.grpc.test.runtime
import kotlinx.rpc.grpc.test.setNativeEnv
import kotlinx.rpc.registerService
import kotlinx.rpc.withService
import kotlin.collections.emptyList
Expand Down Expand Up @@ -101,7 +98,10 @@ class GrpcCompressionTest : GrpcProtoTest() {
) {
var reqHeaders = emptyMap<String, String>()
var respHeaders = emptyMap<String, String>()
val logs = captureNativeGrpcLogs {
val logs = captureGrpcLogs(
nativeTracers = listOf("compression", "http"),
jvmLoggers = emptyList(),
) {
runGrpcTest(
clientInterceptors = clientInterceptor {
clientCompression?.let { compression ->
Expand Down Expand Up @@ -149,17 +149,6 @@ class GrpcCompressionTest : GrpcProtoTest() {
assertContainsAll(listOf("gzip"), respHeaders.grpcAcceptEncoding())
}

private suspend fun captureNativeGrpcLogs(block: suspend () -> Unit): String {
try {
return captureStdErr {
setNativeEnv("GRPC_TRACE", "compression,http")
block()
}
} finally {
clearNativeEnv("GRPC_GRACE")
}
}

private fun GrpcMetadata.toMap(): Map<String, String> {
return keys().mapNotNull { key ->
if (!key.endsWith("-bin")) {
Expand Down
Loading
Loading