Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/secure-yaml-and-breakdowns.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use safe YAML loaders and remove dynamic eval from parameter breakdown handling.
30 changes: 19 additions & 11 deletions policyengine_core/parameters/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from policyengine_core.warnings import LibYAMLWarning

try:
from yaml import CLoader as Loader
from yaml import CSafeLoader as Loader
except ImportError:
message = [
"libyaml is not installed in your environment.",
Expand All @@ -17,7 +17,7 @@
]
warnings.warn(" ".join(message), LibYAMLWarning)
from yaml import (
Loader,
SafeLoader as Loader,
) # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270)

ALLOWED_PARAM_TYPES = (float, int, bool, type(None), typing.List)
Expand All @@ -33,15 +33,23 @@ def date_constructor(_loader, node):


def dict_no_duplicate_constructor(loader, node, deep=False):
keys = [key.value for key, value in node.value]

if len(keys) != len(set(keys)):
duplicate = next((key for key in keys if keys.count(key) > 1))
raise yaml.parser.ParserError(
"", node.start_mark, f"Found duplicate key '{duplicate}'"
)

return loader.construct_mapping(node, deep)
loader.flatten_mapping(node)
pairs = loader.construct_pairs(node, deep=deep)
mapping = {}

for key, value in pairs:
try:
if key in mapping:
raise yaml.parser.ParserError(
"", node.start_mark, f"Found duplicate key '{key}'"
)
except TypeError as exc:
raise yaml.constructor.ConstructorError(
"", node.start_mark, f"Found unhashable key '{key}'"
) from exc
mapping[key] = value

return mapping


yaml.add_constructor(
Expand Down
119 changes: 107 additions & 12 deletions policyengine_core/parameters/operations/homogenize_parameters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import logging
from typing import Any, Dict, List, Type

Expand All @@ -6,6 +7,8 @@
from policyengine_core.parameters.parameter_node import ParameterNode
from policyengine_core.variables import Variable

MAX_DYNAMIC_BREAKDOWN_VALUES = 10_000


def homogenize_parameter_structures(
root: ParameterNode, variables: Dict[str, Variable], default_value: Any = 0
Expand Down Expand Up @@ -43,6 +46,11 @@ def get_breakdown_variables(node: ParameterNode) -> List[str]:
f"Invalid breakdown metadata for parameter {node.name}: {type(breakdown)}"
)
return None
if len(breakdown) == 0:
logging.warning(
f"Invalid breakdown metadata for parameter {node.name}: empty list"
)
return None
return breakdown
else:
return None
Expand Down Expand Up @@ -71,8 +79,7 @@ def homogenize_parameter_node(
elif dtype == bool:
possible_values = [True, False]
else:
# Try to execute the breakdown as Python code
possible_values = list(eval(first_breakdown))
possible_values = evaluate_dynamic_breakdown(first_breakdown)
if not hasattr(node, "children"):
node = ParameterNode(
node.name,
Expand All @@ -96,19 +103,107 @@ def homogenize_parameter_node(
{"0000-01-01": default_value, "2040-01-01": default_value},
),
)
possible_values_str = {str(v) for v in possible_values}
extra_children = []
for child in node.children:
child_key = child.split(".")[-1]
if (
child_key not in possible_values_str
and str(child_key) not in possible_values_str
):
extra_children.append(child_key)
if extra_children:
raise ValueError(
f"Parameter {node.name} has children {extra_children} "
f"that are not in the possible values of the breakdown "
f"variable '{first_breakdown}'. Check that the breakdown "
f"metadata references the correct variable and that all "
f"parameter keys are valid enum values."
)
for child in node.children:
if child.split(".")[-1] not in possible_values:
try:
int(child)
is_int = True
except:
is_int = False
if not is_int or str(child) not in node.children:
logging.warning(
f"Parameter {node.name} has a child {child} that is not in the possible values of {first_breakdown}, ignoring."
)
if further_breakdown:
node.children[child] = homogenize_parameter_node(
node.children[child], breakdown[1:], variables, default_value
)
return node


def evaluate_dynamic_breakdown(expression: str) -> List[Any]:
"""Safely evaluate a dynamic breakdown expression.

The parameter metadata only needs literal collections and the documented
``range(...)`` / ``list(range(...))`` forms. Anything else is rejected.
"""

parsed = ast.parse(expression, mode="eval")
evaluated = evaluate_dynamic_breakdown_node(parsed.body)
if isinstance(evaluated, range):
validate_dynamic_breakdown_range_cardinality(evaluated, expression)
return list(evaluated)
if isinstance(evaluated, (list, tuple)):
validate_dynamic_breakdown_cardinality(len(evaluated), expression)
return list(evaluated)
if isinstance(evaluated, set):
validate_dynamic_breakdown_cardinality(len(evaluated), expression)
return list(evaluated)
raise ValueError(
f"Invalid dynamic breakdown expression '{expression}'. "
"Only literal collections and range() calls are allowed."
)


def validate_dynamic_breakdown_cardinality(count: int, expression: str) -> None:
if count > MAX_DYNAMIC_BREAKDOWN_VALUES:
raise ValueError(
f"Dynamic breakdown expression '{expression}' produces {count} values, "
f"which exceeds the maximum of {MAX_DYNAMIC_BREAKDOWN_VALUES}."
)


def validate_dynamic_breakdown_range_cardinality(
values: range, expression: str
) -> None:
try:
count = len(values)
except OverflowError as exc:
raise ValueError(
f"Dynamic breakdown expression '{expression}' produces too many values."
) from exc
validate_dynamic_breakdown_cardinality(count, expression)


def evaluate_dynamic_breakdown_node(node: ast.AST) -> Any:
if isinstance(node, ast.Constant):
return node.value
if isinstance(node, ast.List):
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
return [evaluate_dynamic_breakdown_node(element) for element in node.elts]
if isinstance(node, ast.Tuple):
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
return tuple(evaluate_dynamic_breakdown_node(element) for element in node.elts)
if isinstance(node, ast.Set):
validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node))
return {evaluate_dynamic_breakdown_node(element) for element in node.elts}
if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)):
operand = evaluate_dynamic_breakdown_node(node.operand)
return operand if isinstance(node.op, ast.UAdd) else -operand
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
if node.func.id == "range":
args = [evaluate_dynamic_breakdown_node(arg) for arg in node.args]
if node.keywords:
raise ValueError("range() keyword arguments are not allowed")
result = range(*args)
validate_dynamic_breakdown_range_cardinality(result, ast.unparse(node))
return result
if node.func.id == "list":
if len(node.args) != 1 or node.keywords:
raise ValueError("list() must contain a single positional argument")
evaluated = evaluate_dynamic_breakdown_node(node.args[0])
if isinstance(evaluated, (range, list, tuple, set)):
return evaluated
raise ValueError(
"list() only supports range() and literal collection expressions"
)
raise ValueError(
f"Unsupported dynamic breakdown expression: {ast.unparse(node) if hasattr(ast, 'unparse') else type(node).__name__}"
)
19 changes: 16 additions & 3 deletions policyengine_core/tools/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def import_yaml():
import yaml

