diff --git a/cmd/proxygenerator/interceptor.go b/cmd/proxygenerator/interceptor.go index af16d5e1..ba17e9eb 100644 --- a/cmd/proxygenerator/interceptor.go +++ b/cmd/proxygenerator/interceptor.go @@ -45,11 +45,15 @@ type VisitPayloadsContext struct { SinglePayloadRequired bool } +// PayloadsVisitor accepts a context and payloads array to be able to visit and optionally +// transform the payloads as a return value. +type PayloadsVisitor func(*VisitPayloadsContext, []*common.Payload) ([]*common.Payload, error) + // VisitPayloadsOptions configure visitor behaviour. type VisitPayloadsOptions struct { // Context is the same for every call of a visit, callers should not store it. This must never // return an empty set of payloads. - Visitor func(*VisitPayloadsContext, []*common.Payload) ([]*common.Payload, error) + Visitor PayloadsVisitor // Don't visit search attribute payloads. SkipSearchAttributes bool // Will be called for each Any encountered. If not set, the default is to recurse into the Any @@ -85,28 +89,42 @@ var failureTypes = []string{ {{ range $i, $name := .GrpcFailure }}{{ if $i }}, { func NewPayloadVisitorInterceptor(options PayloadVisitorInterceptorOptions) (grpc.UnaryClientInterceptor, error) { return func(ctx context.Context, method string, req, response interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if reqMsg, ok := req.(proto.Message); ok && options.Outbound != nil { - err := VisitPayloads(ctx, reqMsg, *options.Outbound) + outboundVisitorOptions := VisitPayloadsOptions{ + Visitor: preserveExternalPayloadsVisitor(options.Outbound.Visitor), + SkipSearchAttributes: options.Outbound.SkipSearchAttributes, + WellKnownAnyVisitor: options.Outbound.WellKnownAnyVisitor, + } + + err := VisitPayloads(ctx, reqMsg, outboundVisitorOptions) if err != nil { return err - } + } } err := invoker(ctx, method, req, response, cc, opts...) - if err != nil && options.Inbound != nil { - if s, ok := status.FromError(err); ok { - // user provided payloads can sometimes end up in the status details of - // gRPC errors, make sure to visit those as well - err = visitGrpcErrorPayload(ctx, err, s, options.Inbound) + if options.Inbound != nil { + inboundVisitorOptions := VisitPayloadsOptions{ + Visitor: preserveExternalPayloadsVisitor(options.Inbound.Visitor), + SkipSearchAttributes: options.Inbound.SkipSearchAttributes, + WellKnownAnyVisitor: options.Inbound.WellKnownAnyVisitor, } - } - if resMsg, ok := response.(proto.Message); ok && options.Inbound != nil { - if visitErr := VisitPayloads(ctx, resMsg, *options.Inbound); visitErr != nil { - // We are choosing visit error over RPC error in this basically-never-should-happen case - err = visitErr + if err != nil { + if s, ok := status.FromError(err); ok { + // user provided payloads can sometimes end up in the status details of + // gRPC errors, make sure to visit those as well + err = visitGrpcErrorPayload(ctx, err, s, &inboundVisitorOptions) + } + } + + if resMsg, ok := response.(proto.Message); ok { + if visitErr := VisitPayloads(ctx, resMsg, inboundVisitorOptions); visitErr != nil { + // We are choosing visit error over RPC error in this basically-never-should-happen case + err = visitErr + } } } - + return err }, nil } @@ -123,6 +141,43 @@ func visitGrpcErrorPayload(ctx context.Context, err error, s *status.Status, inb return status.ErrorProto(p) } +func preserveExternalPayloadsVisitor(visitor PayloadsVisitor) PayloadsVisitor { + return func(vpc *VisitPayloadsContext, payloads []*common.Payload) ([]*common.Payload, error) { + // Save a copy of external payload details + externalPayloadDetails := make([][]*common.Payload_ExternalPayloadDetails, len(payloads)) + for payloadIndex, payload := range payloads { + if payload != nil { + externalPayloadDetails[payloadIndex] = make([]*common.Payload_ExternalPayloadDetails, len(payload.ExternalPayloads)) + for detailsIndex, details := range payload.ExternalPayloads { + if details != nil { + externalPayloadDetails[payloadIndex][detailsIndex] = &common.Payload_ExternalPayloadDetails{ + SizeBytes: details.SizeBytes, + } + } + } + } + } + + newPayloads, err := visitor(vpc, payloads) + if err != nil { + return newPayloads, err + } + + if len(payloads) != len(newPayloads) { + return newPayloads, fmt.Errorf("expected payload count %d but received %d", len(payloads), len(newPayloads)) + } + + // Restore external payload details + for payloadIndex, payload := range newPayloads { + if payload != nil { + payload.ExternalPayloads = externalPayloadDetails[payloadIndex] + } + } + + return newPayloads, err + } +} + // VisitFailuresContext provides Failure context for visitor functions. type VisitFailuresContext struct { context.Context diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 7004ac54..d61e34ca 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -39,11 +39,15 @@ type VisitPayloadsContext struct { SinglePayloadRequired bool } +// PayloadsVisitor accepts a context and payloads array to be able to visit and optionally +// transform the payloads as a return value. +type PayloadsVisitor func(*VisitPayloadsContext, []*common.Payload) ([]*common.Payload, error) + // VisitPayloadsOptions configure visitor behaviour. type VisitPayloadsOptions struct { // Context is the same for every call of a visit, callers should not store it. This must never // return an empty set of payloads. - Visitor func(*VisitPayloadsContext, []*common.Payload) ([]*common.Payload, error) + Visitor PayloadsVisitor // Don't visit search attribute payloads. SkipSearchAttributes bool // Will be called for each Any encountered. If not set, the default is to recurse into the Any @@ -79,25 +83,39 @@ var failureTypes = []string{"temporal.api.errordetails.v1.QueryFailedFailure", " func NewPayloadVisitorInterceptor(options PayloadVisitorInterceptorOptions) (grpc.UnaryClientInterceptor, error) { return func(ctx context.Context, method string, req, response interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if reqMsg, ok := req.(proto.Message); ok && options.Outbound != nil { - err := VisitPayloads(ctx, reqMsg, *options.Outbound) + outboundVisitorOptions := VisitPayloadsOptions{ + Visitor: preserveExternalPayloadsVisitor(options.Outbound.Visitor), + SkipSearchAttributes: options.Outbound.SkipSearchAttributes, + WellKnownAnyVisitor: options.Outbound.WellKnownAnyVisitor, + } + + err := VisitPayloads(ctx, reqMsg, outboundVisitorOptions) if err != nil { return err } } err := invoker(ctx, method, req, response, cc, opts...) - if err != nil && options.Inbound != nil { - if s, ok := status.FromError(err); ok { - // user provided payloads can sometimes end up in the status details of - // gRPC errors, make sure to visit those as well - err = visitGrpcErrorPayload(ctx, err, s, options.Inbound) + if options.Inbound != nil { + inboundVisitorOptions := VisitPayloadsOptions{ + Visitor: preserveExternalPayloadsVisitor(options.Inbound.Visitor), + SkipSearchAttributes: options.Inbound.SkipSearchAttributes, + WellKnownAnyVisitor: options.Inbound.WellKnownAnyVisitor, } - } - if resMsg, ok := response.(proto.Message); ok && options.Inbound != nil { - if visitErr := VisitPayloads(ctx, resMsg, *options.Inbound); visitErr != nil { - // We are choosing visit error over RPC error in this basically-never-should-happen case - err = visitErr + if err != nil { + if s, ok := status.FromError(err); ok { + // user provided payloads can sometimes end up in the status details of + // gRPC errors, make sure to visit those as well + err = visitGrpcErrorPayload(ctx, err, s, &inboundVisitorOptions) + } + } + + if resMsg, ok := response.(proto.Message); ok { + if visitErr := VisitPayloads(ctx, resMsg, inboundVisitorOptions); visitErr != nil { + // We are choosing visit error over RPC error in this basically-never-should-happen case + err = visitErr + } } } @@ -117,6 +135,43 @@ func visitGrpcErrorPayload(ctx context.Context, err error, s *status.Status, inb return status.ErrorProto(p) } +func preserveExternalPayloadsVisitor(visitor PayloadsVisitor) PayloadsVisitor { + return func(vpc *VisitPayloadsContext, payloads []*common.Payload) ([]*common.Payload, error) { + // Save a copy of external payload details + externalPayloadDetails := make([][]*common.Payload_ExternalPayloadDetails, len(payloads)) + for payloadIndex, payload := range payloads { + if payload != nil { + externalPayloadDetails[payloadIndex] = make([]*common.Payload_ExternalPayloadDetails, len(payload.ExternalPayloads)) + for detailsIndex, details := range payload.ExternalPayloads { + if details != nil { + externalPayloadDetails[payloadIndex][detailsIndex] = &common.Payload_ExternalPayloadDetails{ + SizeBytes: details.SizeBytes, + } + } + } + } + } + + newPayloads, err := visitor(vpc, payloads) + if err != nil { + return newPayloads, err + } + + if len(payloads) != len(newPayloads) { + return newPayloads, fmt.Errorf("expected payload count %d but received %d", len(payloads), len(newPayloads)) + } + + // Restore external payload details + for payloadIndex, payload := range newPayloads { + if payload != nil { + payload.ExternalPayloads = externalPayloadDetails[payloadIndex] + } + } + + return newPayloads, err + } +} + // VisitFailuresContext provides Failure context for visitor functions. type VisitFailuresContext struct { context.Context diff --git a/proxy/interceptor_test.go b/proxy/interceptor_test.go index f7950559..511e4e94 100644 --- a/proxy/interceptor_test.go +++ b/proxy/interceptor_test.go @@ -53,6 +53,14 @@ func inputPayload() *common.Payload { "encoding": []byte("plain/json"), }, Data: []byte("test"), + ExternalPayloads: []*common.Payload_ExternalPayloadDetails{ + { + SizeBytes: 2097152, // 2 MiB + }, + { + SizeBytes: 1024, // 1 KiB + }, + }, } } @@ -421,6 +429,159 @@ func TestClientInterceptor(t *testing.T) { require.True(proto.Equal(inputs.Payloads[0], inboundPayload)) } +func TestClientInterceptorDifferentPayloadsCount(t *testing.T) { + require := require.New(t) + + server, err := startTestGRPCServer() + require.NoError(err) + + interceptor, err := NewPayloadVisitorInterceptor( + PayloadVisitorInterceptorOptions{ + Outbound: &VisitPayloadsOptions{ + Visitor: func(vpc *VisitPayloadsContext, payloads []*common.Payload) ([]*common.Payload, error) { + // Return empty array of payloads to validate length checks (interceptor expects same payload count) + return []*common.Payload{}, nil + }, + }, + Inbound: &VisitPayloadsOptions{ + Visitor: func(vpc *VisitPayloadsContext, payloads []*common.Payload) ([]*common.Payload, error) { + // Return array of 2 payloads to validate length checks (interceptor expects same payload count) + return make([]*common.Payload, 2), nil + }, + }, + }, + ) + require.NoError(err) + + c, err := grpc.Dial( + server.addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithChainUnaryInterceptor(interceptor), + ) + require.NoError(err) + + client := workflowservice.NewWorkflowServiceClient(c) + + _, err = client.StartWorkflowExecution( + context.Background(), + &workflowservice.StartWorkflowExecutionRequest{ + Input: inputPayloads(), + }, + ) + require.ErrorContains(err, "expected payload count 1 but received 0") + + _, err = client.PollActivityTaskQueue( + context.Background(), + &workflowservice.PollActivityTaskQueueRequest{}, + ) + require.ErrorContains(err, "expected payload count 1 but received 2") +} + +func TestClientInterceptorExternalPayloadsPreserved(t *testing.T) { + require := require.New(t) + + server, err := startTestGRPCServer() + require.NoError(err) + + mutatingInterceptor, err := NewPayloadVisitorInterceptor( + PayloadVisitorInterceptorOptions{ + Outbound: &VisitPayloadsOptions{ + Visitor: func(vpc *VisitPayloadsContext, payloads []*common.Payload) ([]*common.Payload, error) { + // Mutate external payloads to attempt to cause mutations that should be corrected by the interceptor + payloads[0].ExternalPayloads[0].SizeBytes = 1337 + payloads[0].ExternalPayloads = nil + return payloads, nil + }, + }, + Inbound: &VisitPayloadsOptions{ + Visitor: func(vpc *VisitPayloadsContext, payloads []*common.Payload) ([]*common.Payload, error) { + // Mutate external payloads to attempt to cause mutations that should be corrected by the interceptor + payloads[0].ExternalPayloads[0].SizeBytes = 7331 + payloads[0].ExternalPayloads = nil + return payloads, nil + }, + }, + }, + ) + require.NoError(err) + + var externalPayload1SizeBytesExpected int64 + var externalPayload2SizeBytesExpected int64 + + validatingInterceptor, err := newChainedPayloadVisitorInterceptor( + mutatingInterceptor, + PayloadVisitorInterceptorOptions{ + Outbound: &VisitPayloadsOptions{ + Visitor: func(vpc *VisitPayloadsContext, payloads []*common.Payload) ([]*common.Payload, error) { + require.Equal(1, len(payloads)) + require.Equal(2, len(payloads[0].ExternalPayloads)) + externalPayload1SizeBytesExpected = payloads[0].ExternalPayloads[0].SizeBytes + externalPayload2SizeBytesExpected = payloads[0].ExternalPayloads[1].SizeBytes + return payloads, nil + }, + }, + Inbound: &VisitPayloadsOptions{ + Visitor: func(vpc *VisitPayloadsContext, payloads []*common.Payload) ([]*common.Payload, error) { + require.Equal(1, len(payloads)) + require.Equal(2, len(payloads[0].ExternalPayloads)) + require.Equal(externalPayload1SizeBytesExpected, payloads[0].ExternalPayloads[0].SizeBytes) + require.Equal(externalPayload2SizeBytesExpected, payloads[0].ExternalPayloads[1].SizeBytes) + return payloads, nil + }, + }, + }, + ) + require.NoError(err) + + c, err := grpc.Dial( + server.addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithChainUnaryInterceptor(validatingInterceptor), + ) + require.NoError(err) + + client := workflowservice.NewWorkflowServiceClient(c) + + _, err = client.StartWorkflowExecution( + context.Background(), + &workflowservice.StartWorkflowExecutionRequest{ + Input: inputPayloads(), + }, + ) + require.NoError(err) + + _, err = client.PollActivityTaskQueue( + context.Background(), + &workflowservice.PollActivityTaskQueueRequest{}, + ) + require.NoError(err) +} + +func newChainedPayloadVisitorInterceptor(interceptor grpc.UnaryClientInterceptor, options PayloadVisitorInterceptorOptions) (grpc.UnaryClientInterceptor, error) { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if reqMsg, ok := req.(proto.Message); ok && options.Outbound != nil { + err := VisitPayloads(ctx, reqMsg, *options.Outbound) + if err != nil { + return err + } + } + + err := interceptor(ctx, method, req, reply, cc, invoker, opts...) + if err != nil { + return err + } + + if resMsg, ok := reply.(proto.Message); ok && options.Inbound != nil { + err := VisitPayloads(ctx, resMsg, *options.Inbound) + if err != nil { + return err + } + } + + return nil + }, nil +} + func TestClientInterceptorGrpcFailures(t *testing.T) { require := require.New(t) @@ -485,7 +646,18 @@ func TestClientInterceptorGrpcFailures(t *testing.T) { err = multiOpFailure.Statuses[0].Details[0].UnmarshalTo(payloads) require.NoError(err) - newPayload := &common.Payload{Data: []byte("new-val")} + newPayload := &common.Payload{ + Data: []byte("new-val"), + ExternalPayloads: []*common.Payload_ExternalPayloadDetails{ + { + SizeBytes: 2097152, // 2 MiB + }, + { + SizeBytes: 1024, // 1 KiB + }, + }, + } + require.True(proto.Equal(payloads.Payloads[0], newPayload)) }