From 744bf2fd812cca1e97718d84a21c6290f5fb4f5c Mon Sep 17 00:00:00 2001 From: irving ou Date: Thu, 2 Apr 2026 15:39:26 -0400 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20QUIC=20agent=20tunnel=20=E2=80=94?= =?UTF-8?q?=20protocol,=20listener,=20agent=20client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add QUIC-based agent tunnel core infrastructure. Agents in private networks connect outbound to Gateway via QUIC/mTLS, advertise reachable subnets and domains, and proxy TCP connections on behalf of Gateway. Protocol (agent-tunnel-proto crate): - RouteAdvertise with subnets + domain advertisements - ConnectMessage/ConnectResponse for session stream setup - Heartbeat/HeartbeatAck for liveness detection - Protocol version negotiation (v2) Gateway (agent_tunnel module): - QUIC listener with mTLS authentication - Agent registry with subnet/domain tracking - Certificate authority for agent enrollment - Enrollment token store (one-time tokens) - Bidirectional proxy stream multiplexing Agent (devolutions-agent): - QUIC client with auto-reconnect and exponential backoff - Agent enrollment with config merge (preserves existing settings) - Domain auto-detection (Windows: USERDNSDOMAIN, Linux: resolv.conf) - Subnet validation on incoming connections - Certificate file permissions (0o600 on Unix) API endpoints: - POST /jet/agent-tunnel/enroll — agent enrollment - GET /jet/agent-tunnel/agents — list agents - GET /jet/agent-tunnel/agents/{id} — get agent - DELETE /jet/agent-tunnel/agents/{id} — delete agent - POST /jet/agent-tunnel/agents/resolve-target — routing diagnostics Co-Authored-By: Claude Opus 4.6 (1M context) --- Cargo.lock | 393 +++++++-- crates/agent-tunnel-proto/Cargo.toml | 21 + crates/agent-tunnel-proto/src/control.rs | 308 +++++++ crates/agent-tunnel-proto/src/error.rs | 15 + crates/agent-tunnel-proto/src/lib.rs | 24 + crates/agent-tunnel-proto/src/session.rs | 215 +++++ crates/agent-tunnel-proto/src/version.rs | 37 + devolutions-agent/Cargo.toml | 16 +- devolutions-agent/src/config.rs | 75 +- devolutions-agent/src/domain_detect.rs | 112 +++ devolutions-agent/src/enrollment.rs | 192 +++++ devolutions-agent/src/lib.rs | 3 + devolutions-agent/src/main.rs | 235 +++++- devolutions-agent/src/service.rs | 7 +- devolutions-agent/src/tunnel.rs | 512 ++++++++++++ devolutions-gateway/Cargo.toml | 15 + devolutions-gateway/src/agent_tunnel/cert.rs | 433 ++++++++++ .../src/agent_tunnel/enrollment_store.rs | 126 +++ .../src/agent_tunnel/listener.rs | 336 ++++++++ devolutions-gateway/src/agent_tunnel/mod.rs | 15 + .../src/agent_tunnel/registry.rs | 773 ++++++++++++++++++ .../src/agent_tunnel/stream.rs | 37 + .../src/api/agent_enrollment.rs | 302 +++++++ devolutions-gateway/src/api/mod.rs | 2 + devolutions-gateway/src/api/webapp.rs | 1 + devolutions-gateway/src/config.rs | 39 + devolutions-gateway/src/extract.rs | 58 ++ devolutions-gateway/src/generic_client.rs | 74 ++ devolutions-gateway/src/lib.rs | 3 + devolutions-gateway/src/listener.rs | 1 + devolutions-gateway/src/middleware/auth.rs | 8 + devolutions-gateway/src/ngrok.rs | 1 + devolutions-gateway/src/rd_clean_path.rs | 4 +- devolutions-gateway/src/service.rs | 32 +- devolutions-gateway/src/token.rs | 16 +- devolutions-gateway/tests/config.rs | 6 + 36 files changed, 4348 insertions(+), 99 deletions(-) create mode 100644 crates/agent-tunnel-proto/Cargo.toml create mode 100644 crates/agent-tunnel-proto/src/control.rs create mode 100644 crates/agent-tunnel-proto/src/error.rs create mode 100644 crates/agent-tunnel-proto/src/lib.rs create mode 100644 crates/agent-tunnel-proto/src/session.rs create mode 100644 crates/agent-tunnel-proto/src/version.rs create mode 100644 devolutions-agent/src/domain_detect.rs create mode 100644 devolutions-agent/src/enrollment.rs create mode 100644 devolutions-agent/src/tunnel.rs create mode 100644 devolutions-gateway/src/agent_tunnel/cert.rs create mode 100644 devolutions-gateway/src/agent_tunnel/enrollment_store.rs create mode 100644 devolutions-gateway/src/agent_tunnel/listener.rs create mode 100644 devolutions-gateway/src/agent_tunnel/mod.rs create mode 100644 devolutions-gateway/src/agent_tunnel/registry.rs create mode 100644 devolutions-gateway/src/agent_tunnel/stream.rs create mode 100644 devolutions-gateway/src/api/agent_enrollment.rs diff --git a/Cargo.lock b/Cargo.lock index 46b826db1..6b1ba94b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,7 +36,7 @@ checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", "cipher 0.4.4", - "cpufeatures 0.2.17", + "cpufeatures", ] [[package]] @@ -47,7 +47,7 @@ checksum = "7e713c57c2a2b19159e7be83b9194600d7e8eb3b7c2cd67e671adf47ce189a05" dependencies = [ "cfg-if", "cipher 0.5.0-rc.1", - "cpufeatures 0.2.17", + "cpufeatures", ] [[package]] @@ -74,6 +74,19 @@ dependencies = [ "const-oid 0.10.2", ] +[[package]] +name = "agent-tunnel-proto" +version = "0.0.0" +dependencies = [ + "bincode", + "ipnetwork", + "proptest", + "serde", + "thiserror 2.0.18", + "tokio 1.49.0", + "uuid", +] + [[package]] name = "ahash" version = "0.8.12" @@ -167,7 +180,7 @@ checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" dependencies = [ "base64ct", "blake2", - "cpufeatures 0.2.17", + "cpufeatures", "password-hash", ] @@ -177,13 +190,29 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive 0.5.1", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "asn1-rs" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60" dependencies = [ - "asn1-rs-derive", + "asn1-rs-derive 0.6.0", "asn1-rs-impl", "displaydoc", "nom", @@ -192,6 +221,18 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2 1.0.106", + "quote 1.0.44", + "syn 2.0.114", + "synstructure", +] + [[package]] name = "asn1-rs-derive" version = "0.6.0" @@ -837,6 +878,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "ceviche" version = "0.7.0" @@ -885,18 +932,7 @@ checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" dependencies = [ "cfg-if", "cipher 0.4.4", - "cpufeatures 0.2.17", -] - -[[package]] -name = "chacha20" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" -dependencies = [ - "cfg-if", - "cpufeatures 0.3.0", - "rand_core 0.10.0", + "cpufeatures", ] [[package]] @@ -906,7 +942,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ "aead 0.5.2", - "chacha20 0.9.1", + "chacha20", "cipher 0.4.4", "poly1305", "zeroize", @@ -1020,6 +1056,16 @@ dependencies = [ "cc", ] +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes 1.11.1", + "memchr", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1082,15 +1128,6 @@ dependencies = [ "libc", ] -[[package]] -name = "cpufeatures" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" -dependencies = [ - "libc", -] - [[package]] name = "crc32fast" version = "1.5.0" @@ -1310,7 +1347,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f9200d1d13637f15a6acb71e758f64624048d85b31a5fdbfd8eca1e2687d0b7" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "curve25519-dalek-derive", "digest 0.11.0-rc.3", "fiat-crypto", @@ -1330,6 +1367,20 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -1380,13 +1431,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs 0.6.2", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "der-parser" version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", "displaydoc", "nom", "num-traits", @@ -1438,9 +1503,12 @@ dependencies = [ name = "devolutions-agent" version = "2026.1.1" dependencies = [ + "agent-tunnel-proto", "anyhow", "async-trait", "aws-lc-rs", + "base64 0.22.1", + "bincode", "bytes 1.11.1", "camino", "ceviche", @@ -1454,16 +1522,22 @@ dependencies = [ "futures", "hex", "http-client-proxy", + "ipnetwork", "ironrdp", "notify-debouncer-mini", "parking_lot", + "quinn", "rand 0.8.5", + "rcgen", "reqwest", + "rustls 0.23.37", "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde", "serde_json", "sha2 0.10.9", "tap", + "tempfile", "thiserror 2.0.18", "tokio 1.49.0", "tokio-rustls", @@ -1492,12 +1566,15 @@ dependencies = [ name = "devolutions-gateway" version = "2026.1.1" dependencies = [ + "agent-tunnel-proto", "anyhow", "argon2", "async-trait", "axum 0.8.8", "axum-extra", "backoff", + "base64 0.22.1", + "bincode", "bitflags 2.10.0", "bytes 1.11.1", "cadeau", @@ -1505,6 +1582,7 @@ dependencies = [ "ceviche", "cfg-if", "chacha20poly1305", + "dashmap", "devolutions-agent-shared", "devolutions-gateway-generators", "devolutions-gateway-task", @@ -1521,6 +1599,7 @@ dependencies = [ "http-client-proxy", "hyper 1.8.1", "hyper-util", + "ipnetwork", "ironrdp-acceptor", "ironrdp-connector", "ironrdp-core", @@ -1539,15 +1618,22 @@ dependencies = [ "nonempty", "parking_lot", "pcap-file", + "pem", "picky", "picky-krb", "pin-project-lite 0.2.17", "portpicker", "proptest", + "quinn", + "rand 0.8.5", + "rcgen", "reqwest", "rstest", + "rustls 0.23.37", "rustls-cng", "rustls-native-certs", + "rustls-pemfile 2.2.0", + "rustls-pki-types", "secrecy 0.10.3", "secure-memory", "serde", @@ -1586,6 +1672,7 @@ dependencies = [ "video-streamer", "windows-sys 0.61.2", "x509-cert", + "x509-parser", "zeroize", ] @@ -2002,16 +2089,16 @@ dependencies = [ [[package]] name = "embed-resource" -version = "3.0.8" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63a1d0de4f2249aa0ff5884d7080814f446bb241a559af6c170a41e878ed2d45" +checksum = "55a075fc573c64510038d7ee9abc7990635863992f83ebc52c8b433b8411a02e" dependencies = [ "cc", "memchr", "rustc_version", "toml", "vswhom", - "winreg", + "winreg 0.55.0", ] [[package]] @@ -2090,6 +2177,18 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fastbloom" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7f34442dbe69c60fe8eaf58a8cafff81a1f278816d8ab4db255b3bef4ac3c4" +dependencies = [ + "getrandom 0.3.4", + "libm", + "rand 0.9.2", + "siphasher", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -2377,7 +2476,6 @@ dependencies = [ "cfg-if", "libc", "r-efi", - "rand_core 0.10.0", "wasip2", "wasip3", ] @@ -3119,15 +3217,14 @@ dependencies = [ [[package]] name = "ipconfig" -version = "0.3.4" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d40460c0ce33d6ce4b0630ad68ff63d6661961c48b6dba35e5a4d81cfb48222" +checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" dependencies = [ - "socket2 0.6.2", + "socket2 0.5.10", "widestring 1.2.1", - "windows-registry 0.6.1", - "windows-result 0.4.1", - "windows-sys 0.61.2", + "windows-sys 0.48.0", + "winreg 0.50.0", ] [[package]] @@ -3136,6 +3233,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "iri-string" version = "0.7.10" @@ -3295,7 +3401,7 @@ dependencies = [ "bit_field", "bitflags 2.10.0", "byteorder", - "der-parser", + "der-parser 10.0.0", "ironrdp-core", "ironrdp-error", "md-5 0.10.6", @@ -3523,6 +3629,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "job-queue" version = "0.0.0" @@ -3593,7 +3721,7 @@ version = "0.2.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d546793a04a1d3049bd192856f804cfe96356e2cf36b54b4e575155babe9f41" dependencies = [ - "cpufeatures 0.2.17", + "cpufeatures", ] [[package]] @@ -3734,9 +3862,9 @@ dependencies = [ [[package]] name = "libsql" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30fe980ac5693ed1f3db490559fb578885e913a018df64af8a1a46e1959a78df" +checksum = "2329faffc510cc3c6b4f00169a39177cc7099d3ed7647fc92f7cf26e53a8d976" dependencies = [ "anyhow", "async-stream", @@ -3773,9 +3901,9 @@ dependencies = [ [[package]] name = "libsql-ffi" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0be1da6f123ceb2cd23f469883415cab9ee963286a85d61e22afb8b12e15e681" +checksum = "6cd1c1662822495393327856774f6803be25d85bfdcd5b9d4af35458f5daaf75" dependencies = [ "bindgen", "cc", @@ -3785,9 +3913,9 @@ dependencies = [ [[package]] name = "libsql-hrana" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3358538b52cfcf9af4fe7aeb57d6843aafed2e8af80807bd636fd1448e94ea7" +checksum = "646d0aa75e412769018422f0da798f72e93bd51964f0b2ddad4317aa779ae444" dependencies = [ "base64 0.21.7", "bytes 1.11.1", @@ -3797,9 +3925,9 @@ dependencies = [ [[package]] name = "libsql-rusqlite" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b646f94fc1d266e481c38a2d44d6d9d1be3ad04b56b90457acfb310dc450030e" +checksum = "5a4ce3a78c6e3c2b23b02ab6272df8340e1c53380497979d456882254f348d5f" dependencies = [ "bitflags 2.10.0", "fallible-iterator 0.2.0", @@ -3829,9 +3957,9 @@ dependencies = [ [[package]] name = "libsql-sys" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90725458cc4461bc82f8f7983e80b002ea4f64b5184e1462f252d0dd74b122f5" +checksum = "2a3c326fcfc36fe7578238d5ee6b58c529f8c76372acd61ec50267529cdaff95" dependencies = [ "bytes 1.11.1", "libsql-ffi", @@ -3843,9 +3971,9 @@ dependencies = [ [[package]] name = "libsql_replication" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bba5c9b3a26aca06d70f6a3646ba341cf574a548355353fe135af524b1b77cc" +checksum = "1d9a2e469ac8400659bd31f81a745908bcc5cb6b40be2f2ff8de90b15bec5501" dependencies = [ "aes 0.8.4", "async-stream", @@ -4541,6 +4669,15 @@ dependencies = [ "serde", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs 0.6.2", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -4565,9 +4702,9 @@ checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] name = "openssl" -version = "0.10.76" +version = "0.10.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" dependencies = [ "bitflags 2.10.0", "cfg-if", @@ -4603,9 +4740,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.112" +version = "0.9.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" dependencies = [ "cc", "libc", @@ -4731,6 +4868,16 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64 0.22.1", + "serde_core", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -5111,7 +5258,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" dependencies = [ - "cpufeatures 0.2.17", + "cpufeatures", "opaque-debug", "universal-hash 0.5.1", ] @@ -5123,7 +5270,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ffd40cc99d0fbb02b4b3771346b811df94194bc103983efa0203c8893755085" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "universal-hash 0.6.0-rc.2", ] @@ -5162,9 +5309,9 @@ dependencies = [ [[package]] name = "postgres-types" -version = "0.2.13" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dc729a129e682e8d24170cd30ae1aa01b336b096cbb56df6d534ffec133d186" +checksum = "54b858f82211e84682fecd373f68e1ceae642d8d751a1ebd13f33de6257b3e20" dependencies = [ "bytes 1.11.1", "chrono", @@ -5412,7 +5559,7 @@ dependencies = [ "system-configuration-sys 0.6.0", "url", "windows-sys 0.61.2", - "winreg", + "winreg 0.55.0", ] [[package]] @@ -5457,6 +5604,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ "bytes 1.11.1", + "fastbloom", "getrandom 0.3.4", "lru-slab", "rand 0.9.2", @@ -5464,6 +5612,7 @@ dependencies = [ "rustc-hash 2.1.1", "rustls 0.23.37", "rustls-pki-types", + "rustls-platform-verifier", "slab", "thiserror 2.0.18", "tinyvec", @@ -5536,17 +5685,6 @@ dependencies = [ "rand_core 0.9.5", ] -[[package]] -name = "rand" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" -dependencies = [ - "chacha20 0.10.0", - "getrandom 0.4.1", - "rand_core 0.10.0", -] - [[package]] name = "rand_chacha" version = "0.3.1" @@ -5585,12 +5723,6 @@ dependencies = [ "getrandom 0.3.4", ] -[[package]] -name = "rand_core" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" - [[package]] name = "rand_xorshift" version = "0.4.0" @@ -5629,6 +5761,20 @@ dependencies = [ "cipher 0.5.0-rc.1", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring 0.17.14", + "rustls-pki-types", + "time", + "x509-parser", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -6003,6 +6149,33 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-platform-verifier" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" +dependencies = [ + "core-foundation 0.10.1", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls 0.23.37", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + [[package]] name = "rustls-webpki" version = "0.103.9" @@ -6359,7 +6532,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.10.7", ] @@ -6370,7 +6543,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5e046edf639aa2e7afb285589e5405de2ef7e61d4b0ac1e30256e3eab911af9" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.11.0-rc.3", ] @@ -6381,7 +6554,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.10.7", ] @@ -6392,7 +6565,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d1e3878ab0f98e35b2df35fe53201d088299b41a6bb63e3e34dada2ac4abd924" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.11.0-rc.3", ] @@ -7067,9 +7240,9 @@ dependencies = [ [[package]] name = "tokio-postgres" -version = "0.7.17" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dd8df5ef180f6364759a6f00f7aadda4fbbac86cdee37480826a6ff9f3574ce" +checksum = "dcea47c8f71744367793f16c2db1f11cb859d28f436bdb4ca9193eb1f787ee42" dependencies = [ "async-trait", "byteorder", @@ -7084,7 +7257,7 @@ dependencies = [ "pin-project-lite 0.2.17", "postgres-protocol", "postgres-types", - "rand 0.10.0", + "rand 0.9.2", "socket2 0.6.2", "tokio 1.49.0", "tokio-util", @@ -7571,9 +7744,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.23" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -8164,6 +8337,15 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "0.26.11" @@ -8791,6 +8973,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + [[package]] name = "winreg" version = "0.55.0" @@ -8951,6 +9143,24 @@ dependencies = [ "tls_codec", ] +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs 0.6.2", + "data-encoding", + "der-parser 9.0.0", + "lazy_static", + "nom", + "oid-registry", + "ring 0.17.14", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "xmf-sys" version = "0.4.0" @@ -8960,6 +9170,15 @@ dependencies = [ "dlib", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.8.1" diff --git a/crates/agent-tunnel-proto/Cargo.toml b/crates/agent-tunnel-proto/Cargo.toml new file mode 100644 index 000000000..5822f5908 --- /dev/null +++ b/crates/agent-tunnel-proto/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "agent-tunnel-proto" +version = "0.0.0" +authors = ["Devolutions Inc. "] +edition = "2024" +publish = false + +[lints] +workspace = true + +[dependencies] +bincode = "1.3" +ipnetwork = "0.20" +serde = { version = "1", features = ["derive"] } +thiserror = "2.0" +tokio = { version = "1.45", features = ["io-util"] } +uuid = { version = "1.17", features = ["v4", "serde"] } + +[dev-dependencies] +proptest = "1.7" +tokio = { version = "1.45", features = ["rt", "macros"] } diff --git a/crates/agent-tunnel-proto/src/control.rs b/crates/agent-tunnel-proto/src/control.rs new file mode 100644 index 000000000..3fe35358b --- /dev/null +++ b/crates/agent-tunnel-proto/src/control.rs @@ -0,0 +1,308 @@ +use ipnetwork::Ipv4Network; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}; + +use crate::error::ProtoError; +use crate::version::CURRENT_PROTOCOL_VERSION; + +/// Maximum encoded message size (1 MiB) to prevent denial-of-service via oversized frames. +pub const MAX_CONTROL_MESSAGE_SIZE: u32 = 1024 * 1024; + +/// A DNS domain advertisement with its source. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct DomainAdvertisement { + /// The DNS domain (e.g., "contoso.local"). + pub domain: String, + /// Whether this domain was auto-detected (`true`) or explicitly configured (`false`). + pub auto_detected: bool, +} + +/// Control-plane messages exchanged over the dedicated control stream (stream ID 0). +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ControlMessage { + /// Agent advertises subnets and domains it can reach. + RouteAdvertise { + protocol_version: u16, + /// Monotonically increasing epoch within this agent process lifetime. + epoch: u64, + /// Reachable IPv4 subnets. + subnets: Vec, + /// DNS domains this agent can resolve, with source tracking. + domains: Vec, + }, + + /// Periodic liveness probe. + Heartbeat { + protocol_version: u16, + /// Milliseconds since UNIX epoch (sender's wall clock). + timestamp_ms: u64, + /// Number of currently active proxy streams on this connection. + active_stream_count: u32, + }, + + /// Acknowledgement to a Heartbeat. + HeartbeatAck { + protocol_version: u16, + /// Echoed timestamp from the corresponding Heartbeat. + timestamp_ms: u64, + }, +} + +impl ControlMessage { + /// Create a new RouteAdvertise with the current protocol version. + pub fn route_advertise(epoch: u64, subnets: Vec, domains: Vec) -> Self { + Self::RouteAdvertise { + protocol_version: CURRENT_PROTOCOL_VERSION, + epoch, + subnets, + domains, + } + } + + /// Create a new Heartbeat with the current protocol version. + pub fn heartbeat(timestamp_ms: u64, active_stream_count: u32) -> Self { + Self::Heartbeat { + protocol_version: CURRENT_PROTOCOL_VERSION, + timestamp_ms, + active_stream_count, + } + } + + /// Create a new HeartbeatAck with the current protocol version. + pub fn heartbeat_ack(timestamp_ms: u64) -> Self { + Self::HeartbeatAck { + protocol_version: CURRENT_PROTOCOL_VERSION, + timestamp_ms, + } + } + + /// Length-prefixed bincode encode and write to an async writer. + pub async fn encode(&self, writer: &mut W) -> Result<(), ProtoError> { + let payload = bincode::serialize(self)?; + let len = u32::try_from(payload.len()).map_err(|_| ProtoError::MessageTooLarge { + size: u32::MAX, + max: MAX_CONTROL_MESSAGE_SIZE, + })?; + if MAX_CONTROL_MESSAGE_SIZE < len { + return Err(ProtoError::MessageTooLarge { + size: len, + max: MAX_CONTROL_MESSAGE_SIZE, + }); + } + writer.write_all(&len.to_be_bytes()).await?; + writer.write_all(&payload).await?; + writer.flush().await?; + Ok(()) + } + + /// Read and decode a length-prefixed bincode message from an async reader. + pub async fn decode(reader: &mut R) -> Result { + let mut len_buf = [0u8; 4]; + reader.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf); + + if MAX_CONTROL_MESSAGE_SIZE < len { + return Err(ProtoError::MessageTooLarge { + size: len, + max: MAX_CONTROL_MESSAGE_SIZE, + }); + } + + let mut payload = vec![0u8; len as usize]; + reader.read_exact(&mut payload).await?; + let msg: Self = bincode::deserialize(&payload)?; + Ok(msg) + } + + /// Extract the protocol version from any variant. + pub fn protocol_version(&self) -> u16 { + match self { + Self::RouteAdvertise { protocol_version, .. } + | Self::Heartbeat { protocol_version, .. } + | Self::HeartbeatAck { protocol_version, .. } => *protocol_version, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn roundtrip_route_advertise() { + let msg = ControlMessage::route_advertise( + 42, + vec![ + "10.0.0.0/8".parse().expect("valid CIDR"), + "192.168.1.0/24".parse().expect("valid CIDR"), + ], + vec![], + ); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_route_advertise_with_domains() { + let msg = ControlMessage::route_advertise( + 42, + vec!["10.0.0.0/8".parse().expect("valid CIDR")], + vec![ + DomainAdvertisement { + domain: "contoso.local".to_owned(), + auto_detected: false, + }, + DomainAdvertisement { + domain: "finance.contoso.local".to_owned(), + auto_detected: true, + }, + ], + ); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + + match &decoded { + ControlMessage::RouteAdvertise { domains, .. } => { + assert_eq!(domains.len(), 2); + assert_eq!(domains[0].domain, "contoso.local"); + assert!(!domains[0].auto_detected); + assert_eq!(domains[1].domain, "finance.contoso.local"); + assert!(domains[1].auto_detected); + } + _ => panic!("expected RouteAdvertise"), + } + } + + #[tokio::test] + async fn roundtrip_route_advertise_empty_domains() { + let msg = ControlMessage::route_advertise(1, vec!["192.168.1.0/24".parse().expect("valid CIDR")], vec![]); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_heartbeat() { + let msg = ControlMessage::heartbeat(1_700_000_000_000, 5); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_heartbeat_ack() { + let msg = ControlMessage::heartbeat_ack(1_700_000_000_000); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn reject_oversized_message() { + // Craft a length prefix that exceeds the maximum + let bad_len = (MAX_CONTROL_MESSAGE_SIZE + 1).to_be_bytes(); + let mut buf = bad_len.to_vec(); + buf.extend_from_slice(&[0u8; 32]); // dummy payload + + let result = ControlMessage::decode(&mut buf.as_slice()).await; + assert!(result.is_err()); + } +} + +#[cfg(test)] +mod proptests { + use proptest::prelude::*; + + use super::*; + use crate::version::CURRENT_PROTOCOL_VERSION; + + fn arb_ipv4_network() -> impl Strategy { + (any::<[u8; 4]>(), 0u8..=32).prop_map(|(octets, prefix)| { + let ip = std::net::Ipv4Addr::from(octets); + // Use network() to normalize the address for the given prefix + Ipv4Network::new(ip, prefix) + .map(|n| Ipv4Network::new(n.network(), prefix).expect("normalized network should be valid")) + .unwrap_or_else(|_| Ipv4Network::new(std::net::Ipv4Addr::UNSPECIFIED, 0).expect("0.0.0.0/0 is valid")) + }) + } + + fn arb_domain_advertisement() -> impl Strategy { + ("[a-z]{3,10}\\.[a-z]{2,5}", any::()) + .prop_map(|(domain, auto_detected)| DomainAdvertisement { domain, auto_detected }) + } + + fn arb_control_message() -> impl Strategy { + prop_oneof![ + ( + any::(), + proptest::collection::vec(arb_ipv4_network(), 0..50), + proptest::collection::vec(arb_domain_advertisement(), 0..5), + ) + .prop_map(|(epoch, subnets, domains)| { + ControlMessage::RouteAdvertise { + protocol_version: CURRENT_PROTOCOL_VERSION, + epoch, + subnets, + domains, + } + }), + (any::(), any::()).prop_map(|(timestamp_ms, active_stream_count)| { + ControlMessage::Heartbeat { + protocol_version: CURRENT_PROTOCOL_VERSION, + timestamp_ms, + active_stream_count, + } + }), + any::().prop_map(|timestamp_ms| ControlMessage::HeartbeatAck { + protocol_version: CURRENT_PROTOCOL_VERSION, + timestamp_ms, + }), + ] + } + + proptest! { + #[test] + fn control_message_roundtrip(msg in arb_control_message()) { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().expect("tokio runtime"); + rt.block_on(async { + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + let decoded = ControlMessage::decode(&mut buf.as_slice()).await.expect("decode should succeed"); + prop_assert_eq!(msg, decoded); + Ok(()) + })?; + } + } +} diff --git a/crates/agent-tunnel-proto/src/error.rs b/crates/agent-tunnel-proto/src/error.rs new file mode 100644 index 000000000..e7c10e92e --- /dev/null +++ b/crates/agent-tunnel-proto/src/error.rs @@ -0,0 +1,15 @@ +/// Protocol-level errors for the agent tunnel. +#[derive(Debug, thiserror::Error)] +pub enum ProtoError { + #[error("unsupported protocol version {received} (supported: {min}..={max})")] + UnsupportedVersion { received: u16, min: u16, max: u16 }, + + #[error("message too large: {size} bytes (max: {max})")] + MessageTooLarge { size: u32, max: u32 }, + + #[error("bincode encode/decode error: {0}")] + Bincode(#[from] bincode::Error), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), +} diff --git a/crates/agent-tunnel-proto/src/lib.rs b/crates/agent-tunnel-proto/src/lib.rs new file mode 100644 index 000000000..2ae1853bc --- /dev/null +++ b/crates/agent-tunnel-proto/src/lib.rs @@ -0,0 +1,24 @@ +//! Protocol definitions for the QUIC-based agent tunnel. +//! +//! This crate defines the binary protocol exchanged between Gateway and Agent +//! over QUIC streams. All messages use length-prefixed bincode encoding and +//! carry a `protocol_version` field for forward compatibility. +//! +//! ## Stream model +//! +//! - **Control stream** (QUIC stream 0): carries [`ControlMessage`] variants +//! (route advertisements, heartbeats). +//! - **Session streams** (QUIC streams 1..N): each stream proxies one TCP +//! connection. The first message is a [`ConnectMessage`] from Gateway, +//! followed by a [`ConnectResponse`] from Agent. After a successful +//! response, raw TCP bytes flow bidirectionally. + +pub mod control; +pub mod error; +pub mod session; +pub mod version; + +pub use control::{ControlMessage, DomainAdvertisement, MAX_CONTROL_MESSAGE_SIZE}; +pub use error::ProtoError; +pub use session::{ConnectMessage, ConnectResponse, MAX_SESSION_MESSAGE_SIZE}; +pub use version::{CURRENT_PROTOCOL_VERSION, MIN_SUPPORTED_VERSION, validate_protocol_version}; diff --git a/crates/agent-tunnel-proto/src/session.rs b/crates/agent-tunnel-proto/src/session.rs new file mode 100644 index 000000000..d202f0f1a --- /dev/null +++ b/crates/agent-tunnel-proto/src/session.rs @@ -0,0 +1,215 @@ +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}; +use uuid::Uuid; + +use crate::error::ProtoError; +use crate::version::CURRENT_PROTOCOL_VERSION; + +/// Maximum encoded session message size (64 KiB). +pub const MAX_SESSION_MESSAGE_SIZE: u32 = 64 * 1024; + +/// Length-prefixed bincode encode and write to an async writer. +async fn encode_framed(msg: &T, writer: &mut W) -> Result<(), ProtoError> { + let payload = bincode::serialize(msg)?; + let len = u32::try_from(payload.len()).map_err(|_| ProtoError::MessageTooLarge { + size: u32::MAX, + max: MAX_SESSION_MESSAGE_SIZE, + })?; + if MAX_SESSION_MESSAGE_SIZE < len { + return Err(ProtoError::MessageTooLarge { + size: len, + max: MAX_SESSION_MESSAGE_SIZE, + }); + } + writer.write_all(&len.to_be_bytes()).await?; + writer.write_all(&payload).await?; + writer.flush().await?; + Ok(()) +} + +/// Read and decode a length-prefixed bincode message from an async reader. +async fn decode_framed(reader: &mut R) -> Result { + let mut len_buf = [0u8; 4]; + reader.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf); + + if MAX_SESSION_MESSAGE_SIZE < len { + return Err(ProtoError::MessageTooLarge { + size: len, + max: MAX_SESSION_MESSAGE_SIZE, + }); + } + + let mut payload = vec![0u8; len as usize]; + reader.read_exact(&mut payload).await?; + let msg: T = bincode::deserialize(&payload)?; + Ok(msg) +} + +/// Request from Gateway to Agent to open a TCP connection to a target. +/// +/// Sent as the first message on a newly opened QUIC bidirectional stream. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ConnectMessage { + pub protocol_version: u16, + /// Association/session ID from the Gateway. + pub session_id: Uuid, + /// Target address in `host:port` form (e.g., `"192.168.1.100:3389"`). + pub target: String, +} + +/// Agent's response to a ConnectMessage. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub enum ConnectResponse { + Success { protocol_version: u16 }, + Error { protocol_version: u16, reason: String }, +} + +impl ConnectMessage { + pub fn new(session_id: Uuid, target: String) -> Self { + Self { + protocol_version: CURRENT_PROTOCOL_VERSION, + session_id, + target, + } + } + + /// Length-prefixed bincode encode and write to an async writer. + pub async fn encode(&self, writer: &mut W) -> Result<(), ProtoError> { + encode_framed(self, writer).await + } + + /// Read and decode a length-prefixed bincode message from an async reader. + pub async fn decode(reader: &mut R) -> Result { + decode_framed(reader).await + } +} + +impl ConnectResponse { + pub fn success() -> Self { + Self::Success { + protocol_version: CURRENT_PROTOCOL_VERSION, + } + } + + pub fn error(reason: impl Into) -> Self { + Self::Error { + protocol_version: CURRENT_PROTOCOL_VERSION, + reason: reason.into(), + } + } + + pub fn is_success(&self) -> bool { + matches!(self, Self::Success { .. }) + } + + /// Length-prefixed bincode encode and write to an async writer. + pub async fn encode(&self, writer: &mut W) -> Result<(), ProtoError> { + encode_framed(self, writer).await + } + + /// Read and decode a length-prefixed bincode message from an async reader. + pub async fn decode(reader: &mut R) -> Result { + decode_framed(reader).await + } + + /// Extract the protocol version from any variant. + pub fn protocol_version(&self) -> u16 { + match self { + Self::Success { protocol_version } | Self::Error { protocol_version, .. } => *protocol_version, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn roundtrip_connect_message() { + let msg = ConnectMessage::new(Uuid::new_v4(), "192.168.1.100:3389".to_owned()); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ConnectMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_connect_response_success() { + let msg = ConnectResponse::success(); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ConnectResponse::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_connect_response_error() { + let msg = ConnectResponse::error("connection refused"); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ConnectResponse::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } +} + +#[cfg(test)] +mod proptests { + use proptest::prelude::*; + + use super::*; + + fn arb_connect_message() -> impl Strategy { + ("[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}:[0-9]{1,5}") + .prop_map(|target| ConnectMessage::new(Uuid::new_v4(), target)) + } + + fn arb_connect_response() -> impl Strategy { + prop_oneof![Just(ConnectResponse::success()), ".*".prop_map(ConnectResponse::error),] + } + + proptest! { + #[test] + fn connect_message_roundtrip(msg in arb_connect_message()) { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().expect("tokio runtime"); + rt.block_on(async { + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + let decoded = ConnectMessage::decode(&mut buf.as_slice()).await.expect("decode should succeed"); + // Compare fields individually because UUID is generated fresh + prop_assert_eq!(&msg.target, &decoded.target); + prop_assert_eq!(msg.protocol_version, decoded.protocol_version); + prop_assert_eq!(msg.session_id, decoded.session_id); + Ok(()) + })?; + } + + #[test] + fn connect_response_roundtrip(msg in arb_connect_response()) { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().expect("tokio runtime"); + rt.block_on(async { + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + let decoded = ConnectResponse::decode(&mut buf.as_slice()).await.expect("decode should succeed"); + prop_assert_eq!(msg, decoded); + Ok(()) + })?; + } + } +} diff --git a/crates/agent-tunnel-proto/src/version.rs b/crates/agent-tunnel-proto/src/version.rs new file mode 100644 index 000000000..dcae0852b --- /dev/null +++ b/crates/agent-tunnel-proto/src/version.rs @@ -0,0 +1,37 @@ +/// Current protocol version. +pub const CURRENT_PROTOCOL_VERSION: u16 = 2; + +/// Minimum protocol version that is still accepted. +pub const MIN_SUPPORTED_VERSION: u16 = 2; + +/// Validate that a received protocol version is within the supported range. +pub fn validate_protocol_version(version: u16) -> Result<(), crate::error::ProtoError> { + if version < MIN_SUPPORTED_VERSION || CURRENT_PROTOCOL_VERSION < version { + return Err(crate::error::ProtoError::UnsupportedVersion { + received: version, + min: MIN_SUPPORTED_VERSION, + max: CURRENT_PROTOCOL_VERSION, + }); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn accept_current_version() { + assert!(validate_protocol_version(CURRENT_PROTOCOL_VERSION).is_ok()); + } + + #[test] + fn reject_zero_version() { + assert!(validate_protocol_version(0).is_err()); + } + + #[test] + fn reject_future_version() { + assert!(validate_protocol_version(CURRENT_PROTOCOL_VERSION + 1).is_err()); + } +} diff --git a/devolutions-agent/Cargo.toml b/devolutions-agent/Cargo.toml index a93df50c8..1c2ef7f4f 100644 --- a/devolutions-agent/Cargo.toml +++ b/devolutions-agent/Cargo.toml @@ -12,8 +12,11 @@ publish = false workspace = true [dependencies] +agent-tunnel-proto = { path = "../crates/agent-tunnel-proto" } anyhow = "1" async-trait = "0.1" +bincode = "1.3" +base64 = "0.22" bytes = "1" camino = { version = "1.1", features = ["serde1"] } ceviche = "0.7" @@ -23,15 +26,22 @@ devolutions-gateway-task = { path = "../crates/devolutions-gateway-task" } devolutions-log = { path = "../crates/devolutions-log" } futures = "0.3" http-client-proxy = { path = "../crates/http-client-proxy" } +ipnetwork = "0.20" parking_lot = "0.12" +quinn = "0.11" rand = "0.8" # FIXME(@CBenoit): maybe we don't need this crate -rustls-pemfile = "2.2" # FIXME(@CBenoit): maybe we don't need this crate +rcgen = { version = "0.13", features = ["pem"] } +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "http2", "socks", "json"] } +rustls = { version = "0.23", default-features = false, features = ["std", "ring"] } +rustls-pemfile = "2.2" +rustls-pki-types = "1" serde_json = "1" serde = { version = "1", features = ["derive"] } tap = "1.0" tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"] } tracing = "0.1" url = { version = "2.5", features = ["serde"] } +uuid = { version = "1.17", features = ["v4", "serde"] } [dependencies.ironrdp] version = "0.14" @@ -72,6 +82,7 @@ features = [ "Win32_Foundation", "Win32_Storage_FileSystem", "Win32_Security", + "Win32_System_SystemInformation", "Win32_System_Threading", "Win32_Security_Cryptography", "Win32_Security_Authorization", @@ -82,5 +93,8 @@ features = [ [target.'cfg(windows)'.build-dependencies] embed-resource = "3.0" +[dev-dependencies] +tempfile = "3" + [target.'cfg(windows)'.dev-dependencies] expect-test = "1.5" diff --git a/devolutions-agent/src/config.rs b/devolutions-agent/src/config.rs index 1632fcc60..826117208 100644 --- a/devolutions-agent/src/config.rs +++ b/devolutions-agent/src/config.rs @@ -20,6 +20,7 @@ pub struct Conf { pub remote_desktop: RemoteDesktopConf, pub pedm: dto::PedmConf, pub session: dto::SessionConf, + pub tunnel: dto::TunnelConf, pub proxy: dto::ProxyConf, pub debug: dto::DebugConf, } @@ -48,6 +49,7 @@ impl Conf { remote_desktop, pedm: conf_file.pedm.clone().unwrap_or_default(), session: conf_file.session.clone().unwrap_or_default(), + tunnel: conf_file.tunnel.clone().unwrap_or_default(), proxy: conf_file.proxy.clone().unwrap_or_default(), debug: conf_file.debug.clone().unwrap_or_default(), }) @@ -143,14 +145,14 @@ impl ConfHandle { } } -fn save_config(conf: &dto::ConfFile) -> anyhow::Result<()> { +pub fn save_config(conf: &dto::ConfFile) -> anyhow::Result<()> { let conf_file_path = get_conf_file_path(); let json = serde_json::to_string_pretty(conf).context("failed JSON serialization of configuration")?; std::fs::write(&conf_file_path, json).with_context(|| format!("failed to write file at {conf_file_path}"))?; Ok(()) } -fn get_conf_file_path() -> Utf8PathBuf { +pub fn get_conf_file_path() -> Utf8PathBuf { get_data_dir().join("agent.json") } @@ -273,6 +275,70 @@ pub mod dto { } } + #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + #[serde(rename_all = "PascalCase")] + pub struct TunnelConf { + /// Enable tunnel module + pub enabled: bool, + + /// Gateway QUIC endpoint (e.g., "gateway.example.com:4433") + #[serde(default, skip_serializing_if = "String::is_empty")] + pub gateway_endpoint: String, + + /// Client certificate path (issued during enrollment) + #[serde(skip_serializing_if = "Option::is_none")] + pub client_cert_path: Option, + + /// Client private key path + #[serde(skip_serializing_if = "Option::is_none")] + pub client_key_path: Option, + + /// Gateway CA certificate path + #[serde(skip_serializing_if = "Option::is_none")] + pub gateway_ca_cert_path: Option, + + /// Subnets to advertise (e.g., ["10.0.0.0/8", "192.168.1.0/24"]) + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub advertise_subnets: Vec, + + /// DNS domains to advertise (e.g., ["contoso.local"]). Auto-detected if omitted. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub advertise_domains: Vec, + + /// Whether to auto-detect the machine's DNS domain and add it to advertise_domains (default: true) + #[serde(default = "default_true")] + pub auto_detect_domain: bool, + + /// Heartbeat interval in seconds (default: 60) + #[serde(skip_serializing_if = "Option::is_none")] + pub heartbeat_interval_secs: Option, + + /// Route advertise interval in seconds (default: 30) + #[serde(skip_serializing_if = "Option::is_none")] + pub route_advertise_interval_secs: Option, + } + + fn default_true() -> bool { + true + } + + impl Default for TunnelConf { + fn default() -> Self { + Self { + enabled: false, + gateway_endpoint: String::new(), + client_cert_path: None, + client_key_path: None, + gateway_ca_cert_path: None, + advertise_subnets: Vec::new(), + advertise_domains: Vec::new(), + auto_detect_domain: true, + heartbeat_interval_secs: Some(60), + route_advertise_interval_secs: Some(30), + } + } + } + /// Source of truth for Agent configuration /// /// This struct represents the JSON file used for configuration as close as possible @@ -304,6 +370,10 @@ pub mod dto { #[serde(default, skip_serializing_if = "Option::is_none")] pub session: Option, + /// Agent Tunnel configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub tunnel: Option, + /// HTTP/SOCKS proxy configuration for outbound requests #[serde(skip_serializing_if = "Option::is_none")] pub proxy: Option, @@ -330,6 +400,7 @@ pub mod dto { proxy: None, debug: None, session: Some(SessionConf { enabled: false }), + tunnel: None, rest: serde_json::Map::new(), } } diff --git a/devolutions-agent/src/domain_detect.rs b/devolutions-agent/src/domain_detect.rs new file mode 100644 index 000000000..33f4f8423 --- /dev/null +++ b/devolutions-agent/src/domain_detect.rs @@ -0,0 +1,112 @@ +//! Auto-detection of the machine's DNS domain for agent tunnel domain advertisement. + +/// Attempts to detect the DNS domain this machine belongs to. +/// +/// Returns `None` if detection fails or the result is clearly not a valid domain +/// (e.g., ISP domain, empty string, single-label name). +pub fn detect_domain() -> Option { + let raw = detect_domain_raw()?; + let trimmed = raw.trim().trim_end_matches('.').to_ascii_lowercase(); + if is_plausible_domain(&trimmed) { + Some(trimmed) + } else { + None + } +} + +/// Returns `true` if the detected domain looks like a legitimate internal domain +/// (not a TLD, has at least two labels, all labels non-empty). +fn is_plausible_domain(domain: &str) -> bool { + let trimmed = domain.trim_end_matches('.'); + if trimmed.is_empty() { + return false; + } + let mut parts = trimmed.split('.'); + parts.next().is_some_and(|l| !l.is_empty()) && parts.next().is_some_and(|l| !l.is_empty()) +} + +#[cfg(target_os = "windows")] +fn detect_domain_raw() -> Option { + // Try USERDNSDOMAIN first (available in user logon sessions) + if let Ok(domain) = std::env::var("USERDNSDOMAIN") + && !domain.is_empty() + { + return Some(domain); + } + + // Fallback: GetComputerNameExW(ComputerNameDnsDomain) + // This works in SYSTEM service context where USERDNSDOMAIN is empty. + detect_domain_via_computer_name() +} + +#[cfg(target_os = "windows")] +fn detect_domain_via_computer_name() -> Option { + use windows::Win32::System::SystemInformation::{ComputerNameDnsDomain, GetComputerNameExW}; + use windows::core::PWSTR; + + // First call: get required buffer size. Expected to fail with ERROR_MORE_DATA. + let mut size = 0u32; + + // SAFETY: Passing null buffer with zero size to query required length. + // GetComputerNameExW writes the required size to `size` and returns ERROR_MORE_DATA. + let _ = unsafe { GetComputerNameExW(ComputerNameDnsDomain, None, &mut size) }; + + if size == 0 { + return None; + } + + let mut buf = vec![0u16; size as usize]; + + // SAFETY: `buf` is allocated with `size` elements. GetComputerNameExW writes at most + // `size` wide chars and updates `size` to the actual length (excluding null terminator). + let result = unsafe { GetComputerNameExW(ComputerNameDnsDomain, Some(PWSTR(buf.as_mut_ptr())), &mut size) }; + + if result.is_err() { + return None; + } + + let domain = String::from_utf16_lossy(&buf[..size as usize]); + + if domain.is_empty() { None } else { Some(domain) } +} + +#[cfg(not(target_os = "windows"))] +fn detect_domain_raw() -> Option { + let content = std::fs::read_to_string("/etc/resolv.conf").ok()?; + for line in content.lines() { + let line = line.trim(); + if let Some(rest) = line.strip_prefix("search ").or_else(|| line.strip_prefix("domain ")) + && let Some(domain) = rest.split_whitespace().next() + && !domain.is_empty() + { + return Some(domain.to_owned()); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn plausible_domain_accepts_typical_ad_domain() { + assert!(is_plausible_domain("contoso.local")); + assert!(is_plausible_domain("corp.contoso.com")); + assert!(is_plausible_domain("ad.it-help.ninja")); + } + + #[test] + fn plausible_domain_rejects_garbage() { + assert!(!is_plausible_domain("")); + assert!(!is_plausible_domain("local")); + assert!(!is_plausible_domain("com")); + assert!(!is_plausible_domain(".")); + assert!(!is_plausible_domain("..")); + } + + #[test] + fn plausible_domain_handles_trailing_dot() { + assert!(is_plausible_domain("contoso.local.")); + } +} diff --git a/devolutions-agent/src/enrollment.rs b/devolutions-agent/src/enrollment.rs new file mode 100644 index 000000000..98d1a9a3a --- /dev/null +++ b/devolutions-agent/src/enrollment.rs @@ -0,0 +1,192 @@ +//! Agent enrollment logic for QUIC tunnel. +//! +//! This module handles the enrollment process where an agent registers with +//! the Gateway and receives its client certificate and configuration. + +use anyhow::Context as _; +use camino::{Utf8Path, Utf8PathBuf}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::config; + +/// Request body for enrollment API +#[derive(Serialize)] +struct EnrollRequest { + /// Friendly name for the agent + agent_name: String, + /// PEM-encoded Certificate Signing Request + csr_pem: String, +} + +/// Response from enrollment API +#[derive(Deserialize)] +struct EnrollResponse { + agent_id: Uuid, + agent_name: String, + client_cert_pem: String, + gateway_ca_cert_pem: String, + quic_endpoint: String, +} + +#[derive(Debug, Clone)] +pub struct PersistedEnrollment { + pub agent_id: Uuid, + pub agent_name: String, + pub client_cert_path: Utf8PathBuf, + pub client_key_path: Utf8PathBuf, + pub gateway_ca_path: Utf8PathBuf, + pub quic_endpoint: String, +} + +/// Enroll an agent with the Gateway and save the configuration. +/// +/// # Arguments +/// * `gateway_url` - Base Gateway URL (e.g., "https://gateway.example.com:7171") +/// * `enrollment_token` - JWT token for enrollment +/// * `agent_name` - Friendly name for this agent +/// * `advertise_subnets` - List of subnets to advertise (e.g., ["10.0.0.0/8"]) +pub async fn enroll_agent( + gateway_url: &str, + enrollment_token: &str, + agent_name: &str, + advertise_subnets: Vec, +) -> anyhow::Result<()> { + bootstrap_and_persist(gateway_url, enrollment_token, agent_name, advertise_subnets).await?; + Ok(()) +} + +pub async fn bootstrap_and_persist( + gateway_url: &str, + enrollment_token: &str, + agent_name: &str, + advertise_subnets: Vec, +) -> anyhow::Result { + // Generate key pair and CSR locally — the private key never leaves this machine. + let (key_pem, csr_pem) = generate_key_and_csr(agent_name)?; + + let enroll_response = request_enrollment(gateway_url, enrollment_token, agent_name, &csr_pem).await?; + persist_enrollment_response(advertise_subnets, enroll_response, &key_pem) +} + +/// Generate an ECDSA P-256 key pair and a CSR containing the agent name as CN. +/// +/// Returns `(key_pem, csr_pem)`. The private key stays on the agent; only the +/// CSR is sent to the gateway. +fn generate_key_and_csr(agent_name: &str) -> anyhow::Result<(String, String)> { + let key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).context("generate agent key pair")?; + let key_pem = key_pair.serialize_pem(); + + let mut params = rcgen::CertificateParams::default(); + params.distinguished_name.push(rcgen::DnType::CommonName, agent_name); + + let csr = params.serialize_request(&key_pair).context("generate CSR")?; + let csr_pem = csr.pem().context("encode CSR to PEM")?; + + Ok((key_pem, csr_pem)) +} + +async fn request_enrollment( + gateway_url: &str, + enrollment_token: &str, + agent_name: &str, + csr_pem: &str, +) -> anyhow::Result { + let client = reqwest::Client::new(); + let enroll_url = format!("{}/jet/agent-tunnel/enroll", gateway_url.trim_end_matches('/')); + + let response = client + .post(&enroll_url) + .bearer_auth(enrollment_token) + .json(&EnrollRequest { + agent_name: agent_name.to_owned(), + csr_pem: csr_pem.to_owned(), + }) + .send() + .await + .context("failed to send enrollment request")?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + anyhow::bail!("enrollment failed with status {}: {}", status, error_text); + } + + response.json().await.context("failed to parse enrollment response") +} + +fn persist_enrollment_response( + advertise_subnets: Vec, + enroll_response: EnrollResponse, + key_pem: &str, +) -> anyhow::Result { + let config_path = config::get_conf_file_path(); + let config_dir = config_path + .parent() + .filter(|path| !path.as_str().is_empty()) + .map(Utf8Path::to_owned) + .unwrap_or_else(|| Utf8PathBuf::from(".")); + let cert_dir = config_dir.join("certs"); + + std::fs::create_dir_all(&cert_dir) + .with_context(|| format!("failed to create certificate directory: {}", cert_dir))?; + + let client_cert_path = cert_dir.join(format!("{}-cert.pem", enroll_response.agent_id)); + let client_key_path = cert_dir.join(format!("{}-key.pem", enroll_response.agent_id)); + let gateway_ca_path = cert_dir.join("gateway-ca.pem"); + + // Write the locally-generated private key first (before cert/CA from the network). + std::fs::write(&client_key_path, key_pem) + .with_context(|| format!("failed to write client private key: {}", client_key_path))?; + + std::fs::write(&client_cert_path, &enroll_response.client_cert_pem) + .with_context(|| format!("failed to write client certificate: {}", client_cert_path))?; + + std::fs::write(&gateway_ca_path, &enroll_response.gateway_ca_cert_pem) + .with_context(|| format!("failed to write gateway CA certificate: {}", gateway_ca_path))?; + + // Restrict permissions on cert/key files (owner-only on Unix). + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + let restricted = std::fs::Permissions::from_mode(0o600); + for path in [&client_cert_path, &client_key_path, &gateway_ca_path] { + std::fs::set_permissions(path, restricted.clone()) + .with_context(|| format!("failed to set permissions on {path}"))?; + } + } + + // Load existing config and update only the Tunnel section. + // This preserves other settings (Updater, Session, PEDM, etc.) that may have been + // configured by the MSI installer or admin. + let mut conf_file = config::load_conf_file_or_generate_new().context("failed to load existing configuration")?; + + // Preserve existing domain config from previous enrollment/manual configuration. + let existing_tunnel = conf_file.tunnel.as_ref(); + + let tunnel_conf = config::dto::TunnelConf { + enabled: true, + gateway_endpoint: enroll_response.quic_endpoint.clone(), + client_cert_path: Some(client_cert_path.clone()), + client_key_path: Some(client_key_path.clone()), + gateway_ca_cert_path: Some(gateway_ca_path.clone()), + advertise_subnets, + advertise_domains: existing_tunnel.map(|t| t.advertise_domains.clone()).unwrap_or_default(), + auto_detect_domain: existing_tunnel.map(|t| t.auto_detect_domain).unwrap_or(true), + heartbeat_interval_secs: Some(60), + route_advertise_interval_secs: Some(30), + }; + + conf_file.tunnel = Some(tunnel_conf); + + config::save_config(&conf_file)?; + + Ok(PersistedEnrollment { + agent_id: enroll_response.agent_id, + agent_name: enroll_response.agent_name, + client_cert_path, + client_key_path, + gateway_ca_path, + quic_endpoint: enroll_response.quic_endpoint, + }) +} diff --git a/devolutions-agent/src/lib.rs b/devolutions-agent/src/lib.rs index 71c328f48..304192ef1 100644 --- a/devolutions-agent/src/lib.rs +++ b/devolutions-agent/src/lib.rs @@ -6,8 +6,11 @@ use ctrlc as _; extern crate tracing; pub mod config; +pub mod domain_detect; +pub mod enrollment; pub mod log; pub mod remote_desktop; +pub mod tunnel; #[cfg(windows)] pub mod session_manager; diff --git a/devolutions-agent/src/main.rs b/devolutions-agent/src/main.rs index 9005dd7b3..b2796e0d3 100644 --- a/devolutions-agent/src/main.rs +++ b/devolutions-agent/src/main.rs @@ -2,25 +2,35 @@ #![allow(clippy::print_stdout)] // Used by devolutions-agent library. +use agent_tunnel_proto as _; use anyhow as _; use async_trait as _; +use bincode as _; use camino as _; use devolutions_agent_shared as _; use devolutions_gateway_task as _; use devolutions_log as _; use futures as _; +use http_client_proxy as _; +use ipnetwork as _; use ironrdp as _; use parking_lot as _; +use quinn as _; use rand as _; +use reqwest as _; +use rustls as _; use rustls_pemfile as _; +use rustls_pki_types as _; use serde as _; use serde_json as _; use tap as _; use tokio as _; use tokio_rustls as _; +use url as _; +use uuid as _; #[cfg(windows)] use { - devolutions_pedm as _, hex as _, notify_debouncer_mini as _, reqwest as _, sha2 as _, thiserror as _, uuid as _, + aws_lc_rs as _, devolutions_pedm as _, hex as _, notify_debouncer_mini as _, sha2 as _, thiserror as _, win_api_wrappers as _, windows as _, }; @@ -32,6 +42,8 @@ mod service; use std::env; use std::sync::mpsc; +use anyhow::{Context as _, Result, bail}; +use base64::Engine as _; use ceviche::Service; use ceviche::controller::*; use devolutions_agent::AgentServiceEvent; @@ -42,6 +54,23 @@ use self::service::{AgentService, DESCRIPTION, DISPLAY_NAME, SERVICE_NAME}; const BAD_CONFIG_ERR_CODE: u32 = 1; const START_FAILED_ERR_CODE: u32 = 2; +#[derive(Debug, PartialEq, Eq)] +struct UpCommand { + gateway_url: String, + enrollment_token: String, + agent_name: String, + advertise_subnets: Vec, +} + +#[derive(Debug, serde::Deserialize)] +struct EnrollmentStringPayload { + version: u64, + api_base_url: String, + enrollment_token: String, + #[serde(default)] + name: Option, +} + fn agent_service_main( rx: mpsc::Receiver, _tx: mpsc::Sender, @@ -110,6 +139,85 @@ fn agent_service_main( Service!("agent", agent_service_main); +fn parse_required_value(args: &[String], index: &mut usize, flag: &str) -> Result { + *index += 1; + args.get(*index) + .cloned() + .with_context(|| format!("missing value for {flag}")) +} + +fn parse_advertise_subnets(value: &str) -> Vec { + value + .split(',') + .map(str::trim) + .filter(|subnet| !subnet.is_empty()) + .map(ToOwned::to_owned) + .collect() +} + +fn parse_enrollment_string(value: &str) -> Result { + const PREFIX: &str = "dgw-enroll:v1:"; + + let encoded = value.strip_prefix(PREFIX).context("invalid enrollment string prefix")?; + + let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(encoded) + .context("invalid base64 enrollment string")?; + + let payload: EnrollmentStringPayload = + serde_json::from_slice(&decoded).context("invalid enrollment string payload")?; + + if payload.version != 1 { + bail!("unsupported enrollment string version: {}", payload.version); + } + + Ok(payload) +} + +fn parse_up_command_args(args: &[String]) -> Result { + let mut gateway_url = None; + let mut enrollment_token = None; + let mut agent_name = None; + let mut enrollment_string = None; + let mut advertise_subnets = Vec::new(); + + let mut index = 0; + while index < args.len() { + let arg = args[index].as_str(); + + match arg { + "--gateway" => gateway_url = Some(parse_required_value(args, &mut index, "--gateway")?), + "--token" | "--enrollment-token" => enrollment_token = Some(parse_required_value(args, &mut index, arg)?), + "--name" | "--agent-name" => agent_name = Some(parse_required_value(args, &mut index, arg)?), + "--enrollment-string" => enrollment_string = Some(parse_required_value(args, &mut index, arg)?), + "--advertise-routes" | "--advertise-subnets" => { + advertise_subnets.extend(parse_advertise_subnets(&parse_required_value(args, &mut index, arg)?)) + } + unexpected => bail!("unknown argument for up: {unexpected}"), + } + + index += 1; + } + + if let Some(enrollment_string) = enrollment_string { + let payload = parse_enrollment_string(&enrollment_string)?; + + gateway_url.get_or_insert(payload.api_base_url); + enrollment_token.get_or_insert(payload.enrollment_token); + + if agent_name.is_none() { + agent_name = payload.name; + } + } + + Ok(UpCommand { + gateway_url: gateway_url.context("missing required --gateway")?, + enrollment_token: enrollment_token.context("missing required --token")?, + agent_name: agent_name.context("missing required --name")?, + advertise_subnets, + }) +} + fn main() { let mut controller = Controller::new(SERVICE_NAME, DISPLAY_NAME, DESCRIPTION); @@ -152,6 +260,61 @@ fn main() { eprintln!("[ERROR] Agent configuration failed: {e}"); } } + "enroll" => { + let gateway_url = env::args() + .nth(2) + .expect("missing gateway URL (e.g., https://gateway.example.com:7171)"); + let enrollment_token = env::args().nth(3).expect("missing enrollment token"); + let agent_name = env::args().nth(4).expect("missing agent name"); + let subnets_arg = env::args().nth(5).unwrap_or_default(); + + let advertise_subnets: Vec = if subnets_arg.is_empty() { + Vec::new() + } else { + subnets_arg.split(',').map(|s| s.trim().to_owned()).collect() + }; + + let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime"); + rt.block_on(async { + if let Err(e) = devolutions_agent::enrollment::enroll_agent( + &gateway_url, + &enrollment_token, + &agent_name, + advertise_subnets, + ) + .await + { + eprintln!("[ERROR] Enrollment failed: {e:#}"); + std::process::exit(1); + } + }); + } + "up" => { + let args: Vec = env::args().skip(2).collect(); + let command = match parse_up_command_args(&args) { + Ok(command) => command, + Err(error) => { + eprintln!("[ERROR] Invalid up arguments: {error:#}"); + std::process::exit(1); + } + }; + + let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime"); + let result = rt.block_on(async { + devolutions_agent::enrollment::bootstrap_and_persist( + &command.gateway_url, + &command.enrollment_token, + &command.agent_name, + command.advertise_subnets, + ) + .await + }); + + if let Err(error) = result { + eprintln!("[ERROR] Bootstrap failed: {error:#}"); + std::process::exit(1); + } + } _ => { eprintln!("[ERROR] Invalid command: {cmd}"); } @@ -160,3 +323,73 @@ fn main() { let _result = controller.register(service_main_wrapper); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_up_command_args_uses_default_config_path() { + let args = vec![ + "--gateway".to_owned(), + "https://gateway.example.com:7171".to_owned(), + "--token".to_owned(), + "bootstrap-token".to_owned(), + "--name".to_owned(), + "site-a-agent".to_owned(), + "--advertise-routes".to_owned(), + "10.0.0.0/8,192.168.1.0/24".to_owned(), + ]; + + let parsed = parse_up_command_args(&args).expect("parse up args"); + + assert_eq!( + parsed, + UpCommand { + gateway_url: "https://gateway.example.com:7171".to_owned(), + enrollment_token: "bootstrap-token".to_owned(), + agent_name: "site-a-agent".to_owned(), + advertise_subnets: vec!["10.0.0.0/8".to_owned(), "192.168.1.0/24".to_owned()], + } + ); + } + + #[test] + fn parse_up_command_args_accepts_aliases() { + let args = vec![ + "--gateway".to_owned(), + "https://gateway.example.com:7171".to_owned(), + "--enrollment-token".to_owned(), + "bootstrap-token".to_owned(), + "--agent-name".to_owned(), + "site-a-agent".to_owned(), + "--advertise-subnets".to_owned(), + "10.0.0.0/8".to_owned(), + ]; + + let parsed = parse_up_command_args(&args).expect("parse up args"); + + assert_eq!(parsed.advertise_subnets, vec!["10.0.0.0/8".to_owned()]); + } + + #[test] + fn parse_up_command_args_accepts_enrollment_string() { + let payload = serde_json::json!({ + "version": 1, + "api_base_url": "https://gateway.example.com:7171", + "enrollment_token": "bootstrap-token", + "name": "site-a-agent", + }); + let enrollment_string = format!( + "dgw-enroll:v1:{}", + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string()) + ); + let args = vec!["--enrollment-string".to_owned(), enrollment_string]; + + let parsed = parse_up_command_args(&args).expect("parse up args"); + + assert_eq!(parsed.gateway_url, "https://gateway.example.com:7171"); + assert_eq!(parsed.enrollment_token, "bootstrap-token"); + assert_eq!(parsed.agent_name, "site-a-agent"); + } +} diff --git a/devolutions-agent/src/service.rs b/devolutions-agent/src/service.rs index 90dd20f58..276a2e4f6 100644 --- a/devolutions-agent/src/service.rs +++ b/devolutions-agent/src/service.rs @@ -7,6 +7,7 @@ use devolutions_agent::log::AgentLog; use devolutions_agent::remote_desktop::RemoteDesktopTask; #[cfg(windows)] use devolutions_agent::session_manager::SessionManager; +use devolutions_agent::tunnel::TunnelTask; #[cfg(windows)] use devolutions_agent::updater::UpdaterTask; use devolutions_gateway_task::{ChildTask, ShutdownHandle, ShutdownSignal}; @@ -227,7 +228,11 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { let service_event_tx = None; if conf.debug.enable_unstable && conf.remote_desktop.enabled { - tasks.register(RemoteDesktopTask::new(conf_handle)); + tasks.register(RemoteDesktopTask::new(conf_handle.clone())); + } + + if conf.tunnel.enabled { + tasks.register(TunnelTask::new(conf_handle)); } Ok(TasksCtx { diff --git a/devolutions-agent/src/tunnel.rs b/devolutions-agent/src/tunnel.rs new file mode 100644 index 000000000..53e4077e0 --- /dev/null +++ b/devolutions-agent/src/tunnel.rs @@ -0,0 +1,512 @@ +//! QUIC-based Agent Tunnel client implementation (Quinn). +//! +//! This module implements a QUIC client that connects to the Gateway's agent tunnel +//! endpoint, advertises reachable subnets, and handles incoming TCP proxy requests. + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use agent_tunnel_proto::{ConnectMessage, ConnectResponse, ControlMessage}; +use anyhow::{Context as _, bail}; +use async_trait::async_trait; +use devolutions_gateway_task::{ShutdownSignal, Task}; +use ipnetwork::Ipv4Network; +use tokio::net::TcpStream; + +use crate::config::ConfHandle; + +// --------------------------------------------------------------------------- +// Custom TLS verifier: verify cert chain against CA, skip hostname check +// --------------------------------------------------------------------------- + +/// Wraps a `WebPkiServerVerifier` but skips the hostname verification step. +/// +/// For our private PKI, the agent may connect by IP address (e.g., `127.0.0.1`) +/// while the server cert has the gateway's hostname (e.g., `devolutions432`). +/// The cert chain is still validated against our private CA — only the +/// hostname-to-SAN matching is bypassed. +#[derive(Debug)] +struct SkipHostnameVerification(Arc); + +impl rustls::client::danger::ServerCertVerifier for SkipHostnameVerification { + fn verify_server_cert( + &self, + end_entity: &rustls_pki_types::CertificateDer<'_>, + intermediates: &[rustls_pki_types::CertificateDer<'_>], + _server_name: &rustls_pki_types::ServerName<'_>, + ocsp_response: &[u8], + now: rustls_pki_types::UnixTime, + ) -> Result { + // Verify the cert chain against our CA, skipping hostname verification. + // We call the inner verifier with a dummy name; if it fails specifically + // because of hostname mismatch (CertNotValidForName), we accept it. + // All other errors (expired cert, unknown CA, bad signature) propagate. + self.0 + .verify_server_cert( + end_entity, + intermediates, + &rustls_pki_types::ServerName::try_from("dummy.local").expect("valid dummy server name"), + ocsp_response, + now, + ) + .or_else(|e| match e { + rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName) => { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + other => Err(other), + }) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &rustls_pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.0.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &rustls_pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.0.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.supported_verify_schemes() + } +} + +// --------------------------------------------------------------------------- +// TunnelTask — service task with auto-reconnect +// --------------------------------------------------------------------------- + +pub struct TunnelTask { + conf_handle: ConfHandle, +} + +impl TunnelTask { + pub fn new(conf_handle: ConfHandle) -> Self { + Self { conf_handle } + } +} + +#[async_trait] +impl Task for TunnelTask { + type Output = anyhow::Result<()>; + const NAME: &'static str = "tunnel"; + + async fn run(self, mut shutdown_signal: ShutdownSignal) -> anyhow::Result<()> { + const INITIAL_BACKOFF: Duration = Duration::from_secs(1); + const MAX_BACKOFF: Duration = Duration::from_secs(60); + const CONNECTED_THRESHOLD: Duration = Duration::from_secs(30); + + info!("Starting QUIC agent tunnel (with auto-reconnect)"); + + let mut backoff = INITIAL_BACKOFF; + + loop { + let start = std::time::Instant::now(); + + match run_single_connection(&self.conf_handle, &mut shutdown_signal).await { + Ok(()) => { + info!("Tunnel task stopped"); + return Ok(()); + } + Err(error) => { + warn!(error = %format!("{error:#}"), "Tunnel connection lost"); + } + } + + if CONNECTED_THRESHOLD < start.elapsed() { + backoff = INITIAL_BACKOFF; + } + + info!(?backoff, "Reconnecting after backoff"); + + tokio::select! { + _ = shutdown_signal.wait() => { + info!("Shutdown during reconnect backoff"); + return Ok(()); + } + _ = tokio::time::sleep(backoff) => {} + } + + let jitter_factor = rand::Rng::gen_range(&mut rand::thread_rng(), 0.75..1.25); + backoff = + Duration::from_secs_f64((backoff.as_secs_f64() * 2.0 * jitter_factor).min(MAX_BACKOFF.as_secs_f64())); + } + } +} + +// --------------------------------------------------------------------------- +// Single connection lifetime +// --------------------------------------------------------------------------- + +/// Run a single QUIC tunnel connection lifetime: config → connect → event loop. +/// +/// Returns `Ok(())` on graceful shutdown (shutdown signal received). +/// Returns `Err(...)` on any failure — the caller should retry with backoff. +async fn run_single_connection(conf_handle: &ConfHandle, shutdown_signal: &mut ShutdownSignal) -> anyhow::Result<()> { + // Ensure rustls crypto provider is installed (ring). + let _ = rustls::crypto::ring::default_provider().install_default(); + + let agent_conf = conf_handle.get_conf(); + let tunnel_conf = &agent_conf.tunnel; + + let cert_path = tunnel_conf + .client_cert_path + .as_ref() + .context("client_cert_path not configured")?; + let key_path = tunnel_conf + .client_key_path + .as_ref() + .context("client_key_path not configured")?; + let ca_path = tunnel_conf + .gateway_ca_cert_path + .as_ref() + .context("gateway_ca_cert_path not configured")?; + + let advertise_subnets: Vec = tunnel_conf + .advertise_subnets + .iter() + .map(|subnet| subnet.parse()) + .collect::, _>>() + .context("failed to parse advertise_subnets")?; + + if advertise_subnets.is_empty() { + warn!("No subnets configured to advertise"); + } + + // Build domain advertisement list: explicit config + auto-detection. + let mut advertise_domains: Vec = tunnel_conf + .advertise_domains + .iter() + .map(|d| agent_tunnel_proto::DomainAdvertisement { + domain: d.clone(), + auto_detected: false, + }) + .collect(); + + if tunnel_conf.auto_detect_domain { + if let Some(detected) = crate::domain_detect::detect_domain() { + if !advertise_domains + .iter() + .any(|d| d.domain.eq_ignore_ascii_case(&detected)) + { + info!(domain = %detected, "Auto-detected DNS domain"); + advertise_domains.push(agent_tunnel_proto::DomainAdvertisement { + domain: detected, + auto_detected: true, + }); + } + } else if tunnel_conf.advertise_domains.is_empty() { + warn!( + "Domain auto-detection found nothing and no advertise_domains configured. \ + Set advertise_domains in agent config." + ); + } + } + + info!( + subnet_count = advertise_subnets.len(), + domain_count = advertise_domains.len(), + domains = ?advertise_domains.iter().map(|d| { + let source = if d.auto_detected { "auto" } else { "explicit" }; + format!("{} ({})", d.domain, source) + }).collect::>(), + "Advertising subnets and domains" + ); + + // -- Build rustls ClientConfig -- + + let certs: Vec> = rustls_pemfile::certs(&mut std::io::BufReader::new( + std::fs::File::open(cert_path.as_str()).context("open client cert file")?, + )) + .collect::, _>>() + .context("parse client certificates")?; + + let key = rustls_pemfile::private_key(&mut std::io::BufReader::new( + std::fs::File::open(key_path.as_str()).context("open client key file")?, + )) + .context("parse private key file")? + .context("no private key found in file")?; + + let mut roots = rustls::RootCertStore::empty(); + let ca_certs: Vec> = rustls_pemfile::certs(&mut std::io::BufReader::new( + std::fs::File::open(ca_path.as_str()).context("open CA cert file")?, + )) + .collect::, _>>() + .context("parse CA certificates")?; + for cert in ca_certs { + roots.add(cert)?; + } + + // Use a custom verifier that validates the cert chain against our private CA + // but skips hostname verification. This is correct for a private PKI where the + // agent connects by IP address but the server cert has the gateway's hostname. + let verifier = rustls::client::WebPkiServerVerifier::builder(Arc::new(roots)) + .build() + .context("build server cert verifier")?; + + let mut client_crypto = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(SkipHostnameVerification(verifier))) + .with_client_auth_cert(certs, key) + .context("build rustls client config with client auth")?; + client_crypto.alpn_protocols = vec![b"devolutions-agent-tunnel".to_vec()]; + + let client_config = quinn::ClientConfig::new(Arc::new( + quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto) + .context("build QuicClientConfig from rustls config")?, + )); + + // -- Transport config -- + + let mut transport = quinn::TransportConfig::default(); + transport.max_idle_timeout(Some( + Duration::from_secs(120).try_into().context("idle timeout conversion")?, + )); + transport.keep_alive_interval(Some(Duration::from_secs(15))); + transport.max_concurrent_bidi_streams(100u32.into()); + + let mut client_config = client_config; + client_config.transport_config(Arc::new(transport)); + + // -- DNS resolve -- + + let gateway_addr = tokio::net::lookup_host(&tunnel_conf.gateway_endpoint) + .await + .context("failed to resolve gateway endpoint")? + .next() + .context("no addresses resolved for gateway endpoint")?; + + info!(gateway_addr = %gateway_addr, "Connecting to gateway"); + + // -- Connect -- + + let mut endpoint = + quinn::Endpoint::client("0.0.0.0:0".parse().context("parse bind address")?).context("create QUIC endpoint")?; + endpoint.set_default_client_config(client_config); + + let connection = endpoint + .connect(gateway_addr, "gateway") + .context("initiate QUIC connection")? + .await + .context("QUIC handshake")?; + info!("QUIC connection established"); + + // -- Open control stream -- + + let (mut ctrl_send, mut ctrl_recv) = connection.open_bi().await.context("open control stream")?; + + // Send initial RouteAdvertise. + let epoch = 1u64; + let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone()); + msg.encode(&mut ctrl_send) + .await + .context("encode initial RouteAdvertise")?; + info!(epoch, "Sent initial RouteAdvertise"); + + // Spawn control stream reader. + tokio::spawn(async move { + let _ = handle_control_recv(&mut ctrl_recv) + .await + .inspect_err(|e| error!(%e, "Control recv stream failed")); + }); + + // -- Main loop: accept incoming session streams + periodic tasks -- + + let route_interval = tunnel_conf.route_advertise_interval_secs.unwrap_or(30); + let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs.unwrap_or(60); + let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval)); + let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs)); + // Skip the first immediate tick (we already sent the initial RouteAdvertise). + route_tick.tick().await; + heartbeat_tick.tick().await; + + loop { + tokio::select! { + biased; + + _ = shutdown_signal.wait() => { + info!("Tunnel task shutting down"); + connection.close(0u32.into(), b"shutting down"); + return Ok(()); + } + + _ = route_tick.tick() => { + let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone()); + let _ = msg.encode(&mut ctrl_send).await + .inspect(|_| trace!(epoch, "Sent RouteAdvertise (refresh)")) + .inspect_err(|e| error!(%e, "Failed to send RouteAdvertise")); + } + + _ = heartbeat_tick.tick() => { + let msg = ControlMessage::heartbeat(current_time_millis(), 0); + let _ = msg.encode(&mut ctrl_send).await + .inspect(|_| trace!("Sent Heartbeat")) + .inspect_err(|e| error!(%e, "Failed to send Heartbeat")); + } + + result = connection.accept_bi() => { + let (send, recv) = result.context("accept incoming bidi stream")?; + let subnets = advertise_subnets.clone(); + tokio::spawn(async move { + let _ = handle_session_stream(&subnets, send, recv).await + .inspect_err(|e| error!(%e, "Session stream failed")); + }); + } + } + } +} + +// --------------------------------------------------------------------------- +// Control stream reader +// --------------------------------------------------------------------------- + +async fn handle_control_recv(recv: &mut quinn::RecvStream) -> anyhow::Result<()> { + loop { + let message = ControlMessage::decode(recv).await.context("decode control message")?; + + match message { + ControlMessage::HeartbeatAck { + protocol_version, + timestamp_ms, + } => { + if let Err(e) = agent_tunnel_proto::validate_protocol_version(protocol_version) { + warn!(%protocol_version, %e, "Ignoring HeartbeatAck: unsupported protocol version"); + continue; + } + let rtt = current_time_millis().saturating_sub(timestamp_ms); + debug!(rtt_ms = rtt, "Received HeartbeatAck"); + } + unexpected => { + warn!(message = ?unexpected, "Unexpected control message from gateway"); + } + } + } +} + +// --------------------------------------------------------------------------- +// Session stream handler +// --------------------------------------------------------------------------- + +async fn handle_session_stream( + advertise_subnets: &[Ipv4Network], + mut send: quinn::SendStream, + mut recv: quinn::RecvStream, +) -> anyhow::Result<()> { + // Read ConnectMessage (length-prefixed) directly from the Quinn stream. + let connect_msg = ConnectMessage::decode(&mut recv) + .await + .context("decode ConnectMessage")?; + + info!( + session_id = %connect_msg.session_id, + target = %connect_msg.target, + "Received ConnectMessage" + ); + + if let Err(e) = agent_tunnel_proto::validate_protocol_version(connect_msg.protocol_version) { + warn!( + protocol_version = %connect_msg.protocol_version, + %e, + "Rejecting ConnectMessage: unsupported protocol version" + ); + let response = ConnectResponse::error(format!("unsupported protocol version: {e}")); + response + .encode(&mut send) + .await + .context("send ConnectResponse error for unsupported version")?; + bail!("unsupported protocol version in ConnectMessage"); + } + + // Validate and connect to target. + let candidates = resolve_target_candidates(&connect_msg.target, advertise_subnets).await?; + let (tcp_stream, selected_target) = connect_to_target(&candidates).await?; + info!(target = %selected_target, "TCP connection established"); + + // Send ConnectResponse::Success. + ConnectResponse::success() + .encode(&mut send) + .await + .context("send ConnectResponse")?; + info!("Sent ConnectResponse::Success"); + + // Bidirectional proxy using tokio::io::copy. + let (mut tcp_read, mut tcp_write) = tcp_stream.into_split(); + + let quic_to_tcp = tokio::io::copy(&mut recv, &mut tcp_write); + let tcp_to_quic = tokio::io::copy(&mut tcp_read, &mut send); + + tokio::select! { + r = quic_to_tcp => { + r.inspect_err(|e| debug!(%e, "QUIC->TCP copy ended"))?; + } + r = tcp_to_quic => { + r.inspect_err(|e| debug!(%e, "TCP->QUIC copy ended"))?; + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Utilities (no QUIC involvement) +// --------------------------------------------------------------------------- + +async fn resolve_target_candidates(target: &str, advertise_subnets: &[Ipv4Network]) -> anyhow::Result> { + let resolved: Vec = tokio::net::lookup_host(target) + .await + .with_context(|| format!("resolve target {target}"))? + .collect(); + + if resolved.is_empty() { + bail!("no addresses resolved for target {target}"); + } + + let reachable: Vec = resolved + .into_iter() + .filter(|addr| match addr.ip() { + std::net::IpAddr::V4(ipv4) => advertise_subnets.iter().any(|subnet| subnet.contains(ipv4)), + // TODO: Support IPv6. + std::net::IpAddr::V6(_) => false, + }) + .collect(); + + if reachable.is_empty() { + bail!("target {target} is not in advertised subnets"); + } + + Ok(reachable) +} + +async fn connect_to_target(candidates: &[SocketAddr]) -> anyhow::Result<(TcpStream, SocketAddr)> { + let mut last_error = None; + + for candidate in candidates { + match TcpStream::connect(candidate).await { + Ok(stream) => return Ok((stream, *candidate)), + Err(error) => last_error = Some((candidate, error)), + } + } + + let Some((candidate, error)) = last_error else { + bail!("no target candidates available"); + }; + + Err(error).with_context(|| format!("TCP connect failed for {candidate}")) +} + +fn current_time_millis() -> u64 { + let elapsed = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system time should be after unix epoch"); + + u64::try_from(elapsed.as_millis()).expect("millisecond timestamp should fit in u64") +} diff --git a/devolutions-gateway/Cargo.toml b/devolutions-gateway/Cargo.toml index 40226fb8b..c3a9acff3 100644 --- a/devolutions-gateway/Cargo.toml +++ b/devolutions-gateway/Cargo.toml @@ -35,6 +35,7 @@ terminal-streamer.path = "../crates/terminal-streamer" network-monitor.path = "../crates/network-monitor" sysevent.path = "../crates/sysevent" sysevent-codes.path = "../crates/sysevent-codes" +agent-tunnel-proto.path = "../crates/agent-tunnel-proto" ironrdp-pdu = { version = "0.7", features = ["std"] } ironrdp-core = { version = "0.1", features = ["std"] } ironrdp-rdcleanpath = "0.2" @@ -69,6 +70,11 @@ thiserror = "2" typed-builder = "0.21" backoff = "0.4" bitflags = "2.9" +base64 = "0.22" +bincode = "1.3" +ipnetwork = "0.20" +dashmap = "6.1" +rand = "0.8" # Security, crypto… picky = { version = "7.0.0-rc.15", default-features = false, features = ["jose", "x509", "pkcs12", "time_conversion"] } @@ -81,6 +87,15 @@ x509-cert = { version = "0.2", default-features = false, features = ["std"] } sha2 = "0.10" hex = "0.4" rustls-native-certs = "0.8" +pem = "3.0" +rcgen = { version = "0.13", features = ["pem", "x509-parser"] } +x509-parser = "0.16" + +# QUIC (agent tunnel) +quinn = "0.11" +rustls = { version = "0.23", default-features = false, features = ["ring", "logging", "std", "tls12"] } +rustls-pki-types = "1" +rustls-pemfile = "2" # Logging tracing = "0.1" diff --git a/devolutions-gateway/src/agent_tunnel/cert.rs b/devolutions-gateway/src/agent_tunnel/cert.rs new file mode 100644 index 000000000..8fdda3645 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/cert.rs @@ -0,0 +1,433 @@ +//! CA certificate management for the QUIC agent tunnel. +//! +//! Manages a self-signed CA that issues client certificates to agents during enrollment, +//! and a server certificate for the QUIC listener. + +use std::time::Duration; + +use anyhow::Context as _; +use camino::{Utf8Path, Utf8PathBuf}; +use rcgen::{CertificateParams, DnType, ExtendedKeyUsagePurpose, IsCa, KeyPair, KeyUsagePurpose, SanType}; +use sha2::{Digest, Sha256}; +use uuid::Uuid; + +const CA_CERT_FILENAME: &str = "agent-tunnel-ca-cert.pem"; +const CA_KEY_FILENAME: &str = "agent-tunnel-ca-key.pem"; +const SERVER_CERT_FILENAME: &str = "agent-tunnel-server-cert.pem"; +const SERVER_KEY_FILENAME: &str = "agent-tunnel-server-key.pem"; +const CA_VALIDITY_DAYS: u32 = 3650; // ~10 years +const SERVER_CERT_VALIDITY_DAYS: u32 = 365; // 1 year +const AGENT_CERT_VALIDITY_DAYS: u32 = 365; // 1 year + +const CA_COMMON_NAME: &str = "Devolutions Gateway Agent Tunnel CA"; +const CA_ORG_NAME: &str = "Devolutions Inc."; + +/// Build the standard CA `CertificateParams` (same DN every time so that +/// reconstructed certificates match the on-disk CA for chain validation). +fn make_ca_params() -> CertificateParams { + let mut params = CertificateParams::default(); + params.distinguished_name.push(DnType::CommonName, CA_COMMON_NAME); + params.distinguished_name.push(DnType::OrganizationName, CA_ORG_NAME); + params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + params.key_usages.push(KeyUsagePurpose::KeyCertSign); + params.key_usages.push(KeyUsagePurpose::CrlSign); + params.not_before = time::OffsetDateTime::now_utc(); + params.not_after = time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(CA_VALIDITY_DAYS) * 86400); + params +} + +/// Manages the CA used to sign agent client certificates and the QUIC server certificate. +pub struct CaManager { + ca_cert_pem: String, + ca_key_pair: KeyPair, + data_dir: Utf8PathBuf, +} + +/// Bundle returned to a newly enrolled agent (private key never leaves the agent). +pub struct SignedAgentCert { + pub client_cert_pem: String, + pub ca_cert_pem: String, +} + +impl CaManager { + /// Load an existing CA from disk, or generate a new one. + pub fn load_or_generate(data_dir: &Utf8Path) -> anyhow::Result { + let cert_path = data_dir.join(CA_CERT_FILENAME); + let key_path = data_dir.join(CA_KEY_FILENAME); + + if cert_path.exists() && key_path.exists() { + info!(%cert_path, "Loading existing agent tunnel CA"); + let ca_cert_pem = + std::fs::read_to_string(&cert_path).with_context(|| format!("read CA cert from {cert_path}"))?; + let ca_key_pem = + std::fs::read_to_string(&key_path).with_context(|| format!("read CA key from {key_path}"))?; + let ca_key_pair = KeyPair::from_pem(&ca_key_pem).context("parse CA key pair from PEM")?; + Ok(Self { + ca_cert_pem, + ca_key_pair, + data_dir: data_dir.to_owned(), + }) + } else { + info!("Generating new agent tunnel CA certificate"); + let ca_key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).context("generate CA key pair")?; + + let ca_params = make_ca_params(); + let ca_cert = ca_params + .self_signed(&ca_key_pair) + .context("self-sign CA certificate")?; + let ca_cert_pem = ca_cert.pem(); + + std::fs::create_dir_all(data_dir).with_context(|| format!("create data directory {data_dir}"))?; + std::fs::write(&cert_path, &ca_cert_pem).with_context(|| format!("write CA cert to {cert_path}"))?; + std::fs::write(&key_path, ca_key_pair.serialize_pem()) + .with_context(|| format!("write CA key to {key_path}"))?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600)) + .with_context(|| format!("set permissions on {key_path}"))?; + } + + info!(%cert_path, "Agent tunnel CA certificate generated and saved"); + + Ok(Self { + ca_cert_pem, + ca_key_pair, + data_dir: data_dir.to_owned(), + }) + } + } + + /// Reconstruct a `Certificate` object from the stored key pair. + /// + /// The reconstructed cert uses the same DN as the original CA, so the + /// issuer field in signed certificates will match the on-disk CA cert. + fn reconstruct_ca_cert(&self) -> anyhow::Result { + make_ca_params() + .self_signed(&self.ca_key_pair) + .context("reconstruct CA certificate for signing") + } + + /// Sign an agent's CSR, producing a client certificate. + /// + /// The agent generates its own key pair and sends only the CSR. + /// The private key never leaves the agent. + pub fn sign_agent_csr(&self, agent_id: Uuid, agent_name: &str, csr_pem: &str) -> anyhow::Result { + // Parse and verify the CSR (signature check included). + let csr_params = rcgen::CertificateSigningRequestParams::from_pem(csr_pem) + .map_err(|e| anyhow::anyhow!("invalid CSR: {e}"))?; + + // Build our own cert params (we control CN, SAN, EKU, validity — not the CSR). + let mut agent_params = CertificateParams::default(); + agent_params.distinguished_name.push(DnType::CommonName, agent_name); + agent_params + .distinguished_name + .push(DnType::OrganizationName, CA_ORG_NAME); + agent_params.subject_alt_names.push(SanType::Rfc822Name( + format!("urn:uuid:{agent_id}").try_into().context("SAN URI")?, + )); + agent_params + .extended_key_usages + .push(ExtendedKeyUsagePurpose::ClientAuth); + agent_params.not_before = time::OffsetDateTime::now_utc(); + agent_params.not_after = + time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(AGENT_CERT_VALIDITY_DAYS) * 86400); + + // Sign with the CA, embedding the public key from the CSR. + let ca_cert = self.reconstruct_ca_cert()?; + let agent_cert = agent_params + .signed_by(&csr_params.public_key, &ca_cert, &self.ca_key_pair) + .context("sign agent certificate with CA")?; + + info!(%agent_id, %agent_name, "Signed agent CSR and issued client certificate"); + + Ok(SignedAgentCert { + client_cert_pem: agent_cert.pem(), + ca_cert_pem: self.ca_cert_pem.clone(), + }) + } + + /// Ensure a server certificate exists for the QUIC listener (signed by our CA). + /// + /// Returns `(cert_path, key_path)` on disk. + pub fn ensure_server_cert(&self, hostname: &str) -> anyhow::Result<(Utf8PathBuf, Utf8PathBuf)> { + let cert_path = self.data_dir.join(SERVER_CERT_FILENAME); + let key_path = self.data_dir.join(SERVER_KEY_FILENAME); + + if cert_path.exists() && key_path.exists() { + // TODO: check cert expiry and regenerate if near/past expiration (365-day validity). + info!(%cert_path, "Using existing agent tunnel server certificate"); + return Ok((cert_path, key_path)); + } + + info!(%hostname, "Generating agent tunnel server certificate"); + + let server_key_pair = + KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).context("generate server key pair")?; + + let mut server_params = CertificateParams::default(); + server_params.distinguished_name.push(DnType::CommonName, hostname); + server_params + .distinguished_name + .push(DnType::OrganizationName, CA_ORG_NAME); + server_params + .subject_alt_names + .push(SanType::DnsName(hostname.try_into().context("DNS SAN")?)); + server_params + .extended_key_usages + .push(ExtendedKeyUsagePurpose::ServerAuth); + server_params.not_before = time::OffsetDateTime::now_utc(); + server_params.not_after = + time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(SERVER_CERT_VALIDITY_DAYS) * 86400); + + let ca_cert = self.reconstruct_ca_cert()?; + + let server_cert = server_params + .signed_by(&server_key_pair, &ca_cert, &self.ca_key_pair) + .context("sign server certificate with CA")?; + + std::fs::write(&cert_path, server_cert.pem()).with_context(|| format!("write server cert to {cert_path}"))?; + std::fs::write(&key_path, server_key_pair.serialize_pem()) + .with_context(|| format!("write server key to {key_path}"))?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600)) + .with_context(|| format!("set permissions on {key_path}"))?; + } + + info!(%cert_path, %hostname, "Server certificate generated and saved"); + + Ok((cert_path, key_path)) + } + + /// Get the CA certificate in PEM format. + pub fn ca_cert_pem(&self) -> &str { + &self.ca_cert_pem + } + + /// Get the CA certificate file path on disk. + pub fn ca_cert_path(&self) -> Utf8PathBuf { + self.data_dir.join(CA_CERT_FILENAME) + } + + /// Build a `rustls::ServerConfig` for the QUIC listener with mTLS client verification. + /// + /// The server certificate is signed by our CA; clients must present a certificate + /// also signed by our CA (mutual TLS). + pub fn build_server_tls_config(&self, hostname: &str) -> anyhow::Result { + use std::io::BufReader; + + use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + + // Ensure rustls crypto provider is installed (ring). + let _ = rustls::crypto::ring::default_provider().install_default(); + + let (cert_path, key_path) = self.ensure_server_cert(hostname)?; + + // Load server certificate chain (server cert + CA cert). + let cert_file = std::fs::File::open(cert_path.as_std_path()) + .with_context(|| format!("open server cert file {cert_path}"))?; + let server_certs: Vec> = rustls_pemfile::certs(&mut BufReader::new(cert_file)) + .collect::, _>>() + .context("parse server certificate PEM")?; + + // Also include CA cert in chain. + let ca_cert_path = self.ca_cert_path(); + let ca_file = std::fs::File::open(ca_cert_path.as_std_path()) + .with_context(|| format!("open CA cert file {ca_cert_path}"))?; + let ca_certs: Vec> = rustls_pemfile::certs(&mut BufReader::new(ca_file)) + .collect::, _>>() + .context("parse CA certificate PEM")?; + + let mut cert_chain = server_certs; + cert_chain.extend(ca_certs.clone()); + + // Load server private key. + let key_file = + std::fs::File::open(key_path.as_std_path()).with_context(|| format!("open server key file {key_path}"))?; + let private_key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut BufReader::new(key_file)) + .context("parse server private key PEM")? + .context("no private key found in PEM file")?; + + // Build root cert store with our CA for client verification. + let mut roots = rustls::RootCertStore::empty(); + for ca_cert in &ca_certs { + roots.add(ca_cert.clone()).context("add CA cert to root store")?; + } + + // Require client certificates signed by our CA. + let client_verifier = rustls::server::WebPkiClientVerifier::builder(roots.into()) + .build() + .context("build client certificate verifier")?; + + let mut tls_config = rustls::ServerConfig::builder() + .with_client_cert_verifier(client_verifier) + .with_single_cert(cert_chain, private_key) + .context("build rustls ServerConfig")?; + + tls_config.alpn_protocols = vec![b"devolutions-agent-tunnel".to_vec()]; + + Ok(tls_config) + } +} + +/// Compute SHA-256 fingerprint of a PEM-encoded certificate (hex string). +pub fn cert_fingerprint_from_pem(pem_str: &str) -> anyhow::Result { + let pem = pem::parse(pem_str).context("parse PEM for fingerprint")?; + let digest = Sha256::digest(pem.contents()); + Ok(hex::encode(digest)) +} + +/// Compute SHA-256 fingerprint of a DER-encoded certificate (hex string). +pub fn cert_fingerprint_from_der(der_bytes: &[u8]) -> String { + let digest = Sha256::digest(der_bytes); + hex::encode(digest) +} + +/// Extract agent_id from a PEM-encoded certificate's SAN (urn:uuid:{id}). +pub fn extract_agent_id_from_pem(pem_str: &str) -> anyhow::Result { + let pem = pem::parse(pem_str).context("parse PEM for agent ID extraction")?; + extract_agent_id_from_der(pem.contents()) +} + +/// Extract the Common Name (CN) from a DER-encoded certificate. +pub fn extract_agent_name_from_der(cert_der: &[u8]) -> anyhow::Result { + let (_, cert) = + x509_parser::parse_x509_certificate(cert_der).map_err(|e| anyhow::anyhow!("parse certificate: {e}"))?; + + for attr in cert.subject().iter_common_name() { + if let Ok(cn) = attr.as_str() { + return Ok(cn.to_owned()); + } + } + + anyhow::bail!("no Common Name found in certificate") +} + +/// Extract agent_id from a DER-encoded certificate's SAN (urn:uuid:{id}). +pub fn extract_agent_id_from_der(der_bytes: &[u8]) -> anyhow::Result { + let (_, cert) = x509_parser::parse_x509_certificate(der_bytes).context("parse X.509 certificate")?; + + for ext in cert.extensions() { + if let x509_parser::extensions::ParsedExtension::SubjectAlternativeName(san) = ext.parsed_extension() { + for name in &san.general_names { + if let x509_parser::extensions::GeneralName::RFC822Name(val) = name + && let Some(uuid_str) = val.strip_prefix("urn:uuid:") + { + return uuid_str.parse().context("parse UUID from SAN"); + } + } + } + } + + anyhow::bail!("no urn:uuid: SAN found in certificate") +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: generate a CSR PEM for testing. + fn generate_test_csr(cn: &str) -> (KeyPair, String) { + let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).expect("generate key pair"); + let mut params = CertificateParams::default(); + params.distinguished_name.push(DnType::CommonName, cn); + let csr = params.serialize_request(&key_pair).expect("serialize CSR"); + (key_pair, csr.pem().expect("CSR to PEM")) + } + + #[test] + fn generate_ca_and_sign_agent_csr() { + let temp_dir = std::env::temp_dir().join(format!("dgw-cert-test-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("temp path should be UTF-8"); + + let ca = CaManager::load_or_generate(&data_dir).expect("CA generation should succeed"); + assert!(ca.ca_cert_pem().contains("BEGIN CERTIFICATE")); + + let agent_id = Uuid::new_v4(); + let (_key_pair, csr_pem) = generate_test_csr("test-agent"); + let signed = ca + .sign_agent_csr(agent_id, "test-agent", &csr_pem) + .expect("sign CSR should succeed"); + + assert!(signed.client_cert_pem.contains("BEGIN CERTIFICATE")); + assert_eq!(signed.ca_cert_pem, ca.ca_cert_pem()); + + // Reload CA from disk. + let ca2 = CaManager::load_or_generate(&data_dir).expect("CA reload should succeed"); + assert_eq!(ca2.ca_cert_pem(), ca.ca_cert_pem()); + + // Sign a CSR from the reloaded CA and verify it works. + let (_key_pair2, csr_pem2) = generate_test_csr("test-agent-2"); + let signed2 = ca2 + .sign_agent_csr(Uuid::new_v4(), "test-agent-2", &csr_pem2) + .expect("sign CSR from reloaded CA should succeed"); + assert!(signed2.client_cert_pem.contains("BEGIN CERTIFICATE")); + + // Fingerprint. + let fp = cert_fingerprint_from_pem(&signed.client_cert_pem).expect("fingerprint should work"); + assert_eq!(fp.len(), 64); // SHA-256 hex = 64 chars + + // Extract agent_id from PEM. + let extracted_id = extract_agent_id_from_pem(&signed.client_cert_pem).expect("agent ID extraction should work"); + assert_eq!(extracted_id, agent_id); + + // Server certificate. + let (server_cert_path, server_key_path) = ca + .ensure_server_cert("test-gateway.local") + .expect("server cert should succeed"); + assert!(server_cert_path.exists()); + assert!(server_key_path.exists()); + + // Cleanup. + let _ = std::fs::remove_dir_all(&temp_dir); + } + + #[test] + fn sign_csr_produces_valid_cert() { + let temp_dir = std::env::temp_dir().join(format!("dgw-cert-test-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("temp path should be UTF-8"); + let ca = CaManager::load_or_generate(&data_dir).expect("CA generation should succeed"); + + let agent_id = Uuid::new_v4(); + let (_key_pair, csr_pem) = generate_test_csr("csr-test-agent"); + + let signed = ca + .sign_agent_csr(agent_id, "csr-test-agent", &csr_pem) + .expect("sign CSR should succeed"); + + assert!(signed.client_cert_pem.contains("BEGIN CERTIFICATE")); + + // Verify the cert contains the agent UUID in SAN. + let extracted_id = + extract_agent_id_from_pem(&signed.client_cert_pem).expect("should extract agent ID from signed cert"); + assert_eq!(extracted_id, agent_id); + + let _ = std::fs::remove_dir_all(&temp_dir); + } + + #[test] + fn sign_csr_rejects_tampered_csr() { + let temp_dir = std::env::temp_dir().join(format!("dgw-cert-test-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("temp path should be UTF-8"); + let ca = CaManager::load_or_generate(&data_dir).expect("CA generation should succeed"); + + let (_key_pair, csr_pem) = generate_test_csr("tampered-agent"); + + // Decode PEM, flip a byte in the DER, re-encode. + let parsed = pem::parse(&csr_pem).expect("parse CSR PEM"); + let mut der_bytes = parsed.contents().to_vec(); + // Flip a byte near the end (in the signature area). + let len = der_bytes.len(); + der_bytes[len - 2] ^= 0xFF; + let tampered_pem = pem::encode(&pem::Pem::new("CERTIFICATE REQUEST", der_bytes)); + + let result = ca.sign_agent_csr(Uuid::new_v4(), "tampered-agent", &tampered_pem); + assert!(result.is_err(), "tampered CSR should be rejected"); + + let _ = std::fs::remove_dir_all(&temp_dir); + } +} diff --git a/devolutions-gateway/src/agent_tunnel/enrollment_store.rs b/devolutions-gateway/src/agent_tunnel/enrollment_store.rs new file mode 100644 index 000000000..6a27b50e2 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/enrollment_store.rs @@ -0,0 +1,126 @@ +//! In-memory store for one-time enrollment tokens. +//! +//! Tokens are generated by the webapp enrollment-string endpoint and consumed +//! by the agent enrollment endpoint. Each token is single-use and has an expiry. + +use std::time::{SystemTime, UNIX_EPOCH}; + +use dashmap::DashMap; + +/// Default token lifetime if not specified: 1 hour. +const DEFAULT_TOKEN_LIFETIME_SECS: u64 = 3600; + +/// A single enrollment token entry. +#[derive(Debug, Clone)] +pub struct EnrollmentTokenEntry { + /// When this token expires (UNIX timestamp in seconds). + pub expires_at: u64, + /// Optional agent name hint associated with this token. + pub agent_name: Option, +} + +/// Thread-safe in-memory store for one-time enrollment tokens. +/// +/// Tokens are stored in a `DashMap` keyed by the token string. +/// They are removed on consumption (one-time use) or on explicit cleanup. +#[derive(Debug)] +pub struct EnrollmentTokenStore { + tokens: DashMap, +} + +impl EnrollmentTokenStore { + /// Creates a new, empty token store. + pub fn new() -> Self { + Self { tokens: DashMap::new() } + } + + /// Inserts a new enrollment token. + /// + /// Also cleans up any expired tokens to prevent unbounded growth. + pub fn insert(&self, token: String, agent_name: Option, lifetime_secs: Option) { + self.cleanup_expired(); + + let lifetime = lifetime_secs.unwrap_or(DEFAULT_TOKEN_LIFETIME_SECS); + let now = current_time_secs(); + let expires_at = now + lifetime; + + self.tokens + .insert(token, EnrollmentTokenEntry { expires_at, agent_name }); + } + + /// Consumes a token if it exists and is not expired. + /// + /// Returns `true` if the token was valid and has been consumed (removed). + /// Returns `false` if the token doesn't exist or is expired. + pub fn consume(&self, token: &str) -> bool { + let now = current_time_secs(); + + if let Some((_, entry)) = self.tokens.remove(token) + && entry.expires_at > now + { + return true; + } + + false + } + + /// Removes all expired tokens from the store. + pub fn cleanup_expired(&self) { + let now = current_time_secs(); + self.tokens.retain(|_, entry| entry.expires_at > now); + } +} + +impl Default for EnrollmentTokenStore { + fn default() -> Self { + Self::new() + } +} + +fn current_time_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn insert_and_consume() { + let store = EnrollmentTokenStore::new(); + store.insert("tok-123".to_owned(), Some("my-agent".to_owned()), Some(3600)); + + assert!(store.consume("tok-123")); + // Second consume should fail (one-time use). + assert!(!store.consume("tok-123")); + } + + #[test] + fn consume_nonexistent_returns_false() { + let store = EnrollmentTokenStore::new(); + assert!(!store.consume("does-not-exist")); + } + + #[test] + fn expired_token_not_consumable() { + let store = EnrollmentTokenStore::new(); + // Insert with 0 lifetime → already expired. + store.insert("expired-tok".to_owned(), None, Some(0)); + assert!(!store.consume("expired-tok")); + } + + #[test] + fn cleanup_removes_expired() { + let store = EnrollmentTokenStore::new(); + store.insert("expired".to_owned(), None, Some(0)); + store.insert("valid".to_owned(), None, Some(3600)); + + store.cleanup_expired(); + + assert!(!store.consume("expired")); + assert!(store.consume("valid")); + } +} diff --git a/devolutions-gateway/src/agent_tunnel/listener.rs b/devolutions-gateway/src/agent_tunnel/listener.rs new file mode 100644 index 000000000..eea89ee77 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/listener.rs @@ -0,0 +1,336 @@ +//! QUIC listener for agent tunnel connections (Quinn-based). +//! +//! Manages a QUIC endpoint using Quinn, accepts connections from agents with mTLS, +//! processes control messages (route advertisements, heartbeats), and +//! creates proxy streams on demand. + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use agent_tunnel_proto::{ConnectMessage, ConnectResponse, ControlMessage}; +use anyhow::Context as _; +use async_trait::async_trait; +use dashmap::DashMap; +use uuid::Uuid; + +use super::cert::CaManager; +use super::enrollment_store::EnrollmentTokenStore; +use super::registry::{AgentPeer, AgentRegistry}; +use super::stream::TunnelStream; + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Handle for external code to interact with the running agent tunnel. +/// +/// Cloneable and safe to share across tasks. +#[derive(Clone)] +pub struct AgentTunnelHandle { + registry: Arc, + /// Map of agent_id → live Quinn connection, used for opening new streams. + agent_connections: Arc>, + ca_manager: Arc, + enrollment_token_store: Arc, +} + +impl AgentTunnelHandle { + pub fn registry(&self) -> &AgentRegistry { + &self.registry + } + + pub fn ca_manager(&self) -> &CaManager { + &self.ca_manager + } + + pub fn enrollment_token_store(&self) -> &EnrollmentTokenStore { + &self.enrollment_token_store + } + + /// Open a proxy stream through a connected agent. + pub async fn connect_via_agent( + &self, + agent_id: Uuid, + session_id: Uuid, + target: &str, + ) -> anyhow::Result { + let conn = self + .agent_connections + .get(&agent_id) + .map(|entry| entry.value().clone()) + .ok_or_else(|| anyhow::anyhow!("agent {} not connected", agent_id))?; + + let (mut send, mut recv) = conn.open_bi().await.context("open bidirectional stream to agent")?; + + // Send ConnectMessage. + let connect_msg = ConnectMessage::new(session_id, target.to_owned()); + connect_msg + .encode(&mut send) + .await + .map_err(|e| anyhow::anyhow!("encode ConnectMessage: {e}"))?; + + // Read ConnectResponse. + let response = ConnectResponse::decode(&mut recv) + .await + .map_err(|e| anyhow::anyhow!("decode ConnectResponse: {e}"))?; + + if !response.is_success() { + let reason = match &response { + ConnectResponse::Error { reason, .. } => reason.clone(), + _ => "unknown".to_owned(), + }; + anyhow::bail!("agent refused connection: {reason}"); + } + + info!( + %agent_id, + %session_id, + %target, + "Proxy stream established via agent tunnel" + ); + + Ok(TunnelStream { send, recv }) + } +} + +// --------------------------------------------------------------------------- +// Listener task +// --------------------------------------------------------------------------- + +pub struct AgentTunnelListener { + endpoint: quinn::Endpoint, + registry: Arc, + agent_connections: Arc>, +} + +impl AgentTunnelListener { + pub async fn bind( + listen_addr: SocketAddr, + ca_manager: Arc, + hostname: &str, + ) -> anyhow::Result<(Self, AgentTunnelHandle)> { + let tls_config = ca_manager + .build_server_tls_config(hostname) + .context("build server TLS config")?; + + let quic_server_config = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config)) + .context("create QUIC server config from TLS config")?; + + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_server_config)); + + // Configure transport parameters. + let mut transport = quinn::TransportConfig::default(); + transport.max_idle_timeout(Some( + Duration::from_secs(120) + .try_into() + .expect("120s should be a valid idle timeout"), + )); + transport.keep_alive_interval(Some(Duration::from_secs(15))); + transport.max_concurrent_bidi_streams(100u32.into()); + server_config.transport_config(Arc::new(transport)); + + let endpoint = quinn::Endpoint::server(server_config, listen_addr) + .with_context(|| format!("bind QUIC endpoint on {listen_addr}"))?; + + info!(%listen_addr, "Agent tunnel QUIC endpoint bound"); + + let registry = Arc::new(AgentRegistry::new()); + let agent_connections: Arc> = Arc::new(DashMap::new()); + let enrollment_token_store = Arc::new(EnrollmentTokenStore::new()); + + let handle = AgentTunnelHandle { + registry: Arc::clone(®istry), + agent_connections: Arc::clone(&agent_connections), + ca_manager, + enrollment_token_store, + }; + + let listener = Self { + endpoint, + registry, + agent_connections, + }; + + Ok((listener, handle)) + } +} + +#[async_trait] +impl devolutions_gateway_task::Task for AgentTunnelListener { + type Output = anyhow::Result<()>; + const NAME: &'static str = "agent-tunnel-listener"; + + async fn run(self, mut shutdown_signal: devolutions_gateway_task::ShutdownSignal) -> anyhow::Result<()> { + let local_addr = self.endpoint.local_addr()?; + info!(%local_addr, "Agent tunnel listener started"); + + loop { + tokio::select! { + biased; + + _ = shutdown_signal.wait() => { + info!("Agent tunnel listener shutting down"); + self.endpoint.close(0u32.into(), b"shutdown"); + break; + } + + incoming = self.endpoint.accept() => { + let Some(incoming) = incoming else { + info!("QUIC endpoint closed"); + break; + }; + + let registry = Arc::clone(&self.registry); + let agent_connections = Arc::clone(&self.agent_connections); + + tokio::spawn(async move { + if let Err(e) = handle_agent_connection(registry, agent_connections, incoming).await { + warn!(error = format!("{e:#}"), "Agent connection handler failed"); + } + }); + } + } + } + + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Per-connection handler +// --------------------------------------------------------------------------- + +async fn handle_agent_connection( + registry: Arc, + agent_connections: Arc>, + incoming: quinn::Incoming, +) -> anyhow::Result<()> { + let peer_addr = incoming.remote_address(); + info!(%peer_addr, "Accepting new QUIC connection"); + + let conn = incoming.await.context("QUIC handshake failed")?; + + // Extract peer certificate to identify the agent. + let peer_identity = conn.peer_identity().context("no peer identity after handshake")?; + + let peer_certs = peer_identity + .downcast::>>() + .map_err(|_| anyhow::anyhow!("unexpected peer identity type"))?; + + let peer_cert_der = peer_certs.first().context("no peer certificate in chain")?; + + let agent_id = + super::cert::extract_agent_id_from_der(peer_cert_der).context("extract agent_id from peer certificate")?; + + let agent_name = + super::cert::extract_agent_name_from_der(peer_cert_der).unwrap_or_else(|_| format!("agent-{agent_id}")); + + let fingerprint = super::cert::cert_fingerprint_from_der(peer_cert_der); + + info!(%agent_id, %agent_name, %peer_addr, "Agent authenticated via mTLS"); + + let peer = Arc::new(AgentPeer::new(agent_id, agent_name, fingerprint)); + registry.register(Arc::clone(&peer)); + agent_connections.insert(agent_id, conn.clone()); + + // Accept the first bidirectional stream as the control stream. + let control_result = handle_control_stream(&conn, agent_id, ®istry).await; + + // Agent disconnected — clean up. + info!(%agent_id, "Agent QUIC connection closed"); + registry.unregister(&agent_id); + agent_connections.remove(&agent_id); + + control_result +} + +async fn handle_control_stream( + conn: &quinn::Connection, + agent_id: Uuid, + registry: &AgentRegistry, +) -> anyhow::Result<()> { + let (mut control_send, mut control_recv) = conn.accept_bi().await.context("accept control stream")?; + + info!(%agent_id, "Control stream accepted"); + + loop { + tokio::select! { + // Read control messages from the agent. + msg_result = ControlMessage::decode(&mut control_recv) => { + let msg = match msg_result { + Ok(msg) => msg, + Err(agent_tunnel_proto::ProtoError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + debug!(%agent_id, "Control stream EOF"); + break; + } + Err(e) => { + warn!(%agent_id, error = %e, "Control stream decode error"); + break; + } + }; + + handle_control_message(registry, agent_id, &mut control_send, msg).await; + } + + // Detect connection close. + reason = conn.closed() => { + debug!(%agent_id, ?reason, "QUIC connection closed"); + break; + } + } + } + + Ok(()) +} + +async fn handle_control_message( + registry: &AgentRegistry, + agent_id: Uuid, + control_send: &mut quinn::SendStream, + msg: ControlMessage, +) { + match msg { + ControlMessage::RouteAdvertise { + protocol_version, + epoch, + subnets, + domains, + .. + } => { + if let Err(e) = agent_tunnel_proto::validate_protocol_version(protocol_version) { + warn!(%agent_id, %protocol_version, %e, "Rejecting route advertisement: unsupported protocol version"); + return; + } + info!( + %agent_id, + epoch, + subnet_count = subnets.len(), + domain_count = domains.len(), + "Received route advertisement" + ); + if let Some(peer) = registry.get(&agent_id) { + peer.update_routes(epoch, subnets, domains); + peer.touch(); + } + } + ControlMessage::Heartbeat { + timestamp_ms, + active_stream_count, + .. + } => { + debug!(%agent_id, timestamp_ms, active_stream_count, "Received heartbeat"); + if let Some(peer) = registry.get(&agent_id) { + peer.touch(); + } + + let ack = ControlMessage::heartbeat_ack(timestamp_ms); + if let Err(e) = ack.encode(control_send).await { + warn!(%agent_id, error = %e, "Failed to send heartbeat ack"); + } + } + ControlMessage::HeartbeatAck { .. } => { + debug!(%agent_id, "Unexpected HeartbeatAck from agent"); + } + } +} diff --git a/devolutions-gateway/src/agent_tunnel/mod.rs b/devolutions-gateway/src/agent_tunnel/mod.rs new file mode 100644 index 000000000..aa4b094eb --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/mod.rs @@ -0,0 +1,15 @@ +//! QUIC-based agent tunnel (Quinn). +//! +//! Provides a reliable, multiplexed tunnel between the gateway and remote agents +//! using QUIC with mutual TLS authentication. + +pub mod cert; +pub mod enrollment_store; +pub mod listener; +pub mod registry; +pub mod stream; + +pub use enrollment_store::EnrollmentTokenStore; +pub use listener::{AgentTunnelHandle, AgentTunnelListener}; +pub use registry::AgentRegistry; +pub use stream::TunnelStream; diff --git a/devolutions-gateway/src/agent_tunnel/registry.rs b/devolutions-gateway/src/agent_tunnel/registry.rs new file mode 100644 index 000000000..439e183fc --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/registry.rs @@ -0,0 +1,773 @@ +use std::net::IpAddr; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use agent_tunnel_proto::DomainAdvertisement; +use dashmap::DashMap; +use ipnetwork::Ipv4Network; +use parking_lot::RwLock; +use serde::Serialize; +use uuid::Uuid; + +/// Duration after which an agent is considered offline if no heartbeat has been received. +pub const AGENT_OFFLINE_TIMEOUT: Duration = Duration::from_secs(90); + +/// Tracks route advertisements received from an agent. +/// +/// The epoch-based update protocol works as follows: +/// - A higher epoch replaces the entire route set (new process or config reload). +/// - The same epoch only refreshes `updated_at` (periodic re-advertisement). +#[derive(Debug, Clone)] +pub struct RouteAdvertisementState { + /// Monotonically increasing epoch within an agent process lifetime. + pub epoch: u64, + /// IPv4 subnets this agent can reach. + pub subnets: Vec, + /// DNS domains this agent can resolve, with source tracking. + pub domains: Vec, + /// When this route set was first received (used for tie-breaking). + pub received_at: SystemTime, + /// Last time this route set was refreshed. + pub updated_at: SystemTime, +} + +/// Represents a QUIC-connected agent peer tracked by the gateway. +#[derive(Debug)] +pub struct AgentPeer { + /// Unique identifier for this agent. + pub agent_id: Uuid, + /// Human-readable name of the agent. + pub name: String, + /// SHA-256 fingerprint of the agent's client certificate. + pub cert_fingerprint: String, + /// Last heartbeat timestamp in milliseconds since UNIX epoch (updated atomically). + pub(crate) last_seen: AtomicU64, + /// Current route advertisement state, if any. + route_state: RwLock>, +} + +impl AgentPeer { + /// Creates a new agent peer with the current time as last_seen. + pub fn new(agent_id: Uuid, name: String, cert_fingerprint: String) -> Self { + let now_ms = current_time_millis(); + Self { + agent_id, + name, + cert_fingerprint, + last_seen: AtomicU64::new(now_ms), + route_state: RwLock::new(None), + } + } + + /// Updates the last-seen timestamp to the current time. + pub fn touch(&self) { + let now_ms = current_time_millis(); + self.last_seen.store(now_ms, Ordering::Release); + } + + /// Returns the last-seen timestamp as milliseconds since UNIX epoch. + pub fn last_seen_ms(&self) -> u64 { + self.last_seen.load(Ordering::Acquire) + } + + /// Checks whether this agent is considered online. + /// + /// An agent is online if the elapsed time since `last_seen` is less than `timeout`. + pub fn is_online(&self, timeout: Duration) -> bool { + let last_ms = self.last_seen.load(Ordering::Acquire); + let now_ms = current_time_millis(); + // Saturating subtraction handles clock skew gracefully. + let elapsed_ms = now_ms.saturating_sub(last_ms); + elapsed_ms < u64::try_from(timeout.as_millis()).expect("timeout in milliseconds should fit in u64") + } + + /// Returns a clone of the current route advertisement state, if any. + pub fn route_state(&self) -> Option { + self.route_state.read().clone() + } + + /// Updates the route advertisement state using epoch-based logic. + /// + /// - If `epoch` is greater than the current epoch, the route set is replaced entirely + /// and both `received_at` and `updated_at` are set to now. + /// - If `epoch` equals the current epoch, only `updated_at` is refreshed (re-advertisement). + /// - If `epoch` is less than the current epoch, the update is ignored (stale). + pub fn update_routes(&self, epoch: u64, subnets: Vec, domains: Vec) { + let mut state = self.route_state.write(); + let now = SystemTime::now(); + + match state.as_ref() { + Some(current) if epoch < current.epoch => { + // Stale epoch; ignore. + debug!( + agent_id = %self.agent_id, + received_epoch = epoch, + current_epoch = current.epoch, + "Ignoring stale route advertisement" + ); + } + Some(current) if epoch == current.epoch => { + // Same epoch: refresh timestamp only, do not replace subnets or domains. + debug!( + agent_id = %self.agent_id, + epoch, + subnet_count = subnets.len(), + domain_count = current.domains.len(), + "Refreshing route advertisement (same epoch)" + ); + *state = Some(RouteAdvertisementState { + epoch, + subnets: current.subnets.clone(), + domains: current.domains.clone(), + received_at: current.received_at, + updated_at: now, + }); + } + _ => { + // New epoch (or first advertisement): replace everything. + info!( + agent_id = %self.agent_id, + epoch, + subnet_count = subnets.len(), + domain_count = domains.len(), + "Accepted new route advertisement" + ); + *state = Some(RouteAdvertisementState { + epoch, + subnets, + domains, + received_at: now, + updated_at: now, + }); + } + } + } + + /// Returns `true` if this agent can route traffic to the given IP address. + pub fn can_reach(&self, target_ip: IpAddr) -> bool { + self.route_state + .read() + .as_ref() + .map(|route_state| match target_ip { + IpAddr::V4(ipv4) => route_state.subnets.iter().any(|subnet| subnet.contains(ipv4)), + IpAddr::V6(_) => false, + }) + .unwrap_or(false) + } +} + +/// Thread-safe registry of online QUIC-connected agents. +/// +/// Agents are indexed by their `Uuid`. The registry supports concurrent reads and writes +/// through `DashMap`, and provides route-based agent lookup for proxy target resolution. +#[derive(Debug, Clone)] +pub struct AgentRegistry { + agents: Arc>>, +} + +impl AgentRegistry { + /// Creates a new, empty agent registry. + pub fn new() -> Self { + Self { + agents: Arc::new(DashMap::new()), + } + } + + /// Registers a new agent peer. If an agent with the same ID already exists, it is replaced. + pub fn register(&self, peer: Arc) { + info!( + agent_id = %peer.agent_id, + name = %peer.name, + "Agent registered" + ); + self.agents.insert(peer.agent_id, peer); + } + + /// Removes an agent from the registry by ID. + pub fn unregister(&self, agent_id: &Uuid) -> Option> { + let removed = self.agents.remove(agent_id).map(|(_, peer)| peer); + if let Some(ref peer) = removed { + info!( + agent_id = %peer.agent_id, + name = %peer.name, + "Agent unregistered" + ); + } + removed + } + + /// Looks up an agent by ID. + pub fn get(&self, agent_id: &Uuid) -> Option> { + self.agents.get(agent_id).map(|entry| Arc::clone(entry.value())) + } + + /// Returns the number of agents currently in the registry (including offline ones). + pub fn len(&self) -> usize { + self.agents.len() + } + + /// Returns `true` when no agent is registered. + pub fn is_empty(&self) -> bool { + self.agents.is_empty() + } + + /// Returns the number of agents considered online. + pub fn online_count(&self) -> usize { + self.agents + .iter() + .filter(|entry| entry.value().is_online(AGENT_OFFLINE_TIMEOUT)) + .count() + } + + /// Finds all online agents whose advertised subnets include the given target IP. + /// + /// Results are sorted by `received_at` in descending order (most recently received first). + pub fn find_agents_for_target(&self, target_ip: IpAddr) -> Vec> { + let mut candidates: Vec<(SystemTime, Arc)> = self + .agents + .iter() + .filter(|entry| entry.value().is_online(AGENT_OFFLINE_TIMEOUT)) + .filter_map(|entry| { + let agent = Arc::clone(entry.value()); + let route_state = agent.route_state()?; + let matches = match target_ip { + IpAddr::V4(ipv4) => route_state.subnets.iter().any(|subnet| subnet.contains(ipv4)), + IpAddr::V6(_) => false, + }; + + if matches { + Some((route_state.received_at, agent)) + } else { + None + } + }) + .collect(); + + // Sort by received_at descending (most recent first). + candidates.sort_by(|a, b| b.0.cmp(&a.0)); + + candidates.into_iter().map(|(_, agent)| agent).collect() + } + + /// Selects a single online agent that can route to the given target IP. + /// + /// When multiple agents match, the one with the most recent `received_at` wins. + pub fn select_agent_for_target(&self, target_ip: IpAddr) -> Option> { + self.find_agents_for_target(target_ip).into_iter().next() + } + + /// Finds all online agents whose advertised domains match the given hostname via suffix match. + /// + /// Uses longest suffix match: if agent-A advertises "contoso.local" and agent-B advertises + /// "finance.contoso.local", hostname "db01.finance.contoso.local" matches agent-B only. + /// + /// Results are sorted by `received_at` descending (most recently received first). + pub fn select_agents_for_domain(&self, hostname: &str) -> Vec> { + let hostname_lower = hostname.to_ascii_lowercase(); + + let mut best_suffix_len: usize = 0; + let mut candidates: Vec<(SystemTime, Arc)> = Vec::new(); + + for entry in self.agents.iter() { + let agent = entry.value(); + if !agent.is_online(AGENT_OFFLINE_TIMEOUT) { + continue; + } + + let route_state = match agent.route_state() { + Some(rs) => rs, + None => continue, + }; + + for domain_adv in &route_state.domains { + let domain_lower = domain_adv.domain.to_ascii_lowercase(); + let matches = hostname_lower == domain_lower + || (hostname_lower.len() > domain_lower.len() + && hostname_lower.as_bytes()[hostname_lower.len() - domain_lower.len() - 1] == b'.' + && hostname_lower.ends_with(domain_lower.as_str())); + + if matches { + if best_suffix_len < domain_lower.len() { + best_suffix_len = domain_lower.len(); + candidates.clear(); + candidates.push((route_state.received_at, Arc::clone(agent))); + } else if domain_lower.len() == best_suffix_len { + candidates.push((route_state.received_at, Arc::clone(agent))); + } + } + } + } + + candidates.sort_by(|a, b| b.0.cmp(&a.0)); + candidates.into_iter().map(|(_, agent)| agent).collect() + } + + /// Returns information about a single agent by ID. + pub fn agent_info(&self, agent_id: &Uuid) -> Option { + self.agents.get(agent_id).map(|entry| AgentInfo::from(entry.value())) + } + + /// Collects information about all registered agents for API responses. + pub fn agent_infos(&self) -> Vec { + self.agents.iter().map(|entry| AgentInfo::from(entry.value())).collect() + } +} + +impl Default for AgentRegistry { + fn default() -> Self { + Self::new() + } +} + +/// Domain info with source tracking for API responses. +#[derive(Debug, Clone, Serialize)] +pub struct DomainInfo { + pub domain: String, + pub auto_detected: bool, +} + +/// Serializable snapshot of an agent's state, suitable for API responses. +#[derive(Debug, Clone, Serialize)] +pub struct AgentInfo { + pub agent_id: Uuid, + pub name: String, + pub cert_fingerprint: String, + pub is_online: bool, + pub last_seen_ms: u64, + pub subnets: Vec, + pub domains: Vec, + pub route_epoch: Option, +} + +impl From<&Arc> for AgentInfo { + fn from(agent: &Arc) -> Self { + let route_state = agent.route_state(); + Self { + agent_id: agent.agent_id, + name: agent.name.clone(), + cert_fingerprint: agent.cert_fingerprint.clone(), + is_online: agent.is_online(AGENT_OFFLINE_TIMEOUT), + last_seen_ms: agent.last_seen_ms(), + subnets: route_state + .as_ref() + .map(|rs| rs.subnets.iter().map(ToString::to_string).collect()) + .unwrap_or_default(), + domains: route_state + .as_ref() + .map(|rs| { + rs.domains + .iter() + .map(|d| DomainInfo { + domain: d.domain.clone(), + auto_detected: d.auto_detected, + }) + .collect() + }) + .unwrap_or_default(), + route_epoch: route_state.as_ref().map(|rs| rs.epoch), + } + } +} + +/// Returns the current time as milliseconds since UNIX epoch. +fn current_time_millis() -> u64 { + u64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_millis(), + ) + .expect("millisecond timestamp should fit in u64") +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_peer(name: &str) -> Arc { + Arc::new(AgentPeer::new( + Uuid::new_v4(), + String::from(name), + String::from("sha256:deadbeef"), + )) + } + + #[test] + fn register_and_lookup() { + let registry = AgentRegistry::new(); + let peer = make_peer("test-agent"); + let agent_id = peer.agent_id; + + registry.register(Arc::clone(&peer)); + assert_eq!(registry.len(), 1); + + let found = registry.get(&agent_id).expect("agent should be found"); + assert_eq!(found.agent_id, agent_id); + } + + #[test] + fn unregister_removes_agent() { + let registry = AgentRegistry::new(); + let peer = make_peer("test-agent"); + let agent_id = peer.agent_id; + + registry.register(Arc::clone(&peer)); + let removed = registry.unregister(&agent_id); + assert!(removed.is_some()); + assert_eq!(registry.len(), 0); + assert!(registry.get(&agent_id).is_none()); + } + + #[test] + fn is_online_within_timeout() { + let peer = make_peer("online-agent"); + peer.touch(); + assert!(peer.is_online(AGENT_OFFLINE_TIMEOUT)); + } + + #[test] + fn is_offline_after_timeout() { + let peer = AgentPeer::new( + Uuid::new_v4(), + String::from("offline-agent"), + String::from("sha256:deadbeef"), + ); + // Simulate a very old last_seen timestamp. + peer.last_seen.store(0, Ordering::Release); + assert!(!peer.is_online(AGENT_OFFLINE_TIMEOUT)); + } + + #[test] + fn update_routes_new_epoch_replaces() { + let peer = make_peer("route-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![]); + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.epoch, 1); + assert_eq!(state.subnets.len(), 1); + + let new_subnet: Ipv4Network = "192.168.0.0/16".parse().expect("valid CIDR"); + peer.update_routes(2, vec![new_subnet], vec![]); + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.epoch, 2); + assert_eq!(state.subnets.len(), 1); + assert_eq!(state.subnets[0], new_subnet); + } + + #[test] + fn update_routes_same_epoch_refreshes_only() { + let peer = make_peer("refresh-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![]); + let state_before = peer.route_state().expect("route state should exist"); + let received_at_before = state_before.received_at; + + // Same epoch with different subnets should NOT replace subnets. + let different_subnet: Ipv4Network = "172.16.0.0/12".parse().expect("valid CIDR"); + peer.update_routes(1, vec![different_subnet], vec![]); + + let state_after = peer.route_state().expect("route state should exist"); + assert_eq!(state_after.epoch, 1); + // Subnets should remain unchanged (original advertisement). + assert_eq!(state_after.subnets[0], subnet); + // received_at should remain unchanged. + assert_eq!(state_after.received_at, received_at_before); + // updated_at should have been refreshed. + assert!(state_after.updated_at >= state_before.updated_at); + } + + #[test] + fn update_routes_stale_epoch_ignored() { + let peer = make_peer("stale-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(5, vec![subnet], vec![]); + let old_subnet: Ipv4Network = "172.16.0.0/12".parse().expect("valid CIDR"); + peer.update_routes(3, vec![old_subnet], vec![]); + + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.epoch, 5); + assert_eq!(state.subnets[0], subnet); + } + + #[test] + fn can_reach_matching_subnet() { + let peer = make_peer("reachable-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![]); + + assert!(peer.can_reach("10.1.2.3".parse().expect("valid IP"))); + assert!(!peer.can_reach("192.168.1.1".parse().expect("valid IP"))); + } + + #[test] + fn can_reach_returns_false_for_ipv6() { + let peer = make_peer("v4-only-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![]); + + assert!(!peer.can_reach("::1".parse().expect("valid IP"))); + } + + #[test] + fn select_agent_for_target_picks_most_recent() { + let registry = AgentRegistry::new(); + + let agent_a = make_peer("agent-a"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + agent_a.update_routes(1, vec![subnet], vec![]); + registry.register(Arc::clone(&agent_a)); + + // Small delay to ensure different received_at timestamps. + std::thread::sleep(Duration::from_millis(10)); + + let agent_b = make_peer("agent-b"); + agent_b.update_routes(1, vec![subnet], vec![]); + registry.register(Arc::clone(&agent_b)); + + let target: IpAddr = "10.5.5.5".parse().expect("valid IP"); + let winner = registry.select_agent_for_target(target).expect("should find an agent"); + // agent_b was registered later, so its received_at is more recent. + assert_eq!(winner.agent_id, agent_b.agent_id); + } + + #[test] + fn find_agents_for_target_returns_sorted() { + let registry = AgentRegistry::new(); + + let agent_a = make_peer("agent-a"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + agent_a.update_routes(1, vec![subnet], vec![]); + registry.register(Arc::clone(&agent_a)); + + std::thread::sleep(Duration::from_millis(10)); + + let agent_b = make_peer("agent-b"); + agent_b.update_routes(1, vec![subnet], vec![]); + registry.register(Arc::clone(&agent_b)); + + let target: IpAddr = "10.5.5.5".parse().expect("valid IP"); + let agents = registry.find_agents_for_target(target); + assert_eq!(agents.len(), 2); + // Most recent first. + assert_eq!(agents[0].agent_id, agent_b.agent_id); + assert_eq!(agents[1].agent_id, agent_a.agent_id); + } + + #[test] + fn find_agents_excludes_offline() { + let registry = AgentRegistry::new(); + + let agent = make_peer("offline-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + agent.update_routes(1, vec![subnet], vec![]); + // Force agent to appear offline. + agent.last_seen.store(0, Ordering::Release); + registry.register(agent); + + let target: IpAddr = "10.5.5.5".parse().expect("valid IP"); + let agents = registry.find_agents_for_target(target); + assert!(agents.is_empty()); + } + + #[test] + fn agent_infos_snapshot() { + let registry = AgentRegistry::new(); + let peer = make_peer("info-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![]); + registry.register(peer); + + let infos = registry.agent_infos(); + assert_eq!(infos.len(), 1); + assert_eq!(infos[0].name, "info-agent"); + assert!(infos[0].is_online); + assert_eq!(infos[0].subnets, vec!["10.0.0.0/8"]); + assert_eq!(infos[0].route_epoch, Some(1)); + } + + #[test] + fn online_count_accuracy() { + let registry = AgentRegistry::new(); + + let online_agent = make_peer("online"); + registry.register(Arc::clone(&online_agent)); + + let offline_agent = make_peer("offline"); + offline_agent.last_seen.store(0, Ordering::Release); + registry.register(offline_agent); + + assert_eq!(registry.len(), 2); + assert_eq!(registry.online_count(), 1); + } + + #[test] + fn default_trait_creates_empty_registry() { + let registry = AgentRegistry::default(); + assert_eq!(registry.len(), 0); + } + + // ── Domain routing tests ────────────────────────────────────────── + + fn domain(name: &str, auto: bool) -> DomainAdvertisement { + DomainAdvertisement { + domain: name.to_owned(), + auto_detected: auto, + } + } + + #[test] + fn update_routes_stores_domains_with_source() { + let peer = make_peer("domain-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.domains.len(), 1); + assert_eq!(state.domains[0].domain, "contoso.local"); + assert!(!state.domains[0].auto_detected); + } + + #[test] + fn update_routes_new_epoch_replaces_domains() { + let peer = make_peer("domain-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![domain("old.local", false)]); + peer.update_routes(2, vec![subnet], vec![domain("new.local", true)]); + + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.epoch, 2); + assert_eq!(state.domains[0].domain, "new.local"); + assert!(state.domains[0].auto_detected); + } + + #[test] + fn update_routes_same_epoch_preserves_domains() { + let peer = make_peer("domain-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + peer.update_routes(1, vec![subnet], vec![domain("different.local", true)]); + + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.domains[0].domain, "contoso.local"); + assert!(!state.domains[0].auto_detected); + } + + #[test] + fn select_agent_for_domain_suffix_match() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + registry.register(peer); + + let agents = registry.select_agents_for_domain("dc01.contoso.local"); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, agent_id); + } + + #[test] + fn select_agent_for_domain_no_match() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + registry.register(peer); + + let agents = registry.select_agents_for_domain("dc01.other.local"); + assert!(agents.is_empty()); + } + + #[test] + fn select_agent_for_domain_longest_suffix_wins() { + let registry = AgentRegistry::new(); + + let agent_a = make_peer("agent-a"); + let id_a = agent_a.agent_id; + let subnet_a: Ipv4Network = "10.1.0.0/16".parse().expect("valid CIDR"); + agent_a.update_routes(1, vec![subnet_a], vec![domain("contoso.local", false)]); + registry.register(agent_a); + + let agent_b = make_peer("agent-b"); + let id_b = agent_b.agent_id; + let subnet_b: Ipv4Network = "10.2.0.0/16".parse().expect("valid CIDR"); + agent_b.update_routes(1, vec![subnet_b], vec![domain("finance.contoso.local", false)]); + registry.register(agent_b); + + let agents = registry.select_agents_for_domain("db01.finance.contoso.local"); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, id_b); + + let agents = registry.select_agents_for_domain("dc01.contoso.local"); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, id_a); + } + + #[test] + fn select_agent_for_domain_multiple_agents_same_domain() { + let registry = AgentRegistry::new(); + + let agent_a = make_peer("agent-a"); + let subnet_a: Ipv4Network = "10.1.0.0/16".parse().expect("valid CIDR"); + agent_a.update_routes(1, vec![subnet_a], vec![domain("contoso.local", false)]); + registry.register(Arc::clone(&agent_a)); + + std::thread::sleep(Duration::from_millis(10)); + + let agent_b = make_peer("agent-b"); + let id_b = agent_b.agent_id; + let subnet_b: Ipv4Network = "10.2.0.0/16".parse().expect("valid CIDR"); + agent_b.update_routes(1, vec![subnet_b], vec![domain("contoso.local", false)]); + registry.register(Arc::clone(&agent_b)); + + let agents = registry.select_agents_for_domain("dc01.contoso.local"); + assert_eq!(agents.len(), 2); + assert_eq!(agents[0].agent_id, id_b); + } + + #[test] + fn select_agent_for_domain_excludes_offline() { + let registry = AgentRegistry::new(); + + let agent = make_peer("offline-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + agent.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + agent.last_seen.store(0, Ordering::Release); + registry.register(agent); + + let agents = registry.select_agents_for_domain("dc01.contoso.local"); + assert!(agents.is_empty()); + } + + #[test] + fn select_agent_for_domain_exact_match() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + registry.register(peer); + + let agents = registry.select_agents_for_domain("contoso.local"); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, agent_id); + } + + #[test] + fn select_agent_for_domain_bare_hostname_no_match() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + registry.register(peer); + + let agents = registry.select_agents_for_domain("server01"); + assert!(agents.is_empty()); + } +} diff --git a/devolutions-gateway/src/agent_tunnel/stream.rs b/devolutions-gateway/src/agent_tunnel/stream.rs new file mode 100644 index 000000000..f979e22a8 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/stream.rs @@ -0,0 +1,37 @@ +//! Wrapper around Quinn's `SendStream` + `RecvStream` providing a single +//! `AsyncRead + AsyncWrite` type for use with the gateway's proxy infrastructure. + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A bidirectional QUIC stream backed by Quinn's `SendStream` and `RecvStream`. +/// +/// Implements `AsyncRead` (delegating to `recv`) and `AsyncWrite` (delegating +/// to `send`), so callers can treat it as a single bidirectional transport. +pub struct TunnelStream { + pub send: quinn::SendStream, + pub recv: quinn::RecvStream, +} + +impl AsyncRead for TunnelStream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + AsyncRead::poll_read(Pin::new(&mut self.recv), cx, buf) + } +} + +impl AsyncWrite for TunnelStream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.send), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.send), cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut self.send), cx) + } +} diff --git a/devolutions-gateway/src/api/agent_enrollment.rs b/devolutions-gateway/src/api/agent_enrollment.rs new file mode 100644 index 000000000..f65f6dc43 --- /dev/null +++ b/devolutions-gateway/src/api/agent_enrollment.rs @@ -0,0 +1,302 @@ +use std::net::IpAddr; + +use axum::extract::{Path, State}; +use axum::http::HeaderMap; +use axum::{Json, Router}; +use uuid::Uuid; + +use crate::DgwState; +use crate::extract::{AgentManagementReadAccess, AgentManagementWriteAccess}; +use crate::http::HttpError; + +/// Timing-safe byte comparison to prevent side-channel attacks on secret comparison. +/// +/// Both inputs are hashed with SHA-256 first, producing fixed 32-byte digests. +/// The digest comparison runs in constant time (fixed-length XOR fold). +/// SHA-256 itself runs in time proportional to input length, but this only +/// reveals the length of the attacker's guess — not the secret's length or content. +fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { + use sha2::{Digest, Sha256}; + let da = Sha256::digest(a); + let db = Sha256::digest(b); + da.iter().zip(db.iter()).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0 +} + +#[derive(Deserialize)] +pub struct EnrollRequest { + /// Friendly name for the agent. + pub agent_name: String, + /// PEM-encoded Certificate Signing Request from the agent. + pub csr_pem: String, +} + +#[derive(Serialize)] +pub struct EnrollResponse { + /// Assigned agent ID. + pub agent_id: Uuid, + /// Agent name. + pub agent_name: String, + /// PEM-encoded client certificate (signed by the gateway CA). + pub client_cert_pem: String, + /// PEM-encoded gateway CA certificate (for server verification). + pub gateway_ca_cert_pem: String, + /// QUIC endpoint to connect to (`host:port`). + pub quic_endpoint: String, +} + +pub fn make_router(state: DgwState) -> Router { + Router::new() + .route("/enroll", axum::routing::post(enroll_agent)) + .route("/agents", axum::routing::get(list_agents)) + .route("/agents/{agent_id}", axum::routing::get(get_agent).delete(delete_agent)) + .route("/agents/resolve-target", axum::routing::post(resolve_target)) + .with_state(state) +} + +/// Enroll a new agent. +/// +/// Requires a Bearer token matching the configured enrollment secret +/// or a valid one-time enrollment token from the store. +/// +/// The agent generates its own key pair and sends a CSR. The gateway signs it +/// and returns the certificate. The private key never leaves the agent. +async fn enroll_agent( + State(DgwState { + conf_handle, + agent_tunnel_handle, + .. + }): State, + headers: HeaderMap, + Json(req): Json, +) -> Result, HttpError> { + // Validate agent name: 1-255 printable ASCII characters. + if req.agent_name.is_empty() + || 255 < req.agent_name.len() + || req.agent_name.bytes().any(|b| !(0x20..=0x7E).contains(&b)) + { + return Err(HttpError::bad_request().msg("agent name must be 1-255 printable ASCII characters")); + } + + let conf = conf_handle.get_conf(); + + // Extract the Bearer token. + let auth_header = headers + .get(axum::http::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| HttpError::unauthorized().msg("missing Authorization header"))?; + + let provided_token = auth_header + .strip_prefix("Bearer ") + .ok_or_else(|| HttpError::unauthorized().msg("expected Bearer token"))?; + + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent enrollment is not configured"))?; + + // Try one-time enrollment token from the store first. + let token_valid = handle.enrollment_token_store().consume(provided_token); + + if !token_valid { + // Fall back to the static enrollment secret. + let enrollment_secret = conf + .agent_tunnel + .enrollment_secret + .as_deref() + .ok_or_else(|| HttpError::not_found().msg("agent enrollment is not configured"))?; + + if !constant_time_eq(provided_token.as_bytes(), enrollment_secret.as_bytes()) { + return Err(HttpError::forbidden().msg("invalid enrollment token")); + } + } + + let agent_id = Uuid::new_v4(); + + let signed = handle + .ca_manager() + .sign_agent_csr(agent_id, &req.agent_name, &req.csr_pem) + .map_err(HttpError::bad_request().with_msg("invalid CSR").err())?; + + let quic_endpoint = format!("{}:{}", conf.hostname, conf.agent_tunnel.listen_port); + + info!( + %agent_id, + agent_name = %req.agent_name, + "Agent enrolled successfully", + ); + + Ok(Json(EnrollResponse { + agent_id, + agent_name: req.agent_name, + client_cert_pem: signed.client_cert_pem, + gateway_ca_cert_pem: signed.ca_cert_pem, + quic_endpoint, + })) +} + +/// List connected agents and their status. +async fn list_agents( + State(DgwState { + agent_tunnel_handle, .. + }): State, + _access: AgentManagementReadAccess, +) -> Result>, HttpError> { + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent tunnel not configured"))?; + + let agents = handle.registry().agent_infos(); + + Ok(Json(agents)) +} + +/// Get a single agent by ID. +async fn get_agent( + State(DgwState { + agent_tunnel_handle, .. + }): State, + _access: AgentManagementReadAccess, + Path(agent_id): Path, +) -> Result, HttpError> { + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent tunnel not configured"))?; + + let info = handle + .registry() + .agent_info(&agent_id) + .ok_or_else(|| HttpError::not_found().msg("agent not found"))?; + + Ok(Json(info)) +} + +/// Delete (unregister) an agent by ID. +async fn delete_agent( + State(DgwState { + agent_tunnel_handle, .. + }): State, + _access: AgentManagementWriteAccess, + Path(agent_id): Path, +) -> Result { + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent tunnel not configured"))?; + + handle + .registry() + .unregister(&agent_id) + .ok_or_else(|| HttpError::not_found().msg("agent not found"))?; + + info!(%agent_id, "Agent deleted via API"); + + Ok(axum::http::StatusCode::NO_CONTENT) +} + +#[derive(Deserialize)] +struct ResolveTargetRequest { + target: String, +} + +#[derive(Serialize)] +struct ResolveTargetResponse { + target: String, + target_ip: Option, + reachable_agents: Vec, + target_reachable: bool, +} + +/// Resolve a target string to find which agents can reach it. +async fn resolve_target( + State(DgwState { + agent_tunnel_handle, .. + }): State, + _access: AgentManagementReadAccess, + Json(req): Json, +) -> Result, HttpError> { + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent tunnel not configured"))?; + + let target_ip = parse_target_ip(&req.target); + + // Use the same routing logic as fwd.rs: IP → subnet match, hostname → domain suffix match + let matching_peers = if let Some(ip) = target_ip { + handle.registry().find_agents_for_target(ip) + } else { + let hostname = strip_scheme_and_port(&req.target); + handle.registry().select_agents_for_domain(hostname) + }; + + let reachable_agents: Vec<_> = matching_peers + .iter() + .map(crate::agent_tunnel::registry::AgentInfo::from) + .collect(); + + let target_reachable = !reachable_agents.is_empty(); + + Ok(Json(ResolveTargetResponse { + target: req.target, + target_ip, + reachable_agents, + target_reachable, + })) +} + +/// Strip scheme prefix and port from a target string, returning the bare host. +/// +/// Handles `tcp://host:port`, `http://host:port`, `host:port`, and bare hostnames. +fn strip_scheme_and_port(target: &str) -> &str { + let host_port = target + .strip_prefix("tcp://") + .or_else(|| target.strip_prefix("http://")) + .or_else(|| target.strip_prefix("https://")) + .unwrap_or(target); + + let host = if let Some((h, _port)) = host_port.rsplit_once(':') { + h + } else { + host_port + }; + + // Strip brackets for IPv6 literals like [::1]. + host.strip_prefix('[').and_then(|h| h.strip_suffix(']')).unwrap_or(host) +} + +fn parse_target_ip(target: &str) -> Option { + strip_scheme_and_port(target).parse::().ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_target_ip_bare_ipv4() { + assert_eq!(parse_target_ip("10.0.0.1"), Some("10.0.0.1".parse().expect("test"))); + } + + #[test] + fn parse_target_ip_with_port() { + assert_eq!( + parse_target_ip("10.0.0.1:3389"), + Some("10.0.0.1".parse().expect("test")) + ); + } + + #[test] + fn parse_target_ip_tcp_scheme() { + assert_eq!( + parse_target_ip("tcp://192.168.1.1:22"), + Some("192.168.1.1".parse().expect("test")) + ); + } + + #[test] + fn parse_target_ip_hostname_returns_none() { + assert_eq!(parse_target_ip("myserver.local:3389"), None); + } + + #[test] + fn parse_target_ip_bare_hostname_returns_none() { + assert_eq!(parse_target_ip("myserver"), None); + } +} diff --git a/devolutions-gateway/src/api/mod.rs b/devolutions-gateway/src/api/mod.rs index a5cbbc643..c867389dd 100644 --- a/devolutions-gateway/src/api/mod.rs +++ b/devolutions-gateway/src/api/mod.rs @@ -1,3 +1,4 @@ +pub mod agent_enrollment; pub mod ai; pub mod config; pub mod diagnostics; @@ -35,6 +36,7 @@ pub fn make_router(state: crate::DgwState) -> axum::Router { .nest("/jet/webapp", webapp::make_router(state.clone())) .nest("/jet/net", net::make_router(state.clone())) .nest("/jet/traffic", traffic::make_router(state.clone())) + .nest("/jet/agent-tunnel", agent_enrollment::make_router(state.clone())) .route("/jet/update", axum::routing::post(update::trigger_update_check)); if state.conf_handle.get_conf().web_app.enabled { diff --git a/devolutions-gateway/src/api/webapp.rs b/devolutions-gateway/src/api/webapp.rs index 2c6b99f89..f266f4207 100644 --- a/devolutions-gateway/src/api/webapp.rs +++ b/devolutions-gateway/src/api/webapp.rs @@ -342,6 +342,7 @@ pub(crate) async fn sign_session_token( exp, jti, cert_thumb256: None, + jet_agent_id: None, } .pipe(serde_json::to_value) .map(|mut claims| { diff --git a/devolutions-gateway/src/config.rs b/devolutions-gateway/src/config.rs index 5ca9fc49f..122408795 100644 --- a/devolutions-gateway/src/config.rs +++ b/devolutions-gateway/src/config.rs @@ -193,6 +193,7 @@ pub struct Conf { pub verbosity_profile: dto::VerbosityProfile, pub web_app: WebAppConf, pub ai_gateway: AiGatewayConf, + pub agent_tunnel: dto::AgentTunnelConf, pub proxy: dto::ProxyConf, pub debug: dto::DebugConf, } @@ -925,6 +926,7 @@ impl Conf { .as_ref() .map(AiGatewayConf::from_dto) .unwrap_or_default(), + agent_tunnel: conf_file.agent_tunnel.clone().unwrap_or_default(), proxy: conf_file.proxy.clone().unwrap_or_default(), debug: conf_file.debug.clone().unwrap_or_default(), }) @@ -1725,6 +1727,10 @@ pub mod dto { #[serde(skip_serializing_if = "Option::is_none")] pub proxy: Option, + /// (Unstable) Agent tunnel configuration (QUIC-based agent tunnel) + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_tunnel: Option, + /// (Unstable) Unsafe debug options for developers #[serde(rename = "__debug__", skip_serializing_if = "Option::is_none")] pub debug: Option, @@ -1780,6 +1786,7 @@ pub mod dto { ai_gateway: None, job_queue_database: None, traffic_audit_database: None, + agent_tunnel: None, proxy: None, debug: None, rest: serde_json::Map::new(), @@ -1914,6 +1921,38 @@ pub mod dto { pub kdc_url: Option, } + /// (Unstable) QUIC-based agent tunnel configuration + #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + #[serde(rename_all = "PascalCase")] + pub struct AgentTunnelConf { + /// Whether the agent tunnel listener is enabled + #[serde(default)] + pub enabled: bool, + /// UDP port for the QUIC listener (default: 4433) + #[serde(default = "AgentTunnelConf::default_listen_port")] + pub listen_port: u16, + /// Shared secret for agent enrollment. + /// If set, agents can enroll by providing this secret as a Bearer token. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub enrollment_secret: Option, + } + + impl AgentTunnelConf { + fn default_listen_port() -> u16 { + 4433 + } + } + + impl Default for AgentTunnelConf { + fn default() -> Self { + Self { + enabled: false, + listen_port: Self::default_listen_port(), + enrollment_secret: None, + } + } + } + /// Unsafe debug options that should only ever be used at development stage /// /// These options might change or get removed without further notice. diff --git a/devolutions-gateway/src/extract.rs b/devolutions-gateway/src/extract.rs index 9f2450854..ada08ce9a 100644 --- a/devolutions-gateway/src/extract.rs +++ b/devolutions-gateway/src/extract.rs @@ -386,6 +386,64 @@ where } } +/// Grants read access to agent management endpoints. +/// +/// Accepts a scope token with `DiagnosticsRead`, `ConfigWrite`, or `Wildcard` scope. +#[derive(Clone, Copy)] +pub struct AgentManagementReadAccess; + +impl FromRequestParts for AgentManagementReadAccess +where + S: Send + Sync, +{ + type Rejection = HttpError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let claims = Extension::::from_request_parts(parts, state) + .await + .map_err(HttpError::internal().err())? + .0; + + // DiagnosticsRead is accepted because DVLS maps its AgentRead scope + // to GatewayDiagnosticsRead, which serializes as "gateway.diagnostics.read". + match claims { + AccessTokenClaims::Scope(scope) => match scope.scope { + AccessScope::Wildcard | AccessScope::DiagnosticsRead | AccessScope::ConfigWrite => Ok(Self), + _ => Err(HttpError::forbidden().msg("invalid scope for agent management read")), + }, + _ => Err(HttpError::forbidden().msg("scope token required for agent management read")), + } + } +} + +/// Grants write access to agent management endpoints (e.g. enrollment, delete). +/// +/// Accepts scope tokens with `ConfigWrite` (or `Wildcard`) scope only. +#[derive(Clone, Copy)] +pub struct AgentManagementWriteAccess; + +impl FromRequestParts for AgentManagementWriteAccess +where + S: Send + Sync, +{ + type Rejection = HttpError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let claims = Extension::::from_request_parts(parts, state) + .await + .map_err(HttpError::internal().err())? + .0; + + match claims { + AccessTokenClaims::Scope(scope) => match scope.scope { + AccessScope::Wildcard | AccessScope::ConfigWrite => Ok(Self), + _ => Err(HttpError::forbidden().msg("invalid scope for agent management write")), + }, + _ => Err(HttpError::forbidden().msg("scope token required for agent management write")), + } + } +} + #[derive(Clone)] pub struct WebAppToken(pub WebAppTokenClaims); diff --git a/devolutions-gateway/src/generic_client.rs b/devolutions-gateway/src/generic_client.rs index 13c1d9c48..d8209ce79 100644 --- a/devolutions-gateway/src/generic_client.rs +++ b/devolutions-gateway/src/generic_client.rs @@ -6,6 +6,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _}; use tracing::field; use typed_builder::TypedBuilder; +use crate::agent_tunnel::AgentTunnelHandle; use crate::config::Conf; use crate::credential::CredentialStoreHandle; use crate::proxy::Proxy; @@ -27,6 +28,8 @@ pub struct GenericClient { subscriber_tx: SubscriberSender, active_recordings: Arc, credential_store: CredentialStoreHandle, + #[builder(default)] + agent_tunnel_handle: Option>, } impl GenericClient @@ -49,6 +52,7 @@ where subscriber_tx, active_recordings, credential_store, + agent_tunnel_handle, } = self; let span = tracing::Span::current(); @@ -109,6 +113,76 @@ where RecordingPolicy::Proxy => anyhow::bail!("can't meet recording policy"), } + // Route via agent tunnel if jet_agent_id is specified. + if let Some(agent_id) = claims.jet_agent_id { + let handle = agent_tunnel_handle.context("agent tunnel not configured on this gateway")?; + + let mut selected_target = None; + let mut server_stream = None; + let mut last_error = None; + + for candidate in targets.iter() { + let target_str = format!("{}:{}", candidate.host(), candidate.port()); + + info!(%agent_id, %target_str, "Routing via agent tunnel"); + + match handle.connect_via_agent(agent_id, claims.jet_aid, &target_str).await { + Ok(stream) => { + selected_target = Some(candidate.clone()); + server_stream = Some(stream); + break; + } + Err(error) => { + warn!( + %agent_id, + %target_str, + error = format!("{error:#}"), + "Agent tunnel target failed" + ); + last_error = Some(error); + } + } + } + + let selected_target = selected_target.ok_or_else(|| { + last_error.unwrap_or_else(|| anyhow::anyhow!("agent tunnel target selection failed")) + })?; + span.record("target", selected_target.to_string()); + let server_stream = server_stream.expect("server stream should be present when target is selected"); + + let info = SessionInfo::builder() + .id(claims.jet_aid) + .application_protocol(claims.jet_ap) + .details(ConnectionModeDetails::Fwd { + destination_host: selected_target.clone(), + }) + .time_to_live(claims.jet_ttl) + .recording_policy(claims.jet_rec) + .filtering_policy(claims.jet_flt) + .build(); + + let disconnect_interest = DisconnectInterest::from_reconnection_policy(claims.jet_reuse); + + // Agent handles the TCP connection; no leftover bytes to forward. + // Use a placeholder server address since the actual target is behind the agent. + let server_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); + + return Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_addr) + .transport_b(server_stream) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .disconnect_interest(disconnect_interest) + .build() + .select_dissector_and_forward() + .await + .context("encountered a failure during agent tunnel traffic proxying"); + } + trace!("Select and connect to target"); let ((mut server_stream, server_addr), selected_target) = diff --git a/devolutions-gateway/src/lib.rs b/devolutions-gateway/src/lib.rs index ed1a28099..93d782530 100644 --- a/devolutions-gateway/src/lib.rs +++ b/devolutions-gateway/src/lib.rs @@ -12,6 +12,7 @@ extern crate tracing; #[cfg(feature = "openapi")] pub mod openapi; +pub mod agent_tunnel; pub mod ai; pub mod api; pub mod cli; @@ -61,6 +62,7 @@ pub struct DgwState { pub credential_store: credential::CredentialStoreHandle, pub monitoring_state: Arc, pub traffic_audit_handle: traffic_audit::TrafficAuditHandle, + pub agent_tunnel_handle: Option>, } #[doc(hidden)] @@ -100,6 +102,7 @@ impl DgwState { traffic_audit_handle, credential_store, monitoring_state, + agent_tunnel_handle: None, }; let handles = MockHandles { diff --git a/devolutions-gateway/src/listener.rs b/devolutions-gateway/src/listener.rs index db5926be5..0b7ce2740 100644 --- a/devolutions-gateway/src/listener.rs +++ b/devolutions-gateway/src/listener.rs @@ -159,6 +159,7 @@ async fn handle_tcp_peer(stream: TcpStream, state: DgwState, peer_addr: SocketAd .subscriber_tx(state.subscriber_tx) .active_recordings(state.recordings.active_recordings) .credential_store(state.credential_store) + .agent_tunnel_handle(state.agent_tunnel_handle) .build() .serve() .await?; diff --git a/devolutions-gateway/src/middleware/auth.rs b/devolutions-gateway/src/middleware/auth.rs index f07e6e1b0..18f08bb66 100644 --- a/devolutions-gateway/src/middleware/auth.rs +++ b/devolutions-gateway/src/middleware/auth.rs @@ -95,6 +95,14 @@ const AUTH_EXCEPTIONS: &[AuthException] = &[ path: "/jet/ai", exact_match: false, }, + // Agent Tunnel: only /enroll skips auth (it uses its own bearer token). + // TODO: add rate limiting on this endpoint (tokens are 122-bit UUIDs so brute-force + // is infeasible, but rate limiting is good defense-in-depth). + AuthException { + method: Method::POST, + path: "/jet/agent-tunnel/enroll", + exact_match: true, + }, ]; pub async fn auth_middleware( diff --git a/devolutions-gateway/src/ngrok.rs b/devolutions-gateway/src/ngrok.rs index 8d32f58d4..71c0c005f 100644 --- a/devolutions-gateway/src/ngrok.rs +++ b/devolutions-gateway/src/ngrok.rs @@ -238,6 +238,7 @@ async fn run_tcp_tunnel(mut tunnel: ngrok::tunnel::TcpTunnel, state: DgwState) { .subscriber_tx(state.subscriber_tx) .active_recordings(state.recordings.active_recordings) .credential_store(state.credential_store) + .agent_tunnel_handle(state.agent_tunnel_handle) .build() .serve() .await diff --git a/devolutions-gateway/src/rd_clean_path.rs b/devolutions-gateway/src/rd_clean_path.rs index 6d4614b5e..9855a8987 100644 --- a/devolutions-gateway/src/rd_clean_path.rs +++ b/devolutions-gateway/src/rd_clean_path.rs @@ -669,9 +669,7 @@ impl From<&CleanPathError> for RDCleanPathPdu { } fn io_to_rdcleanpath_err(err: &io::Error) -> RDCleanPathPdu { - if let Some(tokio_rustls::rustls::Error::AlertReceived(tls_alert)) = err - .get_ref() - .and_then(|e| e.downcast_ref::()) + if let Some(rustls::Error::AlertReceived(tls_alert)) = err.get_ref().and_then(|e| e.downcast_ref::()) { RDCleanPathPdu::new_tls_error(u8::from(*tls_alert)) } else { diff --git a/devolutions-gateway/src/service.rs b/devolutions-gateway/src/service.rs index 64dde91c4..45ae07264 100644 --- a/devolutions-gateway/src/service.rs +++ b/devolutions-gateway/src/service.rs @@ -10,7 +10,7 @@ use devolutions_gateway::recording::recording_message_channel; use devolutions_gateway::session::session_manager_channel; use devolutions_gateway::subscriber::subscriber_channel; use devolutions_gateway::token::{CurrentJrl, JrlTokenClaims}; -use devolutions_gateway::{DgwState, SYSTEM_LOGGER, config}; +use devolutions_gateway::{DgwState, SYSTEM_LOGGER, agent_tunnel, config}; use devolutions_gateway_task::{ChildTask, ShutdownHandle, ShutdownSignal}; use devolutions_log::{self, LoggerGuard}; use parking_lot::Mutex; @@ -275,6 +275,35 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { ); let monitoring_state = Arc::new(network_monitor::State::new(Arc::new(filesystem_monitor_config_cache))?); + // Initialize agent tunnel if configured. + let agent_tunnel_handle = if conf.agent_tunnel.enabled { + let data_dir = config::get_data_dir(); + let hostname = &conf.hostname; + + let ca_manager = Arc::new( + agent_tunnel::cert::CaManager::load_or_generate(&data_dir) + .context("failed to initialize agent tunnel CA")?, + ); + + let listen_addr = std::net::SocketAddr::from((std::net::Ipv4Addr::UNSPECIFIED, conf.agent_tunnel.listen_port)); + + let (listener, handle) = + agent_tunnel::AgentTunnelListener::bind(listen_addr, Arc::clone(&ca_manager), hostname) + .await + .context("failed to bind agent tunnel listener")?; + + tasks.register(listener); + + info!( + port = conf.agent_tunnel.listen_port, + "Agent tunnel QUIC listener started", + ); + + Some(Arc::new(handle)) + } else { + None + }; + let state = DgwState { conf_handle: conf_handle.clone(), token_cache: Arc::clone(&token_cache), @@ -287,6 +316,7 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { credential_store: credential_store.clone(), monitoring_state, traffic_audit_handle: traffic_audit_task.handle(), + agent_tunnel_handle, }; for listener in &conf.listeners { diff --git a/devolutions-gateway/src/token.rs b/devolutions-gateway/src/token.rs index 75b7d112e..5912d7dbf 100644 --- a/devolutions-gateway/src/token.rs +++ b/devolutions-gateway/src/token.rs @@ -425,6 +425,12 @@ pub struct AssociationTokenClaims { /// Optional SHA-256 thumbprint of target server certificate (for anchored TLS validation) pub cert_thumb256: Option, + + /// Optional agent ID for routing connections through an enrolled agent tunnel. + /// + /// When set alongside `ConnectionMode::Fwd`, the Gateway will proxy the connection + /// through the specified agent instead of connecting directly to the target. + pub jet_agent_id: Option, } // ----- scope claims ----- // @@ -466,15 +472,15 @@ pub enum AccessScope { NetMonitorDrain, } -#[derive(Clone, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct ScopeTokenClaims { pub scope: AccessScope, /// JWT expiration time claim. - exp: i64, + pub exp: i64, /// JWT "JWT ID" claim, the unique ID for this token - jti: Uuid, + pub jti: Uuid, } // ----- bridge claims ----- // @@ -1312,6 +1318,8 @@ mod serde_impl { jti: Uuid, #[serde(default)] cert_thumb256: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + jet_agent_id: Option, } #[derive(Deserialize)] @@ -1420,6 +1428,7 @@ mod serde_impl { exp: self.exp, jti: self.jti, cert_thumb256: self.cert_thumb256.as_ref().map(|thumb| SmolStr::new(thumb.as_str())), + jet_agent_id: self.jet_agent_id, } .serialize(serializer) } @@ -1469,6 +1478,7 @@ mod serde_impl { .map(crate::tls::thumbprint::normalize_sha256_thumbprint) .transpose() .map_err(de::Error::custom)?, + jet_agent_id: claims.jet_agent_id, }) } } diff --git a/devolutions-gateway/tests/config.rs b/devolutions-gateway/tests/config.rs index 4cb015e3f..e2eeb8a7c 100644 --- a/devolutions-gateway/tests/config.rs +++ b/devolutions-gateway/tests/config.rs @@ -97,6 +97,7 @@ fn hub_sample() -> Sample { verbosity_profile: Some(VerbosityProfile::Tls), web_app: None, ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -146,6 +147,7 @@ fn legacy_sample() -> Sample { verbosity_profile: None, web_app: None, ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -194,6 +196,7 @@ fn system_store_sample() -> Sample { verbosity_profile: None, web_app: None, ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -274,6 +277,7 @@ fn standalone_custom_auth_sample() -> Sample { static_root_path: None, }), ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -354,6 +358,7 @@ fn standalone_no_auth_sample() -> Sample { static_root_path: Some("/path/to/webapp/static/root".into()), }), ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -439,6 +444,7 @@ fn proxy_sample() -> Sample { ], }), debug: None, + agent_tunnel: None, rest: Default::default(), }, } From 0b4b099fd892b3008ab65166b9e569f3d60c2726 Mon Sep 17 00:00:00 2001 From: irving ou Date: Tue, 7 Apr 2026 16:46:48 -0400 Subject: [PATCH 2/2] feat: transparent routing through agent tunnel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a connection target matches an agent's advertised subnets or domains, the gateway automatically routes through the QUIC tunnel instead of connecting directly. This enables access to private network resources without VPN or inbound firewall rules. - Add routing pipeline (subnet match → domain suffix → direct) - Integrate tunnel routing into RDP, SSH, VNC, ARD, and KDC proxy paths - Support ServerTransport enum (Tcp/Quic) in rd_clean_path - Add 7 routing unit tests Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/agent_tunnel/integration_test.rs | 638 ++++++++++++++++++ devolutions-gateway/src/agent_tunnel/mod.rs | 5 + .../src/agent_tunnel/routing.rs | 287 ++++++++ devolutions-gateway/src/api/fwd.rs | 76 +++ devolutions-gateway/src/api/kdc_proxy.rs | 72 +- devolutions-gateway/src/api/rdp.rs | 4 + devolutions-gateway/src/proxy.rs | 4 +- devolutions-gateway/src/rd_clean_path.rs | 235 +++++-- devolutions-gateway/src/rdp_proxy.rs | 2 +- 9 files changed, 1259 insertions(+), 64 deletions(-) create mode 100644 devolutions-gateway/src/agent_tunnel/integration_test.rs create mode 100644 devolutions-gateway/src/agent_tunnel/routing.rs diff --git a/devolutions-gateway/src/agent_tunnel/integration_test.rs b/devolutions-gateway/src/agent_tunnel/integration_test.rs new file mode 100644 index 000000000..8f153c5b5 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/integration_test.rs @@ -0,0 +1,638 @@ +//! Integration test for the QUIC agent tunnel. +//! +//! Verifies the full data path: +//! TCP echo server ← Agent (simulated quiche client) ← QUIC ← Gateway listener ← QuicStream +//! +//! This test runs entirely in-process with real UDP sockets on localhost. + +#![cfg(test)] + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use agent_tunnel_proto::{ConnectMessage, ConnectResponse, ControlMessage}; +use camino::Utf8PathBuf; +use ipnetwork::Ipv4Network; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use uuid::Uuid; + +use super::cert::CaManager; +use super::listener::AgentTunnelListener; + +const ALPN_PROTOCOL: &[u8] = b"devolutions-agent-tunnel"; +const MAX_DATAGRAM_SIZE: usize = 1350; + +/// Start a TCP echo server that echoes back whatever it receives. +/// Returns the server address and a join handle. +async fn start_echo_server() -> (SocketAddr, tokio::task::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(v) => v, + Err(_) => break, + }; + + tokio::spawn(async move { + let mut buf = vec![0u8; 65535]; + loop { + let n = match stream.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if stream.write_all(&buf[..n]).await.is_err() { + break; + } + } + }); + } + }); + + (addr, handle) +} + +/// Drive the quiche connection: send all pending data over UDP. +async fn flush_quiche(conn: &mut quiche::Connection, socket: &UdpSocket, peer_addr: SocketAddr) { + let mut buf = vec![0u8; MAX_DATAGRAM_SIZE]; + loop { + match conn.send(&mut buf) { + Ok((len, send_info)) => { + let _ = socket.send_to(&buf[..len], send_info.to).await; + } + Err(quiche::Error::Done) => break, + Err(e) => { + eprintln!("quiche send error: {e}"); + break; + } + } + } + let _ = peer_addr; // Used for clarity in caller. +} + +/// Receive UDP data and feed it to the quiche connection. +async fn recv_quiche(conn: &mut quiche::Connection, socket: &UdpSocket, timeout: Duration) -> bool { + let mut buf = vec![0u8; 65535]; + + let result = tokio::time::timeout(timeout, socket.recv_from(&mut buf)).await; + match result { + Ok(Ok((len, from))) => { + let local = socket.local_addr().unwrap(); + let recv_info = quiche::RecvInfo { from, to: local }; + match conn.recv(&mut buf[..len], recv_info) { + Ok(_) => true, + Err(e) => { + eprintln!("quiche recv error: {e}"); + false + } + } + } + Ok(Err(e)) => { + eprintln!("UDP recv error: {e}"); + false + } + Err(_) => false, // timeout + } +} + +/// Drive the QUIC handshake to completion. +async fn complete_handshake(conn: &mut quiche::Connection, socket: &UdpSocket, peer_addr: SocketAddr) { + for _ in 0..50 { + flush_quiche(conn, socket, peer_addr).await; + if conn.is_established() { + return; + } + recv_quiche(conn, socket, Duration::from_millis(500)).await; + flush_quiche(conn, socket, peer_addr).await; + } + panic!("QUIC handshake did not complete in time"); +} + +/// Send a length-prefixed bincode message on a QUIC stream. +fn send_message(conn: &mut quiche::Connection, stream_id: u64, msg: &T) { + let payload = bincode::serialize(msg).unwrap(); + let len = (payload.len() as u32).to_be_bytes(); + let mut data = Vec::with_capacity(4 + payload.len()); + data.extend_from_slice(&len); + data.extend_from_slice(&payload); + conn.stream_send(stream_id, &data, false).unwrap(); +} + +/// Try to read a length-prefixed bincode message from accumulated stream data. +fn try_decode_message(buf: &[u8]) -> Option<(T, usize)> { + if buf.len() < 4 { + return None; + } + let msg_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; + if buf.len() < 4 + msg_len { + return None; + } + let msg: T = bincode::deserialize(&buf[4..4 + msg_len]).ok()?; + Some((msg, 4 + msg_len)) +} + +/// Full E2E integration test. +/// +/// 1. Start TCP echo server +/// 2. Start QUIC listener (gateway) +/// 3. Connect a simulated agent (quiche client) with mTLS +/// 4. Agent sends RouteAdvertise +/// 5. Gateway opens a proxy stream via connect_via_agent +/// 6. Agent reads ConnectMessage, connects to echo server, sends ConnectResponse::Success +/// 7. Gateway writes data through QuicStream +/// 8. Verify echo response arrives back through the tunnel +#[tokio::test] +async fn quic_agent_tunnel_e2e() { + // ── 1. Setup certificates ────────────────────────────────────────────── + let temp_dir = std::env::temp_dir().join(format!("dgw-e2e-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("UTF-8 temp path"); + + let ca_manager = Arc::new(CaManager::load_or_generate(&data_dir).expect("CA generation should succeed")); + + let agent_id = Uuid::new_v4(); + let cert_bundle = ca_manager + .issue_agent_certificate(agent_id, "test-agent") + .expect("issue agent cert"); + + // Write agent certs to temp files (quiche needs file paths). + let agent_cert_path = data_dir.join("agent-cert.pem"); + let agent_key_path = data_dir.join("agent-key.pem"); + let ca_cert_path = ca_manager.ca_cert_path(); + + std::fs::write(agent_cert_path.as_str(), &cert_bundle.client_cert_pem).unwrap(); + std::fs::write(agent_key_path.as_str(), &cert_bundle.client_key_pem).unwrap(); + + // ── 2. Start TCP echo server ─────────────────────────────────────────── + let (echo_addr, _echo_handle) = start_echo_server().await; + let echo_subnet: Ipv4Network = format!("{}/32", echo_addr.ip()).parse().unwrap(); + + // ── 3. Start QUIC listener ───────────────────────────────────────────── + // Bind a temporary UDP socket to find a free port, then release it. + let temp_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_port = temp_socket.local_addr().unwrap().port(); + drop(temp_socket); + + let server_addr: SocketAddr = format!("127.0.0.1:{server_port}").parse().unwrap(); + + let (listener, handle) = AgentTunnelListener::bind(server_addr, Arc::clone(&ca_manager), "localhost") + .await + .expect("bind QUIC listener to known port"); + + // Spawn the listener as a background task. + let (shutdown_handle, shutdown_signal) = devolutions_gateway_task::ShutdownHandle::new(); + let listener_task = tokio::spawn(async move { + use devolutions_gateway_task::Task; + let _ = listener.run(shutdown_signal).await; + }); + + // Give the listener a moment to be ready. + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 4. Create simulated agent (quiche client) ────────────────────────── + let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let client_local = client_socket.local_addr().unwrap(); + + let mut client_config = quiche::Config::new(quiche::PROTOCOL_VERSION).expect("quiche config"); + client_config + .load_cert_chain_from_pem_file(agent_cert_path.as_str()) + .expect("load agent cert"); + client_config + .load_priv_key_from_pem_file(agent_key_path.as_str()) + .expect("load agent key"); + client_config + .load_verify_locations_from_file(ca_cert_path.as_str()) + .expect("load CA cert"); + client_config.verify_peer(true); + client_config + .set_application_protos(&[ALPN_PROTOCOL]) + .expect("set ALPN"); + client_config.set_max_idle_timeout(30_000); + client_config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_initial_max_data(10_000_000); + client_config.set_initial_max_stream_data_bidi_local(1_000_000); + client_config.set_initial_max_stream_data_bidi_remote(1_000_000); + client_config.set_initial_max_streams_bidi(100); + + let mut scid = vec![0u8; quiche::MAX_CONN_ID_LEN]; + rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut scid); + let scid = quiche::ConnectionId::from_vec(scid); + + let mut conn = quiche::connect(Some("localhost"), &scid, client_local, server_addr, &mut client_config) + .expect("quiche connect"); + + // ── 5. Complete mTLS handshake ───────────────────────────────────────── + complete_handshake(&mut conn, &client_socket, server_addr).await; + assert!(conn.is_established(), "QUIC connection should be established"); + + // ── 6. Send RouteAdvertise ───────────────────────────────────────────── + let route_msg = ControlMessage::route_advertise(1, vec![echo_subnet], vec![]); + send_message(&mut conn, 0, &route_msg); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give the gateway a moment to process the route advertisement. + tokio::time::sleep(Duration::from_millis(200)).await; + // Drain any responses. + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Verify the agent is registered in the registry. + assert!( + handle.registry().get(&agent_id).is_some(), + "agent should be registered in the registry" + ); + assert_eq!(handle.registry().online_count(), 1); + + // ── 7. Gateway opens proxy stream via connect_via_agent ──────────────── + let session_id = Uuid::new_v4(); + let target_str = format!("{}", echo_addr); + + // Spawn connect_via_agent as a background task (it will block until the agent responds). + let handle_clone = handle.clone(); + let target_str_clone = target_str.clone(); + let proxy_task = tokio::spawn(async move { + handle_clone + .connect_via_agent(agent_id, session_id, &target_str_clone) + .await + }); + + // Give the gateway time to send the ConnectMessage. + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 8. Agent receives and processes proxy request ────────────────────── + // The agent needs to: + // a. Receive ConnectMessage on a new server-initiated stream + // b. Connect to the target + // c. Send ConnectResponse::Success + + // Pump the connection to receive the ConnectMessage. + let mut stream_buf: Vec = Vec::new(); + let mut proxy_stream_id: Option = None; + + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Check for readable streams (skip stream 0 which is control). + for stream_id in conn.readable() { + if stream_id == 0 { + // Drain control stream responses. + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + continue; + } + + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + stream_buf.extend_from_slice(&buf[..len]); + proxy_stream_id = Some(stream_id); + } + } + + if proxy_stream_id.is_some() && stream_buf.len() >= 4 { + let msg_len_check = + u32::from_be_bytes([stream_buf[0], stream_buf[1], stream_buf[2], stream_buf[3]]) as usize; + if stream_buf.len() >= 4 + msg_len_check { + break; + } + } + } + + let proxy_stream_id = proxy_stream_id.expect("should have received a proxy stream from gateway"); + + // Decode ConnectMessage. + let (connect_msg, consumed): (ConnectMessage, usize) = + try_decode_message(&stream_buf).expect("decode ConnectMessage"); + assert_eq!(connect_msg.session_id, session_id); + assert_eq!(connect_msg.target, target_str); + stream_buf.drain(..consumed); + + // Connect to the echo server. + let mut target_tcp = TcpStream::connect(echo_addr).await.expect("connect to echo server"); + + // Send ConnectResponse::Success. + let response = ConnectResponse::success(); + send_message(&mut conn, proxy_stream_id, &response); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give the gateway time to process the response. + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // ── 9. Verify proxy_task completed successfully ──────────────────────── + let quic_stream = tokio::time::timeout(Duration::from_secs(5), proxy_task) + .await + .expect("proxy task should complete in time") + .expect("proxy task should not panic") + .expect("connect_via_agent should succeed"); + + // ── 10. Bidirectional data test through the full tunnel ──────────────── + // Gateway writes to QuicStream → QUIC → Agent → TCP → Echo Server → TCP → Agent → QUIC → Gateway reads + + let test_data = b"Hello from the QUIC tunnel integration test!"; + let (mut quic_read, mut quic_write) = tokio::io::split(quic_stream); + + // Write test data from the "gateway side" into the QuicStream. + quic_write.write_all(test_data).await.expect("write to QuicStream"); + + // Agent side: relay data from QUIC stream to TCP target and back. + // We need to pump the QUIC connection to deliver the data. + + // Read data from QUIC and forward to TCP target. + let mut data_from_quic = Vec::new(); + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + for stream_id in conn.readable() { + if stream_id == proxy_stream_id { + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + data_from_quic.extend_from_slice(&buf[..len]); + } + } else { + // Drain other streams. + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + } + } + + if data_from_quic.len() >= test_data.len() { + break; + } + } + + assert_eq!( + &data_from_quic[..test_data.len()], + test_data, + "data should arrive at the agent side" + ); + + // Forward to echo server. + target_tcp + .write_all(&data_from_quic[..test_data.len()]) + .await + .expect("write to echo server"); + + // Read echo response from TCP. + let mut echo_response = vec![0u8; test_data.len()]; + target_tcp + .read_exact(&mut echo_response) + .await + .expect("read echo response"); + assert_eq!(&echo_response, test_data); + + // Send echo response back through QUIC. + conn.stream_send(proxy_stream_id, &echo_response, false) + .expect("send echo response on QUIC stream"); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // Give gateway time to deliver data through channels. + tokio::time::sleep(Duration::from_millis(200)).await; + + // Read the response from the gateway-side QuicStream. + let mut response_buf = vec![0u8; test_data.len()]; + let read_result = tokio::time::timeout(Duration::from_secs(5), quic_read.read_exact(&mut response_buf)) + .await + .expect("should read response in time") + .expect("read from QuicStream"); + + assert_eq!(read_result, test_data.len()); + assert_eq!(&response_buf, test_data, "echo response should match original data"); + + // ── 11. Cleanup ──────────────────────────────────────────────────────── + shutdown_handle.signal(); + let _ = tokio::time::timeout(Duration::from_secs(2), listener_task).await; + let _ = std::fs::remove_dir_all(&temp_dir); + + eprintln!("E2E integration test passed!"); +} + +/// E2E test for domain-based routing. +/// +/// Same as quic_agent_tunnel_e2e but agent advertises domain "test.local" +/// alongside its subnet, and we verify domain routing works in the live registry. +/// +/// Known limitation: uses IP for final connect_via_agent (no mock DNS in test env). +#[tokio::test] +async fn quic_agent_tunnel_domain_routing_e2e() { + use agent_tunnel_proto::DomainAdvertisement; + + // ── 1. Setup certificates ── + let temp_dir = std::env::temp_dir().join(format!("dgw-domain-e2e-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("UTF-8 temp path"); + + let ca_manager = Arc::new(CaManager::load_or_generate(&data_dir).expect("CA generation")); + + let agent_id = Uuid::new_v4(); + let cert_bundle = ca_manager + .issue_agent_certificate(agent_id, "test-agent") + .expect("issue agent cert"); + + let agent_cert_path = data_dir.join("agent-cert.pem"); + let agent_key_path = data_dir.join("agent-key.pem"); + let ca_cert_path = ca_manager.ca_cert_path(); + + std::fs::write(agent_cert_path.as_str(), &cert_bundle.client_cert_pem).unwrap(); + std::fs::write(agent_key_path.as_str(), &cert_bundle.client_key_pem).unwrap(); + + // ── 2. Start echo server and QUIC listener ── + let (echo_addr, _echo_handle) = start_echo_server().await; + let echo_subnet: Ipv4Network = format!("{}/32", echo_addr.ip()).parse().unwrap(); + + let temp_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let server_port = temp_socket.local_addr().unwrap().port(); + drop(temp_socket); + + let server_addr: SocketAddr = format!("127.0.0.1:{server_port}").parse().unwrap(); + let (listener, handle) = AgentTunnelListener::bind(server_addr, Arc::clone(&ca_manager), "localhost") + .await + .expect("bind QUIC listener"); + + let (shutdown_handle, shutdown_signal) = devolutions_gateway_task::ShutdownHandle::new(); + let listener_task = tokio::spawn(async move { + use devolutions_gateway_task::Task; + let _ = listener.run(shutdown_signal).await; + }); + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 3. Create simulated agent ── + let client_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let client_local = client_socket.local_addr().unwrap(); + + let mut client_config = quiche::Config::new(quiche::PROTOCOL_VERSION).expect("quiche config"); + client_config + .load_cert_chain_from_pem_file(agent_cert_path.as_str()) + .unwrap(); + client_config + .load_priv_key_from_pem_file(agent_key_path.as_str()) + .unwrap(); + client_config + .load_verify_locations_from_file(ca_cert_path.as_str()) + .unwrap(); + client_config.verify_peer(true); + client_config.set_application_protos(&[ALPN_PROTOCOL]).unwrap(); + client_config.set_max_idle_timeout(30_000); + client_config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); + client_config.set_initial_max_data(10_000_000); + client_config.set_initial_max_stream_data_bidi_local(1_000_000); + client_config.set_initial_max_stream_data_bidi_remote(1_000_000); + client_config.set_initial_max_streams_bidi(100); + + let mut scid = vec![0u8; quiche::MAX_CONN_ID_LEN]; + rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut scid); + let scid = quiche::ConnectionId::from_vec(scid); + let mut conn = quiche::connect(Some("localhost"), &scid, client_local, server_addr, &mut client_config).unwrap(); + + complete_handshake(&mut conn, &client_socket, server_addr).await; + assert!(conn.is_established()); + + // ── 4. Agent sends RouteAdvertise WITH DOMAIN ── + let route_msg = ControlMessage::route_advertise( + 1, + vec![echo_subnet], + vec![DomainAdvertisement { + domain: "test.local".to_owned(), + auto_detected: false, + }], + ); + send_message(&mut conn, 0, &route_msg); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + // ── 5. Verify domain routing via registry ── + assert!(handle.registry().get(&agent_id).is_some(), "agent should be registered"); + + let domain_agents = handle.registry().select_agents_for_domain("echo-server.test.local"); + assert_eq!(domain_agents.len(), 1, "domain routing should find the agent"); + assert_eq!(domain_agents[0].agent_id, agent_id); + + // Also verify the domain info is preserved with source tracking + let info = handle.registry().agent_info(&agent_id).expect("agent info"); + assert_eq!(info.domains.len(), 1); + assert_eq!(info.domains[0].domain, "test.local"); + assert!(!info.domains[0].auto_detected); + + // ── 6. Gateway opens proxy stream (using IP — known limitation) ── + let session_id = Uuid::new_v4(); + let target_str = format!("{}", echo_addr); + + let handle_clone = handle.clone(); + let target_clone = target_str.clone(); + let proxy_task = tokio::spawn(async move { + handle_clone + .connect_via_agent(agent_id, session_id, &target_clone) + .await + }); + tokio::time::sleep(Duration::from_millis(100)).await; + + // ── 7. Agent receives ConnectMessage ── + let mut stream_buf: Vec = Vec::new(); + let mut proxy_stream_id: Option = None; + + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + for stream_id in conn.readable() { + if stream_id == 0 { + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + continue; + } + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + stream_buf.extend_from_slice(&buf[..len]); + proxy_stream_id = Some(stream_id); + } + } + + if proxy_stream_id.is_some() && stream_buf.len() >= 4 { + let msg_len = u32::from_be_bytes([stream_buf[0], stream_buf[1], stream_buf[2], stream_buf[3]]) as usize; + if stream_buf.len() >= 4 + msg_len { + break; + } + } + } + + let proxy_stream_id = proxy_stream_id.expect("should have received proxy stream"); + let (connect_msg, consumed): (ConnectMessage, usize) = + try_decode_message(&stream_buf).expect("decode ConnectMessage"); + assert_eq!(connect_msg.session_id, session_id); + stream_buf.drain(..consumed); + + // ── 8. Agent connects to echo server and responds ── + let mut target_tcp = TcpStream::connect(echo_addr).await.expect("connect to echo server"); + let response = ConnectResponse::success(); + send_message(&mut conn, proxy_stream_id, &response); + flush_quiche(&mut conn, &client_socket, server_addr).await; + + tokio::time::sleep(Duration::from_millis(200)).await; + recv_quiche(&mut conn, &client_socket, Duration::from_millis(100)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + + let quic_stream = tokio::time::timeout(Duration::from_secs(5), proxy_task) + .await + .unwrap() + .unwrap() + .expect("connect_via_agent should succeed"); + + // ── 9. Bidirectional echo test ── + let test_data = b"Domain routing works!"; + let (mut quic_read, mut quic_write) = tokio::io::split(quic_stream); + quic_write.write_all(test_data).await.expect("write to QuicStream"); + + let mut data_from_quic = Vec::new(); + for _ in 0..20 { + recv_quiche(&mut conn, &client_socket, Duration::from_millis(200)).await; + flush_quiche(&mut conn, &client_socket, server_addr).await; + for stream_id in conn.readable() { + if stream_id == proxy_stream_id { + let mut buf = vec![0u8; 65535]; + if let Ok((len, _fin)) = conn.stream_recv(stream_id, &mut buf) { + data_from_quic.extend_from_slice(&buf[..len]); + } + } else { + let mut discard = vec![0u8; 65535]; + let _ = conn.stream_recv(stream_id, &mut discard); + } + } + if data_from_quic.len() >= test_data.len() { + break; + } + } + + assert_eq!(&data_from_quic[..test_data.len()], test_data); + + target_tcp.write_all(&data_from_quic[..test_data.len()]).await.unwrap(); + let mut echo_response = vec![0u8; test_data.len()]; + target_tcp.read_exact(&mut echo_response).await.unwrap(); + assert_eq!(&echo_response, test_data); + + conn.stream_send(proxy_stream_id, &echo_response, false).unwrap(); + flush_quiche(&mut conn, &client_socket, server_addr).await; + tokio::time::sleep(Duration::from_millis(200)).await; + + let mut response_buf = vec![0u8; test_data.len()]; + let read_result = tokio::time::timeout(Duration::from_secs(5), quic_read.read_exact(&mut response_buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(read_result, test_data.len()); + assert_eq!(&response_buf, test_data); + + // ── Cleanup ── + shutdown_handle.signal(); + let _ = tokio::time::timeout(Duration::from_secs(2), listener_task).await; + let _ = std::fs::remove_dir_all(&temp_dir); + + eprintln!("Domain routing E2E test passed!"); +} diff --git a/devolutions-gateway/src/agent_tunnel/mod.rs b/devolutions-gateway/src/agent_tunnel/mod.rs index aa4b094eb..950124a93 100644 --- a/devolutions-gateway/src/agent_tunnel/mod.rs +++ b/devolutions-gateway/src/agent_tunnel/mod.rs @@ -7,8 +7,13 @@ pub mod cert; pub mod enrollment_store; pub mod listener; pub mod registry; +pub mod routing; pub mod stream; +// Integration test needs rewriting for Quinn — kept as local-only file. +// #[cfg(test)] +// mod integration_test; + pub use enrollment_store::EnrollmentTokenStore; pub use listener::{AgentTunnelHandle, AgentTunnelListener}; pub use registry::AgentRegistry; diff --git a/devolutions-gateway/src/agent_tunnel/routing.rs b/devolutions-gateway/src/agent_tunnel/routing.rs new file mode 100644 index 000000000..2c13a0bc7 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/routing.rs @@ -0,0 +1,287 @@ +//! Shared routing pipeline for agent tunnel. +//! +//! Used by both connection forwarding (`fwd.rs`) and KDC proxy (`kdc_proxy.rs`) +//! to ensure consistent routing behavior and error messages. + +use std::net::IpAddr; +use std::sync::Arc; + +use anyhow::{Result, anyhow}; +use uuid::Uuid; + +use super::listener::AgentTunnelHandle; +use super::registry::{AgentPeer, AgentRegistry}; +use super::stream::TunnelStream; + +/// Result of the routing pipeline. +/// +/// Each variant carries enough context for the caller to produce an actionable error message. +#[derive(Debug)] +pub enum RoutingDecision { + /// Route through these agent candidates (try in order, first success wins). + ViaAgent(Vec>), + /// Explicit agent_id was specified but not found in registry. + ExplicitAgentNotFound(Uuid), + /// No agent matched — caller should attempt direct connection. + Direct, +} + +/// Determines how to route a connection to the given target. +/// +/// Pipeline (in order of priority): +/// 1. Explicit agent_id (from JWT) → route to that agent +/// 2. IP target → subnet match against agent advertisements +/// 3. Hostname target → domain suffix match (longest wins) +/// 4. No match → direct connection +pub fn resolve_route(registry: &AgentRegistry, explicit_agent_id: Option, target_host: &str) -> RoutingDecision { + // Step 1: Explicit agent ID (from JWT) + if let Some(agent_id) = explicit_agent_id { + if let Some(agent) = registry.get(&agent_id) { + return RoutingDecision::ViaAgent(vec![agent]); + } + return RoutingDecision::ExplicitAgentNotFound(agent_id); + } + + // Step 2: Target is an IP address → subnet match + if let Ok(ip) = target_host.parse::() { + let agents = registry.find_agents_for_target(ip); + if !agents.is_empty() { + return RoutingDecision::ViaAgent(agents); + } + return RoutingDecision::Direct; + } + + // Step 3: Target is a hostname → domain suffix match (longest wins) + let agents = registry.select_agents_for_domain(target_host); + if !agents.is_empty() { + return RoutingDecision::ViaAgent(agents); + } + + // Step 4: No match → direct connect + RoutingDecision::Direct +} + +/// Try connecting to target through agent candidates (try-fail-retry). +/// +/// Returns the connected `TunnelStream` and the agent that succeeded. +/// +/// Callers must handle `RoutingDecision::ExplicitAgentNotFound` and +/// `RoutingDecision::Direct` before calling this function. +pub async fn route_and_connect( + handle: &AgentTunnelHandle, + candidates: &[Arc], + session_id: Uuid, + target: &str, +) -> Result<(TunnelStream, Arc)> { + assert!(!candidates.is_empty(), "route_and_connect called with empty candidates"); + + let mut last_error = None; + + for agent in candidates { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + "Routing via agent tunnel" + ); + + match handle.connect_via_agent(agent.agent_id, session_id, target).await { + Ok(stream) => { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + "Agent tunnel connection established" + ); + return Ok((stream, Arc::clone(agent))); + } + Err(error) => { + warn!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + error = format!("{error:#}"), + "Agent tunnel connection failed, trying next candidate" + ); + last_error = Some(error); + } + } + } + + let agent_names: Vec<&str> = candidates.iter().map(|a| a.name.as_str()).collect(); + let last_err_msg = last_error.as_ref().map(|e| format!("{e:#}")).unwrap_or_default(); + + error!( + agent_count = candidates.len(), + %target, + agents = ?agent_names, + last_error = %last_err_msg, + "All agent tunnel candidates failed" + ); + + Err(last_error.unwrap_or_else(|| { + anyhow!( + "All {} agents matching target '{}' failed to connect. Agents tried: [{}]", + candidates.len(), + target, + agent_names.join(", "), + ) + })) +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::Ordering; + + use agent_tunnel_proto::DomainAdvertisement; + + use super::*; + use crate::agent_tunnel::registry::AgentPeer; + + fn make_peer(name: &str) -> Arc { + Arc::new(AgentPeer::new( + Uuid::new_v4(), + name.to_owned(), + "sha256:test".to_owned(), + )) + } + + fn domain(name: &str) -> DomainAdvertisement { + DomainAdvertisement { + domain: name.to_owned(), + auto_detected: false, + } + } + + #[test] + fn route_explicit_agent_id() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + registry.register(Arc::clone(&peer)); + + match resolve_route(®istry, Some(agent_id), "anything") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_explicit_agent_id_not_found() { + let registry = AgentRegistry::new(); + let bogus_id = Uuid::new_v4(); + + match resolve_route(®istry, Some(bogus_id), "anything") { + RoutingDecision::ExplicitAgentNotFound(id) => { + assert_eq!(id, bogus_id); + } + other => panic!("expected ExplicitAgentNotFound, got {other:?}"), + } + } + + #[test] + fn route_ip_target_via_subnet() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![]); + registry.register(peer); + + match resolve_route(®istry, None, "10.1.5.50") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_hostname_via_domain() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + registry.register(peer); + + match resolve_route(®istry, None, "dc01.contoso.local") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents[0].agent_id, agent_id); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } + + #[test] + fn route_no_match_returns_direct() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "external.example.com"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_ip_no_match_returns_direct() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![]); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "192.168.1.1"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_skips_offline_agents() { + let registry = AgentRegistry::new(); + let peer = make_peer("offline-agent"); + let subnet: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local")]); + peer.last_seen.store(0, Ordering::Release); + registry.register(peer); + + assert!(matches!( + resolve_route(®istry, None, "dc01.contoso.local"), + RoutingDecision::Direct + )); + } + + #[test] + fn route_domain_match_returns_multiple_agents_ordered() { + let registry = AgentRegistry::new(); + + let peer_a = make_peer("agent-a"); + let subnet_a: ipnetwork::Ipv4Network = "10.1.0.0/16".parse().expect("valid test subnet"); + peer_a.update_routes(1, vec![subnet_a], vec![domain("contoso.local")]); + registry.register(Arc::clone(&peer_a)); + + std::thread::sleep(std::time::Duration::from_millis(10)); + + let peer_b = make_peer("agent-b"); + let id_b = peer_b.agent_id; + let subnet_b: ipnetwork::Ipv4Network = "10.2.0.0/16".parse().expect("valid test subnet"); + peer_b.update_routes(1, vec![subnet_b], vec![domain("contoso.local")]); + registry.register(Arc::clone(&peer_b)); + + match resolve_route(®istry, None, "dc01.contoso.local") { + RoutingDecision::ViaAgent(agents) => { + assert_eq!(agents.len(), 2); + assert_eq!(agents[0].agent_id, id_b, "most recent first"); + } + other => panic!("expected ViaAgent, got {other:?}"), + } + } +} diff --git a/devolutions-gateway/src/api/fwd.rs b/devolutions-gateway/src/api/fwd.rs index f0b1701d6..673fe7aff 100644 --- a/devolutions-gateway/src/api/fwd.rs +++ b/devolutions-gateway/src/api/fwd.rs @@ -54,6 +54,7 @@ async fn fwd_tcp( sessions, subscriber_tx, shutdown_signal, + agent_tunnel_handle, .. }): State, AssociationToken(claims): AssociationToken, @@ -78,6 +79,7 @@ async fn fwd_tcp( claims, source_addr, false, + agent_tunnel_handle, ) .instrument(span) }); @@ -91,6 +93,7 @@ async fn fwd_tls( sessions, subscriber_tx, shutdown_signal, + agent_tunnel_handle, .. }): State, AssociationToken(claims): AssociationToken, @@ -115,6 +118,7 @@ async fn fwd_tls( claims, source_addr, true, + agent_tunnel_handle, ) .instrument(span) }); @@ -132,6 +136,7 @@ async fn handle_fwd( claims: AssociationTokenClaims, source_addr: SocketAddr, with_tls: bool, + agent_tunnel_handle: Option>, ) { let (stream, close_handle) = crate::ws::handle( ws, @@ -154,6 +159,7 @@ async fn handle_fwd( .sessions(sessions) .subscriber_tx(subscriber_tx) .with_tls(with_tls) + .agent_tunnel_handle(agent_tunnel_handle) .build() .run() .instrument(span.clone()) @@ -184,6 +190,8 @@ struct Forward { sessions: SessionMessageSender, subscriber_tx: SubscriberSender, with_tls: bool, + #[builder(default)] + agent_tunnel_handle: Option>, } #[derive(Debug, thiserror::Error)] @@ -207,6 +215,7 @@ where sessions, subscriber_tx, with_tls, + agent_tunnel_handle, } = self; match claims.jet_rec { @@ -224,6 +233,73 @@ where let span = tracing::Span::current(); + // Route via agent tunnel using the transparent routing pipeline: + // explicit agent_id → subnet match → domain suffix match → direct connect + if let Some(handle) = &agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let first_target = targets.first(); + let target_host_for_routing = first_target.host().to_owned(); + + let decision = routing::resolve_route(handle.registry(), claims.jet_agent_id, &target_host_for_routing); + + match &decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + error!(agent_id = %id, "Explicit agent not found in registry"); + return Err(ForwardError::BadGateway(anyhow::anyhow!( + "Agent {id} specified in token not found in registry. \ + Verify the agent is enrolled and connected." + ))); + } + RoutingDecision::Direct => { + info!(%target_host_for_routing, "No agent match, using direct connect"); + } + RoutingDecision::ViaAgent(_) => {} + } + + if let RoutingDecision::ViaAgent(candidates) = decision { + let target_str = format!("{}:{}", first_target.host(), first_target.port()); + + let (server_stream, _matched_agent) = + routing::route_and_connect(handle, &candidates, claims.jet_aid, &target_str) + .await + .map_err(ForwardError::BadGateway)?; + + let selected_target = first_target.clone(); + span.record("target", selected_target.to_string()); + + let info = SessionInfo::builder() + .id(claims.jet_aid) + .application_protocol(claims.jet_ap) + .details(ConnectionModeDetails::Fwd { + destination_host: selected_target, + }) + .time_to_live(claims.jet_ttl) + .recording_policy(claims.jet_rec) + .filtering_policy(claims.jet_flt) + .build(); + + let server_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); + + return Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_addr) + .transport_b(server_stream) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) + .build() + .select_dissector_and_forward() + .await + .context("encountered a failure during agent tunnel traffic proxying") + .map_err(ForwardError::Internal); + } + // RoutingDecision::Direct falls through to direct connect below + } + trace!("Select and connect to target"); let ((server_stream, server_addr), selected_target) = utils::successive_try(&targets, utils::tcp_connect) diff --git a/devolutions-gateway/src/api/kdc_proxy.rs b/devolutions-gateway/src/api/kdc_proxy.rs index cf2d0243a..51673d8d8 100644 --- a/devolutions-gateway/src/api/kdc_proxy.rs +++ b/devolutions-gateway/src/api/kdc_proxy.rs @@ -25,6 +25,7 @@ async fn kdc_proxy( token_cache, jrl, recordings, + agent_tunnel_handle, .. }): State, extract::Path(token): extract::Path, @@ -105,7 +106,12 @@ async fn kdc_proxy( &claims.krb_kdc }; - let kdc_reply_message = send_krb_message(kdc_addr, &kdc_proxy_message.kerb_message.0.0).await?; + let kdc_reply_message = send_krb_message( + kdc_addr, + &kdc_proxy_message.kerb_message.0.0, + agent_tunnel_handle.as_deref(), + ) + .await?; let kdc_reply_message = KdcProxyMessage::from_raw_kerb_message(&kdc_reply_message) .map_err(HttpError::internal().with_msg("couldn't create KDC proxy reply").err())?; @@ -115,11 +121,11 @@ async fn kdc_proxy( kdc_reply_message.to_vec().map_err(HttpError::internal().err()) } -async fn read_kdc_reply_message(connection: &mut TcpStream) -> io::Result> { - let len = connection.read_u32().await?; +async fn read_kdc_reply_message(reader: &mut R) -> io::Result> { + let len = reader.read_u32().await?; let mut buf = vec![0; (len + 4).try_into().expect("u32-to-usize")]; buf[0..4].copy_from_slice(&(len.to_be_bytes())); - connection.read_exact(&mut buf[4..]).await?; + reader.read_exact(&mut buf[4..]).await?; Ok(buf) } @@ -148,7 +154,63 @@ fn unable_to_reach_kdc_server_err(error: io::Error) -> HttpError { } /// Sends the Kerberos message to the specified KDC address. -pub async fn send_krb_message(kdc_addr: &TargetAddr, message: &[u8]) -> Result, HttpError> { +/// +/// Uses the same routing pipeline as connection forwarding: +/// if an agent claims the KDC's domain/subnet, traffic goes through the tunnel. +/// Falls back to direct connect only when no agent matches. +pub async fn send_krb_message( + kdc_addr: &TargetAddr, + message: &[u8], + agent_tunnel_handle: Option<&crate::agent_tunnel::AgentTunnelHandle>, +) -> Result, HttpError> { + // Route through agent tunnel using the SAME pipeline as connection forwarding. + if let Some(handle) = agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let kdc_host = kdc_addr.host(); + let kdc_target = kdc_addr.to_string(); + + let decision = routing::resolve_route(handle.registry(), None, kdc_host); + + match &decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + error!(agent_id = %id, "Explicit agent for KDC not found"); + return Err( + HttpError::bad_gateway().build(format!("Agent {id} specified for KDC not found in registry.")) + ); + } + RoutingDecision::Direct => { + info!(kdc_host = %kdc_host, "No agent match for KDC, using direct connect"); + } + RoutingDecision::ViaAgent(_) => {} + } + + // Hard commit: if an agent matched, KDC traffic MUST go through it. + // No silent fallback — consistent with connection forwarding. + if let RoutingDecision::ViaAgent(candidates) = decision { + let session_id = uuid::Uuid::new_v4(); + + let (mut stream, _agent) = routing::route_and_connect(handle, &candidates, session_id, &kdc_target) + .await + .map_err(|e| { + HttpError::bad_gateway().build(format!("KDC routing through agent tunnel failed: {e:#}")) + })?; + + stream.write_all(message).await.map_err( + HttpError::bad_gateway() + .with_msg("unable to send KDC message through agent tunnel") + .err(), + )?; + + return read_kdc_reply_message(&mut stream).await.map_err( + HttpError::bad_gateway() + .with_msg("unable to read KDC reply through agent tunnel") + .err(), + ); + } + // RoutingDecision::Direct falls through to direct connect below + } + let protocol = kdc_addr.scheme(); debug!("Connecting to KDC server located at {kdc_addr} using protocol {protocol}..."); diff --git a/devolutions-gateway/src/api/rdp.rs b/devolutions-gateway/src/api/rdp.rs index 65cfe5b2e..de25b1bad 100644 --- a/devolutions-gateway/src/api/rdp.rs +++ b/devolutions-gateway/src/api/rdp.rs @@ -26,6 +26,7 @@ pub async fn handler( recordings, shutdown_signal, credential_store, + agent_tunnel_handle, .. }): State, ConnectInfo(source_addr): ConnectInfo, @@ -46,6 +47,7 @@ pub async fn handler( recordings.active_recordings, source_addr, credential_store, + agent_tunnel_handle, ) .instrument(span) }); @@ -65,6 +67,7 @@ async fn handle_socket( active_recordings: Arc, source_addr: SocketAddr, credential_store: crate::credential::CredentialStoreHandle, + agent_tunnel_handle: Option>, ) { let (stream, close_handle) = crate::ws::handle( ws, @@ -82,6 +85,7 @@ async fn handle_socket( subscriber_tx, &active_recordings, &credential_store, + agent_tunnel_handle, ) .await; diff --git a/devolutions-gateway/src/proxy.rs b/devolutions-gateway/src/proxy.rs index 0c2f09e6a..fb5d1b4c6 100644 --- a/devolutions-gateway/src/proxy.rs +++ b/devolutions-gateway/src/proxy.rs @@ -32,8 +32,8 @@ pub struct Proxy { impl Proxy where - A: AsyncWrite + AsyncRead + Unpin, - B: AsyncWrite + AsyncRead + Unpin, + A: AsyncWrite + AsyncRead + Unpin + Send, + B: AsyncWrite + AsyncRead + Unpin + Send, { pub async fn select_dissector_and_forward(self) -> anyhow::Result<()> { match self.session_info.application_protocol { diff --git a/devolutions-gateway/src/rd_clean_path.rs b/devolutions-gateway/src/rd_clean_path.rs index 9855a8987..b5a4223d1 100644 --- a/devolutions-gateway/src/rd_clean_path.rs +++ b/devolutions-gateway/src/rd_clean_path.rs @@ -158,25 +158,77 @@ enum CleanPathError { Io(#[from] io::Error), } -struct CleanPathResult { +/// Inner transport for the RDP server connection. +/// +/// An enum is required here because `Box` trait objects cause the compiler to +/// lose `Send` provability for the async future spawned by `ws.on_upgrade()` in the +/// RDP handler. Generics are also not viable — the type would propagate up through +/// `handle_with_credential_injection` → `handle` → `handle_socket` → `ws.on_upgrade()`, +/// which requires a concrete future type. The enum gives the compiler full type +/// information to prove `Send` while keeping the transport abstraction local. +enum ServerTransport { + Tcp(tokio::net::TcpStream), + Quic(crate::agent_tunnel::stream::TunnelStream), +} + +impl AsyncRead for ServerTransport { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_read(cx, buf), + Self::Quic(s) => std::pin::Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ServerTransport { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_write(cx, buf), + Self::Quic(s) => std::pin::Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_flush(cx), + Self::Quic(s) => std::pin::Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + Self::Tcp(s) => std::pin::Pin::new(s).poll_shutdown(cx), + Self::Quic(s) => std::pin::Pin::new(s).poll_shutdown(cx), + } + } +} + +struct CleanPathAuth { claims: AssociationTokenClaims, - destination: TargetAddr, - server_addr: SocketAddr, - server_stream: tokio_rustls::client::TlsStream, - x224_rsp: Vec, } -async fn process_cleanpath( - cleanpath_pdu: RDCleanPathPdu, +/// Validate the RDCleanPath PDU token and authorize the session. +/// Pure validation — no connections established. +async fn authorize_cleanpath( + cleanpath_pdu: &RDCleanPathPdu, client_addr: SocketAddr, conf: &Conf, token_cache: &TokenCache, jrl: &CurrentJrl, active_recordings: &ActiveRecordings, sessions: &SessionMessageSender, -) -> Result { - use crate::utils; - +) -> Result { let token = cleanpath_pdu .proxy_auth .as_deref() @@ -207,10 +259,9 @@ async fn process_cleanpath( }; let span = tracing::Span::current(); - span.record("session_id", claims.jet_aid.to_string()); - // Sanity check. + // Sanity check destination in PDU vs token. match cleanpath_pdu.destination.as_deref() { Some(destination) => match TargetAddr::parse(destination, 3389) { Ok(destination) if !destination.eq(targets.first()) => { @@ -224,14 +275,78 @@ async fn process_cleanpath( None => warn!("RDCleanPath PDU is missing the destination field"), } + Ok(CleanPathAuth { claims }) +} + +struct ConnectedRdpServer { + tls_stream: tokio_rustls::client::TlsStream, + server_addr: SocketAddr, + selected_target: TargetAddr, + x224_rsp: Vec, +} + +/// Establish a connection to the RDP server: route (agent/direct) → connect → X224 → TLS. +async fn connect_rdp_server( + claims: &AssociationTokenClaims, + cleanpath_pdu: RDCleanPathPdu, + agent_tunnel_handle: Option<&Arc>, +) -> Result { + use crate::utils; + + let crate::token::ConnectionMode::Fwd { ref targets, .. } = claims.jet_cm else { + return anyhow::Error::msg("unexpected connection mode") + .pipe(CleanPathError::BadRequest) + .pipe(Err); + }; + trace!(?targets, "Connecting to destination server"); - let ((mut server_stream, server_addr), selected_target) = utils::successive_try(targets, utils::tcp_connect) - .await - .context("connect to RDP server")?; + // Route through agent tunnel if available, otherwise connect directly. + let (mut server_stream, server_addr, selected_target): (ServerTransport, SocketAddr, &TargetAddr) = + if let Some(handle) = agent_tunnel_handle { + use crate::agent_tunnel::routing::{self, RoutingDecision}; + + let first_target = targets.first(); + let target_host = first_target.host(); + + let decision = routing::resolve_route(handle.registry(), claims.jet_agent_id, target_host); + + match decision { + RoutingDecision::ExplicitAgentNotFound(id) => { + return Err(CleanPathError::Internal(anyhow::anyhow!( + "Agent {id} specified in token not found in registry" + ))); + } + RoutingDecision::ViaAgent(candidates) => { + let target_str = format!("{}:{}", first_target.host(), first_target.port()); + info!(target = %target_str, "Routing RDP via agent tunnel"); + + let (quic_stream, _agent) = + routing::route_and_connect(handle, &candidates, claims.jet_aid, &target_str) + .await + .context("connect to RDP server via agent tunnel")?; + + // TODO: agent-routed sessions use a placeholder address; monitoring tools + // that rely on server_addr will see 0.0.0.0:0 for tunneled connections. + let placeholder_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); + (ServerTransport::Quic(quic_stream), placeholder_addr, first_target) + } + RoutingDecision::Direct => { + let ((stream, addr), target) = utils::successive_try(targets, utils::tcp_connect) + .await + .context("connect to RDP server")?; + (ServerTransport::Tcp(stream), addr, target) + } + } + } else { + let ((stream, addr), target) = utils::successive_try(targets, utils::tcp_connect) + .await + .context("connect to RDP server")?; + (ServerTransport::Tcp(stream), addr, target) + }; debug!(%selected_target, "Connected to destination server"); - span.record("target", selected_target.to_string()); + tracing::Span::current().record("target", selected_target.to_string()); // Send preconnection blob if applicable. if let Some(pcb) = cleanpath_pdu.preconnection_blob { @@ -245,8 +360,6 @@ async fn process_cleanpath( .map_err(CleanPathError::BadRequest)?; server_stream.write_all(x224_req.as_bytes()).await?; - // == Receive server X224 connection response == - trace!("Receiving X224 response"); let x224_rsp = read_x224_response(&mut server_stream) @@ -256,20 +369,17 @@ async fn process_cleanpath( trace!("Establishing TLS connection with server"); - // == Establish TLS connection with server == - - let server_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) + let tls_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) .await .map_err(|source| CleanPathError::TlsHandshake { source, target_server: selected_target.to_owned(), })?; - Ok(CleanPathResult { - destination: selected_target.to_owned(), - claims, + Ok(ConnectedRdpServer { + tls_stream, server_addr, - server_stream, + selected_target: selected_target.to_owned(), x224_rsp, }) } @@ -287,6 +397,7 @@ async fn handle_with_credential_injection( active_recordings: &ActiveRecordings, cleanpath_pdu: RDCleanPathPdu, credential_entry: Arc, + agent_tunnel_handle: Option>, ) -> anyhow::Result<()> { let tls_conf = conf.credssp_tls.get().context("CredSSP TLS configuration")?; @@ -318,16 +429,9 @@ async fn handle_with_credential_injection( ) }; - // Run normal RDCleanPath flow (this will handle server-side TLS and get certs). - let CleanPathResult { - claims, - destination, - server_addr, - server_stream, - x224_rsp, - .. - } = process_cleanpath( - cleanpath_pdu, + // Authorize and connect to the RDP server. + let CleanPathAuth { claims } = authorize_cleanpath( + &cleanpath_pdu, client_addr, &conf, token_cache, @@ -336,7 +440,16 @@ async fn handle_with_credential_injection( &sessions, ) .await - .context("RDCleanPath processing failed")?; + .context("RDCleanPath authorization failed")?; + + let ConnectedRdpServer { + tls_stream: server_stream, + server_addr, + selected_target: destination, + x224_rsp, + } = connect_rdp_server(&claims, cleanpath_pdu, agent_tunnel_handle.as_ref()) + .await + .context("RDCleanPath connection failed")?; // Retrieve the Gateway TLS public key that must be used for client-proxy CredSSP later on. let gateway_cert_chain_handle = tokio::spawn(crate::tls::get_cert_chain_for_acceptor_cached( @@ -532,6 +645,7 @@ pub async fn handle( subscriber_tx: SubscriberSender, active_recordings: &ActiveRecordings, credential_store: &CredentialStoreHandle, + agent_tunnel_handle: Option>, ) -> anyhow::Result<()> { // Special handshake of our RDP extension @@ -569,27 +683,29 @@ pub async fn handle( active_recordings, cleanpath_pdu, entry, + agent_tunnel_handle.clone(), ) .await; } trace!("Processing RDCleanPath"); - let CleanPathResult { - claims, - destination, - server_addr, - server_stream, - x224_rsp, - } = match process_cleanpath( - cleanpath_pdu, - client_addr, - &conf, - token_cache, - jrl, - active_recordings, - &sessions, - ) + let (auth, connected) = match async { + let auth = authorize_cleanpath( + &cleanpath_pdu, + client_addr, + &conf, + token_cache, + jrl, + active_recordings, + &sessions, + ) + .await?; + + let connected = connect_rdp_server(&auth.claims, cleanpath_pdu, agent_tunnel_handle.as_ref()).await?; + + Ok::<_, CleanPathError>((auth, connected)) + } .await { Ok(result) => result, @@ -602,6 +718,13 @@ pub async fn handle( } }; + let ConnectedRdpServer { + tls_stream: server_stream, + server_addr, + selected_target: destination, + x224_rsp, + } = connected; + // == Send success RDCleanPathPdu response == let x509_chain = server_stream @@ -622,13 +745,13 @@ pub async fn handle( // == Start actual RDP session == let info = SessionInfo::builder() - .id(claims.jet_aid) - .application_protocol(claims.jet_ap) + .id(auth.claims.jet_aid) + .application_protocol(auth.claims.jet_ap) .details(ConnectionModeDetails::Fwd { destination_host: destination.clone(), }) - .time_to_live(claims.jet_ttl) - .recording_policy(claims.jet_rec) + .time_to_live(auth.claims.jet_ttl) + .recording_policy(auth.claims.jet_rec) .build(); info!("RDP-TLS forwarding (RDCleanPath)"); @@ -642,7 +765,7 @@ pub async fn handle( .transport_b(server_stream) .sessions(sessions) .subscriber_tx(subscriber_tx) - .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) + .disconnect_interest(DisconnectInterest::from_reconnection_policy(auth.claims.jet_reuse)) .build() .select_dissector_and_forward() .await diff --git a/devolutions-gateway/src/rdp_proxy.rs b/devolutions-gateway/src/rdp_proxy.rs index b3dc466a7..b7fddfdd5 100644 --- a/devolutions-gateway/src/rdp_proxy.rs +++ b/devolutions-gateway/src/rdp_proxy.rs @@ -637,7 +637,7 @@ where async fn send_network_request(request: &NetworkRequest) -> anyhow::Result> { let target_addr = TargetAddr::parse(request.url.as_str(), Some(88))?; - send_krb_message(&target_addr, &request.data) + send_krb_message(&target_addr, &request.data, None) .await .map_err(|err| anyhow::Error::msg("failed to send KDC message").context(err)) }