diff --git a/google/ads/googleads/interceptors/exception_interceptor.py b/google/ads/googleads/interceptors/exception_interceptor.py index 5fd3c21cd..3d0fc6d26 100644 --- a/google/ads/googleads/interceptors/exception_interceptor.py +++ b/google/ads/googleads/interceptors/exception_interceptor.py @@ -159,8 +159,8 @@ def __await__(self): return response else: return util.convert_proto_plus_to_protobuf(response) - except grpc.RpcError: - yield from self._interceptor._handle_grpc_failure_async(self._call).__await__() + except grpc.RpcError as exception: + yield from self._interceptor._handle_grpc_failure_async(self._call, exception).__await__() raise def cancel(self): @@ -209,8 +209,8 @@ async def _wrapped_aiter(): yield response else: yield util.convert_proto_plus_to_protobuf(response) - except grpc.RpcError: - await self._interceptor._handle_grpc_failure_async(self._call) + except grpc.RpcError as exception: + await self._interceptor._handle_grpc_failure_async(self._call, exception) raise return _wrapped_aiter() @@ -252,19 +252,30 @@ async def read(self): return response else: return util.convert_proto_plus_to_protobuf(response) - except grpc.RpcError: - await self._interceptor._handle_grpc_failure_async(self._call) + except grpc.RpcError as exception: + await self._interceptor._handle_grpc_failure_async(self._call, exception) raise -class _AsyncExceptionInterceptor( - ExceptionInterceptor, -): +class _AsyncExceptionInterceptor(Interceptor, grpc.aio.UnaryUnaryClientInterceptor, grpc.aio.UnaryStreamClientInterceptor,): """An interceptor that wraps rpc exceptions.""" - async def _handle_grpc_failure_async(self, response: grpc.aio.Call): + def __init__(self, api_version: str, use_proto_plus: bool = False): + """Initializes the ExceptionInterceptor. + + Args: + api_version: a str of the API version of the request. + use_proto_plus: a boolean of whether returned messages should be + proto_plus or protobuf. + """ + super().__init__(api_version) + self._api_version = api_version + self._use_proto_plus = use_proto_plus + + async def _handle_grpc_failure_async(self, response: grpc.aio.Call, exception: grpc.RpcError): """Async version of _handle_grpc_failure.""" - status_code = response.code() - response_exception = response.exception() + status_code = await response.code() + + response_exception = exception # We need to access _RETRY_STATUS_CODES from interceptor module? # It's imported in interceptor.py but not exposed in ExceptionInterceptor? @@ -300,7 +311,7 @@ async def _handle_grpc_failure_async(self, response: grpc.aio.Call): raise response_exception # If we got here, maybe no exception? But we only call this on error. - raise response.exception() + raise response_exception async def intercept_unary_unary( self, diff --git a/google/ads/googleads/interceptors/logging_interceptor.py b/google/ads/googleads/interceptors/logging_interceptor.py index c57108b48..6e5c47783 100644 --- a/google/ads/googleads/interceptors/logging_interceptor.py +++ b/google/ads/googleads/interceptors/logging_interceptor.py @@ -468,10 +468,13 @@ async def _log_request_async( # Since this is called in on_done, it is done. try: - # This might raise if cancelled? - exception = response.exception() - except Exception: - exception = None + if hasattr(response, "exception"): + exception = response.code() + else: + await response.code() + exception = None + except Exception as ex: + exception = ex if exception: # We need to adapt exception logging for async exception? diff --git a/tests/interceptors/exception_interceptor_test.py b/tests/interceptors/exception_interceptor_test.py index 7e3d7c878..2def1807e 100644 --- a/tests/interceptors/exception_interceptor_test.py +++ b/tests/interceptors/exception_interceptor_test.py @@ -361,7 +361,7 @@ async def test_handle_grpc_failure(self): mock_error_message = _MOCK_FAILURE_VALUE class MockRpcErrorResponse(grpc.RpcError): - def code(self): + async def code(self): return grpc.StatusCode.INVALID_ARGUMENT async def trailing_metadata(self): @@ -373,13 +373,16 @@ def exception(self): interceptor = self._create_test_interceptor() with self.assertRaises(GoogleAdsException): - await interceptor._handle_grpc_failure_async(MockRpcErrorResponse()) + error_response = MockRpcErrorResponse() + await interceptor._handle_grpc_failure_async( + error_response, error_response + ) async def test_handle_grpc_failure_retryable(self): """Raises retryable exceptions as-is.""" class MockRpcErrorResponse(grpc.RpcError): - def code(self): + async def code(self): return grpc.StatusCode.INTERNAL def exception(self): @@ -388,13 +391,16 @@ def exception(self): interceptor = self._create_test_interceptor() with self.assertRaises(MockRpcErrorResponse): - await interceptor._handle_grpc_failure_async(MockRpcErrorResponse()) + error_response = MockRpcErrorResponse() + await interceptor._handle_grpc_failure_async( + error_response, error_response + ) async def test_handle_grpc_failure_not_google_ads_failure(self): """Raises as-is non-retryable non-GoogleAdsFailure exceptions.""" class MockRpcErrorResponse(grpc.RpcError): - def code(self): + async def code(self): return grpc.StatusCode.INVALID_ARGUMENT async def trailing_metadata(self): @@ -406,7 +412,10 @@ def exception(self): interceptor = self._create_test_interceptor() with self.assertRaises(MockRpcErrorResponse): - await interceptor._handle_grpc_failure_async(MockRpcErrorResponse()) + error_response = MockRpcErrorResponse() + await interceptor._handle_grpc_failure_async( + error_response, error_response + ) async def test_intercept_unary_unary_response_is_exception(self): """If response.exception() is not None exception is handled.""" @@ -439,7 +448,7 @@ async def mock_continuation(client_call_details, request): except grpc.RpcError: pass - mock_handle.assert_called_once_with(mock_call) + mock_handle.assert_called_once_with(mock_call, mock_exception) async def test_intercept_unary_stream_response_is_exception(self): """Ensure errors raised from response iteration are handled/wrapped.""" @@ -470,21 +479,15 @@ async def mock_continuation(client_call_details, request): mock_continuation, mock_client_call_details, mock_request ) - # Ensure the returned value is a wrapped response object. self.assertIsInstance(response, _AsyncUnaryStreamCallWrapper) - # Initiate an iteration of the wrapped response object try: async for _ in response: - # This loop body should not be entered because the exception - # is raised on the first attempt to get an item. pass except grpc.RpcError: pass - # Check that the error handler method on the interceptor instance - # was called as a result of the iteration. - mock_handle.assert_called_once_with(mock_call) + mock_handle.assert_called_once_with(mock_call, mock_exception) async def test_intercept_unary_unary_response_is_successful(self): """If response.exception() is None response is returned.""" @@ -625,7 +628,9 @@ async def mock_continuation(client_call_details, request): found = True break # We only need the first item - self.assertTrue(found, "Iterator should have yielded at least one message") + self.assertTrue( + found, "Iterator should have yielded at least one message" + ) self.assertIsInstance(message, proto.Message) async def test_intercept_unary_stream_protobuf_proto(self): @@ -661,5 +666,7 @@ async def mock_continuation(client_call_details, request): found = True break # We only need the first item - self.assertTrue(found, "Iterator should have yielded at least one message") + self.assertTrue( + found, "Iterator should have yielded at least one message" + ) self.assertIsInstance(message, ProtobufMessageType) diff --git a/tests/interceptors/logging_interceptor_test.py b/tests/interceptors/logging_interceptor_test.py index 3ee0415b3..5fff77772 100644 --- a/tests/interceptors/logging_interceptor_test.py +++ b/tests/interceptors/logging_interceptor_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for the Logging gRPC Interceptor.""" - from importlib import import_module import json import logging @@ -32,7 +31,6 @@ import google.ads.googleads.interceptors.logging_interceptor as interceptor_module from google.ads.googleads import util - default_version = Client._DEFAULT_VERSION module_prefix = f"google.ads.googleads.{default_version}" @@ -60,7 +58,6 @@ ) - class AwaitableMagicMock(mock.MagicMock): def __await__(self): return self._await_impl().__await__() @@ -937,7 +934,9 @@ def _create_test_interceptor( if not version: version = default_version - return interceptor_module._AsyncLoggingInterceptor(logger, version, endpoint) + return interceptor_module._AsyncLoggingInterceptor( + logger, version, endpoint + ) def _get_mock_client_call_details(self): mock_client_call_details = mock.Mock() @@ -970,19 +969,26 @@ def _get_mock_exception(self): def _get_mock_response(self, failed=False, streaming=False): mock_response = AwaitableMagicMock() - # Async trailing_metadata async def mock_trailing_metadata(): return self._MOCK_TRAILING_METADATA + mock_response.trailing_metadata = mock_trailing_metadata - # Sync exception - def mock_exception_fn(): - if failed: + if failed: + + def mock_exception_fn(): return self._get_mock_exception() - return None - mock_response.exception = mock_exception_fn - # Async await for UnaryUnary + mock_response.exception = mock_exception_fn + mock_response.code = lambda: self._get_mock_exception() + else: + del mock_response.exception + + async def mock_code(): + return 0 + + mock_response.code = mock_code + async def get_result(): if streaming: return self._MOCK_STREAM @@ -990,7 +996,6 @@ async def get_result(): mock_response._await_impl = get_result - # For streaming, we might need 'read' attribute to distinguish if streaming: mock_response.read = mock.AsyncMock(return_value=self._MOCK_STREAM) else: @@ -998,11 +1003,7 @@ async def get_result(): del mock_response.result - # Sync add_done_callback def mock_add_done_callback(fn): - # In async interceptor, this is called to register the logging task. - # We want to execute it immediately or schedule it. - # Since fn expects a future (the call), and mock_response is the call. fn(mock_response) mock_response.add_done_callback = mock_add_done_callback @@ -1013,6 +1014,7 @@ def _get_mock_continuation_fn(self, fail=False): async def mock_continuation_fn(*args): mock_response = self._get_mock_response(fail) return mock_response + return mock_continuation_fn async def test_intercept_unary_unary_unconfigured(self): @@ -1058,7 +1060,9 @@ async def test_intercept_unary_unary_successful_request(self): mock_request = self._get_mock_request() # We need to get the response to assert against it - mock_response = await mock_continuation_fn(mock_client_call_details, mock_request) + mock_response = await mock_continuation_fn( + mock_client_call_details, mock_request + ) mock_trailing_metadata = await mock_response.trailing_metadata() with ( @@ -1156,7 +1160,7 @@ async def mock_continuation_fn(*args): initial_metadata, mock_request, trailing_metadata, - None, # Result is None for stream in async interceptor + None, # Result is None for stream in async interceptor ) )