diff --git a/tests/unit/test_ai_rules.py b/tests/unit/test_ai_rules.py index 1173f9e..72abe98 100644 --- a/tests/unit/test_ai_rules.py +++ b/tests/unit/test_ai_rules.py @@ -29,6 +29,11 @@ def _ai206_rule() -> dict: return next(rule for rule in rules["rule"] if rule["id"] == "AI206") +def _ai_rule(rule_id: str) -> dict: + rules = toml.loads(RULES_PATH.read_text(encoding="utf-8")) + return next(rule for rule in rules["rule"] if rule["id"] == rule_id) + + def _ast_node(node: ast.AST) -> dict: children = {} fields = {} @@ -132,3 +137,41 @@ def test_trust_remote_code_false_safe(self): ) """ assert not fires(code, "AI206") + + +class TestAIModelDeserializationPatterns: + def test_keras_h5_model_load_metadata(self): + rule = _ai_rule("AI203") + assert rule["severity"] == "High" + assert rule["cwe"] == "CWE-502" + assert rule["pattern"] == r"keras\.models\.load_model" + + def test_keras_h5_model_load_fires(self): + code = """ + model = keras.models.load_model(model_path) + """ + assert fires(code, "AI203") + + def test_commented_keras_h5_model_load_safe(self): + code = """ + # model = keras.models.load_model(model_path) + """ + assert not fires(code, "AI203") + + def test_joblib_model_load_metadata(self): + rule = _ai_rule("AI204") + assert rule["severity"] == "High" + assert rule["cwe"] == "CWE-502" + assert rule["pattern"] == r"joblib\.load" + + def test_joblib_model_load_fires(self): + code = """ + model = joblib.load(model_path) + """ + assert fires(code, "AI204") + + def test_commented_joblib_model_load_safe(self): + code = """ + # model = joblib.load(model_path) + """ + assert not fires(code, "AI204")