Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor
from ._serialization import deserialize_value, serialize_value
from ._serialization import deserialize_value, serialize_value, strip_pickle_markers
from ._workflow import (
SOURCE_HITL_RESPONSE,
SOURCE_ORCHESTRATOR,
Expand Down Expand Up @@ -515,6 +515,10 @@ async def send_hitl_response(req: func.HttpRequest, client: df.DurableOrchestrat
except ValueError:
return self._build_error_response("Request body must be valid JSON.")

# Sanitize untrusted HTTP input before it reaches pickle.loads().
# See strip_pickle_markers() docstring for details on the attack vector.
response_data = strip_pickle_markers(response_data)

# Send the response as an external event
# The request_id is used as the event name for correlation
await client.raise_event(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@
import logging
from contextlib import suppress
from dataclasses import is_dataclass
from typing import Any

from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from typing import Any, cast

from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER, # pyright: ignore[reportPrivateUsage]
_TYPE_MARKER, # pyright: ignore[reportPrivateUsage]
decode_checkpoint_value,
encode_checkpoint_value,
)
from pydantic import BaseModel

logger = logging.getLogger(__name__)
Expand All @@ -48,6 +53,41 @@ def resolve_type(type_key: str) -> type | None:
return None


# ============================================================================
# Pickle marker sanitization (security)
# ============================================================================


def strip_pickle_markers(data: Any) -> Any:
"""Recursively strip pickle/type markers from untrusted data.
The core checkpoint encoding uses ``__pickled__`` and ``__type__`` markers to
roundtrip arbitrary Python objects via *pickle*. If an attacker crafts an
HTTP payload that contains these markers, the data would flow into
``pickle.loads()`` and enable **arbitrary code execution**.
This function walks the incoming data structure and replaces any ``dict``
that contains either marker key with ``None``, neutralising the attack
vector while leaving all other data untouched.
It **must** be called on every value that originates from an untrusted
source (e.g. ``req.get_json()``) *before* the value is passed to
``deserialize_value`` / ``decode_checkpoint_value``.
"""
if isinstance(data, dict):
if _PICKLE_MARKER in data or _TYPE_MARKER in data:
logger.debug("Stripped pickle/type markers from untrusted input.")
return None
typed_dict = cast(dict[str, Any], data)
return {k: strip_pickle_markers(v) for k, v in typed_dict.items()}

if isinstance(data, list):
typed_list = cast(list[Any], data) # type: ignore[redundant-cast]
return [strip_pickle_markers(item) for item in typed_list]

return data


# ============================================================================
# Serialize / Deserialize
# ============================================================================
Expand Down Expand Up @@ -117,7 +157,10 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any:
if not isinstance(value, dict):
return value

# Try decoding if data has pickle markers (from checkpoint encoding)
# Try decoding if data has pickle markers (from checkpoint encoding).
# NOTE: This function is general-purpose. Callers that handle untrusted
# data (e.g. HITL responses) MUST call strip_pickle_markers() before
# passing data here. See _deserialize_hitl_response in _workflow.py.
decoded = deserialize_value(value)
if not isinstance(decoded, dict):
return decoded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

from ._context import CapturingRunnerContext
from ._orchestration import AzureFunctionsAgentExecutor
from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value
from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value, strip_pickle_markers

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -961,6 +961,13 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
type(response_data).__name__,
)

if response_data is None:
return None

# Sanitize untrusted external input before deserialization.
# HITL response data originates from an HTTP POST and must not contain
# pickle/type markers that would reach pickle.loads().
response_data = strip_pickle_markers(response_data)
if response_data is None:
return None

Expand All @@ -969,7 +976,7 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__)
return response_data

# Try to deserialize using the type hint
# Try to reconstruct using the type hint (Pydantic / dataclass)
if response_type_str:
response_type = resolve_type(response_type_str)
if response_type:
Expand All @@ -979,6 +986,8 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
return result
logger.warning("Could not resolve response type: %s", response_type_str)

