Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions tests/unit/test_ai_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def _ai206_rule() -> dict:
return next(rule for rule in rules["rule"] if rule["id"] == "AI206")


def _ai_rule(rule_id: str) -> dict:
rules = toml.loads(RULES_PATH.read_text(encoding="utf-8"))
return next(rule for rule in rules["rule"] if rule["id"] == rule_id)


def _ast_node(node: ast.AST) -> dict:
children = {}
fields = {}
Expand Down Expand Up @@ -132,3 +137,41 @@ def test_trust_remote_code_false_safe(self):
)
"""
assert not fires(code, "AI206")


class TestAIModelDeserializationPatterns:
def test_keras_h5_model_load_metadata(self):
rule = _ai_rule("AI203")
assert rule["severity"] == "High"
assert rule["cwe"] == "CWE-502"
assert rule["pattern"] == r"keras\.models\.load_model"

def test_keras_h5_model_load_fires(self):
code = """
model = keras.models.load_model(model_path)
"""
assert fires(code, "AI203")

def test_commented_keras_h5_model_load_safe(self):
code = """
# model = keras.models.load_model(model_path)
"""
assert not fires(code, "AI203")

def test_joblib_model_load_metadata(self):
rule = _ai_rule("AI204")
assert rule["severity"] == "High"
assert rule["cwe"] == "CWE-502"
assert rule["pattern"] == r"joblib\.load"

def test_joblib_model_load_fires(self):
code = """
model = joblib.load(model_path)
"""
assert fires(code, "AI204")

def test_commented_joblib_model_load_safe(self):
code = """
# model = joblib.load(model_path)
"""
assert not fires(code, "AI204")
Loading