diff --git a/.github/workflows/test-ros1.yml b/.github/workflows/test-ros1.yml index 6b07a53..5ac3532 100644 --- a/.github/workflows/test-ros1.yml +++ b/.github/workflows/test-ros1.yml @@ -55,5 +55,6 @@ jobs: run: | pytest tests/ros1 - name: Tear down docker containers + if: always() run: | docker rm -f rosbridge diff --git a/.github/workflows/transport-benchmark.yml b/.github/workflows/transport-benchmark.yml new file mode 100644 index 0000000..afdab8f --- /dev/null +++ b/.github/workflows/transport-benchmark.yml @@ -0,0 +1,58 @@ +name: Transport benchmark + +on: + push: + branches: + - main + tags: + - 'v*' + pull_request: + branches: + - main + workflow_dispatch: + +jobs: + transport-benchmark: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Set up Python 3.11 + uses: actions/setup-python@v6 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install wheel + - name: Install + run: | + python -m pip install --no-cache-dir -r requirements-dev.txt + python -m pip install --no-cache-dir uvloop + - name: Set up docker containers + run: | + docker build -t gramaziokohler/rosbridge:integration_tests_ros1 ./docker/ros1 + docker run -d -p 9090:9090 --name rosbridge gramaziokohler/rosbridge:integration_tests_ros1 /bin/bash -c "roslaunch /integration-tests.launch" + docker ps -a + - name: Run transport benchmark + continue-on-error: true + run: | + python benchmarks/transport.py \ + --host 127.0.0.1 \ + --port 9090 \ + --cases twisted asyncio asyncio-uvloop asyncio-no-compression asyncio-uvloop-no-compression \ + --service-count 300 \ + --topic-count 500 \ + --markdown transport-benchmark.md + cat transport-benchmark.md >> "$GITHUB_STEP_SUMMARY" + - name: Upload transport benchmark + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: transport-benchmark + path: transport-benchmark.md + if-no-files-found: ignore + - name: Tear down docker containers + if: always() + run: | + docker rm -f rosbridge diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..e88e6c7 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,13 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +sphinx: + configuration: docs/conf.py + +python: + install: + - requirements: docs/requirements.txt diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6dca910..7701d59 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,11 @@ Unreleased **Added** +* Added an asyncio transport backend selectable with ``ROSLIBPY_TRANSPORT``, + ``set_default_transport()``, or the ``Ros(..., transport=...)`` argument. +* Added transport-parametrized ROS integration tests for CPython, while + keeping IronPython on the ``cli`` transport. + **Changed** **Fixed** diff --git a/MANIFEST.in b/MANIFEST.in index 1d9c40c..cc7865b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,11 +1,13 @@ graft docs graft src graft tests +graft benchmarks prune docker include .bumpversion.cfg include .editorconfig +include .readthedocs.yaml include AUTHORS.rst include CHANGELOG.rst diff --git a/README.rst b/README.rst index edb3122..c84691d 100644 --- a/README.rst +++ b/README.rst @@ -68,6 +68,11 @@ To install **roslibpy**, simply use ``pip``:: pip install roslibpy +The default WebSocket transport on CPython uses Twisted/Autobahn. An asyncio +transport is also available and can be installed with:: + + pip install roslibpy[asyncio] + For IronPython, the ``pip`` command is slightly different:: ipy -X:Frames -m pip install --user roslibpy @@ -75,6 +80,17 @@ For IronPython, the ``pip`` command is slightly different:: Remember that you will need a working ROS setup including the **rosbridge server** and **TF2 web republisher** accessible within your network. +Transport selection +------------------- + +CPython uses the ``twisted`` transport by default; IronPython uses ``cli``. To +select the asyncio transport, either pass it per client, set the process-wide +default, or set the environment variable:: + + roslibpy.Ros(host='localhost', port=9090, transport='asyncio') + roslibpy.set_default_transport('asyncio') + ROSLIBPY_TRANSPORT=asyncio python my_script.py + Documentation ------------- diff --git a/benchmarks/transport.py b/benchmarks/transport.py new file mode 100644 index 0000000..3976af8 --- /dev/null +++ b/benchmarks/transport.py @@ -0,0 +1,325 @@ +"""Benchmark roslibpy WebSocket transports against a running rosbridge. + +This is an opt-in development helper, not a pytest test. Start a rosbridge on +the chosen host/port, then run for example: + + python benchmarks/transport.py --host 127.0.0.1 --port 9090 + +The benchmark compares connection setup, blocking rosapi service calls, and +topic publish/subscribe round-trip latency for the ``twisted`` and ``asyncio`` +transports. + +The GitHub Actions benchmark runs on shared CI infrastructure, so occasional +topic latency spikes are expected. Prefer medians and throughput for quick +comparisons, and compare p95/max values across several runs before drawing +conclusions about tail latency. +""" + +from __future__ import print_function + +import argparse +import json +import os +import statistics +import subprocess +import sys +import tempfile +import threading +import time + +from roslibpy import Message, Ros, Topic + + +CASES = { + "twisted": { + "transport": "twisted", + "event_loop": None, + "compression": None, + }, + "asyncio": { + "transport": "asyncio", + "event_loop": "default", + "compression": "deflate", + }, + "asyncio-uvloop": { + "transport": "asyncio", + "event_loop": "uvloop", + "compression": "deflate", + }, + "asyncio-no-compression": { + "transport": "asyncio", + "event_loop": "default", + "compression": None, + }, + "asyncio-uvloop-no-compression": { + "transport": "asyncio", + "event_loop": "uvloop", + "compression": None, + }, +} + + +def percentile(values, pct): + ordered = sorted(values) + return ordered[int(round((len(ordered) - 1) * pct / 100.0))] + + +def summary(values): + return { + "mean": statistics.mean(values) * 1000.0, + "median": statistics.median(values) * 1000.0, + "p95": percentile(values, 95) * 1000.0, + "max": max(values) * 1000.0, + } + + +def format_summary_line(transport, label, values): + data = summary(values) + return "{:<30} {:<16} mean={mean:7.3f} ms median={median:7.3f} ms " "p95={p95:7.3f} ms max={max:7.3f} ms".format( + transport, label, **data + ) + + +def print_summary(transport, label, values): + print(format_summary_line(transport, label, values)) + + +def markdown_row(transport, metric, values=None, value=None): + if values is not None: + data = summary(values) + return "| {transport} | {metric} | {mean:.3f} ms | {median:.3f} ms | " "{p95:.3f} ms | {max:.3f} ms | |".format( + transport=transport, metric=metric, **data + ) + return "| {} | {} | | | | | {} |".format(transport, metric, value) + + +def wait_connected(ros, timeout): + deadline = time.time() + timeout + while time.time() < deadline: + if ros.is_connected: + return + time.sleep(0.005) + raise RuntimeError("Timed out waiting for connection") + + +def wait_rosbridge_ready(transport, args): + deadline = time.time() + args.ready_timeout + last_error = None + while time.time() < deadline: + ros = None + try: + ros = Ros(args.host, args.port, transport=transport) + ros.run() + wait_connected(ros, args.connect_timeout) + ros.get_time() + return + except Exception as error: + last_error = error + time.sleep(args.ready_interval) + finally: + if ros is not None: + try: + ros.close() + except Exception: + pass + raise RuntimeError("Timed out waiting for rosbridge readiness: {}".format(last_error)) + + +def service_latency(ros, count, warmup): + for _ in range(warmup): + ros.get_time() + timings = [] + for _ in range(count): + start = time.perf_counter() + ros.get_time() + timings.append(time.perf_counter() - start) + return timings + + +def topic_latency(ros, case_name, count, warmup, delay): + topic_name = "/roslibpy_transport_benchmark_{}".format(case_name.replace("-", "_")) + listener = Topic(ros, topic_name, "std_msgs/String") + publisher = Topic(ros, topic_name, "std_msgs/String") + + expected = count + warmup + sent = {} + timings = [] + done = threading.Event() + + def receive(message): + payload = json.loads(message["data"]) + seq = payload["seq"] + start = sent.get(seq) + if start is None: + return + if seq >= warmup: + timings.append(time.perf_counter() - start) + if seq == expected - 1: + done.set() + + listener.subscribe(receive) + time.sleep(0.5) + + start_total = time.perf_counter() + for seq in range(expected): + sent[seq] = time.perf_counter() + publisher.publish(Message({"data": json.dumps({"seq": seq})})) + if delay: + time.sleep(delay) + + if not done.wait(15): + raise RuntimeError("Timed out after receiving {} of {} messages".format(len(timings), count)) + + total = time.perf_counter() - start_total + listener.unsubscribe() + publisher.unadvertise() + return timings, len(timings) / total + + +def configure_case(case_name): + case = CASES[case_name] + if case["event_loop"] == "uvloop": + import uvloop + + uvloop.install() + + if case["transport"] == "asyncio": + from roslibpy.comm.comm_asyncio import AsyncioRosBridgeClientFactory + + AsyncioRosBridgeClientFactory.compression = case["compression"] + + return case + + +def run_case(case_name, args): + case = configure_case(case_name) + transport = case["transport"] + + wait_rosbridge_ready(transport, args) + + ros = Ros(args.host, args.port, transport=transport) + start = time.perf_counter() + ros.run() + wait_connected(ros, args.connect_timeout) + connect_time = time.perf_counter() - start + + services = service_latency(ros, args.service_count, args.warmup) + topics, topic_rate = topic_latency(ros, case_name, args.topic_count, args.warmup, args.topic_delay) + + print("{:<30} {:<16} {:7.3f} ms".format(case_name, "initial connect", connect_time * 1000.0)) + print_summary(case_name, "get_time", services) + print_summary(case_name, "topic rtt", topics) + print("{:<30} {:<16} {:7.1f} msg/s".format(case_name, "topic rate", topic_rate)) + + ros.close() + time.sleep(0.5) + return { + "connect": [connect_time], + "services": services, + "topics": topics, + "topic_rate": topic_rate, + } + + +def run_case_subprocess(case_name, args): + with tempfile.NamedTemporaryFile(delete=False) as result_file: + result_path = result_file.name + + command = [ + sys.executable, + os.path.abspath(__file__), + "--case", + case_name, + "--host", + args.host, + "--port", + str(args.port), + "--service-count", + str(args.service_count), + "--topic-count", + str(args.topic_count), + "--warmup", + str(args.warmup), + "--topic-delay", + str(args.topic_delay), + "--connect-timeout", + str(args.connect_timeout), + "--ready-timeout", + str(args.ready_timeout), + "--ready-interval", + str(args.ready_interval), + "--json-result", + result_path, + ] + try: + subprocess.check_call(command) + with open(result_path) as fh: + return json.load(fh) + finally: + try: + os.remove(result_path) + except OSError: + pass + + +def write_markdown(path, results): + lines = [ + "# roslibpy transport benchmark", + "", + "These numbers are sampled on shared CI infrastructure. Topic p95 and max", + "latencies can be noisy because of runner scheduling, Docker networking,", + "and rosbridge timing. Prefer medians and throughput for quick comparisons;", + "interpret tail latency across several runs.", + "", + "| Transport | Metric | Mean | Median | P95 | Max | Value |", + "| --- | --- | ---: | ---: | ---: | ---: | ---: |", + ] + for transport, result in results: + lines.append(markdown_row(transport, "initial connect", result["connect"])) + lines.append(markdown_row(transport, "get_time service", result["services"])) + lines.append(markdown_row(transport, "topic round trip", result["topics"])) + lines.append(markdown_row(transport, "topic throughput", value="{:.1f} msg/s".format(result["topic_rate"]))) + lines.append("") + with open(path, "w") as fh: + fh.write("\n".join(lines)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=9090) + parser.add_argument("--cases", nargs="+") + parser.add_argument("--transports", nargs="+", help=argparse.SUPPRESS) + parser.add_argument("--case", choices=sorted(CASES), help=argparse.SUPPRESS) + parser.add_argument("--json-result", help=argparse.SUPPRESS) + parser.add_argument("--service-count", type=int, default=1000) + parser.add_argument("--topic-count", type=int, default=1000) + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--topic-delay", type=float, default=0.0005) + parser.add_argument("--connect-timeout", type=float, default=5.0) + parser.add_argument("--ready-timeout", type=float, default=30.0) + parser.add_argument("--ready-interval", type=float, default=0.5) + parser.add_argument("--markdown", help="Write a Markdown summary table to this path") + args = parser.parse_args() + + if args.case: + result = run_case(args.case, args) + if args.json_result: + with open(args.json_result, "w") as fh: + json.dump(result, fh) + return + + cases = args.cases or args.transports or ["twisted", "asyncio"] + + results = [] + for case_name in cases: + if case_name not in CASES: + raise ValueError("Unknown benchmark case {!r}; expected one of {}".format(case_name, sorted(CASES))) + results.append((case_name, run_case_subprocess(case_name, args))) + + if args.markdown: + write_markdown(args.markdown, results) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 82ea93d..a613457 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,6 @@ def read(*names, **kwargs): ], keywords=["ros", "ros-bridge", "robotics", "websockets"], install_requires=requirements, - extras_require={}, + extras_require={"asyncio": ["websockets>=12.0"]}, entry_points={"console_scripts": ["roslibpy=roslibpy.__main__:main"]}, ) diff --git a/src/roslibpy/__init__.py b/src/roslibpy/__init__.py index 4ccf076..ae3c469 100644 --- a/src/roslibpy/__init__.py +++ b/src/roslibpy/__init__.py @@ -34,9 +34,21 @@ Other classes that need an active connection with ROS receive this instance as an argument to their constructors. +Transport selection +------------------- + +CPython uses the ``twisted`` transport by default; IronPython uses ``cli``. The +``asyncio`` transport can be selected per client, process-wide, or via the +``ROSLIBPY_TRANSPORT`` environment variable:: + + ros = roslibpy.Ros(host='localhost', port=9090, transport='asyncio') + roslibpy.set_default_transport('asyncio') + ROSLIBPY_TRANSPORT=asyncio python my_script.py + .. autoclass:: Ros :members: .. autofunction:: set_rosapi_timeout +.. autofunction:: set_default_transport Main ROS concepts ================= @@ -129,6 +141,7 @@ class and are passed around via :class:`Topics ` using a **publish/subscr __url__, __version__, ) +from .comm import set_default_transport from .core import ( ActionClient, Feedback, @@ -167,6 +180,7 @@ class and are passed around via :class:`Topics ` using a **publish/subscr "Service", "ServiceRequest", "ServiceResponse", + "set_default_transport", "set_rosapi_timeout", "Time", "Topic", diff --git a/src/roslibpy/comm/__init__.py b/src/roslibpy/comm/__init__.py index 0e88ce7..cd96cd6 100644 --- a/src/roslibpy/comm/__init__.py +++ b/src/roslibpy/comm/__init__.py @@ -1,10 +1,107 @@ +"""Transport selection for the ROS bridge connection. + +Three transports are available: + +* ``twisted``: default. Built on autobahn + twisted. The historical + implementation. +* ``asyncio``: opt-in (2.1+). Built on the ``websockets`` library. Cleaner + per-test loop isolation; requires ``pip install roslibpy[asyncio]``. +* ``cli``: IronPython only (auto-selected on ``sys.platform == "cli"``). + +Selection precedence (highest → lowest): + +1. ``transport=`` kwarg on ``Ros`` (per-instance). +2. ``ROSLIBPY_TRANSPORT`` environment variable. +3. Module-level default set via ``roslibpy.set_default_transport(...)``. +4. Platform default: ``cli`` on IronPython, ``twisted`` elsewhere. +""" + +import os import sys from .comm import RosBridgeException, RosBridgeProtocol -if sys.platform == "cli": - from .comm_cli import CliRosBridgeClientFactory as RosBridgeClientFactory -else: - from .comm_autobahn import AutobahnRosBridgeClientFactory as RosBridgeClientFactory +__all__ = [ + "RosBridgeException", + "RosBridgeProtocol", + "RosBridgeClientFactory", + "select_factory", + "set_default_transport", + "TRANSPORT_TWISTED", + "TRANSPORT_ASYNCIO", + "TRANSPORT_CLI", +] + +TRANSPORT_TWISTED = "twisted" +TRANSPORT_ASYNCIO = "asyncio" +TRANSPORT_CLI = "cli" + +_VALID_TRANSPORTS = (TRANSPORT_TWISTED, TRANSPORT_ASYNCIO, TRANSPORT_CLI) +_PLATFORM_DEFAULT = TRANSPORT_CLI if sys.platform == "cli" else TRANSPORT_TWISTED +_DEFAULT_TRANSPORT = _PLATFORM_DEFAULT + + +def set_default_transport(name): + """Set the process-wide default transport. + + Args: + name (str): One of ``"twisted"``, ``"asyncio"``, ``"cli"``. + + Raises: + ValueError: If ``name`` is not a known transport. + """ + global _DEFAULT_TRANSPORT + if name not in _VALID_TRANSPORTS: + raise ValueError("Unknown transport %r; expected one of %s" % (name, _VALID_TRANSPORTS)) + _DEFAULT_TRANSPORT = name + + +def _resolve_transport(explicit=None): + """Apply the precedence rules to land on a single transport name.""" + if explicit is not None: + if explicit not in _VALID_TRANSPORTS: + raise ValueError("Unknown transport %r; expected one of %s" % (explicit, _VALID_TRANSPORTS)) + return explicit + env_choice = os.environ.get("ROSLIBPY_TRANSPORT") + if env_choice: + if env_choice not in _VALID_TRANSPORTS: + raise ValueError( + "Unknown ROSLIBPY_TRANSPORT=%r; expected one of %s" % (env_choice, _VALID_TRANSPORTS) + ) + return env_choice + return _DEFAULT_TRANSPORT + + +def select_factory(transport=None): + """Return the factory class for the requested (or resolved) transport. + + The optional dependencies are imported lazily so a process that never uses + the asyncio transport never imports ``websockets``, and a process that + never uses the twisted transport never imports ``twisted``. + + Args: + transport (str, optional): One of ``"twisted"``, ``"asyncio"``, + ``"cli"``. If ``None``, applies the precedence rules described in + the module docstring. + + Returns: + The factory class to use for new ``Ros`` instances. + """ + name = _resolve_transport(transport) + if name == TRANSPORT_CLI: + from .comm_cli import CliRosBridgeClientFactory + + return CliRosBridgeClientFactory + if name == TRANSPORT_ASYNCIO: + from .comm_asyncio import AsyncioRosBridgeClientFactory + + return AsyncioRosBridgeClientFactory + + # Fallback to default + from .comm_autobahn import AutobahnRosBridgeClientFactory + + return AutobahnRosBridgeClientFactory + -__all__ = ["RosBridgeException", "RosBridgeProtocol", "RosBridgeClientFactory"] +# `RosBridgeClientFactory` remains a module-level binding for back-compatibility +RosBridgeClientFactory = select_factory() diff --git a/src/roslibpy/comm/comm_asyncio.py b/src/roslibpy/comm/comm_asyncio.py new file mode 100644 index 0000000..045427d --- /dev/null +++ b/src/roslibpy/comm/comm_asyncio.py @@ -0,0 +1,517 @@ +"""Asyncio-based transport for roslibpy. + +Opt-in alternative to the default twisted/autobahn transport. Selected via: + +* env var ``ROSLIBPY_TRANSPORT=asyncio`` +* module-level ``roslibpy.set_default_transport("asyncio")`` +* per-instance ``Ros(host, port, transport="asyncio")`` + +Why a separate transport +------------------------ + +Twisted's reactor is a process-wide singleton that cannot be restarted after +``reactor.stop()``. Long-running test sessions that create many ``Ros`` +instances accumulate state on it. asyncio loops, by contrast, are first-class +objects that can be started, stopped, and discarded independently; that +gives us clean per-process (or per-test, when needed) isolation. + +The public ``Ros`` / ``Topic`` / ``Service`` / ``ActionClient`` / ``Param`` API +is unaffected — only the transport layer changes. + +Dependencies +------------ + +This module imports the ``websockets`` library lazily; it is declared as an +optional extra (``roslibpy[asyncio]``). The transport raises a clear error if +selected without the dependency available. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from typing import Any, Callable, Optional +from urllib.parse import urlparse + +from ..core import RosTimeoutError +from ..event_emitter import EventEmitterMixin +from . import RosBridgeProtocol + +LOGGER = logging.getLogger("roslibpy") + +# Defaults matched to ReconnectingClientFactory's behaviour so users moving +# between transports see the same retry cadence. +DEFAULT_INITIAL_RECONNECT_DELAY = 1.0 +DEFAULT_MAX_RECONNECT_DELAY = 3600.0 +DEFAULT_RECONNECT_FACTOR = 2.7 +DEFAULT_RECONNECT_JITTER = 0.119 +DEFAULT_MAX_RECONNECT_RETRIES = None # None = unbounded, matching twisted + +# Single shared event loop manager, owned by the module so all factories in +# this process share one background thread + loop — same singleton semantics +# as the twisted reactor, but with proper teardown via terminate(). +_MANAGER_SINGLETON: "Optional[AsyncioEventLoopManager]" = None +_MANAGER_SINGLETON_LOCK = threading.Lock() + + +def _get_shared_manager() -> "AsyncioEventLoopManager": + global _MANAGER_SINGLETON + if _MANAGER_SINGLETON is not None: + return _MANAGER_SINGLETON + with _MANAGER_SINGLETON_LOCK: + if _MANAGER_SINGLETON is not None: + return _MANAGER_SINGLETON + _MANAGER_SINGLETON = AsyncioEventLoopManager() + return _MANAGER_SINGLETON + + +def _import_websockets(): + """Import the optional ``websockets`` dependency lazily, with a clear error.""" + try: + import websockets # noqa: F401 + from websockets.asyncio.client import connect as ws_connect # noqa: F401 + from websockets.exceptions import ConnectionClosed, InvalidStatus # noqa: F401 + except ImportError as exc: # pragma: no cover — exercised by missing-extra test + raise ImportError( + "The asyncio transport requires the 'websockets' package. " "Install with: pip install 'roslibpy[asyncio]'" + ) from exc + return ws_connect, ConnectionClosed, InvalidStatus + + +class AsyncioRosBridgeProtocol(RosBridgeProtocol): + """ROS Bridge protocol implementation over an asyncio websockets connection. + + Instances are owned by :class:`AsyncioRosBridgeClientFactory`; user code + interacts with the factory, not the protocol directly. ``send_message`` + and ``send_close`` are thread-safe — they schedule the IO onto the + background loop via ``call_soon_threadsafe``. + """ + + def __init__(self, factory: "AsyncioRosBridgeClientFactory", ws_connection: Any) -> None: + super(AsyncioRosBridgeProtocol, self).__init__() + self.factory = factory + self.ws = ws_connection + self._manual_disconnect = False + self._closed = False + self._send_queue = asyncio.Queue() + self._send_task = asyncio.create_task(self._send_loop()) + + def send_message(self, payload: bytes) -> None: + """Send an already-encoded ROS bridge message frame. + + Safe to call from any thread; the actual send is scheduled onto the + background loop. + """ + if self._closed: + return + loop = self.factory.manager.loop + if loop is None or loop.is_closed(): + return + loop.call_soon_threadsafe(self._enqueue_send, payload) + + def _enqueue_send(self, payload: bytes) -> None: + if not self._closed: + self._send_queue.put_nowait(payload) + + async def _send_loop(self) -> None: + try: + while True: + payload = await self._send_queue.get() + try: + await self._send_async(payload) + finally: + self._send_queue.task_done() + except asyncio.CancelledError: + pass + + async def _send_async(self, payload: bytes) -> None: + try: + if isinstance(payload, bytes): + payload = payload.decode("utf-8") + await self.ws.send(payload) + except Exception: + LOGGER.exception("Failed to send ROS bridge frame; connection likely dropped.") + + def send_close(self) -> None: + """Initiate a clean WebSocket close. + + Sets the manual-disconnect flag so the factory's reconnect supervisor + knows to stand down, then asks the loop to close the socket. + """ + self._manual_disconnect = True + if self._closed: + return + loop = self.factory.manager.loop + if loop is None or loop.is_closed(): + return + loop.call_soon_threadsafe(self._schedule_close) + + def _schedule_close(self) -> None: + asyncio.create_task(self._close_after_pending_sends()) + + async def _close_after_pending_sends(self) -> None: + await self._send_queue.join() + await self._close_async() + + async def _close_async(self) -> None: + try: + await self.ws.close() + except Exception: + LOGGER.debug("Error during WebSocket close (often harmless if already closed).", exc_info=True) + finally: + self._stop_sender() + + def _stop_sender(self) -> None: + if not self._send_task.done(): + self._send_task.cancel() + + +class AsyncioRosBridgeClientFactory(EventEmitterMixin): + """ROS Bridge client factory backed by the ``websockets`` library on asyncio. + + Mirrors the public surface of :class:`AutobahnRosBridgeClientFactory` so + that callers (``Ros`` and friends) don't care which transport is selected. + """ + + # Class-level reconnect tuning, kept as class attributes so the + # `set_initial_delay` / `set_max_delay` / `set_max_retries` classmethods + # behave like their autobahn counterparts. + initialDelay = DEFAULT_INITIAL_RECONNECT_DELAY + maxDelay = DEFAULT_MAX_RECONNECT_DELAY + factor = DEFAULT_RECONNECT_FACTOR + jitter = DEFAULT_RECONNECT_JITTER + maxRetries = DEFAULT_MAX_RECONNECT_RETRIES + compression = "deflate" + + def __init__(self, url: str, headers: Optional[dict] = None) -> None: + super(AsyncioRosBridgeClientFactory, self).__init__() + self._validate_url(url) + self._url = url + self._headers = headers + self._proto: Optional[AsyncioRosBridgeProtocol] = None + self._manager: Optional[AsyncioEventLoopManager] = None + # Lock guarding `_proto` reads/writes in on_ready / ready. Closes the + # TOCTOU race between checking ``_proto`` and registering a one-shot + # "ready" listener that exists in the autobahn factory. + self._proto_lock = threading.Lock() + # Background reconnect supervisor task, owned by the loop thread. + self._supervisor_task: Optional[asyncio.Task] = None + self._stop_supervisor = False + self._retry_count = 0 + + # ------------------------------------------------------------------ + # Public surface mirroring AutobahnRosBridgeClientFactory + # ------------------------------------------------------------------ + + @staticmethod + def _validate_url(url: str) -> None: + parsed = urlparse(url) + if parsed.scheme not in ("ws", "wss") or not parsed.netloc: + raise ValueError("WebSocket URL must use the ws:// or wss:// schema") + + def connect(self) -> None: + """Schedule the initial WebSocket connection on the background loop.""" + manager = self.manager + manager.run() # ensure background loop is up + loop = manager.loop + assert loop is not None + # Schedule the supervisor task; it owns the connect / reconnect loop. + loop.call_soon_threadsafe(self._launch_supervisor) + + def _launch_supervisor(self) -> None: + if self._supervisor_task is not None and not self._supervisor_task.done(): + return + self._stop_supervisor = False + self._supervisor_task = asyncio.create_task(self._supervise_connection()) + + @property + def is_connected(self) -> bool: + proto = self._proto + return proto is not None and not proto._closed + + def on_ready(self, callback: Callable[[AsyncioRosBridgeProtocol], None]) -> None: + """Register a callback to fire as soon as the connection is established. + + If the connection is already established, fires synchronously. Otherwise + registers a one-shot listener for the next "ready" event. Protected by a + lock so the TOCTOU race between checking ``_proto`` and registering the + listener can't drop callbacks the way the autobahn variant occasionally + did under reactor contention. + """ + proto_to_fire: Optional[AsyncioRosBridgeProtocol] = None + with self._proto_lock: + if self._proto is not None: + proto_to_fire = self._proto + else: + self.once("ready", callback) + if proto_to_fire is not None: + callback(proto_to_fire) + + def ready(self, proto: AsyncioRosBridgeProtocol) -> None: + """Mark the connection as ready and notify any pending listeners.""" + with self._proto_lock: + self._proto = proto + self._retry_count = 0 # reset backoff on every successful connect + self.emit("ready", proto) + + @classmethod + def create_url(cls, host: str, port: Optional[int] = None, is_secure: bool = False) -> str: + if port is None: + return host + scheme = "wss" if is_secure else "ws" + return "{}://{}:{}/".format(scheme, host, port) + + @classmethod + def set_max_delay(cls, max_delay: float) -> None: + """Set the maximum reconnect backoff delay in seconds (3600 by default).""" + cls.maxDelay = max_delay + + @classmethod + def set_initial_delay(cls, initial_delay: float) -> None: + """Set the initial reconnect backoff delay in seconds (1 by default).""" + cls.initialDelay = initial_delay + + @classmethod + def set_max_retries(cls, max_retries: Optional[int]) -> None: + """Set the max reconnect attempts when the connection is lost (unbounded by default).""" + cls.maxRetries = max_retries + + # ------------------------------------------------------------------ + # Supervisor coroutine: owns connect / reconnect with backoff + # ------------------------------------------------------------------ + + @property + def manager(self) -> "AsyncioEventLoopManager": + if self._manager is None: + self._manager = _get_shared_manager() + return self._manager + + async def _supervise_connection(self) -> None: + """Maintain an open connection, reconnecting with exponential backoff + if it drops unexpectedly. + + Stops when a manual disconnect is observed (``proto._manual_disconnect`` + set by ``send_close()``) or when ``maxRetries`` is exhausted. + """ + import random + + ws_connect, ConnectionClosed, InvalidStatus = _import_websockets() + delay = self.initialDelay + + while not self._stop_supervisor: + try: + LOGGER.debug("Connecting to %s...", self._url) + async with ws_connect( + self._url, + additional_headers=self._headers, + compression=self.compression, + open_timeout=None, # we don't impose our own connect timeout here + close_timeout=5, + ) as ws: + LOGGER.info("Connection to ROS ready.") + proto = AsyncioRosBridgeProtocol(self, ws) + self.ready(proto) + # Reset backoff on every successful connection. + delay = self.initialDelay + self._retry_count = 0 + try: + await self._receive_loop(proto) + finally: + self._on_connection_closed(proto) + if proto._manual_disconnect: + LOGGER.debug("Manual disconnect — supervisor exiting.") + self._stop_supervisor = True + break + except (ConnectionClosed, OSError, InvalidStatus) as exc: + LOGGER.debug("Connection attempt failed: %s", exc) + except asyncio.CancelledError: + raise + except Exception: # noqa: BLE001 + LOGGER.exception("Unexpected error in connection supervisor.") + + if self._stop_supervisor: + break + + self._retry_count += 1 + if self.maxRetries is not None and self._retry_count > self.maxRetries: + LOGGER.warning("Exceeded max reconnect retries (%s); supervisor exiting.", self.maxRetries) + break + + # Apply jittered exponential backoff, capped at maxDelay. + jittered = delay + (random.random() * 2 - 1) * delay * self.jitter + sleep_for = max(0.0, min(jittered, self.maxDelay)) + LOGGER.debug("Will retry connection in %.2fs (attempt %d).", sleep_for, self._retry_count) + try: + await asyncio.sleep(sleep_for) + except asyncio.CancelledError: + raise + delay = min(delay * self.factor, self.maxDelay) + + async def _receive_loop(self, proto: AsyncioRosBridgeProtocol) -> None: + """Pump incoming WebSocket frames into the protocol's ``on_message``.""" + ws_connect, ConnectionClosed, InvalidStatus = _import_websockets() + try: + async for payload in proto.ws: + if isinstance(payload, str): + payload = payload.encode("utf-8") + try: + proto.on_message(payload) + except Exception: + LOGGER.exception("Exception in user message handler; skipping frame.") + except ConnectionClosed: + LOGGER.info("WebSocket connection closed.") + + def _on_connection_closed(self, proto: AsyncioRosBridgeProtocol) -> None: + proto._closed = True + proto._stop_sender() + with self._proto_lock: + self._proto = None + # Notify listeners that the connection is gone. Matches the autobahn + # factory's "close" emit out of clientConnectionLost; downstream code + # (e.g. ``Ros.close()`` post-2.1 lifecycle work) uses this to know + # the socket is actually torn down. + try: + self.emit("close", proto) + except Exception: + LOGGER.exception("Error in user 'close' listener.") + + +class AsyncioEventLoopManager(object): + """Manage the asyncio event loop on a background thread. + + Mirrors :class:`TwistedEventLoopManager`'s surface so ``Ros`` doesn't need + transport-specific branches. There is only ever one of these per process + (held by ``_MANAGER_SINGLETON``), and the same loop is shared by every + ``Ros`` instance using the asyncio transport. + """ + + def __init__(self) -> None: + self.loop: Optional[asyncio.AbstractEventLoop] = None + self._thread: Optional[threading.Thread] = None + self._started = threading.Event() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def run(self) -> None: + """Spin up the background loop thread if it isn't running yet.""" + if self._thread is not None and self._thread.is_alive(): + return + self._started.clear() + self._thread = threading.Thread(target=self._run_thread, daemon=True, name="roslibpy-asyncio") + self._thread.start() + # Wait until the loop is actually running before returning so callers + # can safely use call_soon_threadsafe immediately after. + self._started.wait(timeout=5) + + def _run_thread(self) -> None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self._started.set() + try: + self.loop.run_forever() + finally: + self.loop.close() + + def run_forever(self) -> None: + """Run the loop on the calling thread (rarely used; matches twisted's run_forever).""" + if self.loop is None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self._started.set() + self.loop.run_forever() + + def terminate(self) -> None: + """Stop the background loop and join the thread. After this, the + manager is unusable in this process (matching twisted's one-shot + ``reactor.stop()`` semantics — though asyncio could in principle + be re-run, we keep parity). + """ + if self.loop is not None and self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=5) + + # ------------------------------------------------------------------ + # Scheduling primitives + # ------------------------------------------------------------------ + + def call_later(self, delay: float, callback: Callable[[], None]) -> None: + """Run ``callback`` on the loop after ``delay`` seconds.""" + loop = self.loop + if loop is None or loop.is_closed(): + return + loop.call_soon_threadsafe(loop.call_later, delay, callback) + + def call_in_thread(self, callback: Callable[[], None]) -> None: + """Run ``callback`` on a worker thread (off the loop).""" + loop = self.loop + if loop is None or loop.is_closed(): + return + loop.call_soon_threadsafe(lambda: loop.run_in_executor(None, callback)) + + def blocking_call_from_thread(self, callback: Callable[[Any], Any], timeout: Optional[float]) -> Any: + """Run ``callback(result_placeholder)`` on the loop, block until result is set. + + ``callback`` is a function that accepts an ``asyncio.Future``; the + callback is expected to register handlers that eventually resolve the + future. We then block on a stdlib ``threading.Event`` set by a done- + callback, so the caller doesn't have to know anything about asyncio. + + Mirrors ``TwistedEventLoopManager.blocking_call_from_thread``; the + ``result_placeholder`` argument it passes to ``callback`` is an + :class:`asyncio.Future` rather than a twisted ``Deferred``, but + ``get_inner_callback`` / ``get_inner_errback`` keep the shape uniform + for ``Ros`` consumers. + """ + loop = self.loop + if loop is None or loop.is_closed(): + raise RuntimeError("asyncio loop is not running") + + result_box: dict = {} + done = threading.Event() + + def _on_loop() -> None: + future = loop.create_future() + + def _on_future_done(fut: asyncio.Future) -> None: + try: + result_box["result"] = fut.result() + except Exception as exc: # noqa: BLE001 + result_box["error"] = exc + finally: + done.set() + + future.add_done_callback(_on_future_done) + try: + callback(future) + except Exception as exc: # noqa: BLE001 + result_box["error"] = exc + done.set() + + loop.call_soon_threadsafe(_on_loop) + if not done.wait(timeout=timeout if timeout else None): + raise RosTimeoutError("No service response received") + if "error" in result_box: + raise result_box["error"] + return result_box["result"] + + def get_inner_callback(self, result_placeholder: asyncio.Future) -> Callable[[Any], None]: + """Return a callback that resolves ``result_placeholder`` with success.""" + + def inner_callback(result: Any) -> None: + if not result_placeholder.done(): + result_placeholder.set_result({"result": result}) + + return inner_callback + + def get_inner_errback(self, result_placeholder: asyncio.Future) -> Callable[[Any], None]: + """Return an errback that resolves ``result_placeholder`` with an error.""" + + def inner_errback(error: Any) -> None: + if not result_placeholder.done(): + result_placeholder.set_result({"exception": error}) + + return inner_errback diff --git a/src/roslibpy/ros.py b/src/roslibpy/ros.py index 78a29c8..d473586 100644 --- a/src/roslibpy/ros.py +++ b/src/roslibpy/ros.py @@ -4,7 +4,7 @@ import threading from . import Message, Param, Service, ServiceRequest, Time -from .comm import RosBridgeClientFactory +from .comm import select_factory from .core import RosTimeoutError __all__ = ["Ros", "set_rosapi_timeout"] @@ -35,10 +35,25 @@ class Ros(object): headers (:obj:`dict`): Additional headers to include in the WebSocket connection. """ - def __init__(self, host, port=None, is_secure=False, headers=None): + def __init__(self, host, port=None, is_secure=False, headers=None, transport=None): + """Create a new connection manager. + + Args: + host (:obj:`str`): Name or IP address of the ROS bridge host, e.g. ``127.0.0.1``. + port (:obj:`int`): ROS bridge port, e.g. ``9090``. + is_secure (:obj:`bool`): ``True`` to use a secure web sockets connection. + headers (:obj:`dict`): Additional headers to include in the WebSocket connection. + transport (:obj:`str`, optional): Transport backend to use. One of + ``"twisted"``, ``"asyncio"``, ``"cli"``. If ``None`` (default), + resolves via the ``ROSLIBPY_TRANSPORT`` env var, + :func:`roslibpy.set_default_transport`, or the platform default + (``cli`` on IronPython, ``twisted`` elsewhere). See + ``roslibpy.comm.__init__`` for the full precedence rules. + """ self._id_counter = 0 - url = RosBridgeClientFactory.create_url(host, port, is_secure) - self.factory = RosBridgeClientFactory(url, headers=headers) + factory_cls = select_factory(transport) + url = factory_cls.create_url(host, port, is_secure) + self.factory = factory_cls(url, headers=headers) self.is_connecting = False self.connect() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8076485 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import sys + +import pytest + +ROS_TRANSPORTS = ("cli",) if sys.platform == "cli" else ("twisted", "asyncio") + + +@pytest.fixture(params=ROS_TRANSPORTS) +def ros_transport(request): + if request.param == "asyncio": + pytest.importorskip("websockets.asyncio.client") + return request.param diff --git a/tests/ros1/test_actionlib.py b/tests/ros1/test_actionlib.py index 54a88af..a9b8ca3 100644 --- a/tests/ros1/test_actionlib.py +++ b/tests/ros1/test_actionlib.py @@ -6,8 +6,8 @@ from roslibpy.ros1.actionlib import ActionClient, Goal, GoalStatus, SimpleActionServer -def test_action_success(): - ros = Ros("127.0.0.1", 9090) +def test_action_success(ros_transport): + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() server = SimpleActionServer(ros, "/test_action", "actionlib/TestAction") @@ -31,8 +31,8 @@ def execute(goal): ros.close() -def test_action_preemt(): - ros = Ros("127.0.0.1", 9090) +def test_action_preemt(ros_transport): + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() server = SimpleActionServer(ros, "/test_action", "actionlib/TestAction") diff --git a/tests/ros1/test_param.py b/tests/ros1/test_param.py index 32352eb..08a2d89 100644 --- a/tests/ros1/test_param.py +++ b/tests/ros1/test_param.py @@ -1,8 +1,8 @@ from roslibpy import Param, Ros -def test_param_manipulation(): - ros = Ros("127.0.0.1", 9090) +def test_param_manipulation(ros_transport): + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() param = Param(ros, "test_param") diff --git a/tests/ros1/test_ros.py b/tests/ros1/test_ros.py index aef72c5..1721085 100644 --- a/tests/ros1/test_ros.py +++ b/tests/ros1/test_ros.py @@ -10,8 +10,8 @@ url = "ws://%s:%d" % (host, port) -def test_reconnect_does_not_trigger_on_client_close(): - ros = Ros(host, port) +def test_reconnect_does_not_trigger_on_client_close(ros_transport): + ros = Ros(host, port, transport=ros_transport) ros.run() assert ros.is_connected, "ROS initially connected" @@ -25,22 +25,22 @@ def test_reconnect_does_not_trigger_on_client_close(): assert not ros.is_connecting, "Not trying to re-connect" -def test_connection(): - ros = Ros(host, port) +def test_connection(ros_transport): + ros = Ros(host, port, transport=ros_transport) ros.run() assert ros.is_connected ros.close() -def test_url_connection(): - ros = Ros(url) +def test_url_connection(ros_transport): + ros = Ros(url, transport=ros_transport) ros.run() assert ros.is_connected ros.close() -def test_closing_event(): - ros = Ros(url) +def test_closing_event(ros_transport): + ros = Ros(url, transport=ros_transport) ros.run() ctx = dict(closing_event_called=False, was_still_connected=False) @@ -60,12 +60,12 @@ def handle_closing(): assert closing_was_handled_synchronously_before_close -def test_multithreaded_connect_disconnect(): +def test_multithreaded_connect_disconnect(ros_transport): CONNECTIONS = 30 clients = [] def connect(clients): - ros = Ros(url) + ros = Ros(url, transport=ros_transport) ros.run() clients.append(ros) diff --git a/tests/ros1/test_rosapi.py b/tests/ros1/test_rosapi.py index 2f37ac1..a11b42d 100644 --- a/tests/ros1/test_rosapi.py +++ b/tests/ros1/test_rosapi.py @@ -9,9 +9,9 @@ url = "ws://%s:%d" % (host, port) -def test_rosapi_topics(): +def test_rosapi_topics(ros_transport): context = dict(wait=threading.Event(), result=None) - ros = Ros(host, port) + ros = Ros(host, port, transport=ros_transport) ros.run() def callback(topic_list): @@ -26,8 +26,8 @@ def callback(topic_list): ros.close() -def test_rosapi_topics_blocking(): - ros = Ros(host, port) +def test_rosapi_topics_blocking(ros_transport): + ros = Ros(host, port, transport=ros_transport) ros.run() topic_list = ros.get_topics() @@ -37,11 +37,11 @@ def test_rosapi_topics_blocking(): ros.close() -def test_connection_fails_when_missing_port(): +def test_connection_fails_when_missing_port(ros_transport): with pytest.raises(Exception): - Ros(host) + Ros(host, transport=ros_transport) -def test_connection_fails_when_schema_not_ws(): +def test_connection_fails_when_schema_not_ws(ros_transport): with pytest.raises(Exception): - Ros("http://%s:%d" % (host, port)) + Ros("http://%s:%d" % (host, port), transport=ros_transport) diff --git a/tests/ros1/test_service.py b/tests/ros1/test_service.py index 64bd165..d4b47b0 100644 --- a/tests/ros1/test_service.py +++ b/tests/ros1/test_service.py @@ -5,8 +5,8 @@ from roslibpy import Ros, Service, ServiceRequest -def test_add_two_ints_service(): - ros = Ros("127.0.0.1", 9090) +def test_add_two_ints_service(ros_transport): + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() def add_two_ints(request, response): @@ -29,8 +29,8 @@ def add_two_ints(request, response): ros.close() -def test_empty_service(): - ros = Ros("127.0.0.1", 9090) +def test_empty_service(ros_transport): + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() service = Service(ros, "/test_empty_service", "std_srvs/Empty") diff --git a/tests/ros1/test_tf.py b/tests/ros1/test_tf.py index bbb922d..40e8f1c 100644 --- a/tests/ros1/test_tf.py +++ b/tests/ros1/test_tf.py @@ -4,9 +4,9 @@ from roslibpy.tf import TFClient -def test_tf_test(): +def test_tf_test(ros_transport): context = dict(wait=threading.Event(), counter=0) - ros = Ros("127.0.0.1", 9090) + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() tf_client = TFClient(ros, fixed_frame="world") diff --git a/tests/ros1/test_topic.py b/tests/ros1/test_topic.py index 97ba430..9248d93 100644 --- a/tests/ros1/test_topic.py +++ b/tests/ros1/test_topic.py @@ -6,10 +6,10 @@ from roslibpy import Header, Message, Ros, Time, Topic -def test_topic_pubsub(): +def test_topic_pubsub(ros_transport): context = dict(wait=threading.Event(), counter=0) - ros = Ros("127.0.0.1", 9090) + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() listener = Topic(ros, "/chatter", "std_msgs/String") @@ -50,10 +50,10 @@ def start_receiving(): ros.close() -def test_topic_with_header(): +def test_topic_with_header(ros_transport): context = dict(wait=threading.Event()) - ros = Ros("127.0.0.1", 9090) + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() listener = Topic(ros, "/points", "geometry_msgs/PointStamped") diff --git a/tests/ros2/test_actions.py b/tests/ros2/test_actions.py index 6f02f08..115dba6 100644 --- a/tests/ros2/test_actions.py +++ b/tests/ros2/test_actions.py @@ -5,8 +5,8 @@ from roslibpy import ActionClient, Goal, GoalStatus, Ros -def test_fibonacci(): - ros = Ros("127.0.0.1", 9090) +def test_fibonacci(ros_transport): + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() action = ActionClient(ros, "/fibonacci", "example_interfaces/action/Fibonacci") @@ -34,8 +34,8 @@ def on_error(error): ros.close() -def test_cancel(): - ros = Ros("127.0.0.1", 9090) +def test_cancel(ros_transport): + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() action = ActionClient(ros, "/fibonacci", "example_interfaces/action/Fibonacci") diff --git a/tests/ros2/test_ros.py b/tests/ros2/test_ros.py index aef72c5..1721085 100644 --- a/tests/ros2/test_ros.py +++ b/tests/ros2/test_ros.py @@ -10,8 +10,8 @@ url = "ws://%s:%d" % (host, port) -def test_reconnect_does_not_trigger_on_client_close(): - ros = Ros(host, port) +def test_reconnect_does_not_trigger_on_client_close(ros_transport): + ros = Ros(host, port, transport=ros_transport) ros.run() assert ros.is_connected, "ROS initially connected" @@ -25,22 +25,22 @@ def test_reconnect_does_not_trigger_on_client_close(): assert not ros.is_connecting, "Not trying to re-connect" -def test_connection(): - ros = Ros(host, port) +def test_connection(ros_transport): + ros = Ros(host, port, transport=ros_transport) ros.run() assert ros.is_connected ros.close() -def test_url_connection(): - ros = Ros(url) +def test_url_connection(ros_transport): + ros = Ros(url, transport=ros_transport) ros.run() assert ros.is_connected ros.close() -def test_closing_event(): - ros = Ros(url) +def test_closing_event(ros_transport): + ros = Ros(url, transport=ros_transport) ros.run() ctx = dict(closing_event_called=False, was_still_connected=False) @@ -60,12 +60,12 @@ def handle_closing(): assert closing_was_handled_synchronously_before_close -def test_multithreaded_connect_disconnect(): +def test_multithreaded_connect_disconnect(ros_transport): CONNECTIONS = 30 clients = [] def connect(clients): - ros = Ros(url) + ros = Ros(url, transport=ros_transport) ros.run() clients.append(ros) diff --git a/tests/ros2/test_rosapi.py b/tests/ros2/test_rosapi.py index 2f37ac1..a11b42d 100644 --- a/tests/ros2/test_rosapi.py +++ b/tests/ros2/test_rosapi.py @@ -9,9 +9,9 @@ url = "ws://%s:%d" % (host, port) -def test_rosapi_topics(): +def test_rosapi_topics(ros_transport): context = dict(wait=threading.Event(), result=None) - ros = Ros(host, port) + ros = Ros(host, port, transport=ros_transport) ros.run() def callback(topic_list): @@ -26,8 +26,8 @@ def callback(topic_list): ros.close() -def test_rosapi_topics_blocking(): - ros = Ros(host, port) +def test_rosapi_topics_blocking(ros_transport): + ros = Ros(host, port, transport=ros_transport) ros.run() topic_list = ros.get_topics() @@ -37,11 +37,11 @@ def test_rosapi_topics_blocking(): ros.close() -def test_connection_fails_when_missing_port(): +def test_connection_fails_when_missing_port(ros_transport): with pytest.raises(Exception): - Ros(host) + Ros(host, transport=ros_transport) -def test_connection_fails_when_schema_not_ws(): +def test_connection_fails_when_schema_not_ws(ros_transport): with pytest.raises(Exception): - Ros("http://%s:%d" % (host, port)) + Ros("http://%s:%d" % (host, port), transport=ros_transport) diff --git a/tests/ros2/test_topic.py b/tests/ros2/test_topic.py index b899176..1cd2879 100644 --- a/tests/ros2/test_topic.py +++ b/tests/ros2/test_topic.py @@ -7,10 +7,10 @@ from roslibpy.ros2 import Header -def test_topic_pubsub(): +def test_topic_pubsub(ros_transport): context = dict(wait=threading.Event(), counter=0) - ros = Ros("127.0.0.1", 9090) + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() listener = Topic(ros, "/chatter", "std_msgs/String") @@ -51,10 +51,10 @@ def start_receiving(): ros.close() -def test_topic_with_header(): +def test_topic_with_header(ros_transport): context = dict(wait=threading.Event()) - ros = Ros("127.0.0.1", 9090) + ros = Ros("127.0.0.1", 9090, transport=ros_transport) ros.run() listener = Topic(ros, "/points", "geometry_msgs/PointStamped") diff --git a/tests/test_transport.py b/tests/test_transport.py new file mode 100644 index 0000000..0f5b497 --- /dev/null +++ b/tests/test_transport.py @@ -0,0 +1,94 @@ +import sys + +import pytest + +from roslibpy import Ros, set_default_transport +from roslibpy.comm import ( + TRANSPORT_ASYNCIO, + TRANSPORT_CLI, + TRANSPORT_TWISTED, + _resolve_transport, +) + +PLATFORM_DEFAULT = TRANSPORT_CLI if sys.platform == "cli" else TRANSPORT_TWISTED + + +@pytest.fixture(autouse=True) +def reset_transport(monkeypatch): + monkeypatch.delenv("ROSLIBPY_TRANSPORT", raising=False) + set_default_transport(PLATFORM_DEFAULT) + yield + set_default_transport(PLATFORM_DEFAULT) + + +def test_transport_can_be_selected_from_environment(monkeypatch): + monkeypatch.setenv("ROSLIBPY_TRANSPORT", TRANSPORT_ASYNCIO) + + assert _resolve_transport() == TRANSPORT_ASYNCIO + + +def test_transport_can_be_selected_as_process_default(): + set_default_transport(TRANSPORT_ASYNCIO) + + assert _resolve_transport() == TRANSPORT_ASYNCIO + + +def test_explicit_transport_takes_precedence(monkeypatch): + monkeypatch.setenv("ROSLIBPY_TRANSPORT", TRANSPORT_TWISTED) + set_default_transport(TRANSPORT_TWISTED) + + assert _resolve_transport(TRANSPORT_ASYNCIO) == TRANSPORT_ASYNCIO + + +def test_ros_passes_explicit_transport_to_factory_selector(monkeypatch): + selected = [] + + class Factory(object): + @classmethod + def create_url(cls, host, port=None, is_secure=False): + return "ws://127.0.0.1:9090" + + def __init__(self, url, headers=None): + self.is_connected = False + + def connect(self): + pass + + def on_ready(self, callback): + pass + + def select_factory(transport): + selected.append(transport) + return Factory + + monkeypatch.setattr("roslibpy.ros.select_factory", select_factory) + + Ros("127.0.0.1", 9090, transport=TRANSPORT_ASYNCIO) + + assert selected == [TRANSPORT_ASYNCIO] + + +def test_asyncio_protocol_sends_text_frames(): + asyncio = pytest.importorskip("asyncio") + + from roslibpy.comm.comm_asyncio import AsyncioRosBridgeProtocol + + class WebSocket(object): + def __init__(self): + self.payload = None + + async def send(self, payload): + self.payload = payload + + async def run_test(): + websocket = WebSocket() + protocol = AsyncioRosBridgeProtocol(object(), websocket) + + await protocol._send_async(b'{"op": "call_service"}') + protocol._stop_sender() + return websocket.payload + + payload = asyncio.run(run_test()) + + assert payload == '{"op": "call_service"}' + assert isinstance(payload, str) diff --git a/tests/test_ws_headers.py b/tests/test_ws_headers.py index 5e3d0a1..263893e 100644 --- a/tests/test_ws_headers.py +++ b/tests/test_ws_headers.py @@ -1,17 +1,14 @@ -from __future__ import print_function - -import asyncio import threading import time -import websockets +import pytest from roslibpy import Ros -headers = { - 'cookie': 'token=rosbridge', - 'authorization': 'Some auth' -} +asyncio = pytest.importorskip("asyncio") +websockets = pytest.importorskip("websockets") + +headers = {"cookie": "token=rosbridge", "authorization": "Some auth"} async def websocket_handler(websocket, path): @@ -22,7 +19,7 @@ async def websocket_handler(websocket, path): async def start_server(stop_event): - server = await websockets.serve(websocket_handler, '127.0.0.1', 9000) + server = await websockets.serve(websocket_handler, "127.0.0.1", 9000) await stop_event.wait() server.close() await server.wait_closed() @@ -32,13 +29,13 @@ def run_server(stop_event): asyncio.run(start_server(stop_event)) -def run_client(): - client = Ros('127.0.0.1', 9000, headers=headers) +def run_client(ros_transport): + client = Ros("127.0.0.1", 9000, headers=headers, transport=ros_transport) client.run() client.close() -def test_websocket_headers(): +def test_websocket_headers(ros_transport): server_stop_event = asyncio.Event() stop_event = threading.Event() @@ -47,7 +44,7 @@ def test_websocket_headers(): time.sleep(1) # Give the server time to start - client_thread = threading.Thread(target=run_client) + client_thread = threading.Thread(target=run_client, args=(ros_transport,)) client_thread.start() # Wait for the client thread to finish or timeout after 10 seconds