22import functools
33import logging
44import sys
5+ from rx import Observable
56
67from six import string_types
78from promise import Promise , promise_for_dict , is_thenable
1516 GraphQLSchema , GraphQLUnionType )
1617from .base import (ExecutionContext , ExecutionResult , ResolveInfo ,
1718 collect_fields , default_resolve_fn , get_field_def ,
18- get_operation_root_type )
19+ get_operation_root_type , SubscriberExecutionContext )
1920from .executors .sync import SyncExecutor
2021from .middleware import MiddlewareManager
2122
2223logger = logging .getLogger (__name__ )
2324
2425
26+ def subscribe (* args , ** kwargs ):
27+ allow_subscriptions = kwargs .pop ('allow_subscriptions' , True )
28+ return execute (* args , allow_subscriptions = allow_subscriptions , ** kwargs )
29+
30+
2531def execute (schema , document_ast , root_value = None , context_value = None ,
2632 variable_values = None , operation_name = None , executor = None ,
27- return_promise = False , middleware = None ):
33+ return_promise = False , middleware = None , allow_subscriptions = False ):
2834 assert schema , 'Must provide schema'
2935 assert isinstance (schema , GraphQLSchema ), (
3036 'Schema must be an instance of GraphQLSchema. Also ensure that there are ' +
@@ -50,7 +56,8 @@ def execute(schema, document_ast, root_value=None, context_value=None,
5056 variable_values ,
5157 operation_name ,
5258 executor ,
53- middleware
59+ middleware ,
60+ allow_subscriptions
5461 )
5562
5663 def executor (v ):
@@ -61,6 +68,9 @@ def on_rejected(error):
6168 return None
6269
6370 def on_resolve (data ):
71+ if isinstance (data , Observable ):
72+ return data
73+
6474 if not context .errors :
6575 return ExecutionResult (data = data )
6676 return ExecutionResult (data = data , errors = context .errors )
@@ -88,6 +98,15 @@ def execute_operation(exe_context, operation, root_value):
8898 if operation .operation == 'mutation' :
8999 return execute_fields_serially (exe_context , type , root_value , fields )
90100
101+ if operation .operation == 'subscription' :
102+ if not exe_context .allow_subscriptions :
103+ raise Exception (
104+ "Subscriptions are not allowed. "
105+ "You will need to either use the subscribe function "
106+ "or pass allow_subscriptions=True"
107+ )
108+ return subscribe_fields (exe_context , type , root_value , fields )
109+
91110 return execute_fields (exe_context , type , root_value , fields )
92111
93112
@@ -140,6 +159,44 @@ def execute_fields(exe_context, parent_type, source_value, fields):
140159 return promise_for_dict (final_results )
141160
142161
162+ def subscribe_fields (exe_context , parent_type , source_value , fields ):
163+ exe_context = SubscriberExecutionContext (exe_context )
164+
165+ def on_error (error ):
166+ exe_context .report_error (error )
167+
168+ def map_result (data ):
169+ if exe_context .errors :
170+ result = ExecutionResult (data = data , errors = exe_context .errors )
171+ else :
172+ result = ExecutionResult (data = data )
173+ exe_context .reset ()
174+ return result
175+
176+ observables = []
177+
178+ # assert len(fields) == 1, "Can only subscribe one element at a time."
179+
180+ for response_name , field_asts in fields .items ():
181+
182+ result = subscribe_field (exe_context , parent_type ,
183+ source_value , field_asts )
184+ if result is Undefined :
185+ continue
186+
187+ def catch_error (error ):
188+ exe_context .errors .append (error )
189+ return Observable .just (None )
190+
191+ # Map observable results
192+ observable = result .catch_exception (catch_error ).map (
193+ lambda data : map_result ({response_name : data }))
194+ return observable
195+ observables .append (observable )
196+
197+ return Observable .merge (observables )
198+
199+
143200def resolve_field (exe_context , parent_type , source , field_asts ):
144201 field_ast = field_asts [0 ]
145202 field_name = field_ast .name .value
@@ -191,6 +248,64 @@ def resolve_field(exe_context, parent_type, source, field_asts):
191248 )
192249
193250
251+ def subscribe_field (exe_context , parent_type , source , field_asts ):
252+ field_ast = field_asts [0 ]
253+ field_name = field_ast .name .value
254+
255+ field_def = get_field_def (exe_context .schema , parent_type , field_name )
256+ if not field_def :
257+ return Undefined
258+
259+ return_type = field_def .type
260+ resolve_fn = field_def .resolver or default_resolve_fn
261+
262+ # We wrap the resolve_fn from the middleware
263+ resolve_fn_middleware = exe_context .get_field_resolver (resolve_fn )
264+
265+ # Build a dict of arguments from the field.arguments AST, using the variables scope to
266+ # fulfill any variable references.
267+ args = exe_context .get_argument_values (field_def , field_ast )
268+
269+ # The resolve function's optional third argument is a context value that
270+ # is provided to every resolve function within an execution. It is commonly
271+ # used to represent an authenticated user, or request-specific caches.
272+ context = exe_context .context_value
273+
274+ # The resolve function's optional third argument is a collection of
275+ # information about the current execution state.
276+ info = ResolveInfo (
277+ field_name ,
278+ field_asts ,
279+ return_type ,
280+ parent_type ,
281+ schema = exe_context .schema ,
282+ fragments = exe_context .fragments ,
283+ root_value = exe_context .root_value ,
284+ operation = exe_context .operation ,
285+ variable_values = exe_context .variable_values ,
286+ context = context
287+ )
288+
289+ executor = exe_context .executor
290+ result = resolve_or_error (resolve_fn_middleware ,
291+ source , info , args , executor )
292+
293+ if isinstance (result , Exception ):
294+ raise result
295+
296+ if not isinstance (result , Observable ):
297+ raise GraphQLError (
298+ 'Subscription must return Async Iterable or Observable. Received: {}' .format (repr (result )))
299+
300+ return result .map (functools .partial (
301+ complete_value_catching_error ,
302+ exe_context ,
303+ return_type ,
304+ field_asts ,
305+ info ,
306+ ))
307+
308+
194309def resolve_or_error (resolve_fn , source , info , args , executor ):
195310 try :
196311 return executor .execute (resolve_fn , source , info , ** args )
0 commit comments