Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,44 @@ def _load_json_from_response(cls, response: Any) -> Any:
)


def _extract_retry_info_delay_seconds(
api_error: errors.APIError,
) -> Optional[float]:
if api_error.code != 429 or api_error.status != 'RESOURCE_EXHAUSTED':
return None

if not isinstance(api_error.details, dict):
return None

for path in (['error', 'details'], ['details']):
details = _common.get_value_by_path(api_error.details, path)
if not isinstance(details, list):
continue

for detail in details:
if not isinstance(detail, dict):
continue
detail_type = detail.get('@type')
if (
not isinstance(detail_type, str)
or not detail_type.endswith('google.rpc.RetryInfo')
):
continue
retry_delay = _common.get_value_by_path(detail, ['retryDelay'])
if not isinstance(retry_delay, str):
continue
retry_delay = retry_delay.strip()
if not retry_delay.endswith('s'):
continue
try:
retry_delay_seconds = float(retry_delay[:-1])
except ValueError:
continue
if retry_delay_seconds >= 0:
return retry_delay_seconds
return None


def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
"""Returns the retry args for the given http retry options.

Expand All @@ -498,11 +536,28 @@ def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
exp_base=options.exp_base or _RETRY_EXP_BASE,
jitter=options.jitter or _RETRY_JITTER,
)
fallback_wait = wait

def wait_with_retry_info(retry_state: tenacity.RetryCallState) -> float:
if retry_state.outcome is not None and retry_state.outcome.failed:
exception = retry_state.outcome.exception()
if isinstance(exception, errors.APIError):
retry_delay_seconds = _extract_retry_info_delay_seconds(exception)
if retry_delay_seconds is not None:
# Add one second because RetryInfo delay can be truncated.
return retry_delay_seconds + 1
return fallback_wait(retry_state)

# Preserve standard attributes.
wait_with_retry_info.initial = wait.initial
wait_with_retry_info.max = wait.max
wait_with_retry_info.exp_base = wait.exp_base
wait_with_retry_info.jitter = wait.jitter
return {
'stop': stop,
'retry': retry,
'reraise': True,
'wait': wait,
'wait': wait_with_retry_info,
'before_sleep': tenacity.before_sleep_log(logger, logging.INFO),
}

Expand Down
101 changes: 99 additions & 2 deletions google/genai/tests/client/test_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
from collections.abc import Sequence
import datetime
import json
from unittest import mock
import pytest
try:
Expand Down Expand Up @@ -61,11 +62,14 @@ def _final_codes(retried_codes: Sequence[int] = _RETRIED_CODES):
return [code for code in range(100, 600) if code not in retried_codes]


def _httpx_response(code: int):
def _httpx_response(code: int, response_json=None):
content = b''
if response_json is not None:
content = json.dumps(response_json).encode('utf-8')
return httpx.Response(
status_code=code,
headers={'status-code': str(code)},
content=b'',
content=content,
)


Expand Down Expand Up @@ -144,6 +148,99 @@ def fn():
assert timestamps[4] - timestamps[3] >= datetime.timedelta(seconds=8)


_RETRY_OPTIONS_NO_JITTER = types.HttpRetryOptions(
attempts=2,
initial_delay=0.25,
max_delay=10,
exp_base=2,
jitter=0,
)


def _resource_exhausted_error_payload(
retry_delay: str,
*,
status: str = 'RESOURCE_EXHAUSTED',
wrapped: bool = True,
):
details = {
'code': 429,
'message': 'Resource exhausted.',
'status': status,
'details': [
{
'@type': 'type.googleapis.com/google.rpc.RetryInfo',
'retryDelay': retry_delay,
}
],
}
if wrapped:
return {'error': details}
return details


def _retry_and_capture_sleep(status_code: int, error_payload: dict[str, object]):
def fn():
errors.APIError.raise_for_response(_httpx_response(status_code, error_payload))

retrying = tenacity.Retrying(
**api_client.retry_args(_RETRY_OPTIONS_NO_JITTER)
)
with mock.patch('tenacity.wait.random.uniform', return_value=0.0):
with mock.patch('tenacity.nap.time.sleep') as mock_sleep:
with pytest.raises(errors.APIError):
retrying(fn)
assert mock_sleep.call_count == 1
return mock_sleep.call_args.args[0]


def test_retry_wait_uses_retry_info_for_429_resource_exhausted():
retry_delay_seconds = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload('21.943984799s'),
)
assert retry_delay_seconds == pytest.approx(22.943984799)


def test_retry_wait_ignores_retry_info_when_status_not_resource_exhausted():
retry_delay_seconds = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload(
'9s', status='UNAVAILABLE'
),
)
assert retry_delay_seconds == 0.25


def test_retry_wait_ignores_retry_info_when_code_not_429():
retry_delay_seconds = _retry_and_capture_sleep(
500,
_resource_exhausted_error_payload('9s'),
)
assert retry_delay_seconds == 0.25


def test_retry_wait_falls_back_on_malformed_retry_delay():
retry_delay_seconds = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload('invalid-delay'),
)
assert retry_delay_seconds == 0.25


def test_retry_wait_supports_error_details_with_or_without_error_wrapper():
wrapped_retry_delay = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload('3.5s', wrapped=True),
)
unwrapped_retry_delay = _retry_and_capture_sleep(
429,
_resource_exhausted_error_payload('3.5s', wrapped=False),
)
assert wrapped_retry_delay == pytest.approx(4.5)
assert unwrapped_retry_delay == pytest.approx(4.5)


def test_retry_args_enabled_with_custom_values_are_not_overridden():
options = types.HttpRetryOptions(
attempts=10,
Expand Down