Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,7 @@ async def async_anthropic_messages_handler(

# Prepare headers
kwargs = kwargs or {}
timeout: Optional[Union[float, httpx.Timeout]] = kwargs.get("timeout", None)
provider_specific_header = cast(
Optional[litellm.types.utils.ProviderSpecificHeader],
kwargs.get("provider_specific_header", None),
Expand Down Expand Up @@ -1958,6 +1959,7 @@ async def async_anthropic_messages_handler(
data=signed_json_body or json.dumps(request_body),
stream=stream or False,
logging_obj=logging_obj,
timeout=timeout,
)
response.raise_for_status()
except Exception as e:
Expand Down
4 changes: 4 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,6 +1931,10 @@ class PassThroughGenericEndpoint(LiteLLMPydanticObjectBase):
default=False,
description="If True, requests to subpaths of the path will be forwarded to the target endpoint. For example, if the path is /bria and include_subpath is True, requests to /bria/v1/text-to-image/base/2.3 will be forwarded to the target endpoint.",
)
timeout: Optional[float] = Field(
default=None,
description="Timeout in seconds for requests to the target endpoint. Defaults to 600 seconds if not specified.",
)
cost_per_request: float = Field(
default=0.0,
description="The USD cost per request to the target endpoint. This is used to calculate the cost of the request to the target endpoint.",
Expand Down
19 changes: 18 additions & 1 deletion litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ async def pass_through_request( # noqa: PLR0915
cost_per_request: Optional[float] = None,
custom_llm_provider: Optional[str] = None,
guardrails_config: Optional[dict] = None,
timeout: Optional[float] = None,
):
"""
Pass through endpoint handler, makes the httpx request for pass-through endpoints and ensures logging hooks are called
Expand Down Expand Up @@ -733,9 +734,10 @@ async def pass_through_request( # noqa: PLR0915
data=_parsed_body,
call_type="pass_through_endpoint",
)
_timeout = timeout if timeout is not None else 600
async_client_obj = get_async_httpx_client(
llm_provider=httpxSpecialProvider.PassThroughEndpoint,
params={"timeout": 600},
params={"timeout": _timeout},
)
async_client = async_client_obj.client
passthrough_logging_payload = PassthroughStandardLoggingPayload(
Expand Down Expand Up @@ -1101,6 +1103,7 @@ def create_pass_through_route(
query_params: Optional[dict] = None,
default_query_params: Optional[dict] = None,
guardrails: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
):
# check if target is an adapter.py or a url
from litellm._uuid import uuid
Expand Down Expand Up @@ -1174,6 +1177,7 @@ async def endpoint_func( # type: ignore
"merge_query_params": _merge_query_params,
"cost_per_request": cost_per_request,
"guardrails": None,
"timeout": timeout,
}

if passthrough_params is not None:
Expand All @@ -1193,6 +1197,7 @@ async def endpoint_func( # type: ignore
)
param_guardrails = target_params.get("guardrails", None)
param_default_query_params = target_params.get("default_query_params", None)
param_timeout = target_params.get("timeout", None)

# Construct the full target URL with subpath if needed
full_target = (
Expand Down Expand Up @@ -1236,6 +1241,7 @@ async def endpoint_func( # type: ignore
cost_per_request=cast(Optional[float], param_cost_per_request),
custom_llm_provider=custom_llm_provider,
guardrails_config=cast(Optional[dict], param_guardrails),
timeout=cast(Optional[float], param_timeout),
)

return endpoint_func
Expand Down Expand Up @@ -1881,6 +1887,7 @@ def add_exact_path_route(
guardrails: Optional[dict] = None,
methods: Optional[List[str]] = None,
default_query_params: Optional[dict] = None,
timeout: Optional[float] = None,
):
"""Add exact path route for pass-through endpoint"""
# Default to all methods if none specified (backward compatibility)
Expand Down Expand Up @@ -1920,6 +1927,7 @@ def add_exact_path_route(
cost_per_request=cost_per_request,
default_query_params=default_query_params,
guardrails=guardrails,
timeout=timeout,
),
methods=methods,
dependencies=dependencies,
Expand All @@ -1940,6 +1948,7 @@ def add_exact_path_route(
"dependencies": dependencies,
"cost_per_request": cost_per_request,
"guardrails": guardrails,
"timeout": timeout,
},
}

Expand All @@ -1957,6 +1966,7 @@ def add_subpath_route(
guardrails: Optional[dict] = None,
methods: Optional[List[str]] = None,
default_query_params: Optional[dict] = None,
timeout: Optional[float] = None,
):
"""Add wildcard route for sub-paths"""
# Default to all methods if none specified (backward compatibility)
Expand Down Expand Up @@ -1997,6 +2007,7 @@ def add_subpath_route(
cost_per_request=cost_per_request,
default_query_params=default_query_params,
guardrails=guardrails,
timeout=timeout,
),
methods=methods,
dependencies=dependencies,
Expand All @@ -2017,6 +2028,7 @@ def add_subpath_route(
"dependencies": dependencies,
"cost_per_request": cost_per_request,
"guardrails": guardrails,
"timeout": timeout,
},
}

Expand Down Expand Up @@ -2230,6 +2242,9 @@ async def initialize_pass_through_endpoints(
# Get guardrails config if present
_guardrails = endpoint.get("guardrails", None)

# Get timeout if present
_timeout = endpoint.get("timeout", None)

# Get methods list if present (None means all methods for backward compatibility)
_methods = endpoint.get("methods", None)

Expand All @@ -2250,6 +2265,7 @@ async def initialize_pass_through_endpoints(
guardrails=_guardrails,
methods=_methods,
default_query_params=_default_query_params,
timeout=_timeout,
)

# Generate route key with methods for tracking
Expand All @@ -2274,6 +2290,7 @@ async def initialize_pass_through_endpoints(
guardrails=_guardrails,
methods=_methods,
default_query_params=_default_query_params,
timeout=_timeout,
)

visited_endpoints.add(f"{endpoint_id}:subpath:{_path}:{methods_str}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
Tests for configurable timeout on pass-through endpoints.
"""

import sys
from unittest.mock import AsyncMock, patch, MagicMock
from typing import Optional

import pytest

from litellm.proxy._types import PassThroughGenericEndpoint


class TestPassThroughGenericEndpointTimeout:
"""Test that PassThroughGenericEndpoint accepts and defaults the timeout field."""

def test_timeout_defaults_to_none(self):
endpoint = PassThroughGenericEndpoint(
path="/test",
target="https://example.com",
)
assert endpoint.timeout is None

def test_timeout_accepts_custom_value(self):
endpoint = PassThroughGenericEndpoint(
path="/test",
target="https://example.com",
timeout=1200.0,
)
assert endpoint.timeout == 1200.0

def test_timeout_included_in_model_dump(self):
endpoint = PassThroughGenericEndpoint(
path="/test",
target="https://example.com",
timeout=900,
)
dumped = endpoint.model_dump()
assert dumped["timeout"] == 900

def test_timeout_none_in_model_dump(self):
endpoint = PassThroughGenericEndpoint(
path="/test",
target="https://example.com",
)
dumped = endpoint.model_dump()
assert dumped["timeout"] is None


def _make_mock_proxy_server_module():
"""Create a fake proxy_server module with a mock proxy_logging_obj."""
mock_module = MagicMock()
mock_module.proxy_logging_obj = MagicMock()
mock_module.proxy_logging_obj.pre_call_hook = AsyncMock(return_value={"test": True})
return mock_module


class TestPassThroughRequestTimeout:
"""Test that the timeout value flows through to get_async_httpx_client."""

@pytest.mark.asyncio
async def test_custom_timeout_passed_to_httpx_client(self):
"""Verify that a custom timeout is forwarded to get_async_httpx_client."""
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)

mock_request = MagicMock()
mock_request.method = "POST"
mock_request.headers = MagicMock()
mock_request.headers.items.return_value = []
mock_request.headers.get.return_value = None
mock_request.query_params = {}
mock_request.body = AsyncMock(return_value=b'{"test": true}')

mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.api_key = "test-key"
mock_user_api_key_dict.user_id = None
mock_user_api_key_dict.team_id = None
mock_user_api_key_dict.end_user_id = None

mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {}
mock_response.text = "{}"
mock_response.iter_bytes = MagicMock(return_value=iter([b"{}"]))
mock_client.send = AsyncMock(return_value=mock_response)

mock_client_obj = MagicMock()
mock_client_obj.client = mock_client

mock_proxy_module = _make_mock_proxy_server_module()
proxy_server_key = "litellm.proxy.proxy_server"
original_module = sys.modules.get(proxy_server_key)
sys.modules[proxy_server_key] = mock_proxy_module

try:
with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
return_value=mock_client_obj,
) as mock_get_client:
try:
await pass_through_request(
request=mock_request,
target="https://example.com/api",
custom_headers={"Authorization": "Bearer test"},
user_api_key_dict=mock_user_api_key_dict,
timeout=1200,
)
except Exception:
# We only care that get_async_httpx_client was called with the right timeout
pass

