diff --git a/python/packages/foundry/agent_framework_foundry/_chat_client.py b/python/packages/foundry/agent_framework_foundry/_chat_client.py index fc2b29e1e4..5d04fb3c83 100644 --- a/python/packages/foundry/agent_framework_foundry/_chat_client.py +++ b/python/packages/foundry/agent_framework_foundry/_chat_client.py @@ -47,9 +47,9 @@ else: from typing_extensions import override # type: ignore # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import Self, TypedDict # type: ignore # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover if TYPE_CHECKING: from agent_framework import ChatAndFunctionMiddlewareTypes, ToolTypes @@ -189,6 +189,7 @@ def __init__( "Either 'project_endpoint' or 'project_client' is required. " "Set project_endpoint via parameter or 'FOUNDRY_PROJECT_ENDPOINT' environment variable." ) + self._should_close_client = False if not project_client: if not project_endpoint: raise ValueError( @@ -206,6 +207,7 @@ def __init__( if allow_preview is not None: project_client_kwargs["allow_preview"] = allow_preview project_client = AIProjectClient(**project_client_kwargs) + self._should_close_client = True openai_kwargs: dict[str, Any] = {} if default_headers: @@ -222,6 +224,24 @@ def __init__( ) self.project_client = project_client + async def close(self) -> None: + """Close the project client if we created it.""" + if self._should_close_client: + await self.project_client.close() + + async def __aenter__(self) -> Self: + """Enter the async context manager.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Close the project client on exit (only when owned).""" + await self.close() + @override def _check_model_presence(self, options: dict[str, Any]) -> None: if not options.get("model"):