Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a
``RestrictedUnpickler`` is used that limits which classes may be instantiated
during deserialization. The default built-in safe set covers common Python
value types (primitives, datetime, uuid, ...), all ``agent_framework`` internal
types, and all ``openai.types`` types. Callers can extend the set by passing
additional ``"module:qualname"`` strings.
value types (primitives, datetime, uuid, ...). Callers MUST extend the set by
passing additional ``"module:qualname"`` strings for any framework or custom
types they wish to deserialize safely.
"""

from __future__ import annotations
Expand All @@ -34,12 +34,6 @@
# Types that are natively JSON-serializable and don't need pickling
_JSON_NATIVE_TYPES = (str, int, float, bool, type(None))

# Module prefix for framework-internal types that are always allowed
_FRAMEWORK_MODULE_PREFIX = "agent_framework."

# Module prefix for OpenAI SDK types that are always allowed
_OPENAI_MODULE_PREFIX = "openai.types."

# Built-in types considered safe for checkpoint deserialization.
# Each entry is a ``module:qualname`` string matching the format produced by
# :func:`_type_to_key`. These are the classes for which pickle's
Expand Down Expand Up @@ -87,9 +81,8 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
"""Unpickler that restricts which classes may be instantiated.

Only classes whose ``module:qualname`` key appears in the combined allow
set (built-in safe types + framework types + OpenAI SDK types +
caller-specified extras) are permitted. All other classes raise
:class:`pickle.UnpicklingError`.
set (built-in safe types + caller-specified extras) are permitted. All
other classes raise :class:`pickle.UnpicklingError`.
"""

def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
Expand All @@ -99,11 +92,11 @@ def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
def find_class(self, module: str, name: str) -> type:
type_key = f"{module}:{name}"

# SECURITY FIX: Over-broad module.startswith() checks have been removed.
# Instantiation must strictly adhere to the explicit whitelists.
if (
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
or type_key in self._allowed_types
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
or module.startswith(_OPENAI_MODULE_PREFIX)
):
return super().find_class(module, name) # type: ignore[no-any-return] # nosec

Expand Down Expand Up @@ -142,7 +135,7 @@ def decode_checkpoint_value(value: Any, *, allowed_types: frozenset[str] | None
Args:
value: A JSON-deserialized value from checkpoint storage.
allowed_types: If not ``None``, restrict pickle deserialization to the
built-in safe set, framework types, and the types listed here.
built-in safe set and the types listed here.
Each entry should use ``"module:qualname"`` format — that is, the
dotted module path followed by a colon and the class
``__qualname__``. For example, given a user-defined class::
Expand Down Expand Up @@ -261,8 +254,7 @@ def _base64_to_unpickle(encoded: str, *, allowed_types: frozenset[str] | None =
Args:
encoded: Base64-encoded pickle data.
allowed_types: If not ``None``, use restricted unpickling that only
permits built-in safe types, framework types, and the specified
extra types.
permits built-in safe types and the specified extra types.

Raises:
WorkflowCheckpointException: If the base64 data is corrupted, the pickle
Expand Down