Skip to content

Commit d6f69e2

Browse files
committed
Apply feedback requested in code review
Signed-off-by: Israel Blancas <iblancasa@gmail.com>
1 parent 7595d74 commit d6f69e2

File tree

7 files changed

+167
-106
lines changed

7 files changed

+167
-106
lines changed

experimental/experimental.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption {
6363
return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool)
6464
}
6565

66-
// AcceptedCompressionNames returns a CallOption that limits the values
66+
// AcceptCompressors returns a CallOption that limits the values
6767
// advertised in the grpc-accept-encoding header for the provided RPC. The
6868
// supplied names must correspond to compressors registered via
69-
// encoding.RegisterCompressor. Passing no names advertises identity only.
70-
func AcceptedCompressionNames(names ...string) grpc.CallOption {
71-
return internal.AcceptedCompressionNames.(func(...string) grpc.CallOption)(names...)
69+
// encoding.RegisterCompressor. Passing no names advertises "identity" (no
70+
// compression) only.
71+
func AcceptCompressors(names ...string) grpc.CallOption {
72+
return internal.AcceptCompressors.(func(...string) grpc.CallOption)(names...)
7273
}

internal/experimental.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ var (
2626
// option to configure a shared buffer pool for a grpc.Server.
2727
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
2828

29-
// AcceptedCompressionNames is implemented by the grpc package and returns
29+
// AcceptCompressors is implemented by the grpc package and returns
3030
// a call option that restricts the grpc-accept-encoding header for a call.
31-
AcceptedCompressionNames any // func(...string) grpc.CallOption
31+
AcceptCompressors any // func(...string) grpc.CallOption
3232
)

rpc_util.go

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import (
4444
)
4545

4646
func init() {
47-
internal.AcceptedCompressionNames = acceptedCompressionNames
47+
internal.AcceptCompressors = AcceptCompressors
4848
}
4949

5050
// Compressor defines the interface gRPC uses to compress a message.
@@ -167,23 +167,22 @@ type callInfo struct {
167167
maxRetryRPCBufferSize int
168168
onFinish []func(err error)
169169
authority string
170-
acceptedResponseCompressors *acceptedCompressionConfig
170+
acceptedResponseCompressors []string
171171
}
172172

173-
type acceptedCompressionConfig struct {
174-
headerValue string
175-
allowed map[string]struct{}
176-
}
177-
178-
func (cfg *acceptedCompressionConfig) allows(name string) bool {
179-
if cfg == nil {
173+
func acceptedCompressorAllows(allowed []string, name string) bool {
174+
if allowed == nil {
180175
return true
181176
}
182177
if name == "" || name == encoding.Identity {
183178
return true
184179
}
185-
_, ok := cfg.allowed[name]
186-
return ok
180+
for _, a := range allowed {
181+
if a == name {
182+
return true
183+
}
184+
}
185+
return false
187186
}
188187

189188
func defaultCallInfo() *callInfo {
@@ -193,14 +192,12 @@ func defaultCallInfo() *callInfo {
193192
}
194193
}
195194

196-
func newAcceptedCompressionConfig(names []string) (*acceptedCompressionConfig, error) {
197-
cfg := &acceptedCompressionConfig{
198-
allowed: make(map[string]struct{}, len(names)),
199-
}
195+
func newAcceptedCompressionConfig(names []string) ([]string, error) {
200196
if len(names) == 0 {
201-
return cfg, nil
197+
return nil, nil
202198
}
203-
var ordered []string
199+
var allowed []string
200+
seen := make(map[string]struct{}, len(names))
204201
for _, name := range names {
205202
name = strings.TrimSpace(name)
206203
if name == "" || name == encoding.Identity {
@@ -209,17 +206,13 @@ func newAcceptedCompressionConfig(names []string) (*acceptedCompressionConfig, e
209206
if !grpcutil.IsCompressorNameRegistered(name) {
210207
return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name)
211208
}
212-
if _, dup := cfg.allowed[name]; dup {
209+
if _, dup := seen[name]; dup {
213210
continue
214211
}
215-
cfg.allowed[name] = struct{}{}
216-
ordered = append(ordered, name)
212+
seen[name] = struct{}{}
213+
allowed = append(allowed, name)
217214
}
218-
if len(ordered) == 0 {
219-
return nil, status.Error(codes.InvalidArgument, "grpc: no valid compressor names provided")
220-
}
221-
cfg.headerValue = strings.Join(ordered, ",")
222-
return cfg, nil
215+
return allowed, nil
223216
}
224217

225218
// CallOption configures a Call before it starts or extracts information from
@@ -523,25 +516,25 @@ func (o CompressorCallOption) before(c *callInfo) error {
523516
}
524517
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}
525518

526-
func acceptedCompressionNames(names ...string) CallOption {
519+
func AcceptCompressors(names ...string) CallOption {
527520
cp := append([]string(nil), names...)
528-
return acceptedCompressionNamesCallOption{names: cp}
521+
return AcceptCompressorsCallOption{names: cp}
529522
}
530523

531-
type acceptedCompressionNamesCallOption struct {
524+
type AcceptCompressorsCallOption struct {
532525
names []string
533526
}
534527

535-
func (o acceptedCompressionNamesCallOption) before(c *callInfo) error {
536-
cfg, err := newAcceptedCompressionConfig(o.names)
528+
func (o AcceptCompressorsCallOption) before(c *callInfo) error {
529+
allowed, err := newAcceptedCompressionConfig(o.names)
537530
if err != nil {
538531
return err
539532
}
540-
c.acceptedResponseCompressors = cfg
533+
c.acceptedResponseCompressors = allowed
541534
return nil
542535
}
543536

544-
func (acceptedCompressionNamesCallOption) after(*callInfo, *csAttempt) {}
537+
func (AcceptCompressorsCallOption) after(*callInfo, *csAttempt) {}
545538

546539
// CallContentSubtype returns a CallOption that will set the content-subtype
547540
// for a call. For example, if content-subtype is "json", the Content-Type over
@@ -893,7 +886,7 @@ func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time
893886
}
894887
}
895888

896-
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool, acceptedCfg *acceptedCompressionConfig) *status.Status {
889+
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
897890
switch pf {
898891
case compressionNone:
899892
case compressionMade:
@@ -906,9 +899,6 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
906899
}
907900
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
908901
}
909-
if !isServer && acceptedCfg != nil && !acceptedCfg.allows(recvCompress) {
910-
return status.Newf(codes.FailedPrecondition, "grpc: peer compressed the response with %q which is not allowed by AcceptedCompressionNames", recvCompress)
911-
}
912902
default:
913903
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
914904
}
@@ -932,16 +922,15 @@ func (p *payloadInfo) free() {
932922
// the buffer is no longer needed.
933923
// TODO: Refactor this function to reduce the number of arguments.
934924
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
935-
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig,
936-
) (out mem.BufferSlice, err error) {
925+
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) (out mem.BufferSlice, err error) {
937926
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
938927
if err != nil {
939928
return nil, err
940929
}
941930

942931
compressedLength := compressed.Len()
943932

944-
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer, acceptedCfg); st != nil {
933+
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
945934
compressed.Free()
946935
return nil, st.Err()
947936
}
@@ -1016,8 +1005,8 @@ type recvCompressor interface {
10161005
// For the two compressor parameters, both should not be set, but if they are,
10171006
// dc takes precedence over compressor.
10181007
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
1019-
func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig) error {
1020-
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer, acceptedCfg)
1008+
func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
1009+
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
10211010
if err != nil {
10221011
return err
10231012
}

rpc_util_test.go

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -48,85 +48,118 @@ const (
4848
decompressionErrorMsg = "invalid compression format"
4949
)
5050

51+
type testCompressorForRegistry struct {
52+
name string
53+
}
54+
55+
func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) {
56+
return &testWriteCloser{w}, nil
57+
}
58+
59+
func (c *testCompressorForRegistry) Decompress(r io.Reader) (io.Reader, error) {
60+
return r, nil
61+
}
62+
63+
func (c *testCompressorForRegistry) Name() string {
64+
return c.name
65+
}
66+
67+
type testWriteCloser struct {
68+
io.Writer
69+
}
70+
71+
func (w *testWriteCloser) Close() error {
72+
return nil
73+
}
74+
5175
func (s) TestNewAcceptedCompressionConfig(t *testing.T) {
76+
// Register a test compressor for multi-compressor tests
77+
testCompressor := &testCompressorForRegistry{name: "test-compressor"}
78+
encoding.RegisterCompressor(testCompressor)
79+
defer func() {
80+
// Unregister the test compressor
81+
encoding.RegisterCompressor(&testCompressorForRegistry{name: "test-compressor"})
82+
}()
83+
5284
tests := []struct {
5385
name string
5486
input []string
55-
wantHeader string
56-
wantAllowed map[string]struct{}
87+
wantAllowed []string
5788
wantErr bool
5889
}{
5990
{
6091
name: "identity-only",
6192
input: nil,
62-
wantHeader: "",
63-
wantAllowed: map[string]struct{}{},
93+
wantAllowed: nil,
6494
},
6595
{
6696
name: "single valid",
6797
input: []string{"gzip"},
68-
wantHeader: "gzip",
69-
wantAllowed: map[string]struct{}{"gzip": {}},
98+
wantAllowed: []string{"gzip"},
7099
},
71100
{
72101
name: "dedupe and trim",
73102
input: []string{" gzip ", "gzip"},
74-
wantHeader: "gzip",
75-
wantAllowed: map[string]struct{}{"gzip": {}},
103+
wantAllowed: []string{"gzip"},
76104
},
77105
{
78106
name: "ignores identity",
79107
input: []string{"identity", "gzip"},
80-
wantHeader: "gzip",
81-
wantAllowed: map[string]struct{}{"gzip": {}},
108+
wantAllowed: []string{"gzip"},
109+
},
110+
{
111+
name: "explicit identity only",
112+
input: []string{"identity"},
113+
wantAllowed: nil,
82114
},
83115
{
84116
name: "invalid compressor",
85117
input: []string{"does-not-exist"},
86118
wantErr: true,
87119
},
88120
{
89-
name: "only whitespace",
90-
input: []string{" ", "\t"},
121+
name: "only whitespace",
122+
input: []string{" ", "\t"},
123+
wantAllowed: nil,
124+
},
125+
{
126+
name: "multiple valid compressors",
127+
input: []string{"gzip", "test-compressor"},
128+
wantAllowed: []string{"gzip", "test-compressor"},
129+
},
130+
{
131+
name: "multiple with identity and whitespace",
132+
input: []string{"gzip", "identity", " test-compressor ", " "},
133+
wantAllowed: []string{"gzip", "test-compressor"},
134+
},
135+
{
136+
name: "empty string in list",
137+
input: []string{"gzip", "", "test-compressor"},
138+
wantAllowed: []string{"gzip", "test-compressor"},
139+
},
140+
{
141+
name: "mixed valid and invalid",
142+
input: []string{"gzip", "invalid-comp"},
91143
wantErr: true,
92144
},
93145
}
94146

95147
for _, tt := range tests {
96148
t.Run(tt.name, func(t *testing.T) {
97-
cfg, err := newAcceptedCompressionConfig(tt.input)
149+
allowed, err := newAcceptedCompressionConfig(tt.input)
98150
if (err != nil) != tt.wantErr {
99151
t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr)
100152
}
101153
if tt.wantErr {
102154
return
103155
}
104-
if cfg.headerValue != tt.wantHeader {
105-
t.Fatalf("headerValue = %q, want %q", cfg.headerValue, tt.wantHeader)
106-
}
107-
if diff := cmp.Diff(tt.wantAllowed, cfg.allowed); diff != "" {
156+
if diff := cmp.Diff(tt.wantAllowed, allowed); diff != "" {
108157
t.Fatalf("allowed diff (-want +got): %v", diff)
109158
}
110159
})
111160
}
112161
}
113162

114-
func (s) TestCheckRecvPayloadHonorsAcceptedCompressors(t *testing.T) {
115-
cfg, err := newAcceptedCompressionConfig([]string{"gzip"})
116-
if err != nil {
117-
t.Fatalf("newAcceptedCompressionConfig returned error: %v", err)
118-
}
119-
120-
if st := checkRecvPayload(compressionMade, "gzip", true, false, cfg); st != nil {
121-
t.Fatalf("checkRecvPayload returned error for allowed compressor: %v", st)
122-
}
123-
124-
st := checkRecvPayload(compressionMade, "snappy", true, false, cfg)
125-
if st == nil || st.Code() != codes.FailedPrecondition {
126-
t.Fatalf("checkRecvPayload = %v, want code %v", st, codes.FailedPrecondition)
127-
}
128-
}
129-
130163
type fullReader struct {
131164
data []byte
132165
}

server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
13811381
defer payInfo.free()
13821382
}
13831383

1384-
d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true, nil)
1384+
d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
13851385
if err != nil {
13861386
if e := stream.WriteStatus(status.Convert(err)); e != nil {
13871387
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)

0 commit comments

Comments
 (0)