From 655784cdefe087593ab2dc271620e1c8361b27d8 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 12 Apr 2026 12:27:54 -0400 Subject: [PATCH 1/2] Use safe YAML loaders for parameters --- .../secure-yaml-and-breakdowns.fixed.md | 1 + policyengine_core/parameters/config.py | 4 +- .../operations/homogenize_parameters.py | 81 ++++++++++++++++--- policyengine_core/tools/test_runner.py | 9 ++- tests/core/test_parameter_security.py | 52 ++++++++++++ .../tools/test_runner/test_yaml_runner.py | 24 ++++++ 6 files changed, 155 insertions(+), 16 deletions(-) create mode 100644 changelog.d/secure-yaml-and-breakdowns.fixed.md create mode 100644 tests/core/test_parameter_security.py diff --git a/changelog.d/secure-yaml-and-breakdowns.fixed.md b/changelog.d/secure-yaml-and-breakdowns.fixed.md new file mode 100644 index 000000000..586241d33 --- /dev/null +++ b/changelog.d/secure-yaml-and-breakdowns.fixed.md @@ -0,0 +1 @@ +Use safe YAML loaders and remove dynamic eval from parameter breakdown handling. diff --git a/policyengine_core/parameters/config.py b/policyengine_core/parameters/config.py index 73982d8a3..d2dbee20a 100644 --- a/policyengine_core/parameters/config.py +++ b/policyengine_core/parameters/config.py @@ -7,7 +7,7 @@ from policyengine_core.warnings import LibYAMLWarning try: - from yaml import CLoader as Loader + from yaml import CSafeLoader as Loader except ImportError: message = [ "libyaml is not installed in your environment.", @@ -17,7 +17,7 @@ ] warnings.warn(" ".join(message), LibYAMLWarning) from yaml import ( - Loader, + SafeLoader as Loader, ) # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270) ALLOWED_PARAM_TYPES = (float, int, bool, type(None), typing.List) diff --git a/policyengine_core/parameters/operations/homogenize_parameters.py b/policyengine_core/parameters/operations/homogenize_parameters.py index 6d2e830a7..cd310a9e5 100644 --- a/policyengine_core/parameters/operations/homogenize_parameters.py +++ b/policyengine_core/parameters/operations/homogenize_parameters.py @@ -1,3 +1,4 @@ +import ast import logging from typing import Any, Dict, List, Type @@ -71,8 +72,7 @@ def homogenize_parameter_node( elif dtype == bool: possible_values = [True, False] else: - # Try to execute the breakdown as Python code - possible_values = list(eval(first_breakdown)) + possible_values = evaluate_dynamic_breakdown(first_breakdown) if not hasattr(node, "children"): node = ParameterNode( node.name, @@ -96,19 +96,76 @@ def homogenize_parameter_node( {"0000-01-01": default_value, "2040-01-01": default_value}, ), ) + possible_values_str = {str(v) for v in possible_values} + extra_children = [] + for child in node.children: + child_key = child.split(".")[-1] + if ( + child_key not in possible_values_str + and str(child_key) not in possible_values_str + ): + extra_children.append(child_key) + if extra_children: + raise ValueError( + f"Parameter {node.name} has children {extra_children} " + f"that are not in the possible values of the breakdown " + f"variable '{first_breakdown}'. Check that the breakdown " + f"metadata references the correct variable and that all " + f"parameter keys are valid enum values." + ) for child in node.children: - if child.split(".")[-1] not in possible_values: - try: - int(child) - is_int = True - except: - is_int = False - if not is_int or str(child) not in node.children: - logging.warning( - f"Parameter {node.name} has a child {child} that is not in the possible values of {first_breakdown}, ignoring." - ) if further_breakdown: node.children[child] = homogenize_parameter_node( node.children[child], breakdown[1:], variables, default_value ) return node + + +def evaluate_dynamic_breakdown(expression: str) -> List[Any]: + """Safely evaluate a dynamic breakdown expression. + + The parameter metadata only needs literal collections and the documented + ``range(...)`` / ``list(range(...))`` forms. Anything else is rejected. + """ + + parsed = ast.parse(expression, mode="eval") + evaluated = evaluate_dynamic_breakdown_node(parsed.body) + if isinstance(evaluated, range): + return list(evaluated) + if isinstance(evaluated, (list, tuple)): + return list(evaluated) + if isinstance(evaluated, set): + return list(evaluated) + raise ValueError( + f"Invalid dynamic breakdown expression '{expression}'. " + "Only literal collections and range() calls are allowed." + ) + + +def evaluate_dynamic_breakdown_node(node: ast.AST) -> Any: + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.List): + return [evaluate_dynamic_breakdown_node(element) for element in node.elts] + if isinstance(node, ast.Tuple): + return tuple(evaluate_dynamic_breakdown_node(element) for element in node.elts) + if isinstance(node, ast.Set): + return {evaluate_dynamic_breakdown_node(element) for element in node.elts} + if isinstance(node, ast.UnaryOp) and isinstance( + node.op, (ast.UAdd, ast.USub) + ): + operand = evaluate_dynamic_breakdown_node(node.operand) + return operand if isinstance(node.op, ast.UAdd) else -operand + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + if node.func.id == "range": + args = [evaluate_dynamic_breakdown_node(arg) for arg in node.args] + if node.keywords: + raise ValueError("range() keyword arguments are not allowed") + return range(*args) + if node.func.id == "list": + if len(node.args) != 1 or node.keywords: + raise ValueError("list() must contain a single positional argument") + return list(evaluate_dynamic_breakdown_node(node.args[0])) + raise ValueError( + f"Unsupported dynamic breakdown expression: {ast.unparse(node) if hasattr(ast, 'unparse') else type(node).__name__}" + ) diff --git a/policyengine_core/tools/test_runner.py b/policyengine_core/tools/test_runner.py index 1874ceed7..1fa0139f8 100644 --- a/policyengine_core/tools/test_runner.py +++ b/policyengine_core/tools/test_runner.py @@ -31,7 +31,7 @@ def import_yaml(): import yaml try: - from yaml import CLoader as Loader + from yaml import CSafeLoader as Loader except ImportError: log.warning( " " @@ -119,7 +119,12 @@ def __init__(self, *, tax_benefit_system, options, **kwargs): def collect(self): try: tests = yaml.load(self.path.open(), Loader=Loader) - except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): + except ( + yaml.scanner.ScannerError, + yaml.parser.ParserError, + yaml.constructor.ConstructorError, + TypeError, + ): message = os.linesep.join( [ traceback.format_exc(), diff --git a/tests/core/test_parameter_security.py b/tests/core/test_parameter_security.py new file mode 100644 index 000000000..add0659ed --- /dev/null +++ b/tests/core/test_parameter_security.py @@ -0,0 +1,52 @@ +import pytest + +from policyengine_core.errors import ParameterParsingError +from policyengine_core.parameters import ParameterNode, homogenize_parameter_structures +from policyengine_core.parameters.helpers import _load_yaml_file + + +def test_parameter_yaml_loader_rejects_python_object_tags(tmp_path, monkeypatch): + calls = [] + + monkeypatch.setattr( + "os.system", + lambda command: calls.append(command) or 0, + ) + + yaml_path = tmp_path / "malicious.yaml" + yaml_path.write_text( + '!!python/object/apply:os.system ["echo pwned"]\n', + encoding="utf-8", + ) + + with pytest.raises(ParameterParsingError): + _load_yaml_file(str(yaml_path)) + + assert calls == [] + + +def test_homogenize_parameter_structures_rejects_dynamic_breakdown_code( + monkeypatch, +): + eval_calls = [] + + monkeypatch.setattr( + "builtins.eval", + lambda expression, globals=None, locals=None: eval_calls.append(expression) + or range(1, 4), + ) + + root = ParameterNode( + data={ + "value_by_category": { + "metadata": { + "breakdown": ['__import__("os").system("echo pwned")'], + }, + } + } + ) + + with pytest.raises(ValueError, match="breakdown"): + homogenize_parameter_structures(root, {}, default_value=0) + + assert eval_calls == [] diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index 72e3faadb..e15d191f1 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -203,6 +203,30 @@ def test_performance_tables_option_output(): clean_performance_files(paths) +def test_yaml_runner_rejects_python_object_tags(tmp_path, monkeypatch): + calls = [] + yaml_path = tmp_path / "malicious.yaml" + yaml_path.write_text( + '!!python/object/apply:os.system ["echo pwned"]\n', + encoding="utf-8", + ) + + monkeypatch.setattr( + "os.system", + lambda command: calls.append(command) or 0, + ) + + malicious_yaml_file = object.__new__(YamlFile) + malicious_yaml_file.path = yaml_path + malicious_yaml_file.options = {} + malicious_yaml_file.tax_benefit_system = TaxBenefitSystem() + + with pytest.raises(ValueError): + list(malicious_yaml_file.collect()) + + assert calls == [] + + def clean_performance_files(paths: List[str]): for path in paths: if os.path.isfile(path): From 41a493d422441a6b81a538f6c3942574d0701cc5 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 12 Apr 2026 12:35:27 -0400 Subject: [PATCH 2/2] Format YAML security hardening changes --- policyengine_core/parameters/config.py | 26 ++++-- .../operations/homogenize_parameters.py | 48 +++++++++-- policyengine_core/tools/test_runner.py | 22 +++-- tests/core/test_parameter_security.py | 63 +++++++++++++- .../tools/test_runner/test_yaml_runner.py | 83 +++++++++++++++++++ 5 files changed, 219 insertions(+), 23 deletions(-) diff --git a/policyengine_core/parameters/config.py b/policyengine_core/parameters/config.py index d2dbee20a..2de0d0c5d 100644 --- a/policyengine_core/parameters/config.py +++ b/policyengine_core/parameters/config.py @@ -33,15 +33,23 @@ def date_constructor(_loader, node): def dict_no_duplicate_constructor(loader, node, deep=False): - keys = [key.value for key, value in node.value] - - if len(keys) != len(set(keys)): - duplicate = next((key for key in keys if keys.count(key) > 1)) - raise yaml.parser.ParserError( - "", node.start_mark, f"Found duplicate key '{duplicate}'" - ) - - return loader.construct_mapping(node, deep) + loader.flatten_mapping(node) + pairs = loader.construct_pairs(node, deep=deep) + mapping = {} + + for key, value in pairs: + try: + if key in mapping: + raise yaml.parser.ParserError( + "", node.start_mark, f"Found duplicate key '{key}'" + ) + except TypeError as exc: + raise yaml.constructor.ConstructorError( + "", node.start_mark, f"Found unhashable key '{key}'" + ) from exc + mapping[key] = value + + return mapping yaml.add_constructor( diff --git a/policyengine_core/parameters/operations/homogenize_parameters.py b/policyengine_core/parameters/operations/homogenize_parameters.py index cd310a9e5..39164b721 100644 --- a/policyengine_core/parameters/operations/homogenize_parameters.py +++ b/policyengine_core/parameters/operations/homogenize_parameters.py @@ -7,6 +7,8 @@ from policyengine_core.parameters.parameter_node import ParameterNode from policyengine_core.variables import Variable +MAX_DYNAMIC_BREAKDOWN_VALUES = 10_000 + def homogenize_parameter_structures( root: ParameterNode, variables: Dict[str, Variable], default_value: Any = 0 @@ -44,6 +46,11 @@ def get_breakdown_variables(node: ParameterNode) -> List[str]: f"Invalid breakdown metadata for parameter {node.name}: {type(breakdown)}" ) return None + if len(breakdown) == 0: + logging.warning( + f"Invalid breakdown metadata for parameter {node.name}: empty list" + ) + return None return breakdown else: return None @@ -131,10 +138,13 @@ def evaluate_dynamic_breakdown(expression: str) -> List[Any]: parsed = ast.parse(expression, mode="eval") evaluated = evaluate_dynamic_breakdown_node(parsed.body) if isinstance(evaluated, range): + validate_dynamic_breakdown_range_cardinality(evaluated, expression) return list(evaluated) if isinstance(evaluated, (list, tuple)): + validate_dynamic_breakdown_cardinality(len(evaluated), expression) return list(evaluated) if isinstance(evaluated, set): + validate_dynamic_breakdown_cardinality(len(evaluated), expression) return list(evaluated) raise ValueError( f"Invalid dynamic breakdown expression '{expression}'. " @@ -142,18 +152,39 @@ def evaluate_dynamic_breakdown(expression: str) -> List[Any]: ) +def validate_dynamic_breakdown_cardinality(count: int, expression: str) -> None: + if count > MAX_DYNAMIC_BREAKDOWN_VALUES: + raise ValueError( + f"Dynamic breakdown expression '{expression}' produces {count} values, " + f"which exceeds the maximum of {MAX_DYNAMIC_BREAKDOWN_VALUES}." + ) + + +def validate_dynamic_breakdown_range_cardinality( + values: range, expression: str +) -> None: + try: + count = len(values) + except OverflowError as exc: + raise ValueError( + f"Dynamic breakdown expression '{expression}' produces too many values." + ) from exc + validate_dynamic_breakdown_cardinality(count, expression) + + def evaluate_dynamic_breakdown_node(node: ast.AST) -> Any: if isinstance(node, ast.Constant): return node.value if isinstance(node, ast.List): + validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node)) return [evaluate_dynamic_breakdown_node(element) for element in node.elts] if isinstance(node, ast.Tuple): + validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node)) return tuple(evaluate_dynamic_breakdown_node(element) for element in node.elts) if isinstance(node, ast.Set): + validate_dynamic_breakdown_cardinality(len(node.elts), ast.unparse(node)) return {evaluate_dynamic_breakdown_node(element) for element in node.elts} - if isinstance(node, ast.UnaryOp) and isinstance( - node.op, (ast.UAdd, ast.USub) - ): + if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)): operand = evaluate_dynamic_breakdown_node(node.operand) return operand if isinstance(node.op, ast.UAdd) else -operand if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): @@ -161,11 +192,18 @@ def evaluate_dynamic_breakdown_node(node: ast.AST) -> Any: args = [evaluate_dynamic_breakdown_node(arg) for arg in node.args] if node.keywords: raise ValueError("range() keyword arguments are not allowed") - return range(*args) + result = range(*args) + validate_dynamic_breakdown_range_cardinality(result, ast.unparse(node)) + return result if node.func.id == "list": if len(node.args) != 1 or node.keywords: raise ValueError("list() must contain a single positional argument") - return list(evaluate_dynamic_breakdown_node(node.args[0])) + evaluated = evaluate_dynamic_breakdown_node(node.args[0]) + if isinstance(evaluated, (range, list, tuple, set)): + return evaluated + raise ValueError( + "list() only supports range() and literal collection expressions" + ) raise ValueError( f"Unsupported dynamic breakdown expression: {ast.unparse(node) if hasattr(ast, 'unparse') else type(node).__name__}" ) diff --git a/policyengine_core/tools/test_runner.py b/policyengine_core/tools/test_runner.py index 1fa0139f8..dfaa65104 100644 --- a/policyengine_core/tools/test_runner.py +++ b/policyengine_core/tools/test_runner.py @@ -119,12 +119,7 @@ def __init__(self, *, tax_benefit_system, options, **kwargs): def collect(self): try: tests = yaml.load(self.path.open(), Loader=Loader) - except ( - yaml.scanner.ScannerError, - yaml.parser.ParserError, - yaml.constructor.ConstructorError, - TypeError, - ): + except (yaml.YAMLError, TypeError): message = os.linesep.join( [ traceback.format_exc(), @@ -137,6 +132,11 @@ def collect(self): tests: List[Dict] = [tests] for test in tests: + if not isinstance(test, dict): + raise ValueError( + f"'{self.path}' is not a valid YAML test file. " + "Expected a mapping or a list of mappings." + ) if not self.should_ignore(test): yield YamlItem.from_parent( self, @@ -148,11 +148,19 @@ def collect(self): def should_ignore(self, test): name_filter = self.options.get("name_filter") + keywords = test.get("keywords", []) + if keywords is None: + keywords = [] + if not isinstance(keywords, list): + raise ValueError( + f"'{self.path}' is not a valid YAML test file. " + "'keywords' must be a list." + ) return ( name_filter is not None and name_filter not in os.path.splitext(self.fspath.basename)[0] and name_filter not in test.get("name", "") - and name_filter not in test.get("keywords", []) + and name_filter not in keywords ) diff --git a/tests/core/test_parameter_security.py b/tests/core/test_parameter_security.py index add0659ed..3c8cace2c 100644 --- a/tests/core/test_parameter_security.py +++ b/tests/core/test_parameter_security.py @@ -3,6 +3,9 @@ from policyengine_core.errors import ParameterParsingError from policyengine_core.parameters import ParameterNode, homogenize_parameter_structures from policyengine_core.parameters.helpers import _load_yaml_file +from policyengine_core.parameters.operations.homogenize_parameters import ( + MAX_DYNAMIC_BREAKDOWN_VALUES, +) def test_parameter_yaml_loader_rejects_python_object_tags(tmp_path, monkeypatch): @@ -32,8 +35,9 @@ def test_homogenize_parameter_structures_rejects_dynamic_breakdown_code( monkeypatch.setattr( "builtins.eval", - lambda expression, globals=None, locals=None: eval_calls.append(expression) - or range(1, 4), + lambda expression, globals=None, locals=None: ( + eval_calls.append(expression) or range(1, 4) + ), ) root = ParameterNode( @@ -50,3 +54,58 @@ def test_homogenize_parameter_structures_rejects_dynamic_breakdown_code( homogenize_parameter_structures(root, {}, default_value=0) assert eval_calls == [] + + +def test_homogenize_parameter_structures_rejects_oversized_dynamic_breakdown(): + root = ParameterNode( + data={ + "value_by_category": { + "metadata": { + "breakdown": [f"list(range({MAX_DYNAMIC_BREAKDOWN_VALUES + 1}))"], + }, + } + } + ) + + with pytest.raises(ValueError, match="exceeds the maximum"): + homogenize_parameter_structures(root, {}, default_value=0) + + +def test_homogenize_parameter_structures_rejects_overflowing_dynamic_breakdown(): + huge_stop = "1" + ("0" * 100) + root = ParameterNode( + data={ + "value_by_category": { + "metadata": { + "breakdown": [f"range(0, {huge_stop})"], + }, + } + } + ) + + with pytest.raises(ValueError, match="too many values"): + homogenize_parameter_structures(root, {}, default_value=0) + + +def test_parameter_yaml_loader_rejects_implicit_duplicate_keys(tmp_path): + yaml_path = tmp_path / "duplicate-bools.yaml" + yaml_path.write_text("true: 1\nTrue: 2\n", encoding="utf-8") + + with pytest.raises(ParameterParsingError, match="duplicate key"): + _load_yaml_file(str(yaml_path)) + + +def test_homogenize_parameter_structures_ignores_empty_breakdown_lists(): + root = ParameterNode( + data={ + "value_by_category": { + "metadata": { + "breakdown": [], + }, + } + } + ) + + result = homogenize_parameter_structures(root, {}, default_value=0) + + assert result is root diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index e15d191f1..9bd6f36f6 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -227,6 +227,89 @@ def test_yaml_runner_rejects_python_object_tags(tmp_path, monkeypatch): assert calls == [] +def test_yaml_runner_wraps_composer_errors(tmp_path): + yaml_path = tmp_path / "invalid-anchor.yaml" + yaml_path.write_text("value: *missing_anchor\n", encoding="utf-8") + + invalid_yaml_file = object.__new__(YamlFile) + invalid_yaml_file.path = yaml_path + invalid_yaml_file.options = {} + invalid_yaml_file.tax_benefit_system = TaxBenefitSystem() + + with pytest.raises(ValueError, match="not a valid YAML file"): + list(invalid_yaml_file.collect()) + + +def test_yaml_runner_rejects_scalar_roots(tmp_path): + yaml_path = tmp_path / "scalar.yaml" + yaml_path.write_text("foo\n", encoding="utf-8") + + scalar_yaml_file = object.__new__(YamlFile) + scalar_yaml_file.path = yaml_path + scalar_yaml_file.options = {} + scalar_yaml_file.tax_benefit_system = TaxBenefitSystem() + + with pytest.raises(ValueError, match="list of mappings"): + list(scalar_yaml_file.collect()) + + +def test_yaml_runner_rejects_scalar_keywords(tmp_path): + yaml_path = tmp_path / "invalid-keywords.yaml" + yaml_path.write_text( + "name: Example\nkeywords: 0\noutput: {}\n", + encoding="utf-8", + ) + + invalid_yaml_file = object.__new__(YamlFile) + invalid_yaml_file.path = yaml_path + invalid_yaml_file.options = {"name_filter": "missing"} + invalid_yaml_file.tax_benefit_system = TaxBenefitSystem() + + with pytest.raises(ValueError, match="'keywords' must be a list"): + list(invalid_yaml_file.collect()) + + +def test_yaml_runner_allows_yaml_merge_anchors(tmp_path): + yaml_path = tmp_path / "anchors.yaml" + yaml_path.write_text( + """ +- name: define anchor + input: + persons: &persons + Alicia: + salary: 4000 + households: + household: + parents: [Alicia] + output: + salary: 4000 + +- name: merge anchor + input: + persons: + <<: *persons + households: + household: + parents: [Alicia] + output: + salary: 4000 +""".strip(), + encoding="utf-8", + ) + + yaml_file = object.__new__(YamlFile) + yaml_file.config = None + yaml_file.session = None + yaml_file._nodeid = "anchors" + yaml_file.path = yaml_path + yaml_file.options = {} + yaml_file.tax_benefit_system = TaxBenefitSystem() + + collected = list(yaml_file.collect()) + + assert len(collected) == 2 + + def clean_performance_files(paths: List[str]): for path in paths: if os.path.isfile(path):