From 7595d74c41387327fe9b7e8a1f56f4c53cd44145 Mon Sep 17 00:00:00 2001 From: Israel Blancas Date: Thu, 27 Nov 2025 15:48:53 +0100 Subject: [PATCH 1/5] client: allow overriding grpc-accept-encoding header Signed-off-by: Israel Blancas --- experimental/experimental.go | 8 +++ internal/experimental.go | 4 ++ internal/transport/http2_client.go | 3 + internal/transport/transport.go | 6 ++ rpc_util.go | 105 ++++++++++++++++++++++++----- rpc_util_test.go | 79 ++++++++++++++++++++++ server.go | 2 +- stream.go | 15 +++-- test/compressor_test.go | 25 +++++++ 9 files changed, 225 insertions(+), 22 deletions(-) diff --git a/experimental/experimental.go b/experimental/experimental.go index 719692636505..c8620cffe415 100644 --- a/experimental/experimental.go +++ b/experimental/experimental.go @@ -62,3 +62,11 @@ func WithBufferPool(bufferPool mem.BufferPool) grpc.DialOption { func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption { return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool) } + +// AcceptedCompressionNames returns a CallOption that limits the values +// advertised in the grpc-accept-encoding header for the provided RPC. The +// supplied names must correspond to compressors registered via +// encoding.RegisterCompressor. Passing no names advertises identity only. +func AcceptedCompressionNames(names ...string) grpc.CallOption { + return internal.AcceptedCompressionNames.(func(...string) grpc.CallOption)(names...) +} diff --git a/internal/experimental.go b/internal/experimental.go index 7617be215895..3482abacdc5e 100644 --- a/internal/experimental.go +++ b/internal/experimental.go @@ -25,4 +25,8 @@ var ( // BufferPool is implemented by the grpc package and returns a server // option to configure a shared buffer pool for a grpc.Server. BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption + + // AcceptedCompressionNames is implemented by the grpc package and returns + // a call option that restricts the grpc-accept-encoding header for a call. + AcceptedCompressionNames any // func(...string) grpc.CallOption ) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 65b4ab2439e2..19c9b1ebad0b 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -551,6 +551,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te hfLen += len(authData) + len(callAuthData) registeredCompressors := t.registeredCompressors + if callHdr.AcceptedCompressors != nil { + registeredCompressors = *callHdr.AcceptedCompressors + } if callHdr.PreviousAttempts > 0 { hfLen++ } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 5ff83a7d7d74..e1e466698e34 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -553,6 +553,12 @@ type CallHdr struct { // outbound message. SendCompress string + // AcceptedCompressors overrides the grpc-accept-encoding header for this + // call. When nil, the transport advertises the default set of registered + // compressors. A non-nil pointer overrides that value (including the empty + // string to advertise none). + AcceptedCompressors *string + // Creds specifies credentials.PerRPCCredentials for a call. Creds credentials.PerRPCCredentials diff --git a/rpc_util.go b/rpc_util.go index 6b04c9e87357..32a5c0f55bda 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -33,6 +33,8 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding/proto" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" @@ -41,6 +43,10 @@ import ( "google.golang.org/grpc/status" ) +func init() { + internal.AcceptedCompressionNames = acceptedCompressionNames +} + // Compressor defines the interface gRPC uses to compress a message. // // Deprecated: use package encoding. @@ -151,16 +157,33 @@ func (d *gzipDecompressor) Type() string { // callInfo contains all related configuration and information about an RPC. type callInfo struct { - compressorName string - failFast bool - maxReceiveMessageSize *int - maxSendMessageSize *int - creds credentials.PerRPCCredentials - contentSubtype string - codec baseCodec - maxRetryRPCBufferSize int - onFinish []func(err error) - authority string + compressorName string + failFast bool + maxReceiveMessageSize *int + maxSendMessageSize *int + creds credentials.PerRPCCredentials + contentSubtype string + codec baseCodec + maxRetryRPCBufferSize int + onFinish []func(err error) + authority string + acceptedResponseCompressors *acceptedCompressionConfig +} + +type acceptedCompressionConfig struct { + headerValue string + allowed map[string]struct{} +} + +func (cfg *acceptedCompressionConfig) allows(name string) bool { + if cfg == nil { + return true + } + if name == "" || name == encoding.Identity { + return true + } + _, ok := cfg.allowed[name] + return ok } func defaultCallInfo() *callInfo { @@ -170,6 +193,35 @@ func defaultCallInfo() *callInfo { } } +func newAcceptedCompressionConfig(names []string) (*acceptedCompressionConfig, error) { + cfg := &acceptedCompressionConfig{ + allowed: make(map[string]struct{}, len(names)), + } + if len(names) == 0 { + return cfg, nil + } + var ordered []string + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" || name == encoding.Identity { + continue + } + if !grpcutil.IsCompressorNameRegistered(name) { + return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name) + } + if _, dup := cfg.allowed[name]; dup { + continue + } + cfg.allowed[name] = struct{}{} + ordered = append(ordered, name) + } + if len(ordered) == 0 { + return nil, status.Error(codes.InvalidArgument, "grpc: no valid compressor names provided") + } + cfg.headerValue = strings.Join(ordered, ",") + return cfg, nil +} + // CallOption configures a Call before it starts or extracts information from // a Call after it completes. type CallOption interface { @@ -471,6 +523,26 @@ func (o CompressorCallOption) before(c *callInfo) error { } func (o CompressorCallOption) after(*callInfo, *csAttempt) {} +func acceptedCompressionNames(names ...string) CallOption { + cp := append([]string(nil), names...) + return acceptedCompressionNamesCallOption{names: cp} +} + +type acceptedCompressionNamesCallOption struct { + names []string +} + +func (o acceptedCompressionNamesCallOption) before(c *callInfo) error { + cfg, err := newAcceptedCompressionConfig(o.names) + if err != nil { + return err + } + c.acceptedResponseCompressors = cfg + return nil +} + +func (acceptedCompressionNamesCallOption) after(*callInfo, *csAttempt) {} + // CallContentSubtype returns a CallOption that will set the content-subtype // for a call. For example, if content-subtype is "json", the Content-Type over // the wire will be "application/grpc+json". The content-subtype is converted @@ -821,7 +893,7 @@ func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time } } -func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status { +func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool, acceptedCfg *acceptedCompressionConfig) *status.Status { switch pf { case compressionNone: case compressionMade: @@ -834,6 +906,9 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool } return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) } + if !isServer && acceptedCfg != nil && !acceptedCfg.allows(recvCompress) { + return status.Newf(codes.FailedPrecondition, "grpc: peer compressed the response with %q which is not allowed by AcceptedCompressionNames", recvCompress) + } default: return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf) } @@ -857,7 +932,7 @@ func (p *payloadInfo) free() { // the buffer is no longer needed. // TODO: Refactor this function to reduce the number of arguments. // See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists -func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, +func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig, ) (out mem.BufferSlice, err error) { pf, compressed, err := p.recvMsg(maxReceiveMessageSize) if err != nil { @@ -866,7 +941,7 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM compressedLength := compressed.Len() - if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil { + if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer, acceptedCfg); st != nil { compressed.Free() return nil, st.Err() } @@ -941,8 +1016,8 @@ type recvCompressor interface { // For the two compressor parameters, both should not be set, but if they are, // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API? -func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error { - data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer) +func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig) error { + data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer, acceptedCfg) if err != nil { return err } diff --git a/rpc_util_test.go b/rpc_util_test.go index a5c5cb8b17e2..a9da704e7303 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -48,6 +48,85 @@ const ( decompressionErrorMsg = "invalid compression format" ) +func (s) TestNewAcceptedCompressionConfig(t *testing.T) { + tests := []struct { + name string + input []string + wantHeader string + wantAllowed map[string]struct{} + wantErr bool + }{ + { + name: "identity-only", + input: nil, + wantHeader: "", + wantAllowed: map[string]struct{}{}, + }, + { + name: "single valid", + input: []string{"gzip"}, + wantHeader: "gzip", + wantAllowed: map[string]struct{}{"gzip": {}}, + }, + { + name: "dedupe and trim", + input: []string{" gzip ", "gzip"}, + wantHeader: "gzip", + wantAllowed: map[string]struct{}{"gzip": {}}, + }, + { + name: "ignores identity", + input: []string{"identity", "gzip"}, + wantHeader: "gzip", + wantAllowed: map[string]struct{}{"gzip": {}}, + }, + { + name: "invalid compressor", + input: []string{"does-not-exist"}, + wantErr: true, + }, + { + name: "only whitespace", + input: []string{" ", "\t"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := newAcceptedCompressionConfig(tt.input) + if (err != nil) != tt.wantErr { + t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + if tt.wantErr { + return + } + if cfg.headerValue != tt.wantHeader { + t.Fatalf("headerValue = %q, want %q", cfg.headerValue, tt.wantHeader) + } + if diff := cmp.Diff(tt.wantAllowed, cfg.allowed); diff != "" { + t.Fatalf("allowed diff (-want +got): %v", diff) + } + }) + } +} + +func (s) TestCheckRecvPayloadHonorsAcceptedCompressors(t *testing.T) { + cfg, err := newAcceptedCompressionConfig([]string{"gzip"}) + if err != nil { + t.Fatalf("newAcceptedCompressionConfig returned error: %v", err) + } + + if st := checkRecvPayload(compressionMade, "gzip", true, false, cfg); st != nil { + t.Fatalf("checkRecvPayload returned error for allowed compressor: %v", st) + } + + st := checkRecvPayload(compressionMade, "snappy", true, false, cfg) + if st == nil || st.Code() != codes.FailedPrecondition { + t.Fatalf("checkRecvPayload = %v, want code %v", st, codes.FailedPrecondition) + } +} + type fullReader struct { data []byte } diff --git a/server.go b/server.go index ddd377341191..2099cff4be9d 100644 --- a/server.go +++ b/server.go @@ -1381,7 +1381,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt defer payInfo.free() } - d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true) + d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true, nil) if err != nil { if e := stream.WriteStatus(status.Convert(err)); e != nil { channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e) diff --git a/stream.go b/stream.go index ca87ff9776ef..a9feb0593d03 100644 --- a/stream.go +++ b/stream.go @@ -301,6 +301,9 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client DoneFunc: doneFunc, Authority: callInfo.authority, } + if cfg := callInfo.acceptedResponseCompressors; cfg != nil { + callHdr.AcceptedCompressors = &cfg.headerValue + } // Set our outgoing compression according to the UseCompressor CallOption, if // set. In that case, also find the compressor from the encoding package. @@ -1141,7 +1144,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { // Only initialize this state once per stream. a.decompressorSet = true } - if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil { + if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err != nil { if err == io.EOF { if statusErr := a.transportStream.Status().Err(); statusErr != nil { return statusErr @@ -1179,7 +1182,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { } // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF { + if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err == io.EOF { return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success } else if err != nil { return toRPCErr(err) @@ -1486,7 +1489,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Only initialize this state once per stream. as.decompressorSet = true } - if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil { + if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err != nil { if err == io.EOF { if statusErr := as.transportStream.Status().Err(); statusErr != nil { return statusErr @@ -1508,7 +1511,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF { + if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err == io.EOF { return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success } else if err != nil { return toRPCErr(err) @@ -1785,7 +1788,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { payInfo = &payloadInfo{} defer payInfo.free() } - if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil { + if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true, nil); err != nil { if err == io.EOF { if len(ss.binlogs) != 0 { chc := &binarylog.ClientHalfClose{} @@ -1829,7 +1832,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { } // Special handling for non-client-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF { + if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true, nil); err == io.EOF { return nil } else if err != nil { return err diff --git a/test/compressor_test.go b/test/compressor_test.go index dbdc06222220..fb80206b164e 100644 --- a/test/compressor_test.go +++ b/test/compressor_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/experimental" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -533,6 +534,30 @@ func (s) TestClientSupportedCompressors(t *testing.T) { } } +func (s) TestAcceptedCompressionNamesCallOption(t *testing.T) { + const want = "gzip" + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + md, _ := metadata.FromIncomingContext(ctx) + if got := md.Get("grpc-accept-encoding"); len(got) != 1 || got[0] != want { + t.Fatalf("unexpected grpc-accept-encoding header: %v", got) + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, experimental.AcceptedCompressionNames(want)); err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } +} + func (s) TestCompressorRegister(t *testing.T) { for _, e := range listTestEnv() { testCompressorRegister(t, e) From d6f69e203702035a0d5e16a0a820e0a8375e3356 Mon Sep 17 00:00:00 2001 From: Israel Blancas Date: Tue, 2 Dec 2025 13:25:02 +0100 Subject: [PATCH 2/5] Apply feedback requested in code review Signed-off-by: Israel Blancas --- experimental/experimental.go | 9 ++-- internal/experimental.go | 4 +- rpc_util.go | 71 +++++++++++--------------- rpc_util_test.go | 99 ++++++++++++++++++++++++------------ server.go | 2 +- stream.go | 26 +++++++--- test/compressor_test.go | 62 +++++++++++++++------- 7 files changed, 167 insertions(+), 106 deletions(-) diff --git a/experimental/experimental.go b/experimental/experimental.go index c8620cffe415..3ba948bab316 100644 --- a/experimental/experimental.go +++ b/experimental/experimental.go @@ -63,10 +63,11 @@ func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption { return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool) } -// AcceptedCompressionNames returns a CallOption that limits the values +// AcceptCompressors returns a CallOption that limits the values // advertised in the grpc-accept-encoding header for the provided RPC. The // supplied names must correspond to compressors registered via -// encoding.RegisterCompressor. Passing no names advertises identity only. -func AcceptedCompressionNames(names ...string) grpc.CallOption { - return internal.AcceptedCompressionNames.(func(...string) grpc.CallOption)(names...) +// encoding.RegisterCompressor. Passing no names advertises "identity" (no +// compression) only. +func AcceptCompressors(names ...string) grpc.CallOption { + return internal.AcceptCompressors.(func(...string) grpc.CallOption)(names...) } diff --git a/internal/experimental.go b/internal/experimental.go index 3482abacdc5e..c90cc51bdd2b 100644 --- a/internal/experimental.go +++ b/internal/experimental.go @@ -26,7 +26,7 @@ var ( // option to configure a shared buffer pool for a grpc.Server. BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption - // AcceptedCompressionNames is implemented by the grpc package and returns + // AcceptCompressors is implemented by the grpc package and returns // a call option that restricts the grpc-accept-encoding header for a call. - AcceptedCompressionNames any // func(...string) grpc.CallOption + AcceptCompressors any // func(...string) grpc.CallOption ) diff --git a/rpc_util.go b/rpc_util.go index 32a5c0f55bda..40c8a8b45d85 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -44,7 +44,7 @@ import ( ) func init() { - internal.AcceptedCompressionNames = acceptedCompressionNames + internal.AcceptCompressors = AcceptCompressors } // Compressor defines the interface gRPC uses to compress a message. @@ -167,23 +167,22 @@ type callInfo struct { maxRetryRPCBufferSize int onFinish []func(err error) authority string - acceptedResponseCompressors *acceptedCompressionConfig + acceptedResponseCompressors []string } -type acceptedCompressionConfig struct { - headerValue string - allowed map[string]struct{} -} - -func (cfg *acceptedCompressionConfig) allows(name string) bool { - if cfg == nil { +func acceptedCompressorAllows(allowed []string, name string) bool { + if allowed == nil { return true } if name == "" || name == encoding.Identity { return true } - _, ok := cfg.allowed[name] - return ok + for _, a := range allowed { + if a == name { + return true + } + } + return false } func defaultCallInfo() *callInfo { @@ -193,14 +192,12 @@ func defaultCallInfo() *callInfo { } } -func newAcceptedCompressionConfig(names []string) (*acceptedCompressionConfig, error) { - cfg := &acceptedCompressionConfig{ - allowed: make(map[string]struct{}, len(names)), - } +func newAcceptedCompressionConfig(names []string) ([]string, error) { if len(names) == 0 { - return cfg, nil + return nil, nil } - var ordered []string + var allowed []string + seen := make(map[string]struct{}, len(names)) for _, name := range names { name = strings.TrimSpace(name) if name == "" || name == encoding.Identity { @@ -209,17 +206,13 @@ func newAcceptedCompressionConfig(names []string) (*acceptedCompressionConfig, e if !grpcutil.IsCompressorNameRegistered(name) { return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name) } - if _, dup := cfg.allowed[name]; dup { + if _, dup := seen[name]; dup { continue } - cfg.allowed[name] = struct{}{} - ordered = append(ordered, name) + seen[name] = struct{}{} + allowed = append(allowed, name) } - if len(ordered) == 0 { - return nil, status.Error(codes.InvalidArgument, "grpc: no valid compressor names provided") - } - cfg.headerValue = strings.Join(ordered, ",") - return cfg, nil + return allowed, nil } // CallOption configures a Call before it starts or extracts information from @@ -523,25 +516,25 @@ func (o CompressorCallOption) before(c *callInfo) error { } func (o CompressorCallOption) after(*callInfo, *csAttempt) {} -func acceptedCompressionNames(names ...string) CallOption { +func AcceptCompressors(names ...string) CallOption { cp := append([]string(nil), names...) - return acceptedCompressionNamesCallOption{names: cp} + return AcceptCompressorsCallOption{names: cp} } -type acceptedCompressionNamesCallOption struct { +type AcceptCompressorsCallOption struct { names []string } -func (o acceptedCompressionNamesCallOption) before(c *callInfo) error { - cfg, err := newAcceptedCompressionConfig(o.names) +func (o AcceptCompressorsCallOption) before(c *callInfo) error { + allowed, err := newAcceptedCompressionConfig(o.names) if err != nil { return err } - c.acceptedResponseCompressors = cfg + c.acceptedResponseCompressors = allowed return nil } -func (acceptedCompressionNamesCallOption) after(*callInfo, *csAttempt) {} +func (AcceptCompressorsCallOption) after(*callInfo, *csAttempt) {} // CallContentSubtype returns a CallOption that will set the content-subtype // for a call. For example, if content-subtype is "json", the Content-Type over @@ -893,7 +886,7 @@ func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time } } -func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool, acceptedCfg *acceptedCompressionConfig) *status.Status { +func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status { switch pf { case compressionNone: case compressionMade: @@ -906,9 +899,6 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool } return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) } - if !isServer && acceptedCfg != nil && !acceptedCfg.allows(recvCompress) { - return status.Newf(codes.FailedPrecondition, "grpc: peer compressed the response with %q which is not allowed by AcceptedCompressionNames", recvCompress) - } default: return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf) } @@ -932,8 +922,7 @@ func (p *payloadInfo) free() { // the buffer is no longer needed. // TODO: Refactor this function to reduce the number of arguments. // See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists -func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig, -) (out mem.BufferSlice, err error) { +func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) (out mem.BufferSlice, err error) { pf, compressed, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return nil, err @@ -941,7 +930,7 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM compressedLength := compressed.Len() - if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer, acceptedCfg); st != nil { + if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil { compressed.Free() return nil, st.Err() } @@ -1016,8 +1005,8 @@ type recvCompressor interface { // For the two compressor parameters, both should not be set, but if they are, // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API? -func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig) error { - data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer, acceptedCfg) +func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error { + data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer) if err != nil { return err } diff --git a/rpc_util_test.go b/rpc_util_test.go index a9da704e7303..79628d1be1d1 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -48,37 +48,69 @@ const ( decompressionErrorMsg = "invalid compression format" ) +type testCompressorForRegistry struct { + name string +} + +func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) { + return &testWriteCloser{w}, nil +} + +func (c *testCompressorForRegistry) Decompress(r io.Reader) (io.Reader, error) { + return r, nil +} + +func (c *testCompressorForRegistry) Name() string { + return c.name +} + +type testWriteCloser struct { + io.Writer +} + +func (w *testWriteCloser) Close() error { + return nil +} + func (s) TestNewAcceptedCompressionConfig(t *testing.T) { + // Register a test compressor for multi-compressor tests + testCompressor := &testCompressorForRegistry{name: "test-compressor"} + encoding.RegisterCompressor(testCompressor) + defer func() { + // Unregister the test compressor + encoding.RegisterCompressor(&testCompressorForRegistry{name: "test-compressor"}) + }() + tests := []struct { name string input []string - wantHeader string - wantAllowed map[string]struct{} + wantAllowed []string wantErr bool }{ { name: "identity-only", input: nil, - wantHeader: "", - wantAllowed: map[string]struct{}{}, + wantAllowed: nil, }, { name: "single valid", input: []string{"gzip"}, - wantHeader: "gzip", - wantAllowed: map[string]struct{}{"gzip": {}}, + wantAllowed: []string{"gzip"}, }, { name: "dedupe and trim", input: []string{" gzip ", "gzip"}, - wantHeader: "gzip", - wantAllowed: map[string]struct{}{"gzip": {}}, + wantAllowed: []string{"gzip"}, }, { name: "ignores identity", input: []string{"identity", "gzip"}, - wantHeader: "gzip", - wantAllowed: map[string]struct{}{"gzip": {}}, + wantAllowed: []string{"gzip"}, + }, + { + name: "explicit identity only", + input: []string{"identity"}, + wantAllowed: nil, }, { name: "invalid compressor", @@ -86,47 +118,48 @@ func (s) TestNewAcceptedCompressionConfig(t *testing.T) { wantErr: true, }, { - name: "only whitespace", - input: []string{" ", "\t"}, + name: "only whitespace", + input: []string{" ", "\t"}, + wantAllowed: nil, + }, + { + name: "multiple valid compressors", + input: []string{"gzip", "test-compressor"}, + wantAllowed: []string{"gzip", "test-compressor"}, + }, + { + name: "multiple with identity and whitespace", + input: []string{"gzip", "identity", " test-compressor ", " "}, + wantAllowed: []string{"gzip", "test-compressor"}, + }, + { + name: "empty string in list", + input: []string{"gzip", "", "test-compressor"}, + wantAllowed: []string{"gzip", "test-compressor"}, + }, + { + name: "mixed valid and invalid", + input: []string{"gzip", "invalid-comp"}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg, err := newAcceptedCompressionConfig(tt.input) + allowed, err := newAcceptedCompressionConfig(tt.input) if (err != nil) != tt.wantErr { t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr) } if tt.wantErr { return } - if cfg.headerValue != tt.wantHeader { - t.Fatalf("headerValue = %q, want %q", cfg.headerValue, tt.wantHeader) - } - if diff := cmp.Diff(tt.wantAllowed, cfg.allowed); diff != "" { + if diff := cmp.Diff(tt.wantAllowed, allowed); diff != "" { t.Fatalf("allowed diff (-want +got): %v", diff) } }) } } -func (s) TestCheckRecvPayloadHonorsAcceptedCompressors(t *testing.T) { - cfg, err := newAcceptedCompressionConfig([]string{"gzip"}) - if err != nil { - t.Fatalf("newAcceptedCompressionConfig returned error: %v", err) - } - - if st := checkRecvPayload(compressionMade, "gzip", true, false, cfg); st != nil { - t.Fatalf("checkRecvPayload returned error for allowed compressor: %v", st) - } - - st := checkRecvPayload(compressionMade, "snappy", true, false, cfg) - if st == nil || st.Code() != codes.FailedPrecondition { - t.Fatalf("checkRecvPayload = %v, want code %v", st, codes.FailedPrecondition) - } -} - type fullReader struct { data []byte } diff --git a/server.go b/server.go index 2099cff4be9d..ddd377341191 100644 --- a/server.go +++ b/server.go @@ -1381,7 +1381,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt defer payInfo.free() } - d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true, nil) + d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true) if err != nil { if e := stream.WriteStatus(status.Convert(err)); e != nil { channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e) diff --git a/stream.go b/stream.go index a9feb0593d03..e9e7adb54164 100644 --- a/stream.go +++ b/stream.go @@ -25,6 +25,7 @@ import ( "math" rand "math/rand/v2" "strconv" + "strings" "sync" "time" @@ -301,8 +302,9 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client DoneFunc: doneFunc, Authority: callInfo.authority, } - if cfg := callInfo.acceptedResponseCompressors; cfg != nil { - callHdr.AcceptedCompressors = &cfg.headerValue + if allowed := callInfo.acceptedResponseCompressors; len(allowed) > 0 { + headerValue := strings.Join(allowed, ",") + callHdr.AcceptedCompressors = &headerValue } // Set our outgoing compression according to the UseCompressor CallOption, if @@ -1137,6 +1139,10 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { a.decompressorV0 = nil a.decompressorV1 = encoding.GetCompressor(ct) } + // Validate that the compression method is acceptable for this call. + if !acceptedCompressorAllows(cs.callInfo.acceptedResponseCompressors, ct) { + return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct) + } } else { // No compression is used; disable our decompressor. a.decompressorV0 = nil @@ -1144,7 +1150,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { // Only initialize this state once per stream. a.decompressorSet = true } - if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err != nil { + if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil { if err == io.EOF { if statusErr := a.transportStream.Status().Err(); statusErr != nil { return statusErr @@ -1182,7 +1188,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { } // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err == io.EOF { + if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF { return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success } else if err != nil { return toRPCErr(err) @@ -1482,6 +1488,10 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { as.decompressorV0 = nil as.decompressorV1 = encoding.GetCompressor(ct) } + // Validate that the compression method is acceptable for this call. + if !acceptedCompressorAllows(as.callInfo.acceptedResponseCompressors, ct) { + return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct) + } } else { // No compression is used; disable our decompressor. as.decompressorV0 = nil @@ -1489,7 +1499,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Only initialize this state once per stream. as.decompressorSet = true } - if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err != nil { + if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil { if err == io.EOF { if statusErr := as.transportStream.Status().Err(); statusErr != nil { return statusErr @@ -1511,7 +1521,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err == io.EOF { + if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF { return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success } else if err != nil { return toRPCErr(err) @@ -1788,7 +1798,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { payInfo = &payloadInfo{} defer payInfo.free() } - if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true, nil); err != nil { + if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil { if err == io.EOF { if len(ss.binlogs) != 0 { chc := &binarylog.ClientHalfClose{} @@ -1832,7 +1842,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { } // Special handling for non-client-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true, nil); err == io.EOF { + if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF { return nil } else if err != nil { return err diff --git a/test/compressor_test.go b/test/compressor_test.go index fb80206b164e..ebc42f2ede3b 100644 --- a/test/compressor_test.go +++ b/test/compressor_test.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/experimental" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -534,27 +535,54 @@ func (s) TestClientSupportedCompressors(t *testing.T) { } } -func (s) TestAcceptedCompressionNamesCallOption(t *testing.T) { - const want = "gzip" - ss := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { - md, _ := metadata.FromIncomingContext(ctx) - if got := md.Get("grpc-accept-encoding"); len(got) != 1 || got[0] != want { - t.Fatalf("unexpected grpc-accept-encoding header: %v", got) - } - return &testpb.Empty{}, nil +func (s) TestAcceptCompressorsCallOption(t *testing.T) { + tests := []struct { + name string + callOption grpc.CallOption + wantHeader string + }{ + { + name: "with AcceptCompressors", + callOption: experimental.AcceptCompressors("gzip"), + wantHeader: "gzip", + }, + { + name: "without AcceptCompressors uses default", + callOption: nil, + wantHeader: grpcutil.RegisteredCompressors(), }, } - if err := ss.Start(nil); err != nil { - t.Fatalf("failed to start server: %v", err) - } - defer ss.Stop() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + md, _ := metadata.FromIncomingContext(ctx) + header := md.Get("grpc-accept-encoding") - if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, experimental.AcceptedCompressionNames(want)); err != nil { - t.Fatalf("EmptyCall failed: %v", err) + if len(header) != 1 || header[0] != tt.wantHeader { + t.Errorf("unexpected grpc-accept-encoding header: got %v, want %v", header, tt.wantHeader) + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + opts := []grpc.CallOption{} + if tt.callOption != nil { + opts = append(opts, tt.callOption) + } + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, opts...); err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } + }) } } From b34a4f2390b7e3bf6dea25afbb71824a3253a2ce Mon Sep 17 00:00:00 2001 From: Israel Blancas Date: Tue, 2 Dec 2025 13:34:16 +0100 Subject: [PATCH 3/5] Fix ci Signed-off-by: Israel Blancas --- rpc_util.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/rpc_util.go b/rpc_util.go index 40c8a8b45d85..32dddc68ab42 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -516,11 +516,16 @@ func (o CompressorCallOption) before(c *callInfo) error { } func (o CompressorCallOption) after(*callInfo, *csAttempt) {} +// AcceptCompressors returns a CallOption that limits the compression algorithms +// advertised in the grpc-accept-encoding header for response messages. +// Compression algorithms not in the provided list will not be advertised, and +// responses compressed with non-listed algorithms will be rejected. func AcceptCompressors(names ...string) CallOption { cp := append([]string(nil), names...) return AcceptCompressorsCallOption{names: cp} } +// AcceptCompressorsCallOption is a CallOption that limits response compression. type AcceptCompressorsCallOption struct { names []string } From 219e100ee00ae908ef9db1db71aea989e2ac517d Mon Sep 17 00:00:00 2001 From: Israel Blancas Date: Tue, 9 Dec 2025 15:59:08 +0100 Subject: [PATCH 4/5] Apply changes requested in code review Signed-off-by: Israel Blancas --- rpc_util.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index 32dddc68ab42..6d180aff5921 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -522,15 +522,15 @@ func (o CompressorCallOption) after(*callInfo, *csAttempt) {} // responses compressed with non-listed algorithms will be rejected. func AcceptCompressors(names ...string) CallOption { cp := append([]string(nil), names...) - return AcceptCompressorsCallOption{names: cp} + return acceptCompressorsCallOption{names: cp} } -// AcceptCompressorsCallOption is a CallOption that limits response compression. -type AcceptCompressorsCallOption struct { +// acceptCompressorsCallOption is a CallOption that limits response compression. +type acceptCompressorsCallOption struct { names []string } -func (o AcceptCompressorsCallOption) before(c *callInfo) error { +func (o acceptCompressorsCallOption) before(c *callInfo) error { allowed, err := newAcceptedCompressionConfig(o.names) if err != nil { return err @@ -539,7 +539,7 @@ func (o AcceptCompressorsCallOption) before(c *callInfo) error { return nil } -func (AcceptCompressorsCallOption) after(*callInfo, *csAttempt) {} +func (acceptCompressorsCallOption) after(*callInfo, *csAttempt) {} // CallContentSubtype returns a CallOption that will set the content-subtype // for a call. For example, if content-subtype is "json", the Content-Type over From de477d3c07540f6bfdd96a80839e6cc0091e7d87 Mon Sep 17 00:00:00 2001 From: Israel Blancas Date: Wed, 10 Dec 2025 00:37:44 +0100 Subject: [PATCH 5/5] Unexport acceptCompressors Signed-off-by: Israel Blancas --- rpc_util.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rpc_util.go b/rpc_util.go index 6d180aff5921..8160f9430405 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -44,7 +44,7 @@ import ( ) func init() { - internal.AcceptCompressors = AcceptCompressors + internal.AcceptCompressors = acceptCompressors } // Compressor defines the interface gRPC uses to compress a message. @@ -516,11 +516,11 @@ func (o CompressorCallOption) before(c *callInfo) error { } func (o CompressorCallOption) after(*callInfo, *csAttempt) {} -// AcceptCompressors returns a CallOption that limits the compression algorithms +// acceptCompressors returns a CallOption that limits the compression algorithms // advertised in the grpc-accept-encoding header for response messages. // Compression algorithms not in the provided list will not be advertised, and // responses compressed with non-listed algorithms will be rejected. -func AcceptCompressors(names ...string) CallOption { +func acceptCompressors(names ...string) CallOption { cp := append([]string(nil), names...) return acceptCompressorsCallOption{names: cp} }