Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions nerve/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ def get_engine() -> AgentEngine:
return _engine


async def _send_session_status(
websocket: WebSocket,
session_id: str,
is_running: bool,
session_record: dict | None,
) -> None:
"""Send a ``session_status`` event to the freshly-bound listener.

Called from the initial WS handshake (only when a turn is in flight so an
idle client doesn't get a no-op message) and from ``switch_session``
(always, to refresh client-side ``is_running``/``status``). When the
session is running, the accumulated stream buffer is attached so the
client can rebuild ``streamingBlocks``, panels, todos, and interaction
state without waiting for new events.
"""
status_msg: dict = {
"type": "session_status",
"session_id": session_id,
"is_running": is_running,
"status": session_record.get("status") if session_record else "unknown",
}
if is_running:
status_msg["buffered_events"] = broadcaster.get_buffer(session_id)
await websocket.send_json(status_msg)


@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan — initialize DB, engine, channels on startup."""
Expand Down Expand Up @@ -310,6 +336,18 @@ async def ws_broadcast(session_id: str, message: dict):
"session_id": active_session,
})

# If a turn is mid-flight (page reload, transient WS drop, sticky
# reconnect after a network blip), replay the broadcaster buffer so
# the freshly-bound listener can rebuild the in-flight stream
# without waiting for new events. Idle sessions get nothing here;
# they hydrate via REST + the existing ``session_switched`` event.
if broadcaster.is_buffering(active_session):
is_running = _engine.is_session_running(active_session)
session_record = await _engine.db.get_session(active_session)
await _send_session_status(
websocket, active_session, is_running, session_record,
)

try:
while True:
data = await websocket.receive_json()
Expand Down Expand Up @@ -368,18 +406,16 @@ async def ws_broadcast(session_id: str, message: dict):
# Persist channel mapping so next page load resumes this session
await router.switch_session("web:default", new_session)

# Send session status (running/idle + buffered events for reconnect)
# Send session status (running/idle + buffered events for
# reconnect). Unlike the initial-bind branch, we always
# ship a status here so the client can flip its
# ``isStreaming`` / ``status`` for the newly-selected
# session even when the session is idle.
is_running = _engine.is_session_running(new_session)
session_record = await _engine.db.get_session(new_session)
status_msg: dict = {
"type": "session_status",
"session_id": new_session,
"is_running": is_running,
"status": session_record.get("status") if session_record else "unknown",
}
if is_running:
status_msg["buffered_events"] = broadcaster.get_buffer(new_session)
await websocket.send_json(status_msg)
await _send_session_status(
websocket, new_session, is_running, session_record,
)

await websocket.send_json({
"type": "session_switched",
Expand Down
244 changes: 244 additions & 0 deletions tests/test_gateway_ws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
"""Tests for nerve.gateway.server WebSocket handshake buffer replay.

The initial WS handshake replays the broadcaster buffer when a turn is in
flight, and stays silent when the session is idle. The existing
``switch_session`` path is covered too so the refactor onto
``_send_session_status`` doesn't regress.
"""

from __future__ import annotations

import pytest

from nerve.agent.streaming import StreamBroadcaster, broadcaster as _global_broadcaster
from nerve.gateway.server import _send_session_status


class FakeWebSocket:
"""Minimal WebSocket stand-in that captures ``send_json`` payloads."""

def __init__(self) -> None:
self.sent: list[dict] = []

async def send_json(self, payload: dict) -> None:
self.sent.append(payload)


@pytest.fixture(autouse=True)
def _reset_broadcaster_buffers():
"""Clear the module-global broadcaster between tests.

Tests poke ``_global_broadcaster.start_buffering`` directly because the
helper reads ``broadcaster.get_buffer`` off the module global, not a
parameter. Reset before and after so a failing test can't leak state.
"""
_global_broadcaster._session_buffers.clear()
yield
_global_broadcaster._session_buffers.clear()


# ---------------------------------------------------------------------------
# _send_session_status helper
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
class TestSendSessionStatus:
"""Unit tests for the shared helper called from both WS branches."""

async def test_running_session_attaches_buffered_events(self):
ws = FakeWebSocket()
session_id = "sess-running"
_global_broadcaster.start_buffering(session_id)
await _global_broadcaster.broadcast(session_id, {
"type": "token", "session_id": session_id, "content": "hello ",
})
await _global_broadcaster.broadcast(session_id, {
"type": "token", "session_id": session_id, "content": "world",
})

await _send_session_status(
ws, session_id, is_running=True,
session_record={"status": "active"},
)

assert len(ws.sent) == 1
msg = ws.sent[0]
assert msg["type"] == "session_status"
assert msg["session_id"] == session_id
assert msg["is_running"] is True
assert msg["status"] == "active"
assert "buffered_events" in msg
contents = [e["content"] for e in msg["buffered_events"]]
assert contents == ["hello ", "world"]

async def test_idle_session_omits_buffered_events(self):
ws = FakeWebSocket()
# No start_buffering: the buffer is empty / absent.
await _send_session_status(
ws, "sess-idle", is_running=False,
session_record={"status": "active"},
)

assert len(ws.sent) == 1
msg = ws.sent[0]
assert msg["is_running"] is False
assert msg["status"] == "active"
# buffered_events MUST be absent when not running; the frontend
# gates buffer replay on its presence, not on length.
assert "buffered_events" not in msg

async def test_missing_session_record_uses_unknown_status(self):
ws = FakeWebSocket()
await _send_session_status(
ws, "sess-gone", is_running=False, session_record=None,
)

assert ws.sent[0]["status"] == "unknown"

async def test_running_session_with_empty_buffer_still_attaches_list(self):
"""is_running gates ``buffered_events``; an empty list is still a signal.