# Fall back to generic deserialization
logger.debug("Falling back to generic deserialization")
return deserialize_value(response_data)
# No type hint available - return the sanitized dict as-is.
# We intentionally do NOT call deserialize_value() here because HITL
# response data is untrusted and must never flow into pickle.loads().
logger.debug("No type hint; returning sanitized data as-is")
return response_data # type: ignore[reportUnknownVariableType]
77 changes: 76 additions & 1 deletion python/packages/azurefunctions/tests/test_func_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
deserialize_value,
reconstruct_to_type,
serialize_value,
strip_pickle_markers,
)


Expand Down Expand Up @@ -353,7 +354,11 @@ class Feedback:
assert result.comment == "Great"

def test_reconstruct_from_checkpoint_markers(self) -> None:
"""Test that data with checkpoint markers is decoded via deserialize_value."""
"""Test that data with checkpoint markers is decoded via deserialize_value.

reconstruct_to_type is general-purpose and handles trusted checkpoint
data. Untrusted HITL callers must call strip_pickle_markers() first.
"""
original = SampleData(value=99, name="marker-test")
encoded = serialize_value(original)

Expand All @@ -372,3 +377,73 @@ class Unrelated:
result = reconstruct_to_type(data, Unrelated)

assert result == data

def test_reconstruct_strips_injected_pickle_markers(self) -> None:
"""End-to-end: strip_pickle_markers + reconstruct_to_type blocks attack.

This mirrors the real HITL flow where callers sanitize before reconstruction.
"""
malicious = {"__pickled__": "gASVDgAAAAAAAACMBHRlc3SULg==", "__type__": "builtins:str"}
sanitized = strip_pickle_markers(malicious)
result = reconstruct_to_type(sanitized, str)
assert result is None


class TestStripPickleMarkers:
"""Security tests for strip_pickle_markers — the defence-in-depth layer
that prevents untrusted HTTP input from reaching pickle.loads()."""

def test_strips_top_level_pickle_marker(self) -> None:
"""A dict containing __pickled__ must be replaced with None."""
data = {"__pickled__": "PAYLOAD", "__type__": "os:system"}
assert strip_pickle_markers(data) is None

def test_strips_top_level_type_marker_only(self) -> None:
"""Even __type__ alone (without __pickled__) must be neutralised."""
data = {"__type__": "os:system", "other": "value"}
assert strip_pickle_markers(data) is None

def test_strips_nested_pickle_marker(self) -> None:
"""Pickle markers nested inside a dict must be neutralised."""
data = {"safe": "value", "nested": {"__pickled__": "PAYLOAD", "__type__": "os:system"}}
result = strip_pickle_markers(data)
assert result == {"safe": "value", "nested": None}

def test_strips_pickle_marker_in_list(self) -> None:
"""Pickle markers inside a list element must be neutralised."""
data = [{"__pickled__": "PAYLOAD"}, "safe"]
result = strip_pickle_markers(data)
assert result == [None, "safe"]

def test_strips_deeply_nested_marker(self) -> None:
"""Deeply nested pickle markers must be neutralised."""
data = {"a": {"b": {"c": {"__pickled__": "deep"}}}}
result = strip_pickle_markers(data)
assert result == {"a": {"b": {"c": None}}}

def test_preserves_safe_dict(self) -> None:
"""Dicts without pickle markers must be left untouched."""
data = {"approved": True, "reason": "Looks good"}
assert strip_pickle_markers(data) == data

def test_preserves_primitives(self) -> None:
"""Primitive values must pass through unchanged."""
assert strip_pickle_markers("hello") == "hello"
assert strip_pickle_markers(42) == 42
assert strip_pickle_markers(None) is None
assert strip_pickle_markers(True) is True

def test_preserves_safe_list(self) -> None:
"""Lists without pickle markers must be left untouched."""
data = [1, "two", {"key": "value"}]
assert strip_pickle_markers(data) == data

def test_mixed_safe_and_malicious(self) -> None:
"""Only the malicious entries should be stripped; safe entries remain."""
data = {
"user_input": "hello",
"evil": {"__pickled__": "PAYLOAD", "__type__": "os:system"},
"count": 42,
}
result = strip_pickle_markers(data)
assert result == {"user_input": "hello", "evil": None, "count": 42}