diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index 8d92872406..a425bf866d 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -323,19 +323,49 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: continue raise - # Update metadata - metadata = { - "session_id": self.session_id, - "created_at": str(int(time.time())), - "updated_at": str(int(time.time())), - } - await self._dapr_client.save_state( - store_name=self._state_store_name, - key=self._metadata_key, - value=json.dumps(metadata), - state_metadata=self._get_metadata(), - options=self._get_state_options(), - ) + # Update metadata, preserving created_at across subsequent writes. + # Use first-write concurrency with the read ETag so a concurrent write + # that already established `created_at` can't be clobbered by a stale + # read that saw no metadata. + now = str(int(time.time())) + meta_attempt = 0 + while True: + meta_attempt += 1 + existing_meta_response = await self._dapr_client.get_state( + store_name=self._state_store_name, + key=self._metadata_key, + state_metadata=self._get_read_metadata(), + ) + created_at = now + if existing_meta_response.data: + try: + existing_meta = json.loads(existing_meta_response.data.decode("utf-8")) + if isinstance(existing_meta, dict) and existing_meta.get("created_at"): + created_at = str(existing_meta["created_at"]) + except (json.JSONDecodeError, UnicodeDecodeError, AttributeError): + # Corrupt metadata — start fresh with current timestamp. + pass + metadata = { + "session_id": self.session_id, + "created_at": created_at, + "updated_at": now, + } + meta_etag = getattr(existing_meta_response, "etag", None) or None + try: + await self._dapr_client.save_state( + store_name=self._state_store_name, + key=self._metadata_key, + value=json.dumps(metadata), + etag=meta_etag, + state_metadata=self._get_metadata(), + options=self._get_state_options(concurrency=Concurrency.first_write), + ) + break + except Exception as error: + should_retry = await self._handle_concurrency_conflict(error, meta_attempt) + if should_retry: + continue + raise async def pop_item(self) -> TResponseInputItem | None: """Remove and return the most recent item from the session. diff --git a/tests/extensions/memory/test_dapr_session.py b/tests/extensions/memory/test_dapr_session.py index 9766f35d40..f2dff71561 100644 --- a/tests/extensions/memory/test_dapr_session.py +++ b/tests/extensions/memory/test_dapr_session.py @@ -448,6 +448,32 @@ async def test_add_empty_items_list(fake_dapr_client: FakeDaprClient): await session.close() +async def test_metadata_preserves_created_at(fake_dapr_client: FakeDaprClient): + """add_items must preserve created_at across writes; only updated_at advances.""" + session = await _create_test_session(fake_dapr_client) + try: + await session.add_items([{"role": "user", "content": "first"}]) + first_meta_raw = fake_dapr_client._state[session._metadata_key].decode("utf-8") + first_meta = json.loads(first_meta_raw) + first_created = first_meta["created_at"] + first_updated = first_meta["updated_at"] + + # Wait one second so timestamps are guaranteed to differ. + import time as _time + + _time.sleep(1) + + await session.add_items([{"role": "user", "content": "second"}]) + second_meta = json.loads(fake_dapr_client._state[session._metadata_key].decode("utf-8")) + + assert second_meta["created_at"] == first_created, ( + "created_at must be preserved across add_items calls" + ) + assert int(second_meta["updated_at"]) >= int(first_updated) + finally: + await session.close() + + async def test_unicode_content(fake_dapr_client: FakeDaprClient): """Test that session correctly stores and retrieves unicode/non-ASCII content.""" session = await _create_test_session(fake_dapr_client)