diff --git a/Cargo.lock b/Cargo.lock index d859fea..6bdc7be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,7 +92,8 @@ dependencies = [ "serde_json", "thiserror", "tokio", - "tower-http", + "tower", + "tower-http 0.7.0", "tracing", "tracing-subscriber", "url", @@ -1942,7 +1943,7 @@ dependencies = [ "tokio-rustls", "tokio-util", "tower", - "tower-http", + "tower-http 0.6.11", "tower-service", "url", "wasm-bindgen", @@ -2679,6 +2680,21 @@ dependencies = [ "url", ] +[[package]] +name = "tower-http" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b11f75e912b0c2be01b63d8cf8057b8c3f97cf34abb3d431a3a4c8675498e233" +dependencies = [ + "bitflags", + "bytes", + "http", + "percent-encoding", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" diff --git a/Cargo.toml b/Cargo.toml index 7270daa..31e6425 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,8 @@ open = "5" rustc-hash = "2.1.1" shell-words = "1.1" strip-ansi-escapes = "0.2" -tower-http = "0.6" +tower = { version = "0.5", features = ["util"] } +tower-http = "0.7" windows-sys = "0.61" diff --git a/src/agent-client-protocol-http/Cargo.toml b/src/agent-client-protocol-http/Cargo.toml index 9c29358..9a71c55 100644 --- a/src/agent-client-protocol-http/Cargo.toml +++ b/src/agent-client-protocol-http/Cargo.toml @@ -73,6 +73,7 @@ axum = { workspace = true, features = ["ws", "macros"] } tokio = { workspace = true, features = ["macros", "net", "rt", "sync", "time"] } async-tungstenite.workspace = true tracing-subscriber.workspace = true +tower.workspace = true [lints] workspace = true diff --git a/src/agent-client-protocol-http/src/server.rs b/src/agent-client-protocol-http/src/server.rs index a310997..695b169 100644 --- a/src/agent-client-protocol-http/src/server.rs +++ b/src/agent-client-protocol-http/src/server.rs @@ -65,7 +65,7 @@ impl CorsOptions { match self { Self::Disabled => None, Self::AllowOrigins(origins) => Some(AllowOrigin::list(origins.clone())), - Self::AllowAnyOrigin => Some(AllowOrigin::mirror_request()), + Self::AllowAnyOrigin => Some(AllowOrigin::any()), } } @@ -196,6 +196,8 @@ async fn handle_get( #[cfg(test)] mod tests { use super::*; + use axum::body::Body; + use tower::{Layer as _, ServiceExt as _, service_fn}; #[test] fn cors_is_disabled_by_default() { @@ -227,4 +229,60 @@ mod tests { assert!(CorsOptions::allow_any_origin().allows_origin(Some(&origin))); } + + #[tokio::test] + async fn allow_any_origin_uses_wildcard_cors_header() { + let response = default_cors( + CorsOptions::allow_any_origin() + .allow_origin_layer() + .expect("CORS layer"), + ) + .layer(service_fn(|_: axum::http::Request| async { + Ok::<_, std::convert::Infallible>(Response::new(Body::empty())) + })) + .oneshot( + axum::http::Request::builder() + .header(header::ORIGIN, "https://example.com") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), + Some(&HeaderValue::from_static("*")) + ); + assert!(response.headers().get(header::VARY).is_none()); + } + + #[tokio::test] + async fn allowlisted_origins_vary_by_origin() { + let response = default_cors( + CorsOptions::allow_origins(["https://example.com"]) + .unwrap() + .allow_origin_layer() + .expect("CORS layer"), + ) + .layer(service_fn(|_: axum::http::Request| async { + Ok::<_, std::convert::Infallible>(Response::new(Body::empty())) + })) + .oneshot( + axum::http::Request::builder() + .header(header::ORIGIN, "https://example.com") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN), + Some(&HeaderValue::from_static("https://example.com")) + ); + assert_eq!( + response.headers().get(header::VARY), + Some(&HeaderValue::from_static("origin")) + ); + } }