Skip to content
28 changes: 26 additions & 2 deletions amplifier_app_cli/session_spawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

from amplifier_core import AmplifierSession
from amplifier_foundation import generate_sub_session_id
from amplifier_foundation import bridge_child_cost

from .agent_config import merge_configs

logger = logging.getLogger(__name__)


# Capture default sys.path entries at import time.
# Used to filter out bundle-added paths when forwarding sys_paths to subprocess children.
_DEFAULT_SYS_PATHS: frozenset[str] = frozenset(sys.path)
Expand Down Expand Up @@ -688,6 +690,13 @@ async def _capture_completion(event: str, data: dict) -> HookResult:
store.save(sub_session_id, transcript, metadata)
logger.debug(f"Sub-session {sub_session_id} state persisted")

# Bridge child session costs to parent coordinator (bridge_child_cost never raises)
await bridge_child_cost(
child_coordinator=child_session.coordinator,
parent_coordinator=parent_session.coordinator,
child_session_id=sub_session_id,
)

finally:
# Unregister child cancellation token before cleanup
# MUST run even if execution was cancelled (CancelledError) or failed
Expand All @@ -714,7 +723,11 @@ async def _capture_completion(event: str, data: dict) -> HookResult:
}


async def resume_sub_session(sub_session_id: str, instruction: str, parent_session: AmplifierSession | None = None) -> dict:
async def resume_sub_session(
sub_session_id: str,
instruction: str,
parent_session: AmplifierSession | None = None,
) -> dict:
"""Resume existing sub-session for multi-turn engagement.

Loads previously saved sub-session state, recreates the session with
Expand Down Expand Up @@ -1013,10 +1026,21 @@ async def _capture_completion(event: str, data: dict) -> HookResult:
f"Sub-session {sub_session_id} state updated (turn {metadata['turn_count']})"
)

# Bridge child session costs to parent coordinator (bridge_child_cost never raises)
if parent_session is not None:
await bridge_child_cost(
child_coordinator=child_session.coordinator,
parent_coordinator=parent_session.coordinator,
child_session_id=sub_session_id,
)

finally:
# Unregister child cancellation token before cleanup
# MUST run even if execution was cancelled (CancelledError) or failed
if resume_parent_cancellation is not None and resume_child_cancellation is not None:
if (
resume_parent_cancellation is not None
and resume_child_cancellation is not None
):
resume_parent_cancellation.unregister_child(resume_child_cancellation)
logger.debug(
f"Unregistered child cancellation token for resumed sub-session {sub_session_id}"
Expand Down
222 changes: 222 additions & 0 deletions tests/test_cost_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""Tests for spawn cost bridge helpers.

_sum_cost_usd and _bridge_child_cost live in amplifier_foundation.bundle._prepared
and are imported directly from there (app-cli delegates, not reimplements).
"""

from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock

import pytest

from amplifier_foundation import bridge_child_cost, sum_cost_usd


def test_sums_single_contribution():
result = sum_cost_usd([{"cost_usd": Decimal("0.05")}])
assert result == Decimal("0.05")


def test_sums_multiple_contributions():
result = sum_cost_usd(
[
{"cost_usd": Decimal("0.03")},
{"cost_usd": Decimal("0.05")},
{"cost_usd": Decimal("0.01")},
]
)
assert result == Decimal("0.09")


def test_returns_none_for_empty_list():
result = sum_cost_usd([])
assert result is None


def test_returns_none_when_all_none():
result = sum_cost_usd([{"cost_usd": None}, None, {}])
assert result is None


def test_accepts_string_cost_usd():
result = sum_cost_usd([{"cost_usd": "0.05"}])
assert result == Decimal("0.05")
assert isinstance(result, Decimal)


def test_skips_none_entries_in_mixed_list():
result = sum_cost_usd(
[
{"cost_usd": Decimal("0.03")},
None,
{"cost_usd": None},
{"cost_usd": Decimal("0.02")},
]
)
assert result == Decimal("0.05")


