diff --git a/plugins/remote_kernels/txl_remote_kernels/driver.py b/plugins/remote_kernels/txl_remote_kernels/driver.py index fc2ee4c..db16949 100644 --- a/plugins/remote_kernels/txl_remote_kernels/driver.py +++ b/plugins/remote_kernels/txl_remote_kernels/driver.py @@ -7,6 +7,7 @@ import httpx from anyio import Lock, sleep from anyioutils import create_task +from httpx import USE_CLIENT_DEFAULT, Timeout from httpx_ws import aconnect_ws from txl_kernel.driver import KernelMixin from txl_kernel.message import date_to_str @@ -30,6 +31,7 @@ def __init__( url: str, kernel_name: str | None = "", comm_handlers=[], + timeout: Timeout = USE_CLIENT_DEFAULT, ) -> None: super().__init__(task_group) self.task_group = task_group @@ -78,6 +80,7 @@ async def start(self): params={"session_id": self.session_id}, cookies=self.cookies, subprotocols=["v1.kernel.websocket.jupyter.org"], + timeout=self.timeout, ) as self.websocket: recv_task = create_task(self._recv(), self.task_group) try: diff --git a/plugins/remote_kernels/txl_remote_kernels/main.py b/plugins/remote_kernels/txl_remote_kernels/main.py index 64bceb4..660de58 100644 --- a/plugins/remote_kernels/txl_remote_kernels/main.py +++ b/plugins/remote_kernels/txl_remote_kernels/main.py @@ -6,6 +6,7 @@ import httpx from anyio import create_task_group, sleep from fps import Module +from httpx import USE_CLIENT_DEFAULT, Timeout from pycrdt import Map from txl.base import Kernels, Kernelspecs @@ -20,9 +21,11 @@ def __init__( self, url: str, kernel_name: str | None, + *, + timeout: Timeout = USE_CLIENT_DEFAULT, ): self.kernel = KernelDriver( - self.task_group, url, kernel_name, comm_handlers=self.comm_handlers + self.task_group, url, kernel_name, comm_handlers=self.comm_handlers, timeout=timeout ) async def execute(self, ycell: Map): @@ -56,19 +59,28 @@ async def get(self) -> dict[str, Any]: class RemoteKernelsModule(Module): - def __init__(self, name: str, url: str = "http://127.0.0.1:8000"): + def __init__( + self, + name: str, + url: str = "http://127.0.0.1:8000", + *, + timeout: Timeout = USE_CLIENT_DEFAULT, + ): super().__init__(name) self.url = url + self.timeout = timeout async def start(self) -> None: url = self.url async with create_task_group() as self.tg: + class _RemoteKernels(RemoteKernels): task_group = self.tg + timeout = self.timeout def __init__(self, *args, **kwargs): - super().__init__(url, *args, **kwargs) + super().__init__(url, *args, timeout=self.timeout, **kwargs) self.put(_RemoteKernels, Kernels) self.done()