try:
from yaml import CLoader as Loader
from yaml import CSafeLoader as Loader
except ImportError:
log.warning(
" "
Expand Down Expand Up @@ -119,7 +119,7 @@ def __init__(self, *, tax_benefit_system, options, **kwargs):
def collect(self):
try:
tests = yaml.load(self.path.open(), Loader=Loader)
except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError):
except (yaml.YAMLError, TypeError):
message = os.linesep.join(
[
traceback.format_exc(),
Expand All @@ -132,6 +132,11 @@ def collect(self):
tests: List[Dict] = [tests]

for test in tests:
if not isinstance(test, dict):
raise ValueError(
f"'{self.path}' is not a valid YAML test file. "
"Expected a mapping or a list of mappings."
)
if not self.should_ignore(test):
yield YamlItem.from_parent(
self,
Expand All @@ -143,11 +148,19 @@ def collect(self):

def should_ignore(self, test):
name_filter = self.options.get("name_filter")
keywords = test.get("keywords", [])
if keywords is None:
keywords = []
if not isinstance(keywords, list):
raise ValueError(
f"'{self.path}' is not a valid YAML test file. "
"'keywords' must be a list."
)
return (
name_filter is not None
and name_filter not in os.path.splitext(self.fspath.basename)[0]
and name_filter not in test.get("name", "")
and name_filter not in test.get("keywords", [])
and name_filter not in keywords
)


Expand Down
111 changes: 111 additions & 0 deletions tests/core/test_parameter_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pytest

from policyengine_core.errors import ParameterParsingError
from policyengine_core.parameters import ParameterNode, homogenize_parameter_structures
from policyengine_core.parameters.helpers import _load_yaml_file
from policyengine_core.parameters.operations.homogenize_parameters import (
MAX_DYNAMIC_BREAKDOWN_VALUES,
)


def test_parameter_yaml_loader_rejects_python_object_tags(tmp_path, monkeypatch):
calls = []

monkeypatch.setattr(
"os.system",
lambda command: calls.append(command) or 0,
)

yaml_path = tmp_path / "malicious.yaml"
yaml_path.write_text(
'!!python/object/apply:os.system ["echo pwned"]\n',
encoding="utf-8",
)

with pytest.raises(ParameterParsingError):
_load_yaml_file(str(yaml_path))

assert calls == []


def test_homogenize_parameter_structures_rejects_dynamic_breakdown_code(
monkeypatch,
):
eval_calls = []

monkeypatch.setattr(
"builtins.eval",
lambda expression, globals=None, locals=None: (
eval_calls.append(expression) or range(1, 4)
),
)

root = ParameterNode(
data={
"value_by_category": {
"metadata": {
"breakdown": ['__import__("os").system("echo pwned")'],
},
}
}
)

with pytest.raises(ValueError, match="breakdown"):
homogenize_parameter_structures(root, {}, default_value=0)

assert eval_calls == []


def test_homogenize_parameter_structures_rejects_oversized_dynamic_breakdown():
root = ParameterNode(
data={
"value_by_category": {
"metadata": {
"breakdown": [f"list(range({MAX_DYNAMIC_BREAKDOWN_VALUES + 1}))"],
},
}
}
)

with pytest.raises(ValueError, match="exceeds the maximum"):
homogenize_parameter_structures(root, {}, default_value=0)


def test_homogenize_parameter_structures_rejects_overflowing_dynamic_breakdown():
huge_stop = "1" + ("0" * 100)
root = ParameterNode(
data={
"value_by_category": {
"metadata": {
"breakdown": [f"range(0, {huge_stop})"],
},
}
}
)

with pytest.raises(ValueError, match="too many values"):
homogenize_parameter_structures(root, {}, default_value=0)


def test_parameter_yaml_loader_rejects_implicit_duplicate_keys(tmp_path):
yaml_path = tmp_path / "duplicate-bools.yaml"
yaml_path.write_text("true: 1\nTrue: 2\n", encoding="utf-8")

with pytest.raises(ParameterParsingError, match="duplicate key"):
_load_yaml_file(str(yaml_path))


def test_homogenize_parameter_structures_ignores_empty_breakdown_lists():
root = ParameterNode(
data={
"value_by_category": {
"metadata": {
"breakdown": [],
},
}
}
)

result = homogenize_parameter_structures(root, {}, default_value=0)

assert result is root
Loading
Loading