@pytest.mark.asyncio
async def test_spawn_bridge_registers_child_cost_on_parent():
"""After spawn_sub_session completes, parent coordinator has a delegate contributor."""
child_coord = MagicMock()
child_coord.collect_contributions = AsyncMock(
return_value=[{"cost_usd": Decimal("0.07")}]
)

parent_coord = MagicMock()
registered = {}

def capture_register(channel, name, callback):
registered[(channel, name)] = callback

parent_coord.register_contributor = capture_register

await bridge_child_cost(
child_coordinator=child_coord,
parent_coordinator=parent_coord,
child_session_id="test-child-123",
)

key = ("session.cost", "delegate:test-child-123")
assert key in registered
result = registered[key]()
assert result == {"cost_usd": Decimal("0.07")}


@pytest.mark.asyncio
async def test_bridge_swallows_exception_and_logs():
"""_bridge_child_cost never raises — errors are logged as warnings."""
child_coord = MagicMock()
# Simulate a failure inside collect_contributions
child_coord.collect_contributions = AsyncMock(side_effect=RuntimeError("simulated"))

parent_coord = MagicMock()
parent_coord.register_contributor = MagicMock()

# Must not raise
await bridge_child_cost(
child_coordinator=child_coord,
parent_coordinator=parent_coord,
child_session_id="test-child-err",
)

# No contributor registered because the bridge failed before it could register
parent_coord.register_contributor.assert_not_called()


@pytest.mark.asyncio
async def test_spawn_bridge_skips_registration_when_no_cost():
"""If child has no cost data, no contributor is registered on parent."""
child_coord = MagicMock()
child_coord.collect_contributions = AsyncMock(return_value=[])

parent_coord = MagicMock()
parent_coord.register_contributor = MagicMock()

await bridge_child_cost(
child_coordinator=child_coord,
parent_coordinator=parent_coord,
child_session_id="test-child-456",
)

parent_coord.register_contributor.assert_not_called()


@pytest.mark.asyncio
async def test_resume_bridge_registers_child_cost_on_parent():
"""resume_sub_session also bridges child costs after execute()."""
child_coord = MagicMock()
child_coord.collect_contributions = AsyncMock(
return_value=[{"cost_usd": Decimal("0.04")}]
)

parent_coord = MagicMock()
registered = {}

def capture_register(channel, name, callback):
registered[(channel, name)] = callback

parent_coord.register_contributor = capture_register

await bridge_child_cost(
child_coordinator=child_coord,
parent_coordinator=parent_coord,
child_session_id="resumed-child-789",
)

assert ("session.cost", "delegate:resumed-child-789") in registered


@pytest.mark.asyncio
async def test_resume_bridge_accumulates_incremental_costs():
"""Resuming the same session twice correctly accumulates incremental costs.

Each resume_sub_session call creates a FRESH child coordinator. The provider
re-mounts from zero, so the child's session.cost channel only contains costs
for THAT resume's turns — not the full session history.

_bridge_child_cost therefore passes the incremental cost for each resume.

register_contributor in amplifier-core APPENDS (coordinator.rs: .push(entry)) —
it does NOT overwrite on duplicate name. Both entries are returned by
collect_contributions and summed correctly by _sum_cost_usd.

Verified properties:
- Both calls use the same (channel, name) key — standard contributor identity.
- Each callback carries only the incremental cost of its resume.
- sum_cost_usd([cb1(), cb2()]) == first_cost + second_cost (no double-count).
"""

parent_coord = MagicMock()
all_register_calls: list[tuple] = []

def capture_register(channel, name, callback):
all_register_calls.append((channel, name, callback))

parent_coord.register_contributor = capture_register

# First resume: fresh child coordinator accumulated $0.04 (turn 1 only)
child_coord_1 = MagicMock()
child_coord_1.collect_contributions = AsyncMock(
return_value=[{"cost_usd": Decimal("0.04")}]
)
await bridge_child_cost(
child_coordinator=child_coord_1,
parent_coordinator=parent_coord,
child_session_id="test-child-xyz",
)

