Skip to content

Commit f3d358a

Browse files
subscription: check all fields until find method
1 parent 4c99d57 commit f3d358a

File tree

1 file changed

+83
-64
lines changed

1 file changed

+83
-64
lines changed

internal/exec/subscribe.go

Lines changed: 83 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/golangid/graphql-go/internal/exec/resolvable"
1414
"github.com/golangid/graphql-go/internal/exec/selected"
1515
"github.com/golangid/graphql-go/internal/query"
16+
"github.com/golangid/graphql-go/internal/schema"
1617
)
1718

1819
type Response struct {
@@ -22,75 +23,28 @@ type Response struct {
2223

2324
func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query.Operation) <-chan *Response {
2425
var result reflect.Value
25-
var f *fieldToExec
2626
var err *errors.QueryError
27-
func() {
28-
defer r.handlePanic(ctx)
27+
sels := selected.ApplyOperation(&r.Request, s, op)
2928

30-
sels := selected.ApplyOperation(&r.Request, s, op)
31-
var fields []*fieldToExec
32-
collectFieldsToResolve(sels, s, s.ResolverSubscription, &fields, make(map[string]*fieldToExec))
33-
34-
// TODO: move this check into validation.Validate
35-
if len(fields) != 1 {
36-
err = errors.Errorf("%s", "can subscribe to at most one subscription at a time")
37-
return
38-
}
39-
f = fields[0]
40-
41-
// TODO: add check all childs
42-
func() {
43-
tmpF := *f
44-
defer func() {
45-
if r := recover(); r != nil {
46-
f = &tmpF
47-
}
48-
}()
49-
50-
if f.resolver.Kind() == reflect.Ptr {
51-
f.resolver = f.resolver.Elem()
52-
}
53-
f.resolver = f.resolver.FieldByIndex(f.field.FieldIndex)
54-
55-
var fieldsDeep []*fieldToExec
56-
collectFieldsToResolve(f.sels, s, f.resolver, &fieldsDeep, make(map[string]*fieldToExec))
57-
if len(fieldsDeep) == 1 {
58-
f = fieldsDeep[0]
59-
} else {
60-
f = &tmpF
61-
}
62-
}()
63-
64-
var in []reflect.Value
65-
if f.field.HasContext {
66-
in = append(in, reflect.ValueOf(ctx))
67-
}
68-
if f.field.ArgsPacker != nil {
69-
in = append(in, f.field.PackedArgs)
70-
}
71-
callOut := f.resolver.Method(f.field.MethodIndex).Call(in)
72-
result = callOut[0]
73-
74-
if f.field.HasError && !callOut[1].IsNil() {
75-
switch resolverErr := callOut[1].Interface().(type) {
76-
case *errors.QueryError:
77-
err = resolverErr
78-
case error:
79-
err = errors.Errorf("%s", resolverErr)
80-
err.ResolverError = resolverErr
81-
default:
82-
panic(fmt.Errorf("can only deal with *QueryError and error types, got %T", resolverErr))
83-
}
84-
}
85-
}()
29+
f := r.subscriptionSearchFieldMethod(ctx, sels, nil, s, s.ResolverSubscription)
30+
if f == nil {
31+
return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}})
32+
}
8633

87-
// Handles the case where the locally executed func above panicked
88-
if len(r.Request.Errs) > 0 {
89-
return sendAndReturnClosed(&Response{Errors: r.Request.Errs})
34+
var in []reflect.Value
35+
if f.field.HasContext {
36+
in = append(in, reflect.ValueOf(ctx))
37+
}
38+
if f.field.ArgsPacker != nil {
39+
in = append(in, f.field.PackedArgs)
9040
}
41+
callOut := f.resolver.Method(f.field.MethodIndex).Call(in)
42+
result = callOut[0]
9143

92-
if f == nil {
93-
return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}})
44+
if f.field.HasError && !callOut[1].IsNil() {
45+
resolverErr := callOut[1].Interface().(error)
46+
err = errors.Errorf("%s", resolverErr)
47+
err.ResolverError = resolverErr
9448
}
9549

9650
if err != nil {
@@ -201,3 +155,68 @@ func sendAndReturnClosed(resp *Response) chan *Response {
201155
close(c)
202156
return c
203157
}
158+
159+
func (r *Request) subscriptionSearchFieldMethod(ctx context.Context, sels []selected.Selection, path *pathSegment, s *resolvable.Schema, resolver reflect.Value) (foundField *fieldToExec) {
160+
161+
var collectFields func(sels []selected.Selection, path *pathSegment, s *resolvable.Schema, resolver reflect.Value)
162+
var execField func(s *resolvable.Schema, f *fieldToExec, path *pathSegment)
163+
var execSelectionSet func(sels []selected.Selection, typ common.Type, path *pathSegment, s *resolvable.Schema, resolver reflect.Value)
164+
165+
collectFields = func(sels []selected.Selection, path *pathSegment, s *resolvable.Schema, resolver reflect.Value) {
166+
var fields []*fieldToExec
167+
collectFieldsToResolve(sels, s, resolver, &fields, make(map[string]*fieldToExec))
168+
169+
for _, f := range fields {
170+
execField(s, f, &pathSegment{path, f.field.Alias})
171+
if f.field.UseMethodResolver() && !f.field.FixedResult.IsValid() {
172+
foundField = f
173+
return
174+
}
175+
}
176+
}
177+
178+
execField = func(s *resolvable.Schema, f *fieldToExec, path *pathSegment) {
179+
var result reflect.Value
180+
181+
if f.field.FixedResult.IsValid() {
182+
result = f.field.FixedResult
183+
return
184+
}
185+
186+
res := f.resolver
187+
if !f.field.UseMethodResolver() {
188+
res = unwrapPtr(res)
189+
result = res.FieldByIndex(f.field.FieldIndex)
190+
}
191+
192+
execSelectionSet(f.sels, f.field.Type, path, s, result)
193+
}
194+
195+
execSelectionSet = func(sels []selected.Selection, typ common.Type, path *pathSegment, s *resolvable.Schema, resolver reflect.Value) {
196+
t, nonNull := unwrapNonNull(typ)
197+
switch t := t.(type) {
198+
case *schema.Object, *schema.Interface, *schema.Union:
199+
if resolver.Kind() == reflect.Invalid || ((resolver.Kind() == reflect.Ptr || resolver.Kind() == reflect.Interface) && resolver.IsNil()) {
200+
if nonNull {
201+
err := errors.Errorf("graphql: got nil for non-null %q", t)
202+
err.Path = path.toSlice()
203+
r.AddError(err)
204+
}
205+
return
206+
}
207+
208+
collectFields(sels, path, s, resolver)
209+
return
210+
}
211+
}
212+
213+
collectFields(sels, path, s, resolver)
214+
return
215+
}
216+
217+
func unwrapPtr(v reflect.Value) reflect.Value {
218+
if v.Kind() == reflect.Ptr {
219+
return v.Elem()
220+
}
221+
return v
222+
}

0 commit comments

Comments
 (0)