Frontend code branches on ``msg.buffered_events !== undefined``;
shipping an empty list tells the client "this session is running
but the stream has produced nothing yet" so it can flip
``isStreaming`` without inventing fake blocks.
"""
ws = FakeWebSocket()
session_id = "sess-running-empty"
_global_broadcaster.start_buffering(session_id)

await _send_session_status(
ws, session_id, is_running=True,
session_record={"status": "active"},
)

msg = ws.sent[0]
assert msg["buffered_events"] == []


# ---------------------------------------------------------------------------
# Initial-bind handshake (AC15)
#
# The actual handler is a closure inside ``create_app`` so it can't be unit-
# tested without spinning a full FastAPI app + lifespan. We re-exercise the
# *same logic* it runs (``is_buffering`` gate + helper invocation) so that a
# regression in either guard or call shape fails this test.
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
class TestInitialBindReplay:

async def _simulate_initial_bind(
self,
ws: FakeWebSocket,
session_id: str,
is_running: bool,
session_record: dict | None,
) -> None:
"""Mirror the gate + helper call from the WS handshake."""
if _global_broadcaster.is_buffering(session_id):
await _send_session_status(
ws, session_id, is_running, session_record,
)

async def test_replays_when_turn_in_flight(self):
ws = FakeWebSocket()
session_id = "sess-in-flight"
_global_broadcaster.start_buffering(session_id)
await _global_broadcaster.broadcast(session_id, {
"type": "tool_use", "session_id": session_id,
"tool": "Read", "input": {"file_path": "/x"},
})

await self._simulate_initial_bind(
ws, session_id, is_running=True,
session_record={"status": "active"},
)

assert len(ws.sent) == 1
msg = ws.sent[0]
assert msg["type"] == "session_status"
assert msg["is_running"] is True
assert msg["buffered_events"][0]["tool"] == "Read"

async def test_no_replay_when_session_idle(self):
ws = FakeWebSocket()
# No start_buffering: handshake guard must short-circuit.
await self._simulate_initial_bind(
ws, "sess-idle", is_running=False,
session_record={"status": "active"},
)

assert ws.sent == []


# ---------------------------------------------------------------------------
# switch_session regression guard (existing behaviour)
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
class TestSwitchSessionStillReplays:
"""``switch_session`` ALWAYS sends ``session_status`` (running or idle)."""

async def _simulate_switch_session(
self,
ws: FakeWebSocket,
new_session: str,
is_running: bool,
session_record: dict | None,
) -> None:
await _send_session_status(
ws, new_session, is_running, session_record,
)

async def test_running_target_replays_buffer(self):
ws = FakeWebSocket()
session_id = "sess-switch-running"
_global_broadcaster.start_buffering(session_id)
await _global_broadcaster.broadcast(session_id, {
"type": "token", "session_id": session_id, "content": "x",
})

await self._simulate_switch_session(
ws, session_id, is_running=True,
session_record={"status": "active"},
)

assert ws.sent[0]["is_running"] is True
assert ws.sent[0]["buffered_events"] == [
{"type": "token", "session_id": session_id, "content": "x"},
]

async def test_idle_target_still_sends_status(self):
ws = FakeWebSocket()
await self._simulate_switch_session(
ws, "sess-switch-idle", is_running=False,
session_record={"status": "active"},
)

assert len(ws.sent) == 1
assert ws.sent[0]["is_running"] is False
assert "buffered_events" not in ws.sent[0]


# ---------------------------------------------------------------------------
# Buffer fidelity: large stream survives intact through replay
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_replay_preserves_event_order_under_load():
"""Replay must hand events back in arrival order with no truncation."""
bc = StreamBroadcaster(max_buffer_size=100)
bc.start_buffering("sess-load")
for i in range(50):
await bc.broadcast("sess-load", {
"type": "token", "session_id": "sess-load", "content": f"#{i}",
})

snapshot = bc.get_buffer("sess-load")
assert len(snapshot) == 50
assert [e["content"] for e in snapshot] == [f"#{i}" for i in range(50)]
Loading