# Second resume: fresh child coordinator accumulated $0.06 (turn 2 only)
child_coord_2 = MagicMock()
child_coord_2.collect_contributions = AsyncMock(
return_value=[{"cost_usd": Decimal("0.06")}]
)
await bridge_child_cost(
child_coordinator=child_coord_2,
parent_coordinator=parent_coord,
child_session_id="test-child-xyz",
)

assert len(all_register_calls) == 2, "Expected exactly two register_contributor calls"

channel1, name1, _ = all_register_calls[0]
channel2, name2, _ = all_register_calls[1]

# Same channel + name: register_contributor appends both, collect_contributions
# returns both, _sum_cost_usd sums them — no key uniqueness required.
assert channel1 == channel2 == "session.cost"
assert name1 == name2 == "delegate:test-child-xyz"

# Verify incremental values and that their sum is correct
_, _, cb1 = all_register_calls[0]
_, _, cb2 = all_register_calls[1]
assert cb1()["cost_usd"] == Decimal("0.04")
assert cb2()["cost_usd"] == Decimal("0.06")

# Simulate what collect_contributions + _sum_cost_usd would produce:
# both entries are returned, summed to $0.10 (no double-counting)
total = sum_cost_usd([cb1(), cb2()])
assert total == Decimal("0.10"), (
f"Expected $0.10 from two incremental contributions, got {total!r}"
)
8 changes: 8 additions & 0 deletions tests/test_session_spawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def child_get(name):

child_coordinator.get = child_get
child_coordinator.mount = AsyncMock()
child_coordinator.collect_contributions = AsyncMock(return_value=[])

async def mock_execute(instruction):
# Simulate orchestrator emitting orchestrator:complete during execute
Expand Down Expand Up @@ -769,6 +770,7 @@ def child_get(name):

child_coordinator.get = child_get
child_coordinator.mount = AsyncMock()
child_coordinator.collect_contributions = AsyncMock(return_value=[])

child_session = MagicMock()
child_session.coordinator = child_coordinator
Expand Down Expand Up @@ -1052,6 +1054,7 @@ def child_get(name):

child_coordinator.get = child_get
child_coordinator.mount = AsyncMock()
child_coordinator.collect_contributions = AsyncMock(return_value=[])

child_session = MagicMock()
child_session.coordinator = child_coordinator
Expand Down Expand Up @@ -1217,6 +1220,7 @@ def child_get(name):

child_coordinator.get = child_get
child_coordinator.mount = AsyncMock()
child_coordinator.collect_contributions = AsyncMock(return_value=[])

child_session = MagicMock()
child_session.coordinator = child_coordinator
Expand Down Expand Up @@ -1335,6 +1339,7 @@ def child_get(name):

child_coordinator.get = child_get
child_coordinator.mount = AsyncMock()
child_coordinator.collect_contributions = AsyncMock(return_value=[])

child_session = MagicMock()
child_session.coordinator = child_coordinator
Expand Down Expand Up @@ -1445,6 +1450,7 @@ def child_get(name):

child_coordinator.get = child_get
child_coordinator.mount = AsyncMock()
child_coordinator.collect_contributions = AsyncMock(return_value=[])

child_session = MagicMock()
child_session.coordinator = child_coordinator
Expand Down Expand Up @@ -1561,6 +1567,7 @@ def child_get(name):

child_coordinator.get = child_get
child_coordinator.mount = AsyncMock()
child_coordinator.collect_contributions = AsyncMock(return_value=[])

child_session = MagicMock()
child_session.coordinator = child_coordinator
Expand Down Expand Up @@ -1657,6 +1664,7 @@ def child_get(name):

child_coordinator.get = child_get
child_coordinator.mount = AsyncMock()
child_coordinator.collect_contributions = AsyncMock(return_value=[])

child_session = MagicMock()
child_session.coordinator = child_coordinator
Expand Down