From b8d691fe231df13b435f9c78fec3073f0264988a Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 21 Aug 2025 02:08:31 +0000 Subject: [PATCH 01/15] feat: Introduce get_trajectory function and enhance gather.py typing --- src/art/gather.py | 24 +- src/art/get_trajectory.py | 6 + tests/unit/test_gather_trajectory.ipynb | 43 +++ tests/unit/test_request_observability.ipynb | 337 ++++++++++++++++++++ 4 files changed, 403 insertions(+), 7 deletions(-) create mode 100644 src/art/get_trajectory.py create mode 100644 tests/unit/test_gather_trajectory.ipynb create mode 100644 tests/unit/test_request_observability.ipynb diff --git a/src/art/gather.py b/src/art/gather.py index 89eea5b4..1cc259ed 100644 --- a/src/art/gather.py +++ b/src/art/gather.py @@ -3,7 +3,16 @@ import contextvars from collections import Counter from dataclasses import dataclass, field -from typing import Awaitable, Callable, Iterable, Iterator, Literal, overload +from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + Iterable, + Iterator, + Literal, + overload, +) from openai.types.chat.chat_completion import Choice from tqdm import auto as tqdm @@ -18,10 +27,12 @@ async def gather_trajectory_groups( pbar_total_completion_tokens: bool = True, max_exceptions: int | float = 0, max_metrics: int | None = None, - after_each: Callable[ - [TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]] - ] - | None = None, + after_each: ( + Callable[ + [TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]] + ] + | None + ) = None, ) -> list[TrajectoryGroup]: groups = list(groups) context = GatherContext( @@ -182,8 +193,7 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None: ] if logprobs: trajectory.metrics["completion_tokens"] = sum( - len(l.content or l.refusal or []) - for l in logprobs # noqa: E741 + len(l.content or l.refusal or []) for l in logprobs # noqa: E741 ) / len(logprobs) context.metric_sums["reward"] += trajectory.reward # type: ignore context.metric_divisors["reward"] += 1 diff --git a/src/art/get_trajectory.py b/src/art/get_trajectory.py new file mode 100644 index 00000000..c23fd5a5 --- /dev/null +++ b/src/art/get_trajectory.py @@ -0,0 +1,6 @@ +from typing import Any, Coroutine + +from .trajectories import Trajectory + + +async def get_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: ... diff --git a/tests/unit/test_gather_trajectory.ipynb b/tests/unit/test_gather_trajectory.ipynb new file mode 100644 index 00000000..ad6fd6f2 --- /dev/null +++ b/tests/unit/test_gather_trajectory.ipynb @@ -0,0 +1,43 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "e24171fe", + "metadata": {}, + "outputs": [], + "source": [ + "from aiohttp import web\n", + "\n", + "\n", + "async def handler(request: web.Request) -> web.Response:\n", + " body = await request.read()\n", + " return web.Response(body=body)\n", + "\n", + "\n", + "app = web.Application()\n", + "app.router.add_route(\"POST\", \"/{p:.*}\", handler)\n", + "\n", + "# Non-blocking async runner\n", + "runner = web.AppRunner(app)\n", + "await runner.setup()\n", + "site = web.TCPSite(runner, 'localhost', 8888)\n", + "await site.start()\n", + "print(\"Server running on http://localhost:8888\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/unit/test_request_observability.ipynb b/tests/unit/test_request_observability.ipynb new file mode 100644 index 00000000..21f2cec3 --- /dev/null +++ b/tests/unit/test_request_observability.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "cfb0cf3c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Server running on http://localhost:8888\n" + ] + } + ], + "source": [ + "from aiohttp import web\n", + "\n", + "\n", + "async def h(r: web.Request) -> web.Response:\n", + " return web.Response(text=f\"{r.method} {r.path}\")\n", + "\n", + "\n", + "app = web.Application()\n", + "app.router.add_route(\"*\", \"/{p:.*}\", h)\n", + "\n", + "# Non-blocking async runner\n", + "runner = web.AppRunner(app)\n", + "await runner.setup()\n", + "site = web.TCPSite(runner, 'localhost', 8888)\n", + "await site.start()\n", + "print(\"Server running on http://localhost:8888\")\n", + "# Server is now running in the background!\n", + "# To stop later: await runner.cleanup()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f1ce74a9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Making test requests...\n", + "Response 1: GET /\n", + "Response 2: POST /test\n", + "Response 3: GET /path/to/resource\n" + ] + } + ], + "source": [ + "# Test the callback by making HTTP requests\n", + "import httpx\n", + "\n", + "async with httpx.AsyncClient() as client:\n", + " # Make multiple requests to see callbacks in action\n", + " print(\"Making test requests...\")\n", + " \n", + " r1 = await client.get(\"http://localhost:8888/\")\n", + " print(f\"Response 1: {r1.text}\")\n", + " \n", + " r2 = await client.post(\"http://localhost:8888/test\", json={\"data\": \"test\"})\n", + " print(f\"Response 2: {r2.text}\")\n", + " \n", + " r3 = await client.get(\"http://localhost:8888/path/to/resource\")\n", + " print(f\"Response 3: {r3.text}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "948f1bcd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Response aclosed: 200 OK\n" + ] + } + ], + "source": [ + "import httpx._models\n", + "\n", + "original_close = httpx._models.Response.close\n", + "original_aclose = httpx._models.Response.aclose\n", + "\n", + "def patched_close(self: httpx._models.Response) -> None:\n", + " original_close(self)\n", + " print(f\"Response closed: {self.status_code} {self.reason_phrase}\")\n", + "\n", + "async def patched_aclose(self: httpx._models.Response) -> None:\n", + " await original_aclose(self)\n", + " print(f\"Response aclosed: {self.status_code} {self.reason_phrase}\")\n", + "\n", + "httpx._models.Response.close = patched_close\n", + "httpx._models.Response.aclose = patched_aclose\n", + "\n", + "\n", + "async with httpx.AsyncClient() as client:\n", + " r = await client.get(\"http://localhost:8888/\")\n", + " r.content" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f97afafe", + "metadata": {}, + "outputs": [], + "source": [ + "import httpx\n", + "import httpx._models\n", + "from typing import Callable, List, Any\n", + "import functools\n", + "\n", + "# Storage for response callbacks\n", + "response_callbacks: List[Callable] = []\n", + "\n", + "def add_response_callback(callback: Callable) -> None:\n", + " \"\"\"Add a callback to be called when a Response is created.\"\"\"\n", + " response_callbacks.append(callback)\n", + "\n", + "def clear_response_callbacks() -> None:\n", + " \"\"\"Clear all response callbacks.\"\"\"\n", + " response_callbacks.clear()\n", + "\n", + "# Store the original Response.__init__\n", + "original_response_init = httpx._models.Response.__init__\n", + "\n", + "@functools.wraps(original_response_init)\n", + "def patched_response_init(self, *args, **kwargs):\n", + " \"\"\"Patched Response.__init__ that calls callbacks.\"\"\"\n", + " # Call the original init first\n", + " original_response_init(self, *args, **kwargs)\n", + " \n", + " # Call all registered callbacks\n", + " for callback in response_callbacks:\n", + " try:\n", + " callback(self)\n", + " except Exception as e:\n", + " print(f\"Error in response callback: {e}\")\n", + "\n", + "# Apply the patch\n", + "httpx._models.Response.__init__ = patched_response_init\n", + "print(\"โœ… httpx Response class patched with callback support\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2c201a3", + "metadata": {}, + "outputs": [], + "source": [ + "# Example: Create a callback to track all responses\n", + "from datetime import datetime\n", + "\n", + "response_log = []\n", + "\n", + "def log_response(response):\n", + " \"\"\"Callback that logs response details.\"\"\"\n", + " log_entry = {\n", + " 'timestamp': datetime.now().isoformat(),\n", + " 'status_code': response.status_code,\n", + " 'url': str(response.url) if hasattr(response, '_request') and response._request else 'N/A',\n", + " 'headers': dict(response.headers),\n", + " 'http_version': response.http_version,\n", + " 'reason_phrase': response.reason_phrase\n", + " }\n", + " response_log.append(log_entry)\n", + " print(f\"๐Ÿ“Š Response captured: {response.status_code} {response.reason_phrase}\")\n", + "\n", + "# Register the callback\n", + "add_response_callback(log_response)\n", + "print(\"Callback registered - will log all HTTP responses\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99a56e15", + "metadata": {}, + "outputs": [], + "source": [ + "# Display the captured response log\n", + "import json\n", + "\n", + "print(f\"\\n๐Ÿ“‹ Total responses captured: {len(response_log)}\")\n", + "print(\"-\" * 60)\n", + "\n", + "for i, entry in enumerate(response_log, 1):\n", + " print(f\"\\n๐Ÿ“Œ Response #{i}:\")\n", + " print(f\" Timestamp: {entry['timestamp']}\")\n", + " print(f\" Status: {entry['status_code']} {entry['reason_phrase']}\")\n", + " print(f\" URL: {entry['url']}\")\n", + " print(f\" HTTP Version: {entry['http_version']}\")\n", + " print(f\" Headers (sample): {dict(list(entry['headers'].items())[:3])}...\")\n", + "\n", + "# You can also get more detailed info\n", + "print(f\"\\n๐Ÿ” Full log available in 'response_log' variable\")\n", + "print(f\"Example: response_log[0] = {json.dumps(response_log[0] if response_log else {}, indent=2)[:200]}...\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ef7a5f2", + "metadata": {}, + "outputs": [], + "source": [ + "# Advanced example: Multiple callbacks with different purposes\n", + "import time\n", + "from collections import defaultdict\n", + "\n", + "# Callback 1: Count responses by status code\n", + "status_counter = defaultdict(int)\n", + "\n", + "def count_status_codes(response):\n", + " status_counter[response.status_code] += 1\n", + "\n", + "# Callback 2: Track response times\n", + "response_times = []\n", + "start_times = {}\n", + "\n", + "def track_response_time(response):\n", + " # Note: This is a simplified example. In real scenarios, \n", + " # you'd track request start time differently\n", + " response_times.append({\n", + " 'url': str(response.url) if hasattr(response, '_request') and response._request else 'N/A',\n", + " 'status': response.status_code,\n", + " 'timestamp': time.time()\n", + " })\n", + "\n", + "# Callback 3: Alert on errors\n", + "def alert_on_errors(response):\n", + " if response.status_code >= 400:\n", + " print(f\"โš ๏ธ ERROR RESPONSE: {response.status_code} for {response.url if hasattr(response, '_request') and response._request else 'unknown URL'}\")\n", + "\n", + "# Register all callbacks\n", + "add_response_callback(count_status_codes)\n", + "add_response_callback(track_response_time)\n", + "add_response_callback(alert_on_errors)\n", + "\n", + "print(\"โœ… Advanced callbacks registered!\")\n", + "print(f\"Total active callbacks: {len(response_callbacks)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00ce67f9", + "metadata": {}, + "outputs": [], + "source": [ + "# Test advanced callbacks and demonstrate cleanup\n", + "async with httpx.AsyncClient() as client:\n", + " print(\"Testing with all callbacks active...\")\n", + " \n", + " # Make requests including one that will cause an error (404)\n", + " await client.get(\"http://localhost:8888/\")\n", + " await client.get(\"http://localhost:8888/api/users\")\n", + " await client.post(\"http://localhost:8888/api/data\", json={\"test\": \"data\"})\n", + " \n", + " # This will trigger a 404 (path doesn't exist on most servers)\n", + " # But our test server responds to all paths, so let's simulate an error scenario\n", + " # by making a request to a non-existent server\n", + " try:\n", + " await client.get(\"http://localhost:9999/\", timeout=1.0)\n", + " except Exception as e:\n", + " print(f\"Expected error: {e}\")\n", + "\n", + "print(f\"\\n๐Ÿ“Š Status code counts: {dict(status_counter)}\")\n", + "print(f\"๐Ÿ“ˆ Total responses tracked: {len(response_times)}\")\n", + "\n", + "# Demonstrate cleanup\n", + "print(f\"\\n๐Ÿงน Cleaning up...\")\n", + "print(f\"Before cleanup: {len(response_callbacks)} callbacks\")\n", + "clear_response_callbacks()\n", + "print(f\"After cleanup: {len(response_callbacks)} callbacks\")\n", + "\n", + "# You can also restore the original Response.__init__ if needed\n", + "# httpx._models.Response.__init__ = original_response_init\n", + "# print(\"โœ… Original Response.__init__ restored\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f646052c", + "metadata": {}, + "outputs": [], + "source": [ + "# Optional: Restore original behavior completely\n", + "def restore_original_response():\n", + " \"\"\"Restore the original Response.__init__ method.\"\"\"\n", + " httpx._models.Response.__init__ = original_response_init\n", + " clear_response_callbacks()\n", + " print(\"โœ… Original httpx Response behavior restored\")\n", + " print(\" - Original __init__ method restored\")\n", + " print(\" - All callbacks cleared\")\n", + "\n", + "# Uncomment to restore:\n", + "# restore_original_response()\n", + "\n", + "# Or keep the patched version for observability\n", + "print(\"๐Ÿ’ก Patch remains active. Call restore_original_response() to revert.\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From d1714de772af8af5f7a5b244a745aaf33a93ebd1 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 21 Aug 2025 02:22:38 +0000 Subject: [PATCH 02/15] refactor: Clean up code formatting and update server port in test notebooks --- src/art/gather.py | 3 +- tests/unit/test_gather_trajectory.ipynb | 31 +++- tests/unit/test_request_observability.ipynb | 177 +++++++++++++++----- 3 files changed, 165 insertions(+), 46 deletions(-) diff --git a/src/art/gather.py b/src/art/gather.py index 1cc259ed..ff9adbb3 100644 --- a/src/art/gather.py +++ b/src/art/gather.py @@ -193,7 +193,8 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None: ] if logprobs: trajectory.metrics["completion_tokens"] = sum( - len(l.content or l.refusal or []) for l in logprobs # noqa: E741 + len(l.content or l.refusal or []) + for l in logprobs # noqa: E741 ) / len(logprobs) context.metric_sums["reward"] += trajectory.reward # type: ignore context.metric_divisors["reward"] += 1 diff --git a/tests/unit/test_gather_trajectory.ipynb b/tests/unit/test_gather_trajectory.ipynb index ad6fd6f2..c8085d2f 100644 --- a/tests/unit/test_gather_trajectory.ipynb +++ b/tests/unit/test_gather_trajectory.ipynb @@ -2,10 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "e24171fe", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Server running on http://localhost:8889\n" + ] + } + ], "source": [ "from aiohttp import web\n", "\n", @@ -21,9 +29,16 @@ "# Non-blocking async runner\n", "runner = web.AppRunner(app)\n", "await runner.setup()\n", - "site = web.TCPSite(runner, 'localhost', 8888)\n", + "site = web.TCPSite(runner, \"localhost\", 8889)\n", "await site.start()\n", - "print(\"Server running on http://localhost:8888\")" + "print(\"Server running on http://localhost:8889\")\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" ] } ], @@ -34,7 +49,15 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", "version": "3.10.13" } }, diff --git a/tests/unit/test_request_observability.ipynb b/tests/unit/test_request_observability.ipynb index 21f2cec3..be906abd 100644 --- a/tests/unit/test_request_observability.ipynb +++ b/tests/unit/test_request_observability.ipynb @@ -28,7 +28,7 @@ "# Non-blocking async runner\n", "runner = web.AppRunner(app)\n", "await runner.setup()\n", - "site = web.TCPSite(runner, 'localhost', 8888)\n", + "site = web.TCPSite(runner, \"localhost\", 8888)\n", "await site.start()\n", "print(\"Server running on http://localhost:8888\")\n", "# Server is now running in the background!\n", @@ -59,20 +59,20 @@ "async with httpx.AsyncClient() as client:\n", " # Make multiple requests to see callbacks in action\n", " print(\"Making test requests...\")\n", - " \n", + "\n", " r1 = await client.get(\"http://localhost:8888/\")\n", " print(f\"Response 1: {r1.text}\")\n", - " \n", + "\n", " r2 = await client.post(\"http://localhost:8888/test\", json={\"data\": \"test\"})\n", " print(f\"Response 2: {r2.text}\")\n", - " \n", + "\n", " r3 = await client.get(\"http://localhost:8888/path/to/resource\")\n", " print(f\"Response 3: {r3.text}\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "948f1bcd", "metadata": {}, "outputs": [ @@ -90,14 +90,17 @@ "original_close = httpx._models.Response.close\n", "original_aclose = httpx._models.Response.aclose\n", "\n", + "\n", "def patched_close(self: httpx._models.Response) -> None:\n", " original_close(self)\n", " print(f\"Response closed: {self.status_code} {self.reason_phrase}\")\n", "\n", + "\n", "async def patched_aclose(self: httpx._models.Response) -> None:\n", " await original_aclose(self)\n", " print(f\"Response aclosed: {self.status_code} {self.reason_phrase}\")\n", "\n", + "\n", "httpx._models.Response.close = patched_close\n", "httpx._models.Response.aclose = patched_aclose\n", "\n", @@ -109,36 +112,49 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "f97afafe", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… httpx Response class patched with callback support\n" + ] + } + ], "source": [ + "import functools\n", + "from typing import Any, Callable, List\n", + "\n", "import httpx\n", "import httpx._models\n", - "from typing import Callable, List, Any\n", - "import functools\n", "\n", "# Storage for response callbacks\n", "response_callbacks: List[Callable] = []\n", "\n", + "\n", "def add_response_callback(callback: Callable) -> None:\n", " \"\"\"Add a callback to be called when a Response is created.\"\"\"\n", " response_callbacks.append(callback)\n", "\n", + "\n", "def clear_response_callbacks() -> None:\n", " \"\"\"Clear all response callbacks.\"\"\"\n", " response_callbacks.clear()\n", "\n", + "\n", "# Store the original Response.__init__\n", "original_response_init = httpx._models.Response.__init__\n", "\n", + "\n", "@functools.wraps(original_response_init)\n", "def patched_response_init(self, *args, **kwargs):\n", " \"\"\"Patched Response.__init__ that calls callbacks.\"\"\"\n", " # Call the original init first\n", " original_response_init(self, *args, **kwargs)\n", - " \n", + "\n", " # Call all registered callbacks\n", " for callback in response_callbacks:\n", " try:\n", @@ -146,47 +162,73 @@ " except Exception as e:\n", " print(f\"Error in response callback: {e}\")\n", "\n", + "\n", "# Apply the patch\n", "httpx._models.Response.__init__ = patched_response_init\n", - "print(\"โœ… httpx Response class patched with callback support\")\n" + "print(\"โœ… httpx Response class patched with callback support\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "c2c201a3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Callback registered - will log all HTTP responses\n" + ] + } + ], "source": [ "# Example: Create a callback to track all responses\n", "from datetime import datetime\n", "\n", "response_log = []\n", "\n", + "\n", "def log_response(response):\n", " \"\"\"Callback that logs response details.\"\"\"\n", " log_entry = {\n", - " 'timestamp': datetime.now().isoformat(),\n", - " 'status_code': response.status_code,\n", - " 'url': str(response.url) if hasattr(response, '_request') and response._request else 'N/A',\n", - " 'headers': dict(response.headers),\n", - " 'http_version': response.http_version,\n", - " 'reason_phrase': response.reason_phrase\n", + " \"timestamp\": datetime.now().isoformat(),\n", + " \"status_code\": response.status_code,\n", + " \"url\": str(response.url)\n", + " if hasattr(response, \"_request\") and response._request\n", + " else \"N/A\",\n", + " \"headers\": dict(response.headers),\n", + " \"http_version\": response.http_version,\n", + " \"reason_phrase\": response.reason_phrase,\n", " }\n", " response_log.append(log_entry)\n", " print(f\"๐Ÿ“Š Response captured: {response.status_code} {response.reason_phrase}\")\n", "\n", + "\n", "# Register the callback\n", "add_response_callback(log_response)\n", - "print(\"Callback registered - will log all HTTP responses\")\n" + "print(\"Callback registered - will log all HTTP responses\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "99a56e15", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "๐Ÿ“‹ Total responses captured: 0\n", + "------------------------------------------------------------\n", + "\n", + "๐Ÿ” Full log available in 'response_log' variable\n", + "Example: response_log[0] = {}...\n" + ] + } + ], "source": [ "# Display the captured response log\n", "import json\n", @@ -204,15 +246,26 @@ "\n", "# You can also get more detailed info\n", "print(f\"\\n๐Ÿ” Full log available in 'response_log' variable\")\n", - "print(f\"Example: response_log[0] = {json.dumps(response_log[0] if response_log else {}, indent=2)[:200]}...\")\n" + "print(\n", + " f\"Example: response_log[0] = {json.dumps(response_log[0] if response_log else {}, indent=2)[:200]}...\"\n", + ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "9ef7a5f2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "โœ… Advanced callbacks registered!\n", + "Total active callbacks: 4\n" + ] + } + ], "source": [ "# Advanced example: Multiple callbacks with different purposes\n", "import time\n", @@ -221,26 +274,37 @@ "# Callback 1: Count responses by status code\n", "status_counter = defaultdict(int)\n", "\n", + "\n", "def count_status_codes(response):\n", " status_counter[response.status_code] += 1\n", "\n", + "\n", "# Callback 2: Track response times\n", "response_times = []\n", "start_times = {}\n", "\n", + "\n", "def track_response_time(response):\n", - " # Note: This is a simplified example. In real scenarios, \n", + " # Note: This is a simplified example. In real scenarios,\n", " # you'd track request start time differently\n", - " response_times.append({\n", - " 'url': str(response.url) if hasattr(response, '_request') and response._request else 'N/A',\n", - " 'status': response.status_code,\n", - " 'timestamp': time.time()\n", - " })\n", + " response_times.append(\n", + " {\n", + " \"url\": str(response.url)\n", + " if hasattr(response, \"_request\") and response._request\n", + " else \"N/A\",\n", + " \"status\": response.status_code,\n", + " \"timestamp\": time.time(),\n", + " }\n", + " )\n", + "\n", "\n", "# Callback 3: Alert on errors\n", "def alert_on_errors(response):\n", " if response.status_code >= 400:\n", - " print(f\"โš ๏ธ ERROR RESPONSE: {response.status_code} for {response.url if hasattr(response, '_request') and response._request else 'unknown URL'}\")\n", + " print(\n", + " f\"โš ๏ธ ERROR RESPONSE: {response.status_code} for {response.url if hasattr(response, '_request') and response._request else 'unknown URL'}\"\n", + " )\n", + "\n", "\n", "# Register all callbacks\n", "add_response_callback(count_status_codes)\n", @@ -248,25 +312,47 @@ "add_response_callback(alert_on_errors)\n", "\n", "print(\"โœ… Advanced callbacks registered!\")\n", - "print(f\"Total active callbacks: {len(response_callbacks)}\")\n" + "print(f\"Total active callbacks: {len(response_callbacks)}\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "00ce67f9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing with all callbacks active...\n", + "๐Ÿ“Š Response captured: 200 OK\n", + "Response aclosed: 200 OK\n", + "๐Ÿ“Š Response captured: 200 OK\n", + "Response aclosed: 200 OK\n", + "๐Ÿ“Š Response captured: 200 OK\n", + "Response aclosed: 200 OK\n", + "Expected error: All connection attempts failed\n", + "\n", + "๐Ÿ“Š Status code counts: {200: 3}\n", + "๐Ÿ“ˆ Total responses tracked: 3\n", + "\n", + "๐Ÿงน Cleaning up...\n", + "Before cleanup: 4 callbacks\n", + "After cleanup: 0 callbacks\n" + ] + } + ], "source": [ "# Test advanced callbacks and demonstrate cleanup\n", "async with httpx.AsyncClient() as client:\n", " print(\"Testing with all callbacks active...\")\n", - " \n", + "\n", " # Make requests including one that will cause an error (404)\n", " await client.get(\"http://localhost:8888/\")\n", " await client.get(\"http://localhost:8888/api/users\")\n", " await client.post(\"http://localhost:8888/api/data\", json={\"test\": \"data\"})\n", - " \n", + "\n", " # This will trigger a 404 (path doesn't exist on most servers)\n", " # But our test server responds to all paths, so let's simulate an error scenario\n", " # by making a request to a non-existent server\n", @@ -286,15 +372,23 @@ "\n", "# You can also restore the original Response.__init__ if needed\n", "# httpx._models.Response.__init__ = original_response_init\n", - "# print(\"โœ… Original Response.__init__ restored\")\n" + "# print(\"โœ… Original Response.__init__ restored\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "f646052c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿ’ก Patch remains active. Call restore_original_response() to revert.\n" + ] + } + ], "source": [ "# Optional: Restore original behavior completely\n", "def restore_original_response():\n", @@ -305,11 +399,12 @@ " print(\" - Original __init__ method restored\")\n", " print(\" - All callbacks cleared\")\n", "\n", + "\n", "# Uncomment to restore:\n", "# restore_original_response()\n", "\n", "# Or keep the patched version for observability\n", - "print(\"๐Ÿ’ก Patch remains active. Call restore_original_response() to revert.\")\n" + "print(\"๐Ÿ’ก Patch remains active. Call restore_original_response() to revert.\")" ] } ], From 0fc4500b440a671dd315c8b40c194429f602c9f6 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 21 Aug 2025 02:25:32 +0000 Subject: [PATCH 03/15] refactor: Simplify server startup message in test notebook --- tests/unit/test_gather_trajectory.ipynb | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/unit/test_gather_trajectory.ipynb b/tests/unit/test_gather_trajectory.ipynb index c8085d2f..faed6c1b 100644 --- a/tests/unit/test_gather_trajectory.ipynb +++ b/tests/unit/test_gather_trajectory.ipynb @@ -31,14 +31,7 @@ "await runner.setup()\n", "site = web.TCPSite(runner, \"localhost\", 8889)\n", "await site.start()\n", - "print(\"Server running on http://localhost:8889\")\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n" + "print(\"Server running on http://localhost:8889\")" ] } ], From 74ad94428c49509acf8ef6c06a9373b3982503e6 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 21 Aug 2025 19:57:18 +0000 Subject: [PATCH 04/15] feat: Add asyncio support with pytest-asyncio and implement trajectory context management --- pyproject.toml | 4 + src/art/__init__.py | 5 + src/art/get_trajectory.py | 6 - src/art/with_trajectory.py | 40 ++++ tests/unit/data.ipynb | 322 +++++++++++++++++++++++++++++ tests/unit/test_with_trajectory.py | 146 +++++++++++++ uv.lock | 24 +++ 7 files changed, 541 insertions(+), 6 deletions(-) delete mode 100644 src/art/get_trajectory.py create mode 100644 src/art/with_trajectory.py create mode 100644 tests/unit/data.ipynb create mode 100644 tests/unit/test_with_trajectory.py diff --git a/pyproject.toml b/pyproject.toml index 32661d20..96664245 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,9 @@ select = ["I"] [tool.ruff.lint.isort] case-sensitive = false +[tool.pytest.ini_options] +asyncio_mode = "auto" + [tool.uv] required-version = ">=0.6.15" dev-dependencies = [ @@ -96,6 +99,7 @@ dev-dependencies = [ "nbval>=0.11.0", "pytest-xdist>=3.8.0", "pyright[nodejs]>=1.1.403", + "pytest-asyncio>=1.1.0", ] [tool.uv.sources] diff --git a/src/art/__init__.py b/src/art/__init__.py index bf8015af..d214472a 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -20,10 +20,12 @@ from .backend import Backend from .batches import trajectory_group_batches from .gather import gather_trajectories, gather_trajectory_groups + from .model import Model, TrainableModel from .trajectories import Trajectory, TrajectoryGroup from .types import Messages, MessagesAndChoices, Tools, TrainConfig from .utils import retry +from .with_trajectory import with_trajectory, contextual_trajectory, required_trajectory __all__ = [ "dev", @@ -40,4 +42,7 @@ "TrainConfig", "Trajectory", "TrajectoryGroup", + "with_trajectory", + "contextual_trajectory", + "required_trajectory", ] diff --git a/src/art/get_trajectory.py b/src/art/get_trajectory.py deleted file mode 100644 index c23fd5a5..00000000 --- a/src/art/get_trajectory.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Any, Coroutine - -from .trajectories import Trajectory - - -async def get_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: ... diff --git a/src/art/with_trajectory.py b/src/art/with_trajectory.py new file mode 100644 index 00000000..8d65d8f0 --- /dev/null +++ b/src/art/with_trajectory.py @@ -0,0 +1,40 @@ +import contextlib +import contextvars +from typing import Any, Coroutine, Iterator + +from .trajectories import Trajectory + + +async def with_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: + trajectory = Trajectory(messages_and_choices=[], reward=0.0) + with set_trajectory_context(trajectory): + await coroutine + return trajectory + + +trajectory_context_var: contextvars.ContextVar[Trajectory | None] = ( + contextvars.ContextVar("trajectory", default=None) +) + + +@contextlib.contextmanager +def set_trajectory_context(trajectory: Trajectory) -> Iterator[None]: + token = trajectory_context_var.set(trajectory) + try: + yield + finally: + trajectory_context_var.reset(token) + + +def contextual_trajectory() -> Trajectory | None: + return trajectory_context_var.get() + + +def required_trajectory() -> Trajectory: + trajectory = contextual_trajectory() + if trajectory is None: + raise RuntimeError( + "No trajectory found. You must run this function in a context that has a trajectory. " + "Try calling your entry coroutine with get_trajectory or using current_trajectory for flexibility." + ) + return trajectory diff --git a/tests/unit/data.ipynb b/tests/unit/data.ipynb new file mode 100644 index 00000000..91c7dd1a --- /dev/null +++ b/tests/unit/data.ipynb @@ -0,0 +1,322 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a7ff6842", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "01f78de0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mbradhilton\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.21.0" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/ubuntu/sky_workdir/tests/unit/wandb/run-20250821_183738-test" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run test to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/bradhilton/tests" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/bradhilton/tests/runs/test" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 08-21 18:37:49 [__init__.py:235] Automatically detected platform cuda.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/sky_workdir/src/art/__init__.py:10: UserWarning: WARNING: Unsloth should be imported before transformers, peft to ensure all optimizations are applied. Your code may run slower or encounter memory issues without these optimizations.\n", + "\n", + "Please restructure your imports with 'import unsloth' at the top of your file.\n", + " import unsloth # type: ignore # noqa: F401\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿฆฅ Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", + "INFO 08-21 18:37:59 [__init__.py:235] Automatically detected platform cuda.\n", + "๐Ÿฆฅ Unsloth Zoo will now patch everything to make training faster!\n", + "Unsloth: Patching vLLM v1 graph capture\n", + "Unsloth: Patching vLLM v0 graph capture\n", + "==((====))== Unsloth 2025.8.6: Fast Qwen2 patching. Transformers: 4.53.2. vLLM: 0.10.0.\n", + " \\\\ /| NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.189 GB. Platform: Linux.\n", + "O^O/ \\_/ \\ Torch: 2.7.1+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.1\n", + "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]\n", + " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", + "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n", + "Unsloth: vLLM loading unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit with actual GPU utilization = 78.47%\n", + "Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 79.19 GB.\n", + "Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 32768. Num Sequences = 368.\n", + "Unsloth: vLLM's KV Cache can use up to 56.27 GB. Also swap space = 6 GB.\n", + "Unsloth: Not an error, but `device` is not supported in vLLM. Skipping.\n", + "INFO 08-21 18:38:20 [config.py:1604] Using max model len 32768\n", + "Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'bfloat16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': ['lm_head', 'multi_modal_projector', 'merger', 'modality_projection', 'model.layers.0.self_attn', 'model.layers.1.self_attn', 'model.layers.2.mlp', 'model.layers.3.mlp', 'model.layers.4.mlp', 'model.layers.25.mlp', 'model.layers.26.mlp'], 'llm_int8_threshold': 6.0}\n", + "INFO 08-21 18:38:21 [llm_engine.py:228] Initializing a V0 LLM engine (v0.10.0) with config: model='unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit', speculative_config=None, tokenizer='unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=bitsandbytes, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit, num_scheduler_steps=16, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, pooler_config=None, compilation_config={\"level\":0,\"debug_dump_path\":\"\",\"cache_dir\":\"\",\"backend\":\"inductor\",\"custom_ops\":[],\"splitting_ops\":[],\"use_inductor\":true,\"compile_sizes\":[],\"inductor_compile_config\":{\"epilogue_fusion\":true,\"max_autotune\":false,\"shape_padding\":true,\"trace.enabled\":false,\"triton.cudagraphs\":true,\"debug\":false,\"dce\":true,\"memory_planning\":true,\"coordinate_descent_tuning\":true,\"trace.graph_diagram\":false,\"compile_threads\":26,\"group_fusion\":true,\"disable_progress\":false,\"verbose_progress\":true,\"triton.multi_kernel\":0,\"triton.use_block_ptr\":true,\"triton.enable_persistent_tma_matmul\":true,\"triton.autotune_at_compile_time\":false,\"triton.cooperative_reductions\":false,\"cuda.compile_opt_level\":\"-O2\",\"cuda.enable_cuda_lto\":true,\"combo_kernels\":false,\"benchmark_combo_kernel\":true,\"combo_kernel_foreach_dynamic_shapes\":true,\"enable_auto_functionalized_v2\":false},\"inductor_passes\":{},\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":1,\"cudagraph_capture_sizes\":[368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"cudagraph_copy_inputs\":false,\"full_cuda_graph\":false,\"max_capture_size\":368,\"local_cache_dir\":null}, use_cached_outputs=False, \n", + "INFO 08-21 18:38:24 [cuda.py:398] Using Flash Attention backend.\n", + "INFO 08-21 18:38:26 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0\n", + "INFO 08-21 18:38:26 [model_runner.py:1083] Starting to load model unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit...\n", + "INFO 08-21 18:38:26 [bitsandbytes_loader.py:733] Loading weights with BitsAndBytes quantization. May take a while ...\n", + "INFO 08-21 18:38:27 [weight_utils.py:296] Using model weights format ['*.safetensors']\n", + "INFO 08-21 18:38:42 [weight_utils.py:312] Time spent downloading weights for unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit: 15.237522 seconds\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00 web.Response: + return web.json_response(mock_response) + + app = web.Application() + app.router.add_route("POST", "/v1/chat/completions", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8888) + await site.start() + print(f"Test server started on http://localhost:8888") + + yield # Tests run here + + print("Cleaning up test server...") + await runner.cleanup() + + +async def test_with_trajectory(test_server: None) -> None: + async def say_hi() -> str | None: + """A method that says hi to an assistant and returns the response.""" + client = AsyncOpenAI(base_url="http://localhost:8888/v1", api_key="default") + message: ChatCompletionMessageParam = {"role": "user", "content": "Hi!"} + chat_completion = await client.chat.completions.create( + model="test", + messages=[message], + ) + return chat_completion.choices[0].message.content + + trajectory = await art.with_trajectory(say_hi()) + assert trajectory.messages_and_choices == [ + {"role": "user", "content": "Hi!"}, + Choice(**mock_response["choices"][0]), + ] diff --git a/uv.lock b/uv.lock index ac495301..5a96e1a5 100644 --- a/uv.lock +++ b/uv.lock @@ -336,6 +336,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "backports-tarfile" version = "1.2.0" @@ -3725,6 +3734,7 @@ dev = [ { name = "openpipe" }, { name = "pyright", extra = ["nodejs"] }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-xdist" }, { name = "ruff" }, ] @@ -3774,6 +3784,7 @@ dev = [ { name = "openpipe", specifier = ">=4.49.0" }, { name = "pyright", extras = ["nodejs"], specifier = ">=1.1.403" }, { name = "pytest", specifier = ">=8.4.1" }, + { name = "pytest-asyncio", specifier = ">=1.1.0" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "ruff", specifier = ">=0.12.1" }, ] @@ -5035,6 +5046,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/51/f8794af39eeb870e87a8c8068642fc07bce0c854d6865d7dd0f2a9d338c2/pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea", size = 46652, upload-time = "2025-07-16T04:29:26.393Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/9d/bf86eddabf8c6c9cb1ea9a869d6873b46f105a5d292d3a6f7071f5b07935/pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf", size = 15157, upload-time = "2025-07-16T04:29:24.929Z" }, +] + [[package]] name = "pytest-xdist" version = "3.8.0" From e2cb601ac0261b0e7c0c551e503e31307055d7e6 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 21 Aug 2025 20:32:19 +0000 Subject: [PATCH 05/15] feat: Enhance trajectory support in tests by integrating optional ART functionality --- tests/unit/test_with_trajectory.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit/test_with_trajectory.py b/tests/unit/test_with_trajectory.py index d7c6451e..2cff1e2b 100644 --- a/tests/unit/test_with_trajectory.py +++ b/tests/unit/test_with_trajectory.py @@ -137,8 +137,13 @@ async def say_hi() -> str | None: model="test", messages=[message], ) + # Add optional ART support with a few lines of code + if trajectory := art.contextual_trajectory(): + trajectory.messages_and_choices = [message, chat_completion.choices[0]] + trajectory.reward = 1.0 return chat_completion.choices[0].message.content + # Use the with_trajectory utility to get a trajectory from the coroutine trajectory = await art.with_trajectory(say_hi()) assert trajectory.messages_and_choices == [ {"role": "user", "content": "Hi!"}, From 583b777a02e7d49d9114fdc08cec041f62877332 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 21 Aug 2025 22:31:13 +0000 Subject: [PATCH 06/15] feat: Introduce yield_trajectory and capture_yielded_trajectory for enhanced trajectory management --- {tests/unit => dev}/data.ipynb | 0 src/art/__init__.py | 8 ++-- src/art/with_trajectory.py | 40 ------------------- src/art/yield_trajectory.py | 33 +++++++++++++++ ...trajectory.py => test_yield_trajectory.py} | 15 ++++--- 5 files changed, 45 insertions(+), 51 deletions(-) rename {tests/unit => dev}/data.ipynb (100%) delete mode 100644 src/art/with_trajectory.py create mode 100644 src/art/yield_trajectory.py rename tests/unit/{test_with_trajectory.py => test_yield_trajectory.py} (93%) diff --git a/tests/unit/data.ipynb b/dev/data.ipynb similarity index 100% rename from tests/unit/data.ipynb rename to dev/data.ipynb diff --git a/src/art/__init__.py b/src/art/__init__.py index d214472a..8cc81c48 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -20,12 +20,11 @@ from .backend import Backend from .batches import trajectory_group_batches from .gather import gather_trajectories, gather_trajectory_groups - from .model import Model, TrainableModel from .trajectories import Trajectory, TrajectoryGroup from .types import Messages, MessagesAndChoices, Tools, TrainConfig from .utils import retry -from .with_trajectory import with_trajectory, contextual_trajectory, required_trajectory +from .yield_trajectory import capture_yielded_trajectory, yield_trajectory __all__ = [ "dev", @@ -42,7 +41,6 @@ "TrainConfig", "Trajectory", "TrajectoryGroup", - "with_trajectory", - "contextual_trajectory", - "required_trajectory", + "capture_yielded_trajectory", + "yield_trajectory", ] diff --git a/src/art/with_trajectory.py b/src/art/with_trajectory.py deleted file mode 100644 index 8d65d8f0..00000000 --- a/src/art/with_trajectory.py +++ /dev/null @@ -1,40 +0,0 @@ -import contextlib -import contextvars -from typing import Any, Coroutine, Iterator - -from .trajectories import Trajectory - - -async def with_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: - trajectory = Trajectory(messages_and_choices=[], reward=0.0) - with set_trajectory_context(trajectory): - await coroutine - return trajectory - - -trajectory_context_var: contextvars.ContextVar[Trajectory | None] = ( - contextvars.ContextVar("trajectory", default=None) -) - - -@contextlib.contextmanager -def set_trajectory_context(trajectory: Trajectory) -> Iterator[None]: - token = trajectory_context_var.set(trajectory) - try: - yield - finally: - trajectory_context_var.reset(token) - - -def contextual_trajectory() -> Trajectory | None: - return trajectory_context_var.get() - - -def required_trajectory() -> Trajectory: - trajectory = contextual_trajectory() - if trajectory is None: - raise RuntimeError( - "No trajectory found. You must run this function in a context that has a trajectory. " - "Try calling your entry coroutine with get_trajectory or using current_trajectory for flexibility." - ) - return trajectory diff --git a/src/art/yield_trajectory.py b/src/art/yield_trajectory.py new file mode 100644 index 00000000..26a1ebd5 --- /dev/null +++ b/src/art/yield_trajectory.py @@ -0,0 +1,33 @@ +import contextvars +from typing import Any, Coroutine + +from .trajectories import Trajectory + + +def yield_trajectory(trajectory: Trajectory) -> None: + yield_trajectory_context_var.get().trajectory = trajectory + + +async def capture_yielded_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: + with YieldTrajectoryContext(): + await coroutine + trajectory = yield_trajectory_context_var.get().trajectory + if trajectory is None: + raise RuntimeError("No trajectory yielded") + return trajectory + + +class YieldTrajectoryContext: + def __init__(self) -> None: + self.trajectory: Trajectory | None = None + + def __enter__(self) -> None: + self.token = yield_trajectory_context_var.set(self) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + yield_trajectory_context_var.reset(self.token) + + +yield_trajectory_context_var: contextvars.ContextVar[YieldTrajectoryContext] = ( + contextvars.ContextVar("trajectory", default=YieldTrajectoryContext()) +) diff --git a/tests/unit/test_with_trajectory.py b/tests/unit/test_yield_trajectory.py similarity index 93% rename from tests/unit/test_with_trajectory.py rename to tests/unit/test_yield_trajectory.py index 2cff1e2b..95a53fe5 100644 --- a/tests/unit/test_with_trajectory.py +++ b/tests/unit/test_yield_trajectory.py @@ -1,8 +1,8 @@ +import pytest_asyncio from aiohttp import web from openai import AsyncOpenAI from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam -import pytest_asyncio import art @@ -138,13 +138,16 @@ async def say_hi() -> str | None: messages=[message], ) # Add optional ART support with a few lines of code - if trajectory := art.contextual_trajectory(): - trajectory.messages_and_choices = [message, chat_completion.choices[0]] - trajectory.reward = 1.0 + art.yield_trajectory( + art.Trajectory( + messages_and_choices=[message, chat_completion.choices[0]], + reward=1.0, + ) + ) return chat_completion.choices[0].message.content - # Use the with_trajectory utility to get a trajectory from the coroutine - trajectory = await art.with_trajectory(say_hi()) + # Use the capture_yielded_trajectory utility to capture the yielded trajectory + trajectory = await art.capture_yielded_trajectory(say_hi()) assert trajectory.messages_and_choices == [ {"role": "user", "content": "Hi!"}, Choice(**mock_response["choices"][0]), From 86d1b85c73968a4f6eee12c35a150bd79741efab Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 00:36:54 +0000 Subject: [PATCH 07/15] feat(auto-trajectory): add auto_trajectory and capture_auto_trajectory functions; update imports and tests --- src/art/__init__.py | 3 + src/art/auto_trajectory.py | 48 +++++++++ src/art/yield_trajectory.py | 2 +- tests/unit/test_auto_trajectory.py | 150 ++++++++++++++++++++++++++++ tests/unit/test_yield_trajectory.py | 2 +- 5 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 src/art/auto_trajectory.py create mode 100644 tests/unit/test_auto_trajectory.py diff --git a/src/art/__init__.py b/src/art/__init__.py index 8cc81c48..1ee47c6e 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -17,6 +17,7 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ",".join(conf) from . import dev +from .auto_trajectory import auto_trajectory, capture_auto_trajectory from .backend import Backend from .batches import trajectory_group_batches from .gather import gather_trajectories, gather_trajectory_groups @@ -28,6 +29,8 @@ __all__ = [ "dev", + "auto_trajectory", + "capture_auto_trajectory", "gather_trajectories", "gather_trajectory_groups", "trajectory_group_batches", diff --git a/src/art/auto_trajectory.py b/src/art/auto_trajectory.py new file mode 100644 index 00000000..78123e7d --- /dev/null +++ b/src/art/auto_trajectory.py @@ -0,0 +1,48 @@ +import contextvars +from typing import Any, Coroutine, Literal, overload + +from .trajectories import Trajectory + + +@overload +def auto_trajectory(*, required: Literal[True]) -> Trajectory: ... + + +@overload +def auto_trajectory(*, required: Literal[False] = False) -> Trajectory | None: ... + + +def auto_trajectory(*, required: bool = False) -> Trajectory | None: + context = auto_trajectory_context_var.get(None) + if context is None: + if required: + raise RuntimeError( + "No auto trajectory in context. `auto_trajectory(required=True)` must be called in a `capture_auto_trajectory(...)` scope." + ) + return None + return context.trajectory + + +async def capture_auto_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: + with AutoTrajectoryContext(): + await coroutine + return auto_trajectory_context_var.get().trajectory + + +class AutoTrajectoryContext: + def __init__(self) -> None: + self.trajectory = Trajectory( + messages_and_choices=[], + reward=0.0, + ) + + def __enter__(self) -> None: + self.token = auto_trajectory_context_var.set(self) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + auto_trajectory_context_var.reset(self.token) + + +auto_trajectory_context_var: contextvars.ContextVar[AutoTrajectoryContext] = ( + contextvars.ContextVar("auto_trajectory_context") +) diff --git a/src/art/yield_trajectory.py b/src/art/yield_trajectory.py index 26a1ebd5..5109d193 100644 --- a/src/art/yield_trajectory.py +++ b/src/art/yield_trajectory.py @@ -29,5 +29,5 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: yield_trajectory_context_var: contextvars.ContextVar[YieldTrajectoryContext] = ( - contextvars.ContextVar("trajectory", default=YieldTrajectoryContext()) + contextvars.ContextVar("yield_trajectory_context", default=YieldTrajectoryContext()) ) diff --git a/tests/unit/test_auto_trajectory.py b/tests/unit/test_auto_trajectory.py new file mode 100644 index 00000000..11e5a362 --- /dev/null +++ b/tests/unit/test_auto_trajectory.py @@ -0,0 +1,150 @@ +import pytest_asyncio +from aiohttp import web +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam + +import art + +mock_response = { + "id": "chatcmpl-293ce9f37dba40e5be39448acaf6fb49", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": { + "content": [ + { + "token": "token_id:9707", + "bytes": [72, 101, 108, 108, 111], + "logprob": -0.0017243054462596774, + "top_logprobs": [], + }, + { + "token": "token_id:0", + "bytes": [33], + "logprob": -0.007611795328557491, + "top_logprobs": [], + }, + { + "token": "token_id:2585", + "bytes": [32, 72, 111, 119], + "logprob": -0.03061593696475029, + "top_logprobs": [], + }, + { + "token": "token_id:646", + "bytes": [32, 99, 97, 110], + "logprob": -1.1920858014491387e-05, + "top_logprobs": [], + }, + { + "token": "token_id:358", + "bytes": [32, 73], + "logprob": -2.3841855067985307e-07, + "top_logprobs": [], + }, + { + "token": "token_id:7789", + "bytes": [32, 97, 115, 115, 105, 115, 116], + "logprob": -0.020548323169350624, + "top_logprobs": [], + }, + { + "token": "token_id:498", + "bytes": [32, 121, 111, 117], + "logprob": 0.0, + "top_logprobs": [], + }, + { + "token": "token_id:3351", + "bytes": [32, 116, 111, 100, 97, 121], + "logprob": -4.410734163684538e-06, + "top_logprobs": [], + }, + { + "token": "token_id:30", + "bytes": [63], + "logprob": -2.3841855067985307e-07, + "top_logprobs": [], + }, + { + "token": "token_id:151645", + "bytes": [], + "logprob": -0.0083366259932518, + "top_logprobs": [], + }, + ], + "refusal": None, + }, + "message": { + "content": "Hello! How can I assist you today?", + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": [], + "reasoning_content": None, + }, + "stop_reason": None, + } + ], + "created": 1755801745, + "model": "test", + "object": "chat.completion", + "service_tier": None, + "system_fingerprint": None, + "usage": { + "completion_tokens": 10, + "prompt_tokens": 31, + "total_tokens": 41, + "completion_tokens_details": None, + "prompt_tokens_details": None, + }, + "prompt_logprobs": None, + "kv_transfer_params": None, +} + + +@pytest_asyncio.fixture +async def test_server(): + """Start a test server for the module.""" + + async def handler(_: web.Request) -> web.Response: + return web.json_response(mock_response) + + app = web.Application() + app.router.add_route("POST", "/v1/chat/completions", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8888) + await site.start() + print(f"Test server started on http://localhost:8888") + + yield # Tests run here + + print("Cleaning up test server...") + await runner.cleanup() + + +async def test_auto_trajectory(test_server: None) -> None: + async def say_hi() -> str | None: + """A method that says hi to an assistant and returns the response.""" + client = AsyncOpenAI(base_url="http://localhost:8888/v1", api_key="default") + message: ChatCompletionMessageParam = {"role": "user", "content": "Hi!"} + chat_completion = await client.chat.completions.create( + model="test", + messages=[message], + ) + # Add optional ART support with a few lines of code + if trajectory := art.auto_trajectory(): + trajectory.reward = 1.0 + return chat_completion.choices[0].message.content + + # Use the capture_auto_trajectory utility to capture a trajectory automatically + trajectory = await art.capture_auto_trajectory(say_hi()) + assert trajectory.messages_and_choices == [ + {"role": "user", "content": "Hi!"}, + Choice(**mock_response["choices"][0]), + ] diff --git a/tests/unit/test_yield_trajectory.py b/tests/unit/test_yield_trajectory.py index 95a53fe5..c730c4e7 100644 --- a/tests/unit/test_yield_trajectory.py +++ b/tests/unit/test_yield_trajectory.py @@ -128,7 +128,7 @@ async def handler(_: web.Request) -> web.Response: await runner.cleanup() -async def test_with_trajectory(test_server: None) -> None: +async def test_yield_trajectory(test_server: None) -> None: async def say_hi() -> str | None: """A method that says hi to an assistant and returns the response.""" client = AsyncOpenAI(base_url="http://localhost:8888/v1", api_key="default") From 8a2deb3299540c96651de569cc53c549516cdd33 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 01:34:34 +0000 Subject: [PATCH 08/15] feat(auto-trajectory): enhance trajectory handling with HTTPX response management and context integration; remove obsolete test notebooks --- src/art/auto_trajectory.py | 87 +++- src/art/trajectories.py | 3 + tests/unit/test_gather_trajectory.ipynb | 59 --- tests/unit/test_request_observability.ipynb | 432 ------------------ .../test_tokenize_trajectory_groups.ipynb | 135 ------ 5 files changed, 87 insertions(+), 629 deletions(-) delete mode 100644 tests/unit/test_gather_trajectory.ipynb delete mode 100644 tests/unit/test_request_observability.ipynb delete mode 100644 tests/unit/test_tokenize_trajectory_groups.ipynb diff --git a/src/art/auto_trajectory.py b/src/art/auto_trajectory.py index 78123e7d..ed7b1b9c 100644 --- a/src/art/auto_trajectory.py +++ b/src/art/auto_trajectory.py @@ -1,7 +1,11 @@ import contextvars -from typing import Any, Coroutine, Literal, overload +import json +from typing import Any, AsyncIterator, Coroutine, Iterator, Literal, overload -from .trajectories import Trajectory +import httpx._models +from openai.types.chat.chat_completion import Choice + +from .trajectories import History, Trajectory @overload @@ -26,7 +30,9 @@ def auto_trajectory(*, required: bool = False) -> Trajectory | None: async def capture_auto_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: with AutoTrajectoryContext(): await coroutine - return auto_trajectory_context_var.get().trajectory + trajectory = auto_trajectory_context_var.get().trajectory + trajectory.finish() + return trajectory class AutoTrajectoryContext: @@ -42,7 +48,82 @@ def __enter__(self) -> None: def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: auto_trajectory_context_var.reset(self.token) + def handle_httpx_response(self, response: httpx._models.Response) -> None: + try: + request_content = json.loads(getattr(response.request, "_content", b"")) + messages = request_content["messages"] + tools = request_content.get("tools", None) + choice = Choice( + **json.loads(getattr(response, "_content_so_far", b""))["choices"][0] + ) + history: Trajectory | History = self.trajectory + history_index = -1 + while True: + history_messages = history.messages() + if history_messages == messages[: len(history_messages)] and ( + history.tools == tools + or (history_messages == [] and history.tools is None) + ): + break + history_index += 1 + try: + history = self.trajectory.additional_histories[history_index] + except IndexError: + history = History(messages_and_choices=[]) + self.trajectory.additional_histories.append(history) + break + history.messages_and_choices.extend( + messages[len(history.messages_and_choices) :] + ) + history.messages_and_choices.append(choice) + history.tools = tools + except: + pass + auto_trajectory_context_var: contextvars.ContextVar[AutoTrajectoryContext] = ( contextvars.ContextVar("auto_trajectory_context") ) + + +def patch_httpx() -> None: + original_iter_bytes = httpx._models.Response.iter_bytes + original_aiter_bytes = httpx._models.Response.aiter_bytes + original_close = httpx._models.Response.close + original_aclose = httpx._models.Response.aclose + + def patched_iter_bytes( + self: httpx._models.Response, chunk_size: int | None = None + ) -> Iterator[bytes]: + for chunk in original_iter_bytes(self, chunk_size): + setattr( + self, "_content_so_far", getattr(self, "_content_so_far", b"") + chunk + ) + yield chunk + + async def patched_aiter_bytes( + self: httpx._models.Response, chunk_size: int | None = None + ) -> AsyncIterator[bytes]: + async for chunk in original_aiter_bytes(self, chunk_size): + setattr( + self, "_content_so_far", getattr(self, "_content_so_far", b"") + chunk + ) + yield chunk + + def patched_close(self: httpx._models.Response) -> None: + original_close(self) + if context := auto_trajectory_context_var.get(None): + context.handle_httpx_response(self) + + async def patched_aclose(self: httpx._models.Response) -> None: + await original_aclose(self) + if context := auto_trajectory_context_var.get(None): + context.handle_httpx_response(self) + + httpx._models.Response.iter_bytes = patched_iter_bytes + httpx._models.Response.aiter_bytes = patched_aiter_bytes + httpx._models.Response.close = patched_close + httpx._models.Response.aclose = patched_aclose + + +patch_httpx() diff --git a/src/art/trajectories.py b/src/art/trajectories.py index d7c340d6..4a2020a5 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -23,6 +23,9 @@ class History(pydantic.BaseModel): messages_and_choices: MessagesAndChoices tools: Tools | None = None + def messages(self) -> Messages: + return get_messages(self.messages_and_choices) + class Trajectory(pydantic.BaseModel): messages_and_choices: MessagesAndChoices diff --git a/tests/unit/test_gather_trajectory.ipynb b/tests/unit/test_gather_trajectory.ipynb deleted file mode 100644 index faed6c1b..00000000 --- a/tests/unit/test_gather_trajectory.ipynb +++ /dev/null @@ -1,59 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "e24171fe", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Server running on http://localhost:8889\n" - ] - } - ], - "source": [ - "from aiohttp import web\n", - "\n", - "\n", - "async def handler(request: web.Request) -> web.Response:\n", - " body = await request.read()\n", - " return web.Response(body=body)\n", - "\n", - "\n", - "app = web.Application()\n", - "app.router.add_route(\"POST\", \"/{p:.*}\", handler)\n", - "\n", - "# Non-blocking async runner\n", - "runner = web.AppRunner(app)\n", - "await runner.setup()\n", - "site = web.TCPSite(runner, \"localhost\", 8889)\n", - "await site.start()\n", - "print(\"Server running on http://localhost:8889\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tests/unit/test_request_observability.ipynb b/tests/unit/test_request_observability.ipynb deleted file mode 100644 index be906abd..00000000 --- a/tests/unit/test_request_observability.ipynb +++ /dev/null @@ -1,432 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "cfb0cf3c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Server running on http://localhost:8888\n" - ] - } - ], - "source": [ - "from aiohttp import web\n", - "\n", - "\n", - "async def h(r: web.Request) -> web.Response:\n", - " return web.Response(text=f\"{r.method} {r.path}\")\n", - "\n", - "\n", - "app = web.Application()\n", - "app.router.add_route(\"*\", \"/{p:.*}\", h)\n", - "\n", - "# Non-blocking async runner\n", - "runner = web.AppRunner(app)\n", - "await runner.setup()\n", - "site = web.TCPSite(runner, \"localhost\", 8888)\n", - "await site.start()\n", - "print(\"Server running on http://localhost:8888\")\n", - "# Server is now running in the background!\n", - "# To stop later: await runner.cleanup()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "f1ce74a9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Making test requests...\n", - "Response 1: GET /\n", - "Response 2: POST /test\n", - "Response 3: GET /path/to/resource\n" - ] - } - ], - "source": [ - "# Test the callback by making HTTP requests\n", - "import httpx\n", - "\n", - "async with httpx.AsyncClient() as client:\n", - " # Make multiple requests to see callbacks in action\n", - " print(\"Making test requests...\")\n", - "\n", - " r1 = await client.get(\"http://localhost:8888/\")\n", - " print(f\"Response 1: {r1.text}\")\n", - "\n", - " r2 = await client.post(\"http://localhost:8888/test\", json={\"data\": \"test\"})\n", - " print(f\"Response 2: {r2.text}\")\n", - "\n", - " r3 = await client.get(\"http://localhost:8888/path/to/resource\")\n", - " print(f\"Response 3: {r3.text}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "948f1bcd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Response aclosed: 200 OK\n" - ] - } - ], - "source": [ - "import httpx._models\n", - "\n", - "original_close = httpx._models.Response.close\n", - "original_aclose = httpx._models.Response.aclose\n", - "\n", - "\n", - "def patched_close(self: httpx._models.Response) -> None:\n", - " original_close(self)\n", - " print(f\"Response closed: {self.status_code} {self.reason_phrase}\")\n", - "\n", - "\n", - "async def patched_aclose(self: httpx._models.Response) -> None:\n", - " await original_aclose(self)\n", - " print(f\"Response aclosed: {self.status_code} {self.reason_phrase}\")\n", - "\n", - "\n", - "httpx._models.Response.close = patched_close\n", - "httpx._models.Response.aclose = patched_aclose\n", - "\n", - "\n", - "async with httpx.AsyncClient() as client:\n", - " r = await client.get(\"http://localhost:8888/\")\n", - " r.content" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "f97afafe", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "โœ… httpx Response class patched with callback support\n" - ] - } - ], - "source": [ - "import functools\n", - "from typing import Any, Callable, List\n", - "\n", - "import httpx\n", - "import httpx._models\n", - "\n", - "# Storage for response callbacks\n", - "response_callbacks: List[Callable] = []\n", - "\n", - "\n", - "def add_response_callback(callback: Callable) -> None:\n", - " \"\"\"Add a callback to be called when a Response is created.\"\"\"\n", - " response_callbacks.append(callback)\n", - "\n", - "\n", - "def clear_response_callbacks() -> None:\n", - " \"\"\"Clear all response callbacks.\"\"\"\n", - " response_callbacks.clear()\n", - "\n", - "\n", - "# Store the original Response.__init__\n", - "original_response_init = httpx._models.Response.__init__\n", - "\n", - "\n", - "@functools.wraps(original_response_init)\n", - "def patched_response_init(self, *args, **kwargs):\n", - " \"\"\"Patched Response.__init__ that calls callbacks.\"\"\"\n", - " # Call the original init first\n", - " original_response_init(self, *args, **kwargs)\n", - "\n", - " # Call all registered callbacks\n", - " for callback in response_callbacks:\n", - " try:\n", - " callback(self)\n", - " except Exception as e:\n", - " print(f\"Error in response callback: {e}\")\n", - "\n", - "\n", - "# Apply the patch\n", - "httpx._models.Response.__init__ = patched_response_init\n", - "print(\"โœ… httpx Response class patched with callback support\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "c2c201a3", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Callback registered - will log all HTTP responses\n" - ] - } - ], - "source": [ - "# Example: Create a callback to track all responses\n", - "from datetime import datetime\n", - "\n", - "response_log = []\n", - "\n", - "\n", - "def log_response(response):\n", - " \"\"\"Callback that logs response details.\"\"\"\n", - " log_entry = {\n", - " \"timestamp\": datetime.now().isoformat(),\n", - " \"status_code\": response.status_code,\n", - " \"url\": str(response.url)\n", - " if hasattr(response, \"_request\") and response._request\n", - " else \"N/A\",\n", - " \"headers\": dict(response.headers),\n", - " \"http_version\": response.http_version,\n", - " \"reason_phrase\": response.reason_phrase,\n", - " }\n", - " response_log.append(log_entry)\n", - " print(f\"๐Ÿ“Š Response captured: {response.status_code} {response.reason_phrase}\")\n", - "\n", - "\n", - "# Register the callback\n", - "add_response_callback(log_response)\n", - "print(\"Callback registered - will log all HTTP responses\")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "99a56e15", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "๐Ÿ“‹ Total responses captured: 0\n", - "------------------------------------------------------------\n", - "\n", - "๐Ÿ” Full log available in 'response_log' variable\n", - "Example: response_log[0] = {}...\n" - ] - } - ], - "source": [ - "# Display the captured response log\n", - "import json\n", - "\n", - "print(f\"\\n๐Ÿ“‹ Total responses captured: {len(response_log)}\")\n", - "print(\"-\" * 60)\n", - "\n", - "for i, entry in enumerate(response_log, 1):\n", - " print(f\"\\n๐Ÿ“Œ Response #{i}:\")\n", - " print(f\" Timestamp: {entry['timestamp']}\")\n", - " print(f\" Status: {entry['status_code']} {entry['reason_phrase']}\")\n", - " print(f\" URL: {entry['url']}\")\n", - " print(f\" HTTP Version: {entry['http_version']}\")\n", - " print(f\" Headers (sample): {dict(list(entry['headers'].items())[:3])}...\")\n", - "\n", - "# You can also get more detailed info\n", - "print(f\"\\n๐Ÿ” Full log available in 'response_log' variable\")\n", - "print(\n", - " f\"Example: response_log[0] = {json.dumps(response_log[0] if response_log else {}, indent=2)[:200]}...\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "9ef7a5f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "โœ… Advanced callbacks registered!\n", - "Total active callbacks: 4\n" - ] - } - ], - "source": [ - "# Advanced example: Multiple callbacks with different purposes\n", - "import time\n", - "from collections import defaultdict\n", - "\n", - "# Callback 1: Count responses by status code\n", - "status_counter = defaultdict(int)\n", - "\n", - "\n", - "def count_status_codes(response):\n", - " status_counter[response.status_code] += 1\n", - "\n", - "\n", - "# Callback 2: Track response times\n", - "response_times = []\n", - "start_times = {}\n", - "\n", - "\n", - "def track_response_time(response):\n", - " # Note: This is a simplified example. In real scenarios,\n", - " # you'd track request start time differently\n", - " response_times.append(\n", - " {\n", - " \"url\": str(response.url)\n", - " if hasattr(response, \"_request\") and response._request\n", - " else \"N/A\",\n", - " \"status\": response.status_code,\n", - " \"timestamp\": time.time(),\n", - " }\n", - " )\n", - "\n", - "\n", - "# Callback 3: Alert on errors\n", - "def alert_on_errors(response):\n", - " if response.status_code >= 400:\n", - " print(\n", - " f\"โš ๏ธ ERROR RESPONSE: {response.status_code} for {response.url if hasattr(response, '_request') and response._request else 'unknown URL'}\"\n", - " )\n", - "\n", - "\n", - "# Register all callbacks\n", - "add_response_callback(count_status_codes)\n", - "add_response_callback(track_response_time)\n", - "add_response_callback(alert_on_errors)\n", - "\n", - "print(\"โœ… Advanced callbacks registered!\")\n", - "print(f\"Total active callbacks: {len(response_callbacks)}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "00ce67f9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Testing with all callbacks active...\n", - "๐Ÿ“Š Response captured: 200 OK\n", - "Response aclosed: 200 OK\n", - "๐Ÿ“Š Response captured: 200 OK\n", - "Response aclosed: 200 OK\n", - "๐Ÿ“Š Response captured: 200 OK\n", - "Response aclosed: 200 OK\n", - "Expected error: All connection attempts failed\n", - "\n", - "๐Ÿ“Š Status code counts: {200: 3}\n", - "๐Ÿ“ˆ Total responses tracked: 3\n", - "\n", - "๐Ÿงน Cleaning up...\n", - "Before cleanup: 4 callbacks\n", - "After cleanup: 0 callbacks\n" - ] - } - ], - "source": [ - "# Test advanced callbacks and demonstrate cleanup\n", - "async with httpx.AsyncClient() as client:\n", - " print(\"Testing with all callbacks active...\")\n", - "\n", - " # Make requests including one that will cause an error (404)\n", - " await client.get(\"http://localhost:8888/\")\n", - " await client.get(\"http://localhost:8888/api/users\")\n", - " await client.post(\"http://localhost:8888/api/data\", json={\"test\": \"data\"})\n", - "\n", - " # This will trigger a 404 (path doesn't exist on most servers)\n", - " # But our test server responds to all paths, so let's simulate an error scenario\n", - " # by making a request to a non-existent server\n", - " try:\n", - " await client.get(\"http://localhost:9999/\", timeout=1.0)\n", - " except Exception as e:\n", - " print(f\"Expected error: {e}\")\n", - "\n", - "print(f\"\\n๐Ÿ“Š Status code counts: {dict(status_counter)}\")\n", - "print(f\"๐Ÿ“ˆ Total responses tracked: {len(response_times)}\")\n", - "\n", - "# Demonstrate cleanup\n", - "print(f\"\\n๐Ÿงน Cleaning up...\")\n", - "print(f\"Before cleanup: {len(response_callbacks)} callbacks\")\n", - "clear_response_callbacks()\n", - "print(f\"After cleanup: {len(response_callbacks)} callbacks\")\n", - "\n", - "# You can also restore the original Response.__init__ if needed\n", - "# httpx._models.Response.__init__ = original_response_init\n", - "# print(\"โœ… Original Response.__init__ restored\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "f646052c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "๐Ÿ’ก Patch remains active. Call restore_original_response() to revert.\n" - ] - } - ], - "source": [ - "# Optional: Restore original behavior completely\n", - "def restore_original_response():\n", - " \"\"\"Restore the original Response.__init__ method.\"\"\"\n", - " httpx._models.Response.__init__ = original_response_init\n", - " clear_response_callbacks()\n", - " print(\"โœ… Original httpx Response behavior restored\")\n", - " print(\" - Original __init__ method restored\")\n", - " print(\" - All callbacks cleared\")\n", - "\n", - "\n", - "# Uncomment to restore:\n", - "# restore_original_response()\n", - "\n", - "# Or keep the patched version for observability\n", - "print(\"๐Ÿ’ก Patch remains active. Call restore_original_response() to revert.\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tests/unit/test_tokenize_trajectory_groups.ipynb b/tests/unit/test_tokenize_trajectory_groups.ipynb deleted file mode 100644 index b90739ed..00000000 --- a/tests/unit/test_tokenize_trajectory_groups.ipynb +++ /dev/null @@ -1,135 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "93a238e4", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "84717d3b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TokenizedResult(advantage=-1.0, chat='<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is the capital of France?<|im_end|>\\n<|im_start|>assistant\\nLondon<|im_end|>\\n', tokens=['<|im_start|>', 'system', '\\n', 'You', ' are', ' Q', 'wen', ',', ' created', ' by', ' Alibaba', ' Cloud', '.', ' You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\\n', '<|im_start|>', 'user', '\\n', 'What', ' is', ' the', ' capital', ' of', ' France', '?', '<|im_end|>', '\\n', '<|im_start|>', 'assistant', '\\n', 'London', '<|im_end|>', '\\n'], token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198, 39572, 151645, 198], input_pos=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], assistant_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], logprobs=[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], weight=1.0, prompt_id=0, prompt_length=36)" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "TokenizedResult(advantage=1.0, chat='<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is the capital of France?<|im_end|>\\n<|im_start|>assistant\\nParis<|im_end|>\\n', tokens=['<|im_start|>', 'system', '\\n', 'You', ' are', ' Q', 'wen', ',', ' created', ' by', ' Alibaba', ' Cloud', '.', ' You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\\n', '<|im_start|>', 'user', '\\n', 'What', ' is', ' the', ' capital', ' of', ' France', '?', '<|im_end|>', '\\n', '<|im_start|>', 'assistant', '\\n', 'Paris', '<|im_end|>', '\\n'], token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198, 59604, 151645, 198], input_pos=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], assistant_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], logprobs=[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, -0.01, nan, nan], weight=1.0, prompt_id=0, prompt_length=36)" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import clear_output\n", - "from openai.types.chat.chat_completion import Choice, ChoiceLogprobs\n", - "from openai.types.chat.chat_completion_message import ChatCompletionMessage\n", - "from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob\n", - "from transformers import AutoTokenizer\n", - "from transformers.utils.logging import disable_progress_bar\n", - "\n", - "import art\n", - "from art.preprocessing.tokenize import tokenize_trajectory_groups\n", - "\n", - "disable_progress_bar()\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n", - "\n", - "tokenized_results = list(\n", - " tokenize_trajectory_groups(\n", - " tokenizer,\n", - " [\n", - " art.TrajectoryGroup(\n", - " [\n", - " art.Trajectory(\n", - " messages_and_choices=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"What is the capital of France?\",\n", - " },\n", - " {\"role\": \"assistant\", \"content\": \"London\"},\n", - " ],\n", - " reward=0.0,\n", - " ),\n", - " art.Trajectory(\n", - " messages_and_choices=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"What is the capital of France?\",\n", - " },\n", - " Choice(\n", - " finish_reason=\"stop\",\n", - " index=0,\n", - " logprobs=ChoiceLogprobs(\n", - " content=[\n", - " ChatCompletionTokenLogprob(\n", - " token=\"token:59604\",\n", - " bytes=[80, 97, 114, 105, 115],\n", - " logprob=-0.01,\n", - " top_logprobs=[],\n", - " )\n", - " ]\n", - " ),\n", - " message=ChatCompletionMessage(\n", - " content=\"Paris\",\n", - " role=\"assistant\",\n", - " ),\n", - " ),\n", - " ],\n", - " reward=1.0,\n", - " ),\n", - " ]\n", - " )\n", - " ],\n", - " allow_training_without_logprobs=True,\n", - " scale_rewards=True,\n", - " shuffle_group_trajectories=False,\n", - " )\n", - ")\n", - "for result in tokenized_results:\n", - " result.advantage = round(result.advantage, 2)\n", - " result.weight = round(result.weight, 2)\n", - " # set prompt_id to 0 to eliminate stochasticity\n", - " result.prompt_id = 0\n", - " display(result)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 127d3b0daa6086a08e0476a477a9d21fdfc98487 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 01:49:48 +0000 Subject: [PATCH 09/15] feat(tests): add unit tests for tokenize_trajectory_groups functionality in a new Jupyter notebook --- .../test_tokenize_trajectory_groups.ipynb | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tests/unit/test_tokenize_trajectory_groups.ipynb diff --git a/tests/unit/test_tokenize_trajectory_groups.ipynb b/tests/unit/test_tokenize_trajectory_groups.ipynb new file mode 100644 index 00000000..b90739ed --- /dev/null +++ b/tests/unit/test_tokenize_trajectory_groups.ipynb @@ -0,0 +1,135 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "93a238e4", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "84717d3b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TokenizedResult(advantage=-1.0, chat='<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is the capital of France?<|im_end|>\\n<|im_start|>assistant\\nLondon<|im_end|>\\n', tokens=['<|im_start|>', 'system', '\\n', 'You', ' are', ' Q', 'wen', ',', ' created', ' by', ' Alibaba', ' Cloud', '.', ' You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\\n', '<|im_start|>', 'user', '\\n', 'What', ' is', ' the', ' capital', ' of', ' France', '?', '<|im_end|>', '\\n', '<|im_start|>', 'assistant', '\\n', 'London', '<|im_end|>', '\\n'], token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198, 39572, 151645, 198], input_pos=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], assistant_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], logprobs=[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], weight=1.0, prompt_id=0, prompt_length=36)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TokenizedResult(advantage=1.0, chat='<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is the capital of France?<|im_end|>\\n<|im_start|>assistant\\nParis<|im_end|>\\n', tokens=['<|im_start|>', 'system', '\\n', 'You', ' are', ' Q', 'wen', ',', ' created', ' by', ' Alibaba', ' Cloud', '.', ' You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\\n', '<|im_start|>', 'user', '\\n', 'What', ' is', ' the', ' capital', ' of', ' France', '?', '<|im_end|>', '\\n', '<|im_start|>', 'assistant', '\\n', 'Paris', '<|im_end|>', '\\n'], token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198, 59604, 151645, 198], input_pos=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], assistant_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], logprobs=[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, -0.01, nan, nan], weight=1.0, prompt_id=0, prompt_length=36)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import clear_output\n", + "from openai.types.chat.chat_completion import Choice, ChoiceLogprobs\n", + "from openai.types.chat.chat_completion_message import ChatCompletionMessage\n", + "from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob\n", + "from transformers import AutoTokenizer\n", + "from transformers.utils.logging import disable_progress_bar\n", + "\n", + "import art\n", + "from art.preprocessing.tokenize import tokenize_trajectory_groups\n", + "\n", + "disable_progress_bar()\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n", + "\n", + "tokenized_results = list(\n", + " tokenize_trajectory_groups(\n", + " tokenizer,\n", + " [\n", + " art.TrajectoryGroup(\n", + " [\n", + " art.Trajectory(\n", + " messages_and_choices=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What is the capital of France?\",\n", + " },\n", + " {\"role\": \"assistant\", \"content\": \"London\"},\n", + " ],\n", + " reward=0.0,\n", + " ),\n", + " art.Trajectory(\n", + " messages_and_choices=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What is the capital of France?\",\n", + " },\n", + " Choice(\n", + " finish_reason=\"stop\",\n", + " index=0,\n", + " logprobs=ChoiceLogprobs(\n", + " content=[\n", + " ChatCompletionTokenLogprob(\n", + " token=\"token:59604\",\n", + " bytes=[80, 97, 114, 105, 115],\n", + " logprob=-0.01,\n", + " top_logprobs=[],\n", + " )\n", + " ]\n", + " ),\n", + " message=ChatCompletionMessage(\n", + " content=\"Paris\",\n", + " role=\"assistant\",\n", + " ),\n", + " ),\n", + " ],\n", + " reward=1.0,\n", + " ),\n", + " ]\n", + " )\n", + " ],\n", + " allow_training_without_logprobs=True,\n", + " scale_rewards=True,\n", + " shuffle_group_trajectories=False,\n", + " )\n", + ")\n", + "for result in tokenized_results:\n", + " result.advantage = round(result.advantage, 2)\n", + " result.weight = round(result.weight, 2)\n", + " # set prompt_id to 0 to eliminate stochasticity\n", + " result.prompt_id = 0\n", + " display(result)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 553b76d6ecf79b430b7be6d6ac2b9f88f43f14e8 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 02:15:48 +0000 Subject: [PATCH 10/15] feat(tests): enhance auto_trajectory tests with tool integration and additional history assertions --- dev/data.ipynb | 174 +---------------------------- tests/unit/test_auto_trajectory.py | 80 ++++++++++++- 2 files changed, 80 insertions(+), 174 deletions(-) diff --git a/dev/data.ipynb b/dev/data.ipynb index 91c7dd1a..ea8abbae 100644 --- a/dev/data.ipynb +++ b/dev/data.ipynb @@ -13,180 +13,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "01f78de0", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mbradhilton\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/html": [ - "Tracking run with wandb version 0.21.0" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Run data is saved locally in /home/ubuntu/sky_workdir/tests/unit/wandb/run-20250821_183738-test" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Syncing run test to Weights & Biases (docs)
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View project at https://wandb.ai/bradhilton/tests" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View run at https://wandb.ai/bradhilton/tests/runs/test" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO 08-21 18:37:49 [__init__.py:235] Automatically detected platform cuda.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ubuntu/sky_workdir/src/art/__init__.py:10: UserWarning: WARNING: Unsloth should be imported before transformers, peft to ensure all optimizations are applied. Your code may run slower or encounter memory issues without these optimizations.\n", - "\n", - "Please restructure your imports with 'import unsloth' at the top of your file.\n", - " import unsloth # type: ignore # noqa: F401\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "๐Ÿฆฅ Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", - "INFO 08-21 18:37:59 [__init__.py:235] Automatically detected platform cuda.\n", - "๐Ÿฆฅ Unsloth Zoo will now patch everything to make training faster!\n", - "Unsloth: Patching vLLM v1 graph capture\n", - "Unsloth: Patching vLLM v0 graph capture\n", - "==((====))== Unsloth 2025.8.6: Fast Qwen2 patching. Transformers: 4.53.2. vLLM: 0.10.0.\n", - " \\\\ /| NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.189 GB. Platform: Linux.\n", - "O^O/ \\_/ \\ Torch: 2.7.1+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.1\n", - "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]\n", - " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", - "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n", - "Unsloth: vLLM loading unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit with actual GPU utilization = 78.47%\n", - "Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 79.19 GB.\n", - "Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 32768. Num Sequences = 368.\n", - "Unsloth: vLLM's KV Cache can use up to 56.27 GB. Also swap space = 6 GB.\n", - "Unsloth: Not an error, but `device` is not supported in vLLM. Skipping.\n", - "INFO 08-21 18:38:20 [config.py:1604] Using max model len 32768\n", - "Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'bfloat16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': ['lm_head', 'multi_modal_projector', 'merger', 'modality_projection', 'model.layers.0.self_attn', 'model.layers.1.self_attn', 'model.layers.2.mlp', 'model.layers.3.mlp', 'model.layers.4.mlp', 'model.layers.25.mlp', 'model.layers.26.mlp'], 'llm_int8_threshold': 6.0}\n", - "INFO 08-21 18:38:21 [llm_engine.py:228] Initializing a V0 LLM engine (v0.10.0) with config: model='unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit', speculative_config=None, tokenizer='unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=bitsandbytes, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit, num_scheduler_steps=16, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, pooler_config=None, compilation_config={\"level\":0,\"debug_dump_path\":\"\",\"cache_dir\":\"\",\"backend\":\"inductor\",\"custom_ops\":[],\"splitting_ops\":[],\"use_inductor\":true,\"compile_sizes\":[],\"inductor_compile_config\":{\"epilogue_fusion\":true,\"max_autotune\":false,\"shape_padding\":true,\"trace.enabled\":false,\"triton.cudagraphs\":true,\"debug\":false,\"dce\":true,\"memory_planning\":true,\"coordinate_descent_tuning\":true,\"trace.graph_diagram\":false,\"compile_threads\":26,\"group_fusion\":true,\"disable_progress\":false,\"verbose_progress\":true,\"triton.multi_kernel\":0,\"triton.use_block_ptr\":true,\"triton.enable_persistent_tma_matmul\":true,\"triton.autotune_at_compile_time\":false,\"triton.cooperative_reductions\":false,\"cuda.compile_opt_level\":\"-O2\",\"cuda.enable_cuda_lto\":true,\"combo_kernels\":false,\"benchmark_combo_kernel\":true,\"combo_kernel_foreach_dynamic_shapes\":true,\"enable_auto_functionalized_v2\":false},\"inductor_passes\":{},\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":1,\"cudagraph_capture_sizes\":[368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"cudagraph_copy_inputs\":false,\"full_cuda_graph\":false,\"max_capture_size\":368,\"local_cache_dir\":null}, use_cached_outputs=False, \n", - "INFO 08-21 18:38:24 [cuda.py:398] Using Flash Attention backend.\n", - "INFO 08-21 18:38:26 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0\n", - "INFO 08-21 18:38:26 [model_runner.py:1083] Starting to load model unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit...\n", - "INFO 08-21 18:38:26 [bitsandbytes_loader.py:733] Loading weights with BitsAndBytes quantization. May take a while ...\n", - "INFO 08-21 18:38:27 [weight_utils.py:296] Using model weights format ['*.safetensors']\n", - "INFO 08-21 18:38:42 [weight_utils.py:312] Time spent downloading weights for unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit: 15.237522 seconds\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00 web.Response: async def test_auto_trajectory(test_server: None) -> None: + message: ChatCompletionMessageParam = {"role": "user", "content": "Hi!"} + tools: list[ChatCompletionToolParam] = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + ] + async def say_hi() -> str | None: """A method that says hi to an assistant and returns the response.""" client = AsyncOpenAI(base_url="http://localhost:8888/v1", api_key="default") - message: ChatCompletionMessageParam = {"role": "user", "content": "Hi!"} chat_completion = await client.chat.completions.create( model="test", messages=[message], + tools=tools, + ) + # test a follow up message + chat_completion = await client.chat.completions.create( + model="test", + messages=[ + message, + { + "role": "assistant", + "content": chat_completion.choices[0].message.content, + }, + message, + ], + tools=tools, + ) + # and another call without tools (should create a new history) + chat_completion = await client.chat.completions.create( + model="test", + messages=[ + message, + { + "role": "assistant", + "content": chat_completion.choices[0].message.content, + }, + message, + { + "role": "assistant", + "content": chat_completion.choices[0].message.content, + }, + message, + ], + ) + # and another call with tools, but limited messages (should create another history) + chat_completion = await client.chat.completions.create( + model="test", + messages=[message], + tools=tools, ) # Add optional ART support with a few lines of code if trajectory := art.auto_trajectory(): @@ -145,6 +198,29 @@ async def say_hi() -> str | None: # Use the capture_auto_trajectory utility to capture a trajectory automatically trajectory = await art.capture_auto_trajectory(say_hi()) assert trajectory.messages_and_choices == [ - {"role": "user", "content": "Hi!"}, + message, + Choice(**mock_response["choices"][0]), + message, + Choice(**mock_response["choices"][0]), + ] + assert trajectory.tools == tools + assert trajectory.additional_histories[0].messages_and_choices == [ + message, + { + "content": "Hello! How can I assist you today?", + "role": "assistant", + }, + message, + { + "content": "Hello! How can I assist you today?", + "role": "assistant", + }, + message, + Choice(**mock_response["choices"][0]), + ] + assert trajectory.additional_histories[0].tools is None + assert trajectory.additional_histories[1].messages_and_choices == [ + message, Choice(**mock_response["choices"][0]), ] + assert trajectory.additional_histories[1].tools == tools From 4db00f304390540e9ae719c4c8f01b332d762323 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 03:13:39 +0000 Subject: [PATCH 11/15] feat(auto-trajectory): integrate synchronous chat completion streaming and enhance trajectory handling with tool calls in tests --- dev/data.ipynb | 464 ++++++++++++++++++++++++++--- src/art/auto_trajectory.py | 22 +- src/art/openai.py | 184 +++++++----- tests/unit/test_auto_trajectory.py | 86 +++++- 4 files changed, 631 insertions(+), 125 deletions(-) diff --git a/dev/data.ipynb b/dev/data.ipynb index ea8abbae..2b1671b0 100644 --- a/dev/data.ipynb +++ b/dev/data.ipynb @@ -13,23 +13,199 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "01f78de0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mbradhilton\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.21.0" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/ubuntu/sky_workdir/dev/wandb/run-20250822_022145-test" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Resuming run test to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/bradhilton/tests" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/bradhilton/tests/runs/test" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 08-22 02:21:51 [__init__.py:235] Automatically detected platform cuda.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/sky_workdir/src/art/__init__.py:10: UserWarning: WARNING: Unsloth should be imported before transformers, peft to ensure all optimizations are applied. Your code may run slower or encounter memory issues without these optimizations.\n", + "\n", + "Please restructure your imports with 'import unsloth' at the top of your file.\n", + " import unsloth # type: ignore # noqa: F401\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "๐Ÿฆฅ Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", + "INFO 08-22 02:21:59 [__init__.py:235] Automatically detected platform cuda.\n", + "๐Ÿฆฅ Unsloth Zoo will now patch everything to make training faster!\n", + "Unsloth: Patching vLLM v1 graph capture\n", + "Unsloth: Patching vLLM v0 graph capture\n", + "==((====))== Unsloth 2025.8.6: Fast Qwen2 patching. Transformers: 4.53.2. vLLM: 0.10.0.\n", + " \\\\ /| NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.189 GB. Platform: Linux.\n", + "O^O/ \\_/ \\ Torch: 2.7.1+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.1\n", + "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]\n", + " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", + "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n", + "Unsloth: vLLM loading unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit with actual GPU utilization = 78.47%\n", + "Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 79.19 GB.\n", + "Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 32768. Num Sequences = 368.\n", + "Unsloth: vLLM's KV Cache can use up to 56.27 GB. Also swap space = 6 GB.\n", + "Unsloth: Not an error, but `device` is not supported in vLLM. Skipping.\n", + "INFO 08-22 02:22:18 [config.py:1604] Using max model len 32768\n", + "Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'bfloat16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': ['lm_head', 'multi_modal_projector', 'merger', 'modality_projection', 'model.layers.0.self_attn', 'model.layers.1.self_attn', 'model.layers.2.mlp', 'model.layers.3.mlp', 'model.layers.4.mlp', 'model.layers.25.mlp', 'model.layers.26.mlp'], 'llm_int8_threshold': 6.0}\n", + "INFO 08-22 02:22:18 [llm_engine.py:228] Initializing a V0 LLM engine (v0.10.0) with config: model='unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit', speculative_config=None, tokenizer='unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=bitsandbytes, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, pooler_config=None, compilation_config={\"level\":0,\"debug_dump_path\":\"\",\"cache_dir\":\"\",\"backend\":\"inductor\",\"custom_ops\":[],\"splitting_ops\":[],\"use_inductor\":true,\"compile_sizes\":[],\"inductor_compile_config\":{\"epilogue_fusion\":true,\"max_autotune\":false,\"shape_padding\":true,\"trace.enabled\":false,\"triton.cudagraphs\":true,\"debug\":false,\"dce\":true,\"memory_planning\":true,\"coordinate_descent_tuning\":true,\"trace.graph_diagram\":false,\"compile_threads\":26,\"group_fusion\":true,\"disable_progress\":false,\"verbose_progress\":true,\"triton.multi_kernel\":0,\"triton.use_block_ptr\":true,\"triton.enable_persistent_tma_matmul\":true,\"triton.autotune_at_compile_time\":false,\"triton.cooperative_reductions\":false,\"cuda.compile_opt_level\":\"-O2\",\"cuda.enable_cuda_lto\":true,\"combo_kernels\":false,\"benchmark_combo_kernel\":true,\"combo_kernel_foreach_dynamic_shapes\":true,\"enable_auto_functionalized_v2\":false},\"inductor_passes\":{},\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":1,\"cudagraph_capture_sizes\":[368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"cudagraph_copy_inputs\":false,\"full_cuda_graph\":false,\"max_capture_size\":368,\"local_cache_dir\":null}, use_cached_outputs=False, \n", + "INFO 08-22 02:22:20 [cuda.py:398] Using Flash Attention backend.\n", + "INFO 08-22 02:22:20 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0\n", + "INFO 08-22 02:22:20 [model_runner.py:1083] Starting to load model unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit...\n", + "INFO 08-22 02:22:21 [bitsandbytes_loader.py:733] Loading weights with BitsAndBytes quantization. May take a while ...\n", + "INFO 08-22 02:22:21 [weight_utils.py:296] Using model weights format ['*.safetensors']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00 None: messages_and_choices=[], reward=0.0, ) + self.openai_client = OpenAI(api_key="") def __enter__(self) -> None: self.token = auto_trajectory_context_var.set(self) @@ -53,9 +58,20 @@ def handle_httpx_response(self, response: httpx._models.Response) -> None: request_content = json.loads(getattr(response.request, "_content", b"")) messages = request_content["messages"] tools = request_content.get("tools", None) - choice = Choice( - **json.loads(getattr(response, "_content_so_far", b""))["choices"][0] - ) + setattr(response, "_content", getattr(response, "_content_so_far", b"")) + print(getattr(response, "_content")) + if request_content.get("stream", False): + choice = consume_sync_chat_completion_stream( + Stream( + cast_to=ChatCompletionChunk, + response=response, + client=self.openai_client, + ) + ).choices[0] + else: + choice = Choice( + **json.loads(getattr(response, "_content"))["choices"][0] + ) history: Trajectory | History = self.trajectory history_index = -1 while True: diff --git a/src/art/openai.py b/src/art/openai.py index 9a2811b1..dd9a32e6 100644 --- a/src/art/openai.py +++ b/src/art/openai.py @@ -1,7 +1,7 @@ from typing import Any, AsyncIterator, Callable, cast import openai -from openai import AsyncStream +from openai import AsyncStream, Stream from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_message import ( @@ -82,86 +82,8 @@ async def consume_chat_completion_stream( chat_completion: ChatCompletion | None = None async for chunk in stream: if chat_completion is None: - chat_completion = ChatCompletion( - id=chunk.id, - choices=[ - Choice( - finish_reason="stop", - index=choice.index, - logprobs=(ChoiceLogprobs() if choice.logprobs else None), - message=ChatCompletionMessage(role="assistant"), - ) - for choice in chunk.choices - ], - created=chunk.created, - model=chunk.model, - object="chat.completion", - ) - for choice, chunk_choice in zip(chat_completion.choices, chunk.choices): - choice.finish_reason = chunk_choice.finish_reason or "stop" - if chunk_choice.logprobs: - if choice.logprobs is None: - choice.logprobs = ChoiceLogprobs() - if chunk_choice.logprobs.content: - if choice.logprobs.content is None: - choice.logprobs.content = [] - choice.logprobs.content.extend(chunk_choice.logprobs.content) - if chunk_choice.logprobs.refusal: - if choice.logprobs.refusal is None: - choice.logprobs.refusal = [] - choice.logprobs.refusal.extend(chunk_choice.logprobs.refusal) - if chunk_choice.delta.content: - if choice.message.content is None: - choice.message.content = "" - choice.message.content += chunk_choice.delta.content - if chunk_choice.delta.refusal: - if choice.message.refusal is None: - choice.message.refusal = "" - choice.message.refusal += chunk_choice.delta.refusal - if chunk_choice.delta.function_call: - if choice.message.function_call is None: - choice.message.function_call = FunctionCall(arguments="", name="") - choice.message.function_call.name += ( - chunk_choice.delta.function_call.name or "" - ) - choice.message.function_call.arguments += ( - chunk_choice.delta.function_call.arguments or "" - ) - if chunk_choice.delta.tool_calls: - if choice.message.tool_calls is None: - choice.message.tool_calls = [] - for tool_call in chunk_choice.delta.tool_calls: - while tool_call.index not in range(len(choice.message.tool_calls)): - choice.message.tool_calls.append( - ChatCompletionMessageToolCall( - id="", - function=Function(arguments="", name=""), - type="function", - ) - ) - if tool_call.id: - choice.message.tool_calls[tool_call.index].id = tool_call.id - if tool_call.function: - if tool_call.function.name: - choice.message.tool_calls[ - tool_call.index - ].function.name = tool_call.function.name - if tool_call.function.arguments: - choice.message.tool_calls[ - tool_call.index - ].function.arguments += tool_call.function.arguments - if getattr(chunk_choice.delta, "reasoning", None): - if not hasattr(choice.message, "reasoning"): - setattr(choice.message, "reasoning", "") - setattr( - choice.message, - "reasoning", - getattr(choice.message, "reasoning") - + getattr(chunk_choice.delta, "reasoning"), - ) - chat_completion.service_tier = chunk.service_tier - chat_completion.system_fingerprint = chunk.system_fingerprint - chat_completion.usage = chunk.usage + chat_completion = init_chat_completion(chunk) + update_chat_completion(chat_completion, chunk) if on_chunk: try: on_chunk(chunk, chat_completion) @@ -170,3 +92,103 @@ async def consume_chat_completion_stream( break assert chat_completion is not None return chat_completion + + +def consume_sync_chat_completion_stream( + stream: Stream[ChatCompletionChunk], +) -> ChatCompletion: + chat_completion: ChatCompletion | None = None + for chunk in stream: + if chat_completion is None: + chat_completion = init_chat_completion(chunk) + update_chat_completion(chat_completion, chunk) + assert chat_completion is not None + return chat_completion + + +def init_chat_completion(chunk: ChatCompletionChunk) -> ChatCompletion: + return ChatCompletion( + id=chunk.id, + choices=[ + Choice( + finish_reason="stop", + index=choice.index, + logprobs=(ChoiceLogprobs() if choice.logprobs else None), + message=ChatCompletionMessage(role="assistant"), + ) + for choice in chunk.choices + ], + created=chunk.created, + model=chunk.model, + object="chat.completion", + ) + + +def update_chat_completion( + chat_completion: ChatCompletion, chunk: ChatCompletionChunk +) -> None: + for choice, chunk_choice in zip(chat_completion.choices, chunk.choices): + choice.finish_reason = chunk_choice.finish_reason or "stop" + if chunk_choice.logprobs: + if choice.logprobs is None: + choice.logprobs = ChoiceLogprobs() + if chunk_choice.logprobs.content: + if choice.logprobs.content is None: + choice.logprobs.content = [] + choice.logprobs.content.extend(chunk_choice.logprobs.content) + if chunk_choice.logprobs.refusal: + if choice.logprobs.refusal is None: + choice.logprobs.refusal = [] + choice.logprobs.refusal.extend(chunk_choice.logprobs.refusal) + if chunk_choice.delta.content: + if choice.message.content is None: + choice.message.content = "" + choice.message.content += chunk_choice.delta.content + if chunk_choice.delta.refusal: + if choice.message.refusal is None: + choice.message.refusal = "" + choice.message.refusal += chunk_choice.delta.refusal + if chunk_choice.delta.function_call: + if choice.message.function_call is None: + choice.message.function_call = FunctionCall(arguments="", name="") + choice.message.function_call.name += ( + chunk_choice.delta.function_call.name or "" + ) + choice.message.function_call.arguments += ( + chunk_choice.delta.function_call.arguments or "" + ) + if chunk_choice.delta.tool_calls: + if choice.message.tool_calls is None: + choice.message.tool_calls = [] + for tool_call in chunk_choice.delta.tool_calls: + while tool_call.index not in range(len(choice.message.tool_calls)): + choice.message.tool_calls.append( + ChatCompletionMessageToolCall( + id="", + function=Function(arguments="", name=""), + type="function", + ) + ) + if tool_call.id: + choice.message.tool_calls[tool_call.index].id = tool_call.id + if tool_call.function: + if tool_call.function.name: + choice.message.tool_calls[ + tool_call.index + ].function.name = tool_call.function.name + if tool_call.function.arguments: + choice.message.tool_calls[ + tool_call.index + ].function.arguments += tool_call.function.arguments + if getattr(chunk_choice.delta, "reasoning", None): + if not hasattr(choice.message, "reasoning"): + setattr(choice.message, "reasoning", "") + setattr( + choice.message, + "reasoning", + getattr(choice.message, "reasoning") + + getattr(chunk_choice.delta, "reasoning"), + ) + chat_completion.service_tier = chunk.service_tier + chat_completion.system_fingerprint = chunk.system_fingerprint + chat_completion.usage = chunk.usage diff --git a/tests/unit/test_auto_trajectory.py b/tests/unit/test_auto_trajectory.py index 77b7d421..2c72cb48 100644 --- a/tests/unit/test_auto_trajectory.py +++ b/tests/unit/test_auto_trajectory.py @@ -106,13 +106,83 @@ "prompt_logprobs": None, "kv_transfer_params": None, } +mock_stream_response = b"""data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"id":"chatcmpl-tool-29e663261e524fcfa2162f4f3d76a7f0","type":"function","index":0,"function":{"name":"get_current_weather","arguments":"{"}}]},"logprobs":{"content":[{"token":"token_id:314","logprob":-0.00015293381875380874,"bytes":[32,123],"top_logprobs":[]}]},"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"id":"chatcmpl-tool-29e663261e524fcfa2162f4f3d76a7f0","type":"function","index":0,"function":{"name":"get_current_weather","arguments":"{"}}]},"logprobs":{"content":[{"token":"token_id:314","logprob":-0.00015293381875380874,"bytes":[32,123],"top_logprobs":[]}]},"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":null,"arguments":"}"}}]},"logprobs":{"content":[{"token":"token_id:3417","logprob":-3.576278118089249e-7,"bytes":[125,125],"top_logprobs":[]}]},"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":null,"arguments":"}"}}]},"logprobs":{"content":[{"token":"token_id:3417","logprob":-3.576278118089249e-7,"bytes":[125,125],"top_logprobs":[]}]},"finish_reason":null}]} + +data: [DONE] + +data: [DONE] + +""" +mock_stream_choice = Choice( + **{ + "finish_reason": "stop", + "index": 0, + "logprobs": { + "content": [ + { + "token": "token_id:314", + "bytes": [32, 123], + "logprob": -0.00015293381875380874, + "top_logprobs": [], + }, + { + "token": "token_id:314", + "bytes": [32, 123], + "logprob": -0.00015293381875380874, + "top_logprobs": [], + }, + { + "token": "token_id:3417", + "bytes": [125, 125], + "logprob": -3.576278118089249e-07, + "top_logprobs": [], + }, + { + "token": "token_id:3417", + "bytes": [125, 125], + "logprob": -3.576278118089249e-07, + "top_logprobs": [], + }, + ], + "refusal": None, + }, + "message": { + "content": None, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": [ + { + "id": "chatcmpl-tool-29e663261e524fcfa2162f4f3d76a7f0", + "function": {"arguments": "{{}}", "name": "get_current_weather"}, + "type": "function", + } + ], + }, + } +) @pytest_asyncio.fixture async def test_server(): """Start a test server for the module.""" - async def handler(_: web.Request) -> web.Response: + async def handler(request: web.Request) -> web.Response: + body = await request.json() + if body.get("stream", False): + return web.Response(body=mock_stream_response) return web.json_response(mock_response) app = web.Application() @@ -190,6 +260,15 @@ async def say_hi() -> str | None: messages=[message], tools=tools, ) + # and another call with tool_choice="required" & stream=True + async for _ in await client.chat.completions.create( + model="test", + messages=[message], + tool_choice="required", + tools=tools, + stream=True, + ): + pass # Add optional ART support with a few lines of code if trajectory := art.auto_trajectory(): trajectory.reward = 1.0 @@ -224,3 +303,8 @@ async def say_hi() -> str | None: Choice(**mock_response["choices"][0]), ] assert trajectory.additional_histories[1].tools == tools + assert trajectory.additional_histories[2].messages_and_choices == [ + message, + mock_stream_choice, + ] + assert trajectory.additional_histories[2].tools == tools From 4392bcf161cb72b48f0c7838775c4c9093e37e0d Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 03:14:56 +0000 Subject: [PATCH 12/15] refactor(tests): update comment for optional ART support in auto_trajectory test --- tests/unit/test_auto_trajectory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_auto_trajectory.py b/tests/unit/test_auto_trajectory.py index 2c72cb48..8dc2af5e 100644 --- a/tests/unit/test_auto_trajectory.py +++ b/tests/unit/test_auto_trajectory.py @@ -269,7 +269,7 @@ async def say_hi() -> str | None: stream=True, ): pass - # Add optional ART support with a few lines of code + # Add ART support with a couple lines of optional code if trajectory := art.auto_trajectory(): trajectory.reward = 1.0 return chat_completion.choices[0].message.content From 3349d7f3970714e1d6e0b27b817cc54b2e029407 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 14:53:07 +0000 Subject: [PATCH 13/15] feat(tests): add unit tests for litellm auto trajectory handling and suppress pydantic warnings --- tests/unit/test_auto_trajectory.py | 85 ++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/tests/unit/test_auto_trajectory.py b/tests/unit/test_auto_trajectory.py index 8dc2af5e..9a4d19a5 100644 --- a/tests/unit/test_auto_trajectory.py +++ b/tests/unit/test_auto_trajectory.py @@ -1,3 +1,12 @@ +import warnings + +# Suppress pydantic warnings at module level +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic") + +import litellm +import litellm.litellm_core_utils.streaming_handler +import litellm.types.utils +import pytest import pytest_asyncio from aiohttp import web from openai import AsyncOpenAI @@ -6,6 +15,7 @@ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam import art +from art.utils.litellm import convert_litellm_choice_to_openai mock_response = { "id": "chatcmpl-293ce9f37dba40e5be39448acaf6fb49", @@ -308,3 +318,78 @@ async def say_hi() -> str | None: mock_stream_choice, ] assert trajectory.additional_histories[2].tools == tools + + +@pytest.mark.filterwarnings("ignore::UserWarning:pydantic") +async def test_litellm_auto_trajectory(test_server: None) -> None: + message: ChatCompletionMessageParam = {"role": "user", "content": "Hi!"} + tools: list[ChatCompletionToolParam] = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + ] + + async def say_hi() -> str | None: + """A method that says hi to an assistant and returns the response.""" + response = await litellm.acompletion( + model="openai/test", + messages=[message], + tools=tools, + base_url="http://localhost:8888/v1", + ) + assert isinstance(response, litellm.types.utils.ModelResponse) + choice = convert_litellm_choice_to_openai(response.choices[0]) + # follow up message + response = await litellm.acompletion( + model="openai/test", + messages=[ + message, + {"role": "assistant", "content": choice.message.content}, + message, + ], + tools=tools, + base_url="http://localhost:8888/v1", + ) + assert isinstance(response, litellm.types.utils.ModelResponse) + choice = convert_litellm_choice_to_openai(response.choices[0]) + # another call with tool_choice="required" & stream=True + stream = await litellm.acompletion( + model="openai/test", + messages=[message], + tool_choice="required", + tools=tools, + stream=True, + base_url="http://localhost:8888/v1", + ) + assert isinstance( + stream, litellm.litellm_core_utils.streaming_handler.CustomStreamWrapper + ) + async for _ in stream: + pass + # Add ART support with a couple lines of optional code + if trajectory := art.auto_trajectory(): + trajectory.reward = 1.0 + return choice.message.content + + # Use the capture_auto_trajectory utility to capture a trajectory automatically + trajectory = await art.capture_auto_trajectory(say_hi()) + assert trajectory.messages_and_choices == [ + message, + Choice(**mock_response["choices"][0]), + message, + Choice(**mock_response["choices"][0]), + ] + assert trajectory.additional_histories[0].messages_and_choices == [ + message, + mock_stream_choice, + ] + assert trajectory.additional_histories[0].tools == tools From 07d92f2ac3e5eedac10ff5dc961e577650c4e7b8 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 14:59:07 +0000 Subject: [PATCH 14/15] feat(tests): add api_key parameter to litellm auto trajectory tests for improved configuration --- tests/unit/test_auto_trajectory.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/test_auto_trajectory.py b/tests/unit/test_auto_trajectory.py index 9a4d19a5..af170ad7 100644 --- a/tests/unit/test_auto_trajectory.py +++ b/tests/unit/test_auto_trajectory.py @@ -345,6 +345,7 @@ async def say_hi() -> str | None: messages=[message], tools=tools, base_url="http://localhost:8888/v1", + api_key="default", ) assert isinstance(response, litellm.types.utils.ModelResponse) choice = convert_litellm_choice_to_openai(response.choices[0]) @@ -358,6 +359,7 @@ async def say_hi() -> str | None: ], tools=tools, base_url="http://localhost:8888/v1", + api_key="default", ) assert isinstance(response, litellm.types.utils.ModelResponse) choice = convert_litellm_choice_to_openai(response.choices[0]) @@ -369,6 +371,7 @@ async def say_hi() -> str | None: tools=tools, stream=True, base_url="http://localhost:8888/v1", + api_key="default", ) assert isinstance( stream, litellm.litellm_core_utils.streaming_handler.CustomStreamWrapper From 93e5ffd00e4e569dbca73d651bce14a843497176 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 22 Aug 2025 15:03:08 +0000 Subject: [PATCH 15/15] refactor(gather): simplify type imports and clean up after_each parameter definition --- src/art/gather.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/art/gather.py b/src/art/gather.py index ff9adbb3..89eea5b4 100644 --- a/src/art/gather.py +++ b/src/art/gather.py @@ -3,16 +3,7 @@ import contextvars from collections import Counter from dataclasses import dataclass, field -from typing import ( - Any, - Awaitable, - Callable, - Coroutine, - Iterable, - Iterator, - Literal, - overload, -) +from typing import Awaitable, Callable, Iterable, Iterator, Literal, overload from openai.types.chat.chat_completion import Choice from tqdm import auto as tqdm @@ -27,12 +18,10 @@ async def gather_trajectory_groups( pbar_total_completion_tokens: bool = True, max_exceptions: int | float = 0, max_metrics: int | None = None, - after_each: ( - Callable[ - [TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]] - ] - | None - ) = None, + after_each: Callable[ + [TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]] + ] + | None = None, ) -> list[TrajectoryGroup]: groups = list(groups) context = GatherContext(