diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 01dcc102f4..c108f7739d 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -44,7 +44,7 @@ from ._entities import create_agent_entity from ._errors import IncomingRequestError from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor -from ._serialization import deserialize_value, serialize_value +from ._serialization import deserialize_value, serialize_value, strip_pickle_markers from ._workflow import ( SOURCE_HITL_RESPONSE, SOURCE_ORCHESTRATOR, @@ -515,6 +515,10 @@ async def send_hitl_response(req: func.HttpRequest, client: df.DurableOrchestrat except ValueError: return self._build_error_response("Request body must be valid JSON.") + # Sanitize untrusted HTTP input before it reaches pickle.loads(). + # See strip_pickle_markers() docstring for details on the attack vector. + response_data = strip_pickle_markers(response_data) + # Send the response as an external event # The request_id is used as the event name for correlation await client.raise_event( diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index f48e55f5d5..4ed080eceb 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -22,9 +22,14 @@ import logging from contextlib import suppress from dataclasses import is_dataclass -from typing import Any - -from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from typing import Any, cast + +from agent_framework._workflows._checkpoint_encoding import ( + _PICKLE_MARKER, # pyright: ignore[reportPrivateUsage] + _TYPE_MARKER, # pyright: ignore[reportPrivateUsage] + decode_checkpoint_value, + encode_checkpoint_value, +) from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -48,6 +53,41 @@ def resolve_type(type_key: str) -> type | None: return None +# ============================================================================ +# Pickle marker sanitization (security) +# ============================================================================ + + +def strip_pickle_markers(data: Any) -> Any: + """Recursively strip pickle/type markers from untrusted data. + + The core checkpoint encoding uses ``__pickled__`` and ``__type__`` markers to + roundtrip arbitrary Python objects via *pickle*. If an attacker crafts an + HTTP payload that contains these markers, the data would flow into + ``pickle.loads()`` and enable **arbitrary code execution**. + + This function walks the incoming data structure and replaces any ``dict`` + that contains either marker key with ``None``, neutralising the attack + vector while leaving all other data untouched. + + It **must** be called on every value that originates from an untrusted + source (e.g. ``req.get_json()``) *before* the value is passed to + ``deserialize_value`` / ``decode_checkpoint_value``. + """ + if isinstance(data, dict): + if _PICKLE_MARKER in data or _TYPE_MARKER in data: + logger.debug("Stripped pickle/type markers from untrusted input.") + return None + typed_dict = cast(dict[str, Any], data) + return {k: strip_pickle_markers(v) for k, v in typed_dict.items()} + + if isinstance(data, list): + typed_list = cast(list[Any], data) # type: ignore[redundant-cast] + return [strip_pickle_markers(item) for item in typed_list] + + return data + + # ============================================================================ # Serialize / Deserialize # ============================================================================ @@ -117,7 +157,10 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any: if not isinstance(value, dict): return value - # Try decoding if data has pickle markers (from checkpoint encoding) + # Try decoding if data has pickle markers (from checkpoint encoding). + # NOTE: This function is general-purpose. Callers that handle untrusted + # data (e.g. HITL responses) MUST call strip_pickle_markers() before + # passing data here. See _deserialize_hitl_response in _workflow.py. decoded = deserialize_value(value) if not isinstance(decoded, dict): return decoded diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 60c04ad66c..a8774353ec 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -50,7 +50,7 @@ from ._context import CapturingRunnerContext from ._orchestration import AzureFunctionsAgentExecutor -from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value +from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value, strip_pickle_markers logger = logging.getLogger(__name__) @@ -961,6 +961,13 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None type(response_data).__name__, ) + if response_data is None: + return None + + # Sanitize untrusted external input before deserialization. + # HITL response data originates from an HTTP POST and must not contain + # pickle/type markers that would reach pickle.loads(). + response_data = strip_pickle_markers(response_data) if response_data is None: return None @@ -969,7 +976,7 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__) return response_data - # Try to deserialize using the type hint + # Try to reconstruct using the type hint (Pydantic / dataclass) if response_type_str: response_type = resolve_type(response_type_str) if response_type: @@ -979,6 +986,8 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None return result logger.warning("Could not resolve response type: %s", response_type_str) - # Fall back to generic deserialization - logger.debug("Falling back to generic deserialization") - return deserialize_value(response_data) + # No type hint available - return the sanitized dict as-is. + # We intentionally do NOT call deserialize_value() here because HITL + # response data is untrusted and must never flow into pickle.loads(). + logger.debug("No type hint; returning sanitized data as-is") + return response_data # type: ignore[reportUnknownVariableType] diff --git a/python/packages/azurefunctions/tests/test_func_utils.py b/python/packages/azurefunctions/tests/test_func_utils.py index 240e2f0a2c..63f0af0182 100644 --- a/python/packages/azurefunctions/tests/test_func_utils.py +++ b/python/packages/azurefunctions/tests/test_func_utils.py @@ -21,6 +21,7 @@ deserialize_value, reconstruct_to_type, serialize_value, + strip_pickle_markers, ) @@ -353,7 +354,11 @@ class Feedback: assert result.comment == "Great" def test_reconstruct_from_checkpoint_markers(self) -> None: - """Test that data with checkpoint markers is decoded via deserialize_value.""" + """Test that data with checkpoint markers is decoded via deserialize_value. + + reconstruct_to_type is general-purpose and handles trusted checkpoint + data. Untrusted HITL callers must call strip_pickle_markers() first. + """ original = SampleData(value=99, name="marker-test") encoded = serialize_value(original) @@ -372,3 +377,73 @@ class Unrelated: result = reconstruct_to_type(data, Unrelated) assert result == data + + def test_reconstruct_strips_injected_pickle_markers(self) -> None: + """End-to-end: strip_pickle_markers + reconstruct_to_type blocks attack. + + This mirrors the real HITL flow where callers sanitize before reconstruction. + """ + malicious = {"__pickled__": "gASVDgAAAAAAAACMBHRlc3SULg==", "__type__": "builtins:str"} + sanitized = strip_pickle_markers(malicious) + result = reconstruct_to_type(sanitized, str) + assert result is None + + +class TestStripPickleMarkers: + """Security tests for strip_pickle_markers — the defence-in-depth layer + that prevents untrusted HTTP input from reaching pickle.loads().""" + + def test_strips_top_level_pickle_marker(self) -> None: + """A dict containing __pickled__ must be replaced with None.""" + data = {"__pickled__": "PAYLOAD", "__type__": "os:system"} + assert strip_pickle_markers(data) is None + + def test_strips_top_level_type_marker_only(self) -> None: + """Even __type__ alone (without __pickled__) must be neutralised.""" + data = {"__type__": "os:system", "other": "value"} + assert strip_pickle_markers(data) is None + + def test_strips_nested_pickle_marker(self) -> None: + """Pickle markers nested inside a dict must be neutralised.""" + data = {"safe": "value", "nested": {"__pickled__": "PAYLOAD", "__type__": "os:system"}} + result = strip_pickle_markers(data) + assert result == {"safe": "value", "nested": None} + + def test_strips_pickle_marker_in_list(self) -> None: + """Pickle markers inside a list element must be neutralised.""" + data = [{"__pickled__": "PAYLOAD"}, "safe"] + result = strip_pickle_markers(data) + assert result == [None, "safe"] + + def test_strips_deeply_nested_marker(self) -> None: + """Deeply nested pickle markers must be neutralised.""" + data = {"a": {"b": {"c": {"__pickled__": "deep"}}}} + result = strip_pickle_markers(data) + assert result == {"a": {"b": {"c": None}}} + + def test_preserves_safe_dict(self) -> None: + """Dicts without pickle markers must be left untouched.""" + data = {"approved": True, "reason": "Looks good"} + assert strip_pickle_markers(data) == data + + def test_preserves_primitives(self) -> None: + """Primitive values must pass through unchanged.""" + assert strip_pickle_markers("hello") == "hello" + assert strip_pickle_markers(42) == 42 + assert strip_pickle_markers(None) is None + assert strip_pickle_markers(True) is True + + def test_preserves_safe_list(self) -> None: + """Lists without pickle markers must be left untouched.""" + data = [1, "two", {"key": "value"}] + assert strip_pickle_markers(data) == data + + def test_mixed_safe_and_malicious(self) -> None: + """Only the malicious entries should be stripped; safe entries remain.""" + data = { + "user_input": "hello", + "evil": {"__pickled__": "PAYLOAD", "__type__": "os:system"}, + "count": 42, + } + result = strip_pickle_markers(data) + assert result == {"user_input": "hello", "evil": None, "count": 42}