diff --git a/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java b/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java index d59c53db1d1..0df8dc452c1 100644 --- a/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java +++ b/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java @@ -23,6 +23,9 @@ import com.google.common.base.Objects; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import javax.annotation.Nullable; /** @@ -33,6 +36,8 @@ public final class HttpConnectProxiedSocketAddress extends ProxiedSocketAddress private final SocketAddress proxyAddress; private final InetSocketAddress targetAddress; + @SuppressWarnings("serial") + private final Map headers; @Nullable private final String username; @Nullable @@ -41,6 +46,7 @@ public final class HttpConnectProxiedSocketAddress extends ProxiedSocketAddress private HttpConnectProxiedSocketAddress( SocketAddress proxyAddress, InetSocketAddress targetAddress, + Map headers, @Nullable String username, @Nullable String password) { checkNotNull(proxyAddress, "proxyAddress"); @@ -53,6 +59,7 @@ private HttpConnectProxiedSocketAddress( } this.proxyAddress = proxyAddress; this.targetAddress = targetAddress; + this.headers = headers; this.username = username; this.password = password; } @@ -87,6 +94,14 @@ public InetSocketAddress getTargetAddress() { return targetAddress; } + /** + * Returns the custom HTTP headers to be sent during the HTTP CONNECT handshake. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12479") + public Map getHeaders() { + return headers; + } + @Override public boolean equals(Object o) { if (!(o instanceof HttpConnectProxiedSocketAddress)) { @@ -95,13 +110,14 @@ public boolean equals(Object o) { HttpConnectProxiedSocketAddress that = (HttpConnectProxiedSocketAddress) o; return Objects.equal(proxyAddress, that.proxyAddress) && Objects.equal(targetAddress, that.targetAddress) + && Objects.equal(headers, that.headers) && Objects.equal(username, that.username) && Objects.equal(password, that.password); } @Override public int hashCode() { - return Objects.hashCode(proxyAddress, targetAddress, username, password); + return Objects.hashCode(proxyAddress, targetAddress, username, password, headers); } @Override @@ -109,6 +125,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("proxyAddr", proxyAddress) .add("targetAddr", targetAddress) + .add("headers", headers) .add("username", username) // Intentionally mask out password .add("hasPassword", password != null) @@ -129,6 +146,7 @@ public static final class Builder { private SocketAddress proxyAddress; private InetSocketAddress targetAddress; + private Map headers = Collections.emptyMap(); @Nullable private String username; @Nullable @@ -153,6 +171,18 @@ public Builder setTargetAddress(InetSocketAddress targetAddress) { return this; } + /** + * Sets custom HTTP headers to be sent during the HTTP CONNECT handshake. This is an optional + * field. The headers will be sent in addition to any authentication headers (if username and + * password are set). + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12479") + public Builder setHeaders(Map headers) { + this.headers = Collections.unmodifiableMap( + new HashMap<>(checkNotNull(headers, "headers"))); + return this; + } + /** * Sets the username used to connect to the proxy. This is an optional field and can be {@code * null}. @@ -175,7 +205,8 @@ public Builder setPassword(@Nullable String password) { * Creates an {@code HttpConnectProxiedSocketAddress}. */ public HttpConnectProxiedSocketAddress build() { - return new HttpConnectProxiedSocketAddress(proxyAddress, targetAddress, username, password); + return new HttpConnectProxiedSocketAddress( + proxyAddress, targetAddress, headers, username, password); } } } diff --git a/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java b/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java new file mode 100644 index 00000000000..6620a7d413a --- /dev/null +++ b/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java @@ -0,0 +1,248 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; + +import com.google.common.testing.EqualsTester; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HttpConnectProxiedSocketAddressTest { + + private final InetSocketAddress proxyAddress = + new InetSocketAddress(InetAddress.getLoopbackAddress(), 8080); + private final InetSocketAddress targetAddress = + InetSocketAddress.createUnresolved("example.com", 443); + + @Test + public void buildWithAllFields() { + Map headers = new HashMap<>(); + headers.put("X-Custom-Header", "custom-value"); + headers.put("Proxy-Authorization", "Bearer token"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .setUsername("user") + .setPassword("pass") + .build(); + + assertThat(address.getProxyAddress()).isEqualTo(proxyAddress); + assertThat(address.getTargetAddress()).isEqualTo(targetAddress); + assertThat(address.getHeaders()).hasSize(2); + assertThat(address.getHeaders()).containsEntry("X-Custom-Header", "custom-value"); + assertThat(address.getHeaders()).containsEntry("Proxy-Authorization", "Bearer token"); + assertThat(address.getUsername()).isEqualTo("user"); + assertThat(address.getPassword()).isEqualTo("pass"); + } + + @Test + public void buildWithoutOptionalFields() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build(); + + assertThat(address.getProxyAddress()).isEqualTo(proxyAddress); + assertThat(address.getTargetAddress()).isEqualTo(targetAddress); + assertThat(address.getHeaders()).isEmpty(); + assertThat(address.getUsername()).isNull(); + assertThat(address.getPassword()).isNull(); + } + + @Test + public void buildWithEmptyHeaders() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(Collections.emptyMap()) + .build(); + + assertThat(address.getHeaders()).isEmpty(); + } + + @Test + public void headersAreImmutable() { + Map headers = new HashMap<>(); + headers.put("key1", "value1"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + headers.put("key2", "value2"); + + assertThat(address.getHeaders()).hasSize(1); + assertThat(address.getHeaders()).containsEntry("key1", "value1"); + assertThat(address.getHeaders()).doesNotContainKey("key2"); + } + + @Test + public void returnedHeadersAreUnmodifiable() { + Map headers = new HashMap<>(); + headers.put("key", "value"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + assertThrows(UnsupportedOperationException.class, + () -> address.getHeaders().put("newKey", "newValue")); + } + + @Test + public void nullHeadersThrowsException() { + assertThrows(NullPointerException.class, + () -> HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(null) + .build()); + } + + @Test + public void equalsAndHashCode() { + Map headers1 = new HashMap<>(); + headers1.put("header", "value"); + + Map headers2 = new HashMap<>(); + headers2.put("header", "value"); + + Map differentHeaders = new HashMap<>(); + differentHeaders.put("different", "header"); + + new EqualsTester() + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers1) + .setUsername("user") + .setPassword("pass") + .build(), + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers2) + .setUsername("user") + .setPassword("pass") + .build()) + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(differentHeaders) + .setUsername("user") + .setPassword("pass") + .build()) + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build()) + .testEquals(); + } + + @Test + public void toStringContainsHeaders() { + Map headers = new HashMap<>(); + headers.put("X-Test", "test-value"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .setUsername("user") + .setPassword("secret") + .build(); + + String toString = address.toString(); + assertThat(toString).contains("headers"); + assertThat(toString).contains("X-Test"); + assertThat(toString).contains("hasPassword=true"); + assertThat(toString).doesNotContain("secret"); + } + + @Test + public void toStringWithoutPassword() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build(); + + String toString = address.toString(); + assertThat(toString).contains("hasPassword=false"); + } + + @Test + public void hashCodeDependsOnHeaders() { + Map headers1 = new HashMap<>(); + headers1.put("header", "value1"); + + Map headers2 = new HashMap<>(); + headers2.put("header", "value2"); + + HttpConnectProxiedSocketAddress address1 = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers1) + .build(); + + HttpConnectProxiedSocketAddress address2 = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers2) + .build(); + + assertNotEquals(address1.hashCode(), address2.hashCode()); + } + + @Test + public void multipleHeadersSupported() { + Map headers = new HashMap<>(); + headers.put("X-Header-1", "value1"); + headers.put("X-Header-2", "value2"); + headers.put("X-Header-3", "value3"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + assertThat(address.getHeaders()).hasSize(3); + assertThat(address.getHeaders()).containsEntry("X-Header-1", "value1"); + assertThat(address.getHeaders()).containsEntry("X-Header-2", "value2"); + assertThat(address.getHeaders()).containsEntry("X-Header-3", "value3"); + } +} + diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 2db5ab20a91..258aa15b005 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -818,6 +818,7 @@ public ConnectionClientTransport newClientTransport( serverAddress = proxiedAddr.getTargetAddress(); localNegotiator = ProtocolNegotiators.httpProxy( proxiedAddr.getProxyAddress(), + proxiedAddr.getHeaders(), proxiedAddr.getUsername(), proxiedAddr.getPassword(), protocolNegotiator); diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index d30a6292d38..9323c58aae1 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -51,10 +51,12 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientUpgradeHandler; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; @@ -77,6 +79,7 @@ import java.util.Arrays; import java.util.EnumSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; import java.util.logging.Level; @@ -484,7 +487,8 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession * Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation. */ public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, - final @Nullable String proxyUsername, final @Nullable String proxyPassword, + final @Nullable Map headers, final @Nullable String proxyUsername, + final @Nullable String proxyPassword, final ProtocolNegotiator negotiator) { Preconditions.checkNotNull(negotiator, "negotiator"); Preconditions.checkNotNull(proxyAddress, "proxyAddress"); @@ -494,8 +498,9 @@ class ProxyNegotiator implements ProtocolNegotiator { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler http2Handler) { ChannelHandler protocolNegotiationHandler = negotiator.newHandler(http2Handler); ChannelLogger negotiationLogger = http2Handler.getNegotiationLogger(); + HttpHeaders httpHeaders = toHttpHeaders(headers); return new ProxyProtocolNegotiationHandler( - proxyAddress, proxyUsername, proxyPassword, protocolNegotiationHandler, + proxyAddress, httpHeaders, proxyUsername, proxyPassword, protocolNegotiationHandler, negotiationLogger); } @@ -515,6 +520,22 @@ public void close() { return new ProxyNegotiator(); } + /** + * Converts generic Map of headers to Netty's HttpHeaders. + * Returns null if the map is null or empty. + */ + @Nullable + private static HttpHeaders toHttpHeaders(@Nullable Map headers) { + if (headers == null || headers.isEmpty()) { + return null; + } + HttpHeaders httpHeaders = new DefaultHttpHeaders(); + for (Map.Entry entry : headers.entrySet()) { + httpHeaders.add(entry.getKey(), entry.getValue()); + } + return httpHeaders; + } + /** * A Proxy handler follows {@link ProtocolNegotiationHandler} pattern. Upon successful proxy * connection, this handler will install {@code next} handler which should be a handler from @@ -523,17 +544,20 @@ public void close() { static final class ProxyProtocolNegotiationHandler extends ProtocolNegotiationHandler { private final SocketAddress address; + @Nullable private final HttpHeaders httpHeaders; @Nullable private final String userName; @Nullable private final String password; public ProxyProtocolNegotiationHandler( SocketAddress address, + @Nullable HttpHeaders httpHeaders, @Nullable String userName, @Nullable String password, ChannelHandler next, ChannelLogger negotiationLogger) { super(next, negotiationLogger); this.address = Preconditions.checkNotNull(address, "address"); + this.httpHeaders = httpHeaders; this.userName = userName; this.password = password; } @@ -542,9 +566,9 @@ public ProxyProtocolNegotiationHandler( protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { HttpProxyHandler nettyProxyHandler; if (userName == null || password == null) { - nettyProxyHandler = new HttpProxyHandler(address); + nettyProxyHandler = new HttpProxyHandler(address, httpHeaders); } else { - nettyProxyHandler = new HttpProxyHandler(address, userName, password); + nettyProxyHandler = new HttpProxyHandler(address, userName, password, httpHeaders); } ctx.pipeline().addBefore(ctx.name(), /* name= */ null, nettyProxyHandler); } diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 9bb5a43d792..cde33139965 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -126,6 +126,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -1088,20 +1089,21 @@ public void tls_invalidHost() throws SSLException { @Test public void httpProxy_nullAddressNpe() { assertThrows(NullPointerException.class, - () -> ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext())); + () -> ProtocolNegotiators.httpProxy(null, null, "user", "pass", + ProtocolNegotiators.plaintext())); } @Test public void httpProxy_nullNegotiatorNpe() { assertThrows(NullPointerException.class, () -> ProtocolNegotiators.httpProxy( - InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null)); + InetSocketAddress.createUnresolved("localhost", 80), null, "user", "pass", null)); } @Test public void httpProxy_nullUserPassNoException() throws Exception { assertNotNull(ProtocolNegotiators.httpProxy( - InetSocketAddress.createUnresolved("localhost", 80), null, null, + InetSocketAddress.createUnresolved("localhost", 80), null, null, null, ProtocolNegotiators.plaintext())); } @@ -1119,7 +1121,7 @@ public void httpProxy_completes() throws Exception { .bind(proxy).sync().channel(); ProtocolNegotiator nego = - ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); + ProtocolNegotiators.httpProxy(proxy, null, null, null, ProtocolNegotiators.plaintext()); // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, // mocking the behavior using KickStartHandler. ChannelHandler handler = @@ -1182,7 +1184,7 @@ public void httpProxy_500() throws Exception { .bind(proxy).sync().channel(); ProtocolNegotiator nego = - ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); + ProtocolNegotiators.httpProxy(proxy, null, null, null, ProtocolNegotiators.plaintext()); // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, // mocking the behavior using KickStartHandler. ChannelHandler handler = @@ -1220,6 +1222,77 @@ public void httpProxy_500() throws Exception { } } + @Test + public void httpProxy_customHeaders() throws Exception { + DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); + // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called + // the channel is already active. + LocalAddress proxy = new LocalAddress("httpProxy_customHeaders"); + SocketAddress host = InetSocketAddress.createUnresolved("example.com", 443); + + ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class); + Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class) + .childHandler(mockHandler) + .bind(proxy).sync().channel(); + + Map headers = new java.util.HashMap<>(); + headers.put("X-Custom-Header", "custom-value"); + headers.put("Proxy-Authorization", "Bearer token123"); + + ProtocolNegotiator nego = ProtocolNegotiators.httpProxy( + proxy, headers, null, null, ProtocolNegotiators.plaintext()); + // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, + // mocking the behavior using KickStartHandler. + ChannelHandler handler = + new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler())); + Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler) + .register().sync().channel(); + pipeline = channel.pipeline(); + // Wait for initialization to complete + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + channel.connect(host).sync(); + serverChannel.close(); + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(ChannelHandlerContext.class); + Mockito.verify(mockHandler).channelActive(contextCaptor.capture()); + ChannelHandlerContext serverContext = contextCaptor.getValue(); + + final String golden = "testData"; + ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel)); + + // Wait for sending initial request to complete + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + ArgumentCaptor objectCaptor = ArgumentCaptor.forClass(Object.class); + Mockito.verify(mockHandler) + .channelRead(ArgumentMatchers.any(), objectCaptor.capture()); + ByteBuf b = (ByteBuf) objectCaptor.getValue(); + String request = b.toString(UTF_8); + b.release(); + + // Verify custom headers are present in the CONNECT request + assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n")); + assertTrue("No CONNECT: " + request, request.startsWith("CONNECT example.com:443 ")); + assertTrue("No custom header: " + request, + request.contains("X-Custom-Header: custom-value")); + assertTrue("No proxy authorization: " + request, + request.contains("Proxy-Authorization: Bearer token123")); + + assertFalse(negotiationFuture.isDone()); + serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync(); + negotiationFuture.sync(); + + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + objectCaptor = ArgumentCaptor.forClass(Object.class); + Mockito.verify(mockHandler, times(2)) + .channelRead(ArgumentMatchers.any(), objectCaptor.capture()); + b = (ByteBuf) objectCaptor.getAllValues().get(1); + String preface = b.toString(UTF_8); + b.release(); + assertEquals(golden, preface); + + channel.close(); + } + @Test public void waitUntilActiveHandler_firesNegotiation() throws Exception { EventLoopGroup elg = new DefaultEventLoopGroup(1);