diff --git a/src/main.rs b/src/main.rs index d9f760f..215ed89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use base64::Engine; use clap::{Parser, Subcommand}; +use futures_util::StreamExt; use http_body_util::BodyExt; use hudsucker::{ certificate_authority::RcgenAuthority, @@ -11,6 +12,7 @@ use hudsucker::{ tokio_tungstenite::tungstenite::http::uri::Scheme, Body, HttpContext, HttpHandler, Proxy, RequestOrResponse, }; +use hyper::Method; use hyper::{StatusCode, Uri}; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use hyper_util::{ @@ -137,11 +139,15 @@ impl HttpHandler for Handler { "GET" | "POST" | "HEAD" => { let original_url = req.uri().clone(); println!("{req:?}"); - let Ok(req) = decode_request(req) else { + let Ok(mut req) = decode_request(req) else { let mut res = Response::new("not found".into()); *res.status_mut() = StatusCode::NOT_FOUND; return res.into(); }; + let is_head_request = req.method() == "HEAD"; + if is_head_request { + *req.method_mut() = Method::GET; + } let (info, body) = req.into_parts(); let Ok(req_body) = body.collect().await.map(|x| x.to_bytes()) else { @@ -165,7 +171,7 @@ impl HttpHandler for Handler { req_body )])), ); - let store_body_info = req.method() != "HEAD"; + let store_body_info = true; let url = process_uri(original_url); if matches!(forget_regex, Some(x) if x.is_match(&url.to_string())) { forget = true; @@ -307,7 +313,20 @@ impl HttpHandler for Handler { } } }); - Response::from_parts(info, ret_body) + if is_head_request { + // For HEAD requests, drain the channel + // output; otherwise rx will be dropped, + // which closes the channel and causes + // the tx.poll_ready call above to + // return an error + let mut stream = ret_body.into_data_stream(); + while let Some(_) = stream.next().await { + // do nothing + } + Response::from_parts(info, Body::empty()) + } else { + Response::from_parts(info, ret_body) + } } else { // remove hash headers to force the software to download this // so we get sha256