diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index b92be199e..df1b192b4 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -688,24 +688,10 @@ fn is_loopback_gateway_endpoint(endpoint: &str) -> bool { } } -/// Check whether mTLS client certs exist on disk for the gateway that -/// would serve this endpoint. -/// -/// Loopback endpoints (`localhost`, `127.0.0.1`, `::1`) resolve to the -/// `"openshell"` gateway name, matching the convention used by local -/// `openshell-gateway generate-certs` and the TLS cert resolver in `tls.rs`. -fn mtls_certs_exist_for_endpoint(name: &str, endpoint: &str) -> bool { - let cert_name = if is_loopback_gateway_endpoint(endpoint) { - "openshell" - } else { - name - }; +/// Check whether mTLS client certs exist on disk for a gateway name. +fn mtls_certs_exist_for_gateway(name: &str) -> bool { openshell_core::paths::xdg_config_dir().is_ok_and(|d| { - let mtls = d - .join("openshell") - .join("gateways") - .join(cert_name) - .join("mtls"); + let mtls = d.join("openshell").join("gateways").join(name).join("mtls"); mtls.join("ca.crt").is_file() && mtls.join("tls.crt").is_file() && mtls.join("tls.key").is_file() @@ -1030,7 +1016,7 @@ pub async fn gateway_add( if endpoint.starts_with("http://") { // Warn if mTLS certs exist for this gateway — the user likely // meant to use https:// instead of http://. - let has_mtls_certs = mtls_certs_exist_for_endpoint(name, &endpoint); + let has_mtls_certs = mtls_certs_exist_for_gateway(name); if has_mtls_certs { let https_endpoint = endpoint.replacen("http://", "https://", 1); @@ -1084,8 +1070,7 @@ pub async fn gateway_add( } else { None }; - let certs_on_disk = - imported_mtls_dir.is_some() || mtls_certs_exist_for_endpoint(name, &endpoint); + let certs_on_disk = imported_mtls_dir.is_some() || mtls_certs_exist_for_gateway(name); if !certs_on_disk { return Err(miette::miette!( "mTLS certificates for gateway '{name}' were not found.\n\ @@ -6986,12 +6971,12 @@ mod tests { gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, inferred_provider_type, local_upload_path_exists, local_upload_path_is_symlink, - package_managed_tls_dirs, parse_cli_setting_value, parse_credential_expiry_cli_value, - parse_credential_expiry_pairs, parse_credential_pairs, plaintext_gateway_is_remote, - progress_step_from_metadata, provider_profile_allows_refresh_bootstrap, - provisioning_timeout_message, ready_false_condition_message, refresh_status_header, - refresh_status_row, resolve_from, sandbox_should_persist, service_expose_status_error, - service_url_for_gateway, + mtls_certs_exist_for_gateway, package_managed_tls_dirs, parse_cli_setting_value, + parse_credential_expiry_cli_value, parse_credential_expiry_pairs, parse_credential_pairs, + plaintext_gateway_is_remote, progress_step_from_metadata, + provider_profile_allows_refresh_bootstrap, provisioning_timeout_message, + ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from, + sandbox_should_persist, service_expose_status_error, service_url_for_gateway, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -7993,6 +7978,21 @@ mod tests { }); } + #[test] + fn mtls_certs_exist_for_gateway_uses_explicit_name_for_loopback_endpoint() { + let tmpdir = tempfile::tempdir().expect("create tmpdir"); + let mtls = tmpdir.path().join("openshell/gateways/k8s/mtls"); + fs::create_dir_all(&mtls).expect("create mtls dir"); + fs::write(mtls.join("ca.crt"), "ca").expect("write ca"); + fs::write(mtls.join("tls.crt"), "client cert").expect("write cert"); + fs::write(mtls.join("tls.key"), "client key").expect("write key"); + + with_tmp_xdg(tmpdir.path(), || { + assert!(mtls_certs_exist_for_gateway("k8s")); + assert!(!mtls_certs_exist_for_gateway("openshell")); + }); + } + #[test] fn plaintext_gateway_locality_infers_loopback_endpoints_as_local() { assert!(!plaintext_gateway_is_remote( diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index 8f83599b1..7cb9e1e76 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -6,7 +6,11 @@ mod helpers; use helpers::{ EnvVarGuard, build_ca, build_client_cert, build_server_cert, install_rustls_provider, }; -use openshell_cli::tls::{TlsOptions, grpc_client}; +use openshell_bootstrap::{get_gateway_metadata, load_active_gateway}; +use openshell_cli::{ + run, + tls::{TlsOptions, grpc_client}, +}; use openshell_core::proto::{ CreateProviderRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, ExecSandboxEvent, ExecSandboxInput, @@ -494,6 +498,167 @@ async fn run_server( addr } +fn write_gateway_mtls_bundle( + config_dir: &std::path::Path, + gateway_name: &str, + ca_cert: &str, + client_cert: &str, + client_key: &str, +) { + let mtls = config_dir + .join("openshell") + .join("gateways") + .join(gateway_name) + .join("mtls"); + std::fs::create_dir_all(&mtls).unwrap(); + std::fs::write(mtls.join("ca.crt"), ca_cert).unwrap(); + std::fs::write(mtls.join("tls.crt"), client_cert).unwrap(); + std::fs::write(mtls.join("tls.key"), client_key).unwrap(); +} + +fn isolated_gateway_add_env( + config_dir: &std::path::Path, + state_dir: &std::path::Path, +) -> EnvVarGuard { + let xdg_config = config_dir.to_string_lossy().into_owned(); + let xdg_state = state_dir.to_string_lossy().into_owned(); + let local_tls_dir = state_dir.join("no-package-managed-tls"); + let local_tls = local_tls_dir.to_string_lossy().into_owned(); + + EnvVarGuard::set(&[ + ("XDG_CONFIG_HOME", xdg_config.as_str()), + ("XDG_STATE_HOME", xdg_state.as_str()), + ("HOME", xdg_state.as_str()), + ("OPENSHELL_LOCAL_TLS_DIR", local_tls.as_str()), + ("OPENSHELL_GATEWAY", "unused-by-named-gateway-add"), + ]) +} + +#[tokio::test] +async fn gateway_add_mtls_loopback_uses_explicit_gateway_name() { + install_rustls_provider(); + + let (ca, ca_key) = build_ca(); + let (server_cert, server_key) = build_server_cert(&ca, &ca_key); + let (client_cert, client_key) = build_client_cert(&ca, &ca_key); + let ca_cert = ca.pem(); + let addr = run_server(server_cert, server_key, ca_cert.clone()).await; + + let config_dir = tempdir().unwrap(); + let state_dir = tempdir().unwrap(); + write_gateway_mtls_bundle( + config_dir.path(), + "k8s", + &ca_cert, + &client_cert, + &client_key, + ); + let _env = isolated_gateway_add_env(config_dir.path(), state_dir.path()); + + let endpoint = format!("https://localhost:{}", addr.port()); + run::gateway_add( + &endpoint, + Some("k8s"), + None, + true, + None, + "openshell-cli", + None, + None, + false, + ) + .await + .unwrap(); + + let metadata = get_gateway_metadata("k8s").unwrap(); + assert_eq!(metadata.name, "k8s"); + assert_eq!(metadata.gateway_endpoint, endpoint); + assert_eq!(metadata.auth_mode.as_deref(), Some("mtls")); + assert_eq!(load_active_gateway().as_deref(), Some("k8s")); + assert!(get_gateway_metadata("openshell").is_none()); +} + +#[tokio::test] +async fn gateway_add_mtls_loopback_without_name_uses_openshell_default() { + install_rustls_provider(); + + let (ca, ca_key) = build_ca(); + let (server_cert, server_key) = build_server_cert(&ca, &ca_key); + let (client_cert, client_key) = build_client_cert(&ca, &ca_key); + let ca_cert = ca.pem(); + let addr = run_server(server_cert, server_key, ca_cert.clone()).await; + + let config_dir = tempdir().unwrap(); + let state_dir = tempdir().unwrap(); + write_gateway_mtls_bundle( + config_dir.path(), + "openshell", + &ca_cert, + &client_cert, + &client_key, + ); + let _env = isolated_gateway_add_env(config_dir.path(), state_dir.path()); + + let endpoint = format!("https://localhost:{}", addr.port()); + run::gateway_add( + &endpoint, + None, + None, + true, + None, + "openshell-cli", + None, + None, + false, + ) + .await + .unwrap(); + + let metadata = get_gateway_metadata("openshell").unwrap(); + assert_eq!(metadata.name, "openshell"); + assert_eq!(metadata.gateway_endpoint, endpoint); + assert_eq!(metadata.auth_mode.as_deref(), Some("mtls")); + assert_eq!(load_active_gateway().as_deref(), Some("openshell")); +} + +#[tokio::test] +async fn gateway_add_mtls_loopback_explicit_name_does_not_fallback_to_openshell_certs() { + install_rustls_provider(); + + let (ca, ca_key) = build_ca(); + let (client_cert, client_key) = build_client_cert(&ca, &ca_key); + let ca_cert = ca.pem(); + + let config_dir = tempdir().unwrap(); + let state_dir = tempdir().unwrap(); + write_gateway_mtls_bundle( + config_dir.path(), + "openshell", + &ca_cert, + &client_cert, + &client_key, + ); + let _env = isolated_gateway_add_env(config_dir.path(), state_dir.path()); + + let err = run::gateway_add( + "https://localhost:1", + Some("k8s"), + None, + true, + None, + "openshell-cli", + None, + None, + false, + ) + .await + .expect_err("explicit name should require matching named mTLS material"); + + assert!(err.to_string().contains("gateway 'k8s'")); + assert!(get_gateway_metadata("k8s").is_none()); + assert!(load_active_gateway().is_none()); +} + #[tokio::test] async fn cli_connects_with_client_cert() { install_rustls_provider();