66 "JwtSuperuserConnection" ,
77]
88
9+ import logging
910import sys
1011import time
1112from abc import abstractmethod
12- from typing import Any , Callable , Optional , Sequence , Union
13+ from typing import Any , Callable , Optional , Sequence , Set , Tuple , Union
1314
1415import jwt
15- from requests import Session
16+ from requests import ConnectionError , Session
1617from requests_toolbelt import MultipartEncoder
1718
1819from arango .exceptions import JWTAuthError , ServerConnectionError
@@ -110,6 +111,48 @@ def prep_response(self, resp: Response, deserialize: bool = True) -> Response:
110111 resp .is_success = http_ok and resp .error_code is None
111112 return resp
112113
114+ def process_request (
115+ self , host_index : int , request : Request , auth : Optional [Tuple [str , str ]] = None
116+ ) -> Response :
117+ """Execute a request until a valid response has been returned.
118+
119+ :param host_index: The index of the first host to try
120+ :type host_index: int
121+ :param request: HTTP request.
122+ :type request: arango.request.Request
123+ :return: HTTP response.
124+ :rtype: arango.response.Response
125+ """
126+ tries = 0
127+ indexes_to_filter : Set [int ] = set ()
128+ while tries < self ._host_resolver .max_tries :
129+ try :
130+ resp = self ._http .send_request (
131+ session = self ._sessions [host_index ],
132+ method = request .method ,
133+ url = self ._url_prefixes [host_index ] + request .endpoint ,
134+ params = request .params ,
135+ data = self .normalize_data (request .data ),
136+ headers = request .headers ,
137+ auth = auth ,
138+ )
139+
140+ return self .prep_response (resp , request .deserialize )
141+ except ConnectionError :
142+ url = self ._url_prefixes [host_index ] + request .endpoint
143+ logging .debug (f"ConnectionError: { url } " )
144+
145+ if len (indexes_to_filter ) == self ._host_resolver .host_count - 1 :
146+ indexes_to_filter .clear ()
147+ indexes_to_filter .add (host_index )
148+
149+ host_index = self ._host_resolver .get_host_index (indexes_to_filter )
150+ tries += 1
151+
152+ raise ConnectionAbortedError (
153+ f"Can't connect to host(s) within limit ({ self ._host_resolver .max_tries } )"
154+ )
155+
113156 def prep_bulk_err_response (self , parent_response : Response , body : Json ) -> Response :
114157 """Build and return a bulk error response.
115158
@@ -227,16 +270,7 @@ def send_request(self, request: Request) -> Response:
227270 :rtype: arango.response.Response
228271 """
229272 host_index = self ._host_resolver .get_host_index ()
230- resp = self ._http .send_request (
231- session = self ._sessions [host_index ],
232- method = request .method ,
233- url = self ._url_prefixes [host_index ] + request .endpoint ,
234- params = request .params ,
235- data = self .normalize_data (request .data ),
236- headers = request .headers ,
237- auth = self ._auth ,
238- )
239- return self .prep_response (resp , request .deserialize )
273+ return self .process_request (host_index , request , auth = self ._auth )
240274
241275
242276class JwtConnection (BaseConnection ):
@@ -302,15 +336,7 @@ def send_request(self, request: Request) -> Response:
302336 if self ._auth_header is not None :
303337 request .headers ["Authorization" ] = self ._auth_header
304338
305- resp = self ._http .send_request (
306- session = self ._sessions [host_index ],
307- method = request .method ,
308- url = self ._url_prefixes [host_index ] + request .endpoint ,
309- params = request .params ,
310- data = self .normalize_data (request .data ),
311- headers = request .headers ,
312- )
313- resp = self .prep_response (resp , request .deserialize )
339+ resp = self .process_request (host_index , request )
314340
315341 # Refresh the token and retry on HTTP 401 and error code 11.
316342 if resp .error_code != 11 or resp .status_code != 401 :
@@ -325,15 +351,7 @@ def send_request(self, request: Request) -> Response:
325351 if self ._auth_header is not None :
326352 request .headers ["Authorization" ] = self ._auth_header
327353
328- resp = self ._http .send_request (
329- session = self ._sessions [host_index ],
330- method = request .method ,
331- url = self ._url_prefixes [host_index ] + request .endpoint ,
332- params = request .params ,
333- data = self .normalize_data (request .data ),
334- headers = request .headers ,
335- )
336- return self .prep_response (resp , request .deserialize )
354+ return self .process_request (host_index , request )
337355
338356 def refresh_token (self ) -> None :
339357 """Get a new JWT token for the current user (cannot be a superuser).
@@ -349,13 +367,7 @@ def refresh_token(self) -> None:
349367
350368 host_index = self ._host_resolver .get_host_index ()
351369
352- resp = self ._http .send_request (
353- session = self ._sessions [host_index ],
354- method = request .method ,
355- url = self ._url_prefixes [host_index ] + request .endpoint ,
356- data = self .normalize_data (request .data ),
357- )
358- resp = self .prep_response (resp )
370+ resp = self .process_request (host_index , request )
359371
360372 if not resp .is_success :
361373 raise JWTAuthError (resp , request )
@@ -429,12 +441,4 @@ def send_request(self, request: Request) -> Response:
429441 host_index = self ._host_resolver .get_host_index ()
430442 request .headers ["Authorization" ] = self ._auth_header
431443
432- resp = self ._http .send_request (
433- session = self ._sessions [host_index ],
434- method = request .method ,
435- url = self ._url_prefixes [host_index ] + request .endpoint ,
436- params = request .params ,
437- data = self .normalize_data (request .data ),
438- headers = request .headers ,
439- )
440- return self .prep_response (resp , request .deserialize )
444+ return self .process_request (host_index , request )
0 commit comments