mock_get_client.assert_called_once()
call_kwargs = mock_get_client.call_args
assert call_kwargs.kwargs["params"]["timeout"] == 1200
finally:
if original_module is not None:
sys.modules[proxy_server_key] = original_module
else:
sys.modules.pop(proxy_server_key, None)

@pytest.mark.asyncio
async def test_default_timeout_when_none(self):
"""Verify that timeout defaults to 600 when not specified."""
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)

mock_request = MagicMock()
mock_request.method = "POST"
mock_request.headers = MagicMock()
mock_request.headers.items.return_value = []
mock_request.headers.get.return_value = None
mock_request.query_params = {}
mock_request.body = AsyncMock(return_value=b'{"test": true}')

mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.api_key = "test-key"
mock_user_api_key_dict.user_id = None
mock_user_api_key_dict.team_id = None
mock_user_api_key_dict.end_user_id = None

mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {}
mock_response.text = "{}"
mock_client.send = AsyncMock(return_value=mock_response)

mock_client_obj = MagicMock()
mock_client_obj.client = mock_client

mock_proxy_module = _make_mock_proxy_server_module()
proxy_server_key = "litellm.proxy.proxy_server"
original_module = sys.modules.get(proxy_server_key)
sys.modules[proxy_server_key] = mock_proxy_module

try:
with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
return_value=mock_client_obj,
) as mock_get_client:
try:
await pass_through_request(
request=mock_request,
target="https://example.com/api",
custom_headers={"Authorization": "Bearer test"},
user_api_key_dict=mock_user_api_key_dict,
)
except Exception:
pass

mock_get_client.assert_called_once()
call_kwargs = mock_get_client.call_args
assert call_kwargs.kwargs["params"]["timeout"] == 600
finally:
if original_module is not None:
sys.modules[proxy_server_key] = original_module
else:
sys.modules.pop(proxy_server_key, None)
Loading