diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index b6fcf853ab5..55f607504b4 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -1843,6 +1843,7 @@ async def async_anthropic_messages_handler( # Prepare headers kwargs = kwargs or {} + timeout: Optional[Union[float, httpx.Timeout]] = kwargs.get("timeout", None) provider_specific_header = cast( Optional[litellm.types.utils.ProviderSpecificHeader], kwargs.get("provider_specific_header", None), @@ -1958,6 +1959,7 @@ async def async_anthropic_messages_handler( data=signed_json_body or json.dumps(request_body), stream=stream or False, logging_obj=logging_obj, + timeout=timeout, ) response.raise_for_status() except Exception as e: diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 42b48446e7a..d685ad0e26a 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1931,6 +1931,10 @@ class PassThroughGenericEndpoint(LiteLLMPydanticObjectBase): default=False, description="If True, requests to subpaths of the path will be forwarded to the target endpoint. For example, if the path is /bria and include_subpath is True, requests to /bria/v1/text-to-image/base/2.3 will be forwarded to the target endpoint.", ) + timeout: Optional[float] = Field( + default=None, + description="Timeout in seconds for requests to the target endpoint. Defaults to 600 seconds if not specified.", + ) cost_per_request: float = Field( default=0.0, description="The USD cost per request to the target endpoint. This is used to calculate the cost of the request to the target endpoint.", diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 5b4166d0979..e06ecfe9533 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -611,6 +611,7 @@ async def pass_through_request( # noqa: PLR0915 cost_per_request: Optional[float] = None, custom_llm_provider: Optional[str] = None, guardrails_config: Optional[dict] = None, + timeout: Optional[float] = None, ): """ Pass through endpoint handler, makes the httpx request for pass-through endpoints and ensures logging hooks are called @@ -733,9 +734,10 @@ async def pass_through_request( # noqa: PLR0915 data=_parsed_body, call_type="pass_through_endpoint", ) + _timeout = timeout if timeout is not None else 600 async_client_obj = get_async_httpx_client( llm_provider=httpxSpecialProvider.PassThroughEndpoint, - params={"timeout": 600}, + params={"timeout": _timeout}, ) async_client = async_client_obj.client passthrough_logging_payload = PassthroughStandardLoggingPayload( @@ -1101,6 +1103,7 @@ def create_pass_through_route( query_params: Optional[dict] = None, default_query_params: Optional[dict] = None, guardrails: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, ): # check if target is an adapter.py or a url from litellm._uuid import uuid @@ -1174,6 +1177,7 @@ async def endpoint_func( # type: ignore "merge_query_params": _merge_query_params, "cost_per_request": cost_per_request, "guardrails": None, + "timeout": timeout, } if passthrough_params is not None: @@ -1193,6 +1197,7 @@ async def endpoint_func( # type: ignore ) param_guardrails = target_params.get("guardrails", None) param_default_query_params = target_params.get("default_query_params", None) + param_timeout = target_params.get("timeout", None) # Construct the full target URL with subpath if needed full_target = ( @@ -1236,6 +1241,7 @@ async def endpoint_func( # type: ignore cost_per_request=cast(Optional[float], param_cost_per_request), custom_llm_provider=custom_llm_provider, guardrails_config=cast(Optional[dict], param_guardrails), + timeout=cast(Optional[float], param_timeout), ) return endpoint_func @@ -1881,6 +1887,7 @@ def add_exact_path_route( guardrails: Optional[dict] = None, methods: Optional[List[str]] = None, default_query_params: Optional[dict] = None, + timeout: Optional[float] = None, ): """Add exact path route for pass-through endpoint""" # Default to all methods if none specified (backward compatibility) @@ -1920,6 +1927,7 @@ def add_exact_path_route( cost_per_request=cost_per_request, default_query_params=default_query_params, guardrails=guardrails, + timeout=timeout, ), methods=methods, dependencies=dependencies, @@ -1940,6 +1948,7 @@ def add_exact_path_route( "dependencies": dependencies, "cost_per_request": cost_per_request, "guardrails": guardrails, + "timeout": timeout, }, } @@ -1957,6 +1966,7 @@ def add_subpath_route( guardrails: Optional[dict] = None, methods: Optional[List[str]] = None, default_query_params: Optional[dict] = None, + timeout: Optional[float] = None, ): """Add wildcard route for sub-paths""" # Default to all methods if none specified (backward compatibility) @@ -1997,6 +2007,7 @@ def add_subpath_route( cost_per_request=cost_per_request, default_query_params=default_query_params, guardrails=guardrails, + timeout=timeout, ), methods=methods, dependencies=dependencies, @@ -2017,6 +2028,7 @@ def add_subpath_route( "dependencies": dependencies, "cost_per_request": cost_per_request, "guardrails": guardrails, + "timeout": timeout, }, } @@ -2230,6 +2242,9 @@ async def initialize_pass_through_endpoints( # Get guardrails config if present _guardrails = endpoint.get("guardrails", None) + # Get timeout if present + _timeout = endpoint.get("timeout", None) + # Get methods list if present (None means all methods for backward compatibility) _methods = endpoint.get("methods", None) @@ -2250,6 +2265,7 @@ async def initialize_pass_through_endpoints( guardrails=_guardrails, methods=_methods, default_query_params=_default_query_params, + timeout=_timeout, ) # Generate route key with methods for tracking @@ -2274,6 +2290,7 @@ async def initialize_pass_through_endpoints( guardrails=_guardrails, methods=_methods, default_query_params=_default_query_params, + timeout=_timeout, ) visited_endpoints.add(f"{endpoint_id}:subpath:{_path}:{methods_str}") diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_timeout.py b/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_timeout.py new file mode 100644 index 00000000000..90ce4ca487c --- /dev/null +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_timeout.py @@ -0,0 +1,182 @@ +""" +Tests for configurable timeout on pass-through endpoints. +""" + +import sys +from unittest.mock import AsyncMock, patch, MagicMock +from typing import Optional + +import pytest + +from litellm.proxy._types import PassThroughGenericEndpoint + + +class TestPassThroughGenericEndpointTimeout: + """Test that PassThroughGenericEndpoint accepts and defaults the timeout field.""" + + def test_timeout_defaults_to_none(self): + endpoint = PassThroughGenericEndpoint( + path="/test", + target="https://example.com", + ) + assert endpoint.timeout is None + + def test_timeout_accepts_custom_value(self): + endpoint = PassThroughGenericEndpoint( + path="/test", + target="https://example.com", + timeout=1200.0, + ) + assert endpoint.timeout == 1200.0 + + def test_timeout_included_in_model_dump(self): + endpoint = PassThroughGenericEndpoint( + path="/test", + target="https://example.com", + timeout=900, + ) + dumped = endpoint.model_dump() + assert dumped["timeout"] == 900 + + def test_timeout_none_in_model_dump(self): + endpoint = PassThroughGenericEndpoint( + path="/test", + target="https://example.com", + ) + dumped = endpoint.model_dump() + assert dumped["timeout"] is None + + +def _make_mock_proxy_server_module(): + """Create a fake proxy_server module with a mock proxy_logging_obj.""" + mock_module = MagicMock() + mock_module.proxy_logging_obj = MagicMock() + mock_module.proxy_logging_obj.pre_call_hook = AsyncMock(return_value={"test": True}) + return mock_module + + +class TestPassThroughRequestTimeout: + """Test that the timeout value flows through to get_async_httpx_client.""" + + @pytest.mark.asyncio + async def test_custom_timeout_passed_to_httpx_client(self): + """Verify that a custom timeout is forwarded to get_async_httpx_client.""" + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + pass_through_request, + ) + + mock_request = MagicMock() + mock_request.method = "POST" + mock_request.headers = MagicMock() + mock_request.headers.items.return_value = [] + mock_request.headers.get.return_value = None + mock_request.query_params = {} + mock_request.body = AsyncMock(return_value=b'{"test": true}') + + mock_user_api_key_dict = MagicMock() + mock_user_api_key_dict.api_key = "test-key" + mock_user_api_key_dict.user_id = None + mock_user_api_key_dict.team_id = None + mock_user_api_key_dict.end_user_id = None + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} + mock_response.text = "{}" + mock_response.iter_bytes = MagicMock(return_value=iter([b"{}"])) + mock_client.send = AsyncMock(return_value=mock_response) + + mock_client_obj = MagicMock() + mock_client_obj.client = mock_client + + mock_proxy_module = _make_mock_proxy_server_module() + proxy_server_key = "litellm.proxy.proxy_server" + original_module = sys.modules.get(proxy_server_key) + sys.modules[proxy_server_key] = mock_proxy_module + + try: + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client", + return_value=mock_client_obj, + ) as mock_get_client: + try: + await pass_through_request( + request=mock_request, + target="https://example.com/api", + custom_headers={"Authorization": "Bearer test"}, + user_api_key_dict=mock_user_api_key_dict, + timeout=1200, + ) + except Exception: + # We only care that get_async_httpx_client was called with the right timeout + pass + + mock_get_client.assert_called_once() + call_kwargs = mock_get_client.call_args + assert call_kwargs.kwargs["params"]["timeout"] == 1200 + finally: + if original_module is not None: + sys.modules[proxy_server_key] = original_module + else: + sys.modules.pop(proxy_server_key, None) + + @pytest.mark.asyncio + async def test_default_timeout_when_none(self): + """Verify that timeout defaults to 600 when not specified.""" + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + pass_through_request, + ) + + mock_request = MagicMock() + mock_request.method = "POST" + mock_request.headers = MagicMock() + mock_request.headers.items.return_value = [] + mock_request.headers.get.return_value = None + mock_request.query_params = {} + mock_request.body = AsyncMock(return_value=b'{"test": true}') + + mock_user_api_key_dict = MagicMock() + mock_user_api_key_dict.api_key = "test-key" + mock_user_api_key_dict.user_id = None + mock_user_api_key_dict.team_id = None + mock_user_api_key_dict.end_user_id = None + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} + mock_response.text = "{}" + mock_client.send = AsyncMock(return_value=mock_response) + + mock_client_obj = MagicMock() + mock_client_obj.client = mock_client + + mock_proxy_module = _make_mock_proxy_server_module() + proxy_server_key = "litellm.proxy.proxy_server" + original_module = sys.modules.get(proxy_server_key) + sys.modules[proxy_server_key] = mock_proxy_module + + try: + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client", + return_value=mock_client_obj, + ) as mock_get_client: + try: + await pass_through_request( + request=mock_request, + target="https://example.com/api", + custom_headers={"Authorization": "Bearer test"}, + user_api_key_dict=mock_user_api_key_dict, + ) + except Exception: + pass + + mock_get_client.assert_called_once() + call_kwargs = mock_get_client.call_args + assert call_kwargs.kwargs["params"]["timeout"] == 600 + finally: + if original_module is not None: + sys.modules[proxy_server_key] = original_module + else: + sys.modules.pop(proxy_server_key, None)