Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions google/ads/googleads/interceptors/exception_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def __await__(self):
return response
else:
return util.convert_proto_plus_to_protobuf(response)
except grpc.RpcError:
yield from self._interceptor._handle_grpc_failure_async(self._call).__await__()
except grpc.RpcError as exception:
yield from self._interceptor._handle_grpc_failure_async(self._call, exception).__await__()
raise

def cancel(self):
Expand Down Expand Up @@ -209,8 +209,8 @@ async def _wrapped_aiter():
yield response
else:
yield util.convert_proto_plus_to_protobuf(response)
except grpc.RpcError:
await self._interceptor._handle_grpc_failure_async(self._call)
except grpc.RpcError as exception:
await self._interceptor._handle_grpc_failure_async(self._call, exception)
raise
return _wrapped_aiter()

Expand Down Expand Up @@ -252,19 +252,30 @@ async def read(self):
return response
else:
return util.convert_proto_plus_to_protobuf(response)
except grpc.RpcError:
await self._interceptor._handle_grpc_failure_async(self._call)
except grpc.RpcError as exception:
await self._interceptor._handle_grpc_failure_async(self._call, exception)
raise

class _AsyncExceptionInterceptor(
ExceptionInterceptor,
):
class _AsyncExceptionInterceptor(Interceptor, grpc.aio.UnaryUnaryClientInterceptor, grpc.aio.UnaryStreamClientInterceptor,):
"""An interceptor that wraps rpc exceptions."""

async def _handle_grpc_failure_async(self, response: grpc.aio.Call):
def __init__(self, api_version: str, use_proto_plus: bool = False):
"""Initializes the ExceptionInterceptor.

Args:
api_version: a str of the API version of the request.
use_proto_plus: a boolean of whether returned messages should be
proto_plus or protobuf.
"""
super().__init__(api_version)
self._api_version = api_version
self._use_proto_plus = use_proto_plus

async def _handle_grpc_failure_async(self, response: grpc.aio.Call, exception: grpc.RpcError):
"""Async version of _handle_grpc_failure."""
status_code = response.code()
response_exception = response.exception()
status_code = await response.code()

response_exception = exception

# We need to access _RETRY_STATUS_CODES from interceptor module?
# It's imported in interceptor.py but not exposed in ExceptionInterceptor?
Expand Down Expand Up @@ -300,7 +311,7 @@ async def _handle_grpc_failure_async(self, response: grpc.aio.Call):
raise response_exception

# If we got here, maybe no exception? But we only call this on error.
raise response.exception()
raise response_exception

async def intercept_unary_unary(
self,
Expand Down
11 changes: 7 additions & 4 deletions google/ads/googleads/interceptors/logging_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,13 @@ async def _log_request_async(
# Since this is called in on_done, it is done.

try:
# This might raise if cancelled?
exception = response.exception()
except Exception:
exception = None
if hasattr(response, "exception"):
exception = response.code()
else:
await response.code()
exception = None
except Exception as ex:
exception = ex

if exception:
# We need to adapt exception logging for async exception?
Expand Down
39 changes: 23 additions & 16 deletions tests/interceptors/exception_interceptor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ async def test_handle_grpc_failure(self):
mock_error_message = _MOCK_FAILURE_VALUE

class MockRpcErrorResponse(grpc.RpcError):
def code(self):
async def code(self):
return grpc.StatusCode.INVALID_ARGUMENT

async def trailing_metadata(self):
Expand All @@ -373,13 +373,16 @@ def exception(self):
interceptor = self._create_test_interceptor()

with self.assertRaises(GoogleAdsException):
await interceptor._handle_grpc_failure_async(MockRpcErrorResponse())
error_response = MockRpcErrorResponse()
await interceptor._handle_grpc_failure_async(
error_response, error_response
)

async def test_handle_grpc_failure_retryable(self):
"""Raises retryable exceptions as-is."""

class MockRpcErrorResponse(grpc.RpcError):
def code(self):
async def code(self):
return grpc.StatusCode.INTERNAL

def exception(self):
Expand All @@ -388,13 +391,16 @@ def exception(self):
interceptor = self._create_test_interceptor()

with self.assertRaises(MockRpcErrorResponse):
await interceptor._handle_grpc_failure_async(MockRpcErrorResponse())
error_response = MockRpcErrorResponse()
await interceptor._handle_grpc_failure_async(
error_response, error_response
)

async def test_handle_grpc_failure_not_google_ads_failure(self):
"""Raises as-is non-retryable non-GoogleAdsFailure exceptions."""

class MockRpcErrorResponse(grpc.RpcError):
def code(self):
async def code(self):
return grpc.StatusCode.INVALID_ARGUMENT

async def trailing_metadata(self):
Expand All @@ -406,7 +412,10 @@ def exception(self):
interceptor = self._create_test_interceptor()

with self.assertRaises(MockRpcErrorResponse):
await interceptor._handle_grpc_failure_async(MockRpcErrorResponse())
error_response = MockRpcErrorResponse()
await interceptor._handle_grpc_failure_async(
error_response, error_response
)

async def test_intercept_unary_unary_response_is_exception(self):
"""If response.exception() is not None exception is handled."""
Expand Down Expand Up @@ -439,7 +448,7 @@ async def mock_continuation(client_call_details, request):
except grpc.RpcError:
pass

mock_handle.assert_called_once_with(mock_call)
mock_handle.assert_called_once_with(mock_call, mock_exception)

async def test_intercept_unary_stream_response_is_exception(self):
"""Ensure errors raised from response iteration are handled/wrapped."""
Expand Down Expand Up @@ -470,21 +479,15 @@ async def mock_continuation(client_call_details, request):
mock_continuation, mock_client_call_details, mock_request
)

# Ensure the returned value is a wrapped response object.
self.assertIsInstance(response, _AsyncUnaryStreamCallWrapper)

# Initiate an iteration of the wrapped response object
try:
async for _ in response:
# This loop body should not be entered because the exception
# is raised on the first attempt to get an item.
pass
except grpc.RpcError:
pass

# Check that the error handler method on the interceptor instance
# was called as a result of the iteration.
mock_handle.assert_called_once_with(mock_call)
mock_handle.assert_called_once_with(mock_call, mock_exception)

async def test_intercept_unary_unary_response_is_successful(self):
"""If response.exception() is None response is returned."""
Expand Down Expand Up @@ -625,7 +628,9 @@ async def mock_continuation(client_call_details, request):
found = True
break # We only need the first item

self.assertTrue(found, "Iterator should have yielded at least one message")
self.assertTrue(
found, "Iterator should have yielded at least one message"
)
self.assertIsInstance(message, proto.Message)

async def test_intercept_unary_stream_protobuf_proto(self):
Expand Down Expand Up @@ -661,5 +666,7 @@ async def mock_continuation(client_call_details, request):
found = True
break # We only need the first item

self.assertTrue(found, "Iterator should have yielded at least one message")
self.assertTrue(
found, "Iterator should have yielded at least one message"
)
self.assertIsInstance(message, ProtobufMessageType)
40 changes: 22 additions & 18 deletions tests/interceptors/logging_interceptor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Tests for the Logging gRPC Interceptor."""


from importlib import import_module
import json
import logging
Expand All @@ -32,7 +31,6 @@
import google.ads.googleads.interceptors.logging_interceptor as interceptor_module
from google.ads.googleads import util


default_version = Client._DEFAULT_VERSION
module_prefix = f"google.ads.googleads.{default_version}"

Expand Down Expand Up @@ -60,7 +58,6 @@
)



class AwaitableMagicMock(mock.MagicMock):
def __await__(self):
return self._await_impl().__await__()
Expand Down Expand Up @@ -937,7 +934,9 @@ def _create_test_interceptor(
if not version:
version = default_version

return interceptor_module._AsyncLoggingInterceptor(logger, version, endpoint)
return interceptor_module._AsyncLoggingInterceptor(
logger, version, endpoint
)

def _get_mock_client_call_details(self):
mock_client_call_details = mock.Mock()
Expand Down Expand Up @@ -970,39 +969,41 @@ def _get_mock_exception(self):
def _get_mock_response(self, failed=False, streaming=False):
mock_response = AwaitableMagicMock()

# Async trailing_metadata
async def mock_trailing_metadata():
return self._MOCK_TRAILING_METADATA

mock_response.trailing_metadata = mock_trailing_metadata

# Sync exception
def mock_exception_fn():
if failed:
if failed:

def mock_exception_fn():
return self._get_mock_exception()
return None
mock_response.exception = mock_exception_fn

# Async await for UnaryUnary
mock_response.exception = mock_exception_fn
mock_response.code = lambda: self._get_mock_exception()
else:
del mock_response.exception

async def mock_code():
return 0

mock_response.code = mock_code

async def get_result():
if streaming:
return self._MOCK_STREAM
return self._MOCK_RESPONSE_MSG

mock_response._await_impl = get_result

# For streaming, we might need 'read' attribute to distinguish
if streaming:
mock_response.read = mock.AsyncMock(return_value=self._MOCK_STREAM)
else:
del mock_response.read

del mock_response.result

# Sync add_done_callback
def mock_add_done_callback(fn):
# In async interceptor, this is called to register the logging task.
# We want to execute it immediately or schedule it.
# Since fn expects a future (the call), and mock_response is the call.
fn(mock_response)

mock_response.add_done_callback = mock_add_done_callback
Expand All @@ -1013,6 +1014,7 @@ def _get_mock_continuation_fn(self, fail=False):
async def mock_continuation_fn(*args):
mock_response = self._get_mock_response(fail)
return mock_response

return mock_continuation_fn

async def test_intercept_unary_unary_unconfigured(self):
Expand Down Expand Up @@ -1058,7 +1060,9 @@ async def test_intercept_unary_unary_successful_request(self):
mock_request = self._get_mock_request()

# We need to get the response to assert against it
mock_response = await mock_continuation_fn(mock_client_call_details, mock_request)
mock_response = await mock_continuation_fn(
mock_client_call_details, mock_request
)
mock_trailing_metadata = await mock_response.trailing_metadata()

with (
Expand Down Expand Up @@ -1156,7 +1160,7 @@ async def mock_continuation_fn(*args):
initial_metadata,
mock_request,
trailing_metadata,
None, # Result is None for stream in async interceptor
None, # Result is None for stream in async interceptor
)
)

Expand Down