diff --git a/dreadnode/airt/__init__.py b/dreadnode/airt/__init__.py index 0c4820be..aacb9373 100644 --- a/dreadnode/airt/__init__.py +++ b/dreadnode/airt/__init__.py @@ -1,4 +1,4 @@ -from dreadnode.airt import attack, search +from dreadnode.airt import attack, compliance, search from dreadnode.airt.attack import ( Attack, goat_attack, @@ -9,20 +9,37 @@ tap_attack, zoo_attack, ) +from dreadnode.airt.compliance import ( + ATTACK_MAPPINGS, + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, + tag_transform, +) from dreadnode.airt.target import CustomTarget, LLMTarget, Target __all__ = [ + "ATTACK_MAPPINGS", + "ATLASTechnique", "Attack", "CustomTarget", "LLMTarget", + "NISTAIRMFFunction", + "OWASPCategory", + "SAIFCategory", "Target", "attack", + "compliance", "goat_attack", "hop_skip_jump_attack", "nes_attack", "prompt_attack", "search", "simba_attack", + "tag_attack", + "tag_transform", "tap_attack", "target", "zoo_attack", diff --git a/dreadnode/airt/attack/base.py b/dreadnode/airt/attack/base.py index ac12dbff..ad9610ad 100644 --- a/dreadnode/airt/attack/base.py +++ b/dreadnode/airt/attack/base.py @@ -26,6 +26,9 @@ class Attack(Study[In, Out]): tags: list[str] = Config(default_factory=lambda: ["attack"]) """A list of tags associated with the attack for logging.""" + compliance_tags: dict[str, t.Any] = Config(default_factory=dict) + """Compliance framework tags (OWASP, ATLAS, SAIF, NIST) for this attack.""" + hooks: list[EvalHook] = Field(default_factory=list, exclude=True, repr=False) """Hooks to run at various points in the attack lifecycle.""" diff --git a/dreadnode/airt/attack/crescendo.py b/dreadnode/airt/attack/crescendo.py index a0d74c57..5dd53f0a 100644 --- a/dreadnode/airt/attack/crescendo.py +++ b/dreadnode/airt/attack/crescendo.py @@ -3,6 +3,13 @@ import yaml from dreadnode.airt.attack import Attack +from dreadnode.airt.compliance import ( + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, +) from dreadnode.airt.target.base import Target from dreadnode.constants import CRESCENDO_VARIANT_1 from dreadnode.data_types.message import Message as DnMessage @@ -14,6 +21,20 @@ from dreadnode.transforms.base import Transform from dreadnode.transforms.refine import adapt_prompt_trials, llm_refine +# Compliance framework tags for Crescendo attack +# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.) +# are added when transforms targeting those categories are used +COMPLIANCE_TAGS = tag_attack( + atlas=[ + ATLASTechnique.PROMPT_INJECTION_DIRECT, + ATLASTechnique.LLM_JAILBREAK, + ], + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", +) + def crescendo_attack( goal: str, @@ -179,6 +200,7 @@ async def crescendo_refiner(trials: list[Trial[DnMessage]]) -> DnMessage: "objective": objective_judge, }, hooks=hooks or [], + compliance_tags=COMPLIANCE_TAGS, ) # Add stop condition based on early_stopping_score diff --git a/dreadnode/airt/attack/goat.py b/dreadnode/airt/attack/goat.py index 33f15738..6c598ea0 100644 --- a/dreadnode/airt/attack/goat.py +++ b/dreadnode/airt/attack/goat.py @@ -1,6 +1,13 @@ import typing as t from dreadnode.airt.attack import Attack +from dreadnode.airt.compliance import ( + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, +) from dreadnode.data_types.message import Message as DnMessage from dreadnode.meta.context import TrialCandidate from dreadnode.optimization.search.graph import graph_neighborhood_search @@ -18,6 +25,21 @@ from dreadnode.optimization.trial import Trial +# Compliance framework tags for GOAT attack +# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.) +# are added when transforms targeting those categories are used +COMPLIANCE_TAGS = tag_attack( + atlas=[ + ATLASTechnique.PROMPT_INJECTION_DIRECT, + ATLASTechnique.LLM_JAILBREAK, + ], + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", +) + + def goat_attack( goal: str, target: "Target[DnMessage, DnMessage]", @@ -121,6 +143,7 @@ async def message_refiner(trials: list["Trial[DnMessage]"]) -> DnMessage: }, constraints=[topic_constraint], hooks=hooks or [], + compliance_tags=COMPLIANCE_TAGS, ) if early_stopping_score is not None: diff --git a/dreadnode/airt/attack/prompt.py b/dreadnode/airt/attack/prompt.py index 28c7f7eb..a33febb9 100644 --- a/dreadnode/airt/attack/prompt.py +++ b/dreadnode/airt/attack/prompt.py @@ -3,6 +3,7 @@ import rigging as rg from dreadnode.airt.attack.base import Attack +from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_attack from dreadnode.data_types.message import Message as DnMessage from dreadnode.meta import TrialCandidate from dreadnode.optimization.search.graph import beam_search @@ -18,6 +19,19 @@ from dreadnode.optimization.trial import Trial +# Compliance framework tags for prompt attack +# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.) +# are added when transforms targeting those categories are used +COMPLIANCE_TAGS = tag_attack( + atlas=[ + ATLASTechnique.PROMPT_INJECTION_DIRECT, + ATLASTechnique.LLM_JAILBREAK, + ], + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, +) + + def prompt_attack( goal: str, target: "Target[DnMessage, DnMessage]", @@ -117,6 +131,7 @@ async def message_refiner(trials: list["Trial[DnMessage]"]) -> DnMessage: "prompt_judge": prompt_judge, }, hooks=hooks or [], + compliance_tags=COMPLIANCE_TAGS, ) if early_stopping_score is not None: diff --git a/dreadnode/airt/attack/tap.py b/dreadnode/airt/attack/tap.py index 9aa98278..a24e9ca4 100644 --- a/dreadnode/airt/attack/tap.py +++ b/dreadnode/airt/attack/tap.py @@ -2,6 +2,13 @@ from dreadnode.airt.attack import Attack from dreadnode.airt.attack.prompt import prompt_attack +from dreadnode.airt.compliance import ( + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, +) from dreadnode.data_types.message import Message as DnMessage from dreadnode.scorers.judge import llm_judge @@ -10,6 +17,21 @@ from dreadnode.eval.hooks.base import EvalHook +# Compliance framework tags for TAP attack +# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.) +# are added when transforms targeting those categories are used +COMPLIANCE_TAGS = tag_attack( + atlas=[ + ATLASTechnique.PROMPT_INJECTION_DIRECT, + ATLASTechnique.LLM_JAILBREAK, + ], + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", +) + + def tap_attack( goal: str, target: "Target[DnMessage, DnMessage]", @@ -45,7 +67,7 @@ def tap_attack( topic_constraint = llm_judge(evaluator_model, ON_TOPIC_RUBRIC.format(goal=goal)) - return prompt_attack( + base_attack = prompt_attack( goal, target, attacker_model, @@ -58,7 +80,13 @@ def tap_attack( branching_factor=branching_factor, context_depth=context_depth, hooks=hooks or [], - ).with_(constraints={"on_topic": topic_constraint}) + ) + + # Set compliance tags before cloning + base_attack.compliance_tags = COMPLIANCE_TAGS + + # Add constraint and return + return base_attack.with_(constraints={"on_topic": topic_constraint}) REFINE_GUIDANCE = """\ diff --git a/dreadnode/airt/compliance/__init__.py b/dreadnode/airt/compliance/__init__.py new file mode 100644 index 00000000..bc2b3f7f --- /dev/null +++ b/dreadnode/airt/compliance/__init__.py @@ -0,0 +1,215 @@ +""" +Compliance framework tagging for AI red teaming. + +Maps attacks, transforms, and security tests to industry-standard frameworks: +- MITRE ATLAS: AI/ML attack taxonomy +- OWASP Top 10 for LLM Applications: Security vulnerabilities +- Google SAIF: Secure AI Framework categories +- NIST AI RMF: Risk management functions + +Example: + ```python + import dreadnode as dn + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, tag_attack + + # Tag an attack + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION_DIRECT, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + ) + + # Tags appear in run metadata + with dn.run("jailbreak-test", **tags): + result = await my_attack.run() + ``` +""" + +import typing as t + +from dreadnode.airt.compliance.atlas import ATLASTechnique +from dreadnode.airt.compliance.nist import NIST_SUBCATEGORIES, NISTAIRMFFunction +from dreadnode.airt.compliance.owasp import OWASPCategory +from dreadnode.airt.compliance.saif import SAIFCategory + + +def tag_attack( + *, + atlas: ATLASTechnique | list[ATLASTechnique] | None = None, + owasp: OWASPCategory | list[OWASPCategory] | None = None, + saif: SAIFCategory | list[SAIFCategory] | None = None, + nist_function: NISTAIRMFFunction | None = None, + nist_subcategory: str | None = None, +) -> dict[str, t.Any]: + """ + Tag an attack with compliance framework mappings. + + Returns a dictionary suitable for run metadata or span attributes. + All parameters are optional - provide only relevant frameworks. + + Args: + atlas: MITRE ATLAS technique ID(s) + owasp: OWASP LLM Application category/categories + saif: Google SAIF security category/categories + nist_function: NIST AI RMF core function + nist_subcategory: NIST AI RMF subcategory code (e.g., "MS-2.7") + + Returns: + Dictionary with framework tags suitable for run metadata + + Example: + ```python + # Single framework + tags = tag_attack(owasp=OWASPCategory.LLM01_PROMPT_INJECTION) + + # Multiple frameworks + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION_DIRECT, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ) + + # Multiple categories from same framework + tags = tag_attack( + owasp=[ + OWASPCategory.LLM01_PROMPT_INJECTION, + OWASPCategory.LLM06_EXCESSIVE_AGENCY, + ] + ) + + # Use in run context + with dn.run("my-attack", **tags): + result = await attack.run() + ``` + """ + tags: dict[str, t.Any] = {} + + if atlas is not None: + atlas_list = [atlas] if isinstance(atlas, (str, ATLASTechnique)) else atlas + tags["atlas_techniques"] = [str(t) for t in atlas_list] + + if owasp is not None: + owasp_list = [owasp] if isinstance(owasp, (str, OWASPCategory)) else owasp + tags["owasp_categories"] = [str(c) for c in owasp_list] + + if saif is not None: + saif_list = [saif] if isinstance(saif, (str, SAIFCategory)) else saif + tags["saif_categories"] = [str(c) for c in saif_list] + + if nist_function is not None: + tags["nist_ai_rmf_function"] = str(nist_function) + if nist_subcategory: + tags["nist_ai_rmf_subcategory"] = nist_subcategory + + return tags + + +def tag_transform( + *, + atlas: ATLASTechnique | list[ATLASTechnique] | None = None, + owasp: OWASPCategory | list[OWASPCategory] | None = None, + saif: SAIFCategory | list[SAIFCategory] | None = None, +) -> dict[str, t.Any]: + """ + Tag a transform with compliance framework mappings. + + Similar to tag_attack() but for transforms. Transforms typically don't + map to NIST RMF functions (which are organizational processes). + + Args: + atlas: MITRE ATLAS technique ID(s) + owasp: OWASP LLM Application category/categories + saif: Google SAIF security category/categories + + Returns: + Dictionary with framework tags + + Example: + ```python + from dreadnode.transforms.pii_extraction import repeat_word_divergence + + # Tags are stored in transform metadata + transform = repeat_word_divergence() + transform.compliance_tags = tag_transform( + atlas=ATLASTechnique.INFER_TRAINING_DATA, + owasp=OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE, + saif=SAIFCategory.PRIVACY_LEAKAGE, + ) + ``` + """ + return tag_attack(atlas=atlas, owasp=owasp, saif=saif) + + +# Pre-defined mappings for common attack patterns +ATTACK_MAPPINGS = { + "jailbreak": tag_attack( + atlas=ATLASTechnique.LLM_JAILBREAK, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ), + "prompt_injection_direct": tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION_DIRECT, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ), + "prompt_injection_indirect": tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION_INDIRECT, + owasp=[OWASPCategory.LLM01_PROMPT_INJECTION, OWASPCategory.LLM03_SUPPLY_CHAIN], + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ), + "tool_misuse": tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM06_EXCESSIVE_AGENCY, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + ), + "pii_extraction": tag_attack( + atlas=[ATLASTechnique.MODEL_INVERSION, ATLASTechnique.MEMBERSHIP_INFERENCE], + owasp=OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE, + saif=SAIFCategory.PRIVACY_LEAKAGE, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.8", + ), + "system_prompt_leakage": tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM07_SYSTEM_PROMPT_LEAKAGE, + saif=SAIFCategory.PRIVACY_LEAKAGE, + nist_function=NISTAIRMFFunction.MEASURE, + ), + "model_extraction": tag_attack( + atlas=ATLASTechnique.MODEL_EXTRACTION, + saif=SAIFCategory.MODEL_THEFT, + nist_function=NISTAIRMFFunction.MEASURE, + ), + "denial_of_service": tag_attack( + atlas=ATLASTechnique.DENIAL_OF_ML_SERVICE, + owasp=OWASPCategory.LLM10_UNBOUNDED_CONSUMPTION, + saif=SAIFCategory.AVAILABILITY_ATTACKS, + nist_function=NISTAIRMFFunction.MEASURE, + ), + "data_poisoning": tag_attack( + atlas=ATLASTechnique.POISON_TRAINING_DATA, + owasp=OWASPCategory.LLM04_DATA_MODEL_POISONING, + saif=SAIFCategory.DATA_POISONING, + nist_function=NISTAIRMFFunction.MEASURE, + ), +} + + +__all__ = [ + "ATTACK_MAPPINGS", + "NIST_SUBCATEGORIES", + "ATLASTechnique", + "NISTAIRMFFunction", + "OWASPCategory", + "SAIFCategory", + "tag_attack", + "tag_transform", +] diff --git a/dreadnode/airt/compliance/atlas.py b/dreadnode/airt/compliance/atlas.py new file mode 100644 index 00000000..8c782bf0 --- /dev/null +++ b/dreadnode/airt/compliance/atlas.py @@ -0,0 +1,132 @@ +""" +MITRE ATLAS (Adversarial Threat Landscape for Artificial-Intelligence Systems). + +AI/ML-specific attack techniques organized by tactics similar to MITRE ATT&CK. + +Reference: https://atlas.mitre.org/ +""" + +from enum import Enum + + +class ATLASTechnique(str, Enum): + """ + MITRE ATLAS AI/ML Attack Technique IDs. + + Techniques are organized by tactics: Reconnaissance, Resource Development, + Initial Access, Execution, Persistence, Privilege Escalation, Defense Evasion, + Discovery, Collection, ML Attack Staging, Exfiltration, and Impact. + + Reference: https://atlas.mitre.org/techniques/ + """ + + # === Reconnaissance === + DISCOVER_ML_ARTIFACTS = "AML.T0002" + """Discover ML artifacts and metadata in public or target systems.""" + + DISCOVER_MODEL_ONTOLOGY = "AML.T0003" + """Discover model architecture, capabilities, and ontology.""" + + # === Resource Development === + ACQUIRE_PUBLIC_ML_ARTIFACTS = "AML.T0000" + """Acquire public ML artifacts like pre-trained models or datasets.""" + + DEVELOP_ADVERSARIAL_ML_MODEL = "AML.T0001" + """Develop models designed to evade or attack target systems.""" + + POISON_TRAINING_DATA = "AML.T0020" + """Inject malicious data into training sets to corrupt model behavior.""" + + # === Initial Access === + PROMPT_INJECTION = "AML.T0051" + """Manipulate LLM inputs to override instructions or execute unintended actions.""" + + PROMPT_INJECTION_DIRECT = "AML.T0051.000" + """Direct prompt injection via user-controlled input.""" + + PROMPT_INJECTION_INDIRECT = "AML.T0051.001" + """Indirect prompt injection via external data sources (emails, documents, web).""" + + SUPPLY_CHAIN_COMPROMISE = "AML.T0010" + """Compromise ML supply chain through malicious models, datasets, or dependencies.""" + + # === Execution === + UNSAFE_ML_ARTIFACT = "AML.T0018" + """Execute unsafe ML artifacts like poisoned models or malicious code.""" + + # === Persistence === + BACKDOOR_ML_MODEL = "AML.T0019" + """Embed backdoors in ML models that activate on specific triggers.""" + + # === Privilege Escalation === + # (Uses techniques from other tactics) + + # === Defense Evasion === + EVADE_ML_MODEL = "AML.T0043" + """Craft inputs that evade detection or classification by ML models.""" + + ADVERSARIAL_PERTURBATION = "AML.T0043.001" + """Add imperceptible perturbations to inputs to cause misclassification.""" + + TRANSFER_ATTACK = "AML.T0043.002" + """Transfer adversarial examples from surrogate to target model.""" + + OBFUSCATE_ARTIFACTS = "AML.T0044" + """Obfuscate malicious content to evade ML-based detection.""" + + # === Credential Access === + # (Uses techniques from traditional ATT&CK) + + # === Discovery === + DISCOVER_TRAINING_DATA = "AML.T0052" + """Infer characteristics or contents of training data.""" + + DISCOVER_MODEL_FAMILY = "AML.T0053" + """Determine model architecture family (transformer, CNN, etc.).""" + + # === Lateral Movement === + # (Uses techniques from traditional ATT&CK) + + # === Collection === + INFER_TRAINING_DATA = "AML.T0024" + """Extract or infer training data through model inversion or membership inference.""" + + MODEL_INVERSION = "AML.T0024.000" + """Reconstruct training data from model outputs (e.g., recover faces, text).""" + + MEMBERSHIP_INFERENCE = "AML.T0024.001" + """Determine if specific data was in the training set.""" + + # === ML Attack Staging === + CRAFT_ADVERSARIAL_DATA = "AML.T0049" + """Generate adversarial examples optimized to fool target models.""" + + VERIFY_ATTACK = "AML.T0042" + """Test adversarial inputs against surrogate or target models.""" + + # === Command and Control === + # (Uses techniques from traditional ATT&CK) + + # === Exfiltration === + EXFILTRATION_VIA_ML_INFERENCE = "AML.T0026" + """Extract sensitive data through repeated model queries and inference.""" + + MODEL_EXTRACTION = "AML.T0040" + """Steal model functionality by querying and replicating behavior.""" + + # === Impact === + ERODE_ML_MODEL_INTEGRITY = "AML.T0048" + """Degrade model accuracy, fairness, or reliability.""" + + LLM_JAILBREAK = "AML.T0054" + """Bypass LLM safety mechanisms to generate prohibited content.""" + + DENIAL_OF_ML_SERVICE = "AML.T0029" + """Exhaust model resources through adversarial queries or sponge examples.""" + + def __str__(self) -> str: + """Return the technique ID.""" + return self.value + + +__all__ = ["ATLASTechnique"] diff --git a/dreadnode/airt/compliance/nist.py b/dreadnode/airt/compliance/nist.py new file mode 100644 index 00000000..8e21731a --- /dev/null +++ b/dreadnode/airt/compliance/nist.py @@ -0,0 +1,63 @@ +""" +NIST AI Risk Management Framework (AI RMF). + +Risk management functions and categories for AI systems. + +Reference: https://www.nist.gov/itl/ai-risk-management-framework +""" + +from enum import Enum + + +class NISTAIRMFFunction(str, Enum): + """ + NIST AI Risk Management Framework Core Functions. + + The AI RMF organizes risk management activities into four core functions + that work together to manage AI risks throughout the system lifecycle. + + Reference: https://www.nist.gov/itl/ai-risk-management-framework + """ + + GOVERN = "GOVERN" + """ + Govern: Cultivate and manage organizational culture, processes, and structures + for responsible AI development and deployment. Includes policies, accountability, + and risk governance. + """ + + MAP = "MAP" + """ + Map: Establish context and understand risks. Includes categorizing AI systems, + identifying stakeholders, and mapping potential risks and impacts. + """ + + MEASURE = "MEASURE" + """ + Measure: Analyze, assess, benchmark, and monitor AI risks and impacts. + Includes testing, evaluation, auditing, and continuous monitoring. + """ + + MANAGE = "MANAGE" + """ + Manage: Allocate resources to prioritize and respond to AI risks. Includes + risk mitigation, treatment, incident response, and continuous improvement. + """ + + def __str__(self) -> str: + """Return the function name.""" + return self.value + + +# Common NIST AI RMF subcategories for reference +NIST_SUBCATEGORIES = { + "MS-2.7": "AI system reliability and robustness under adversarial conditions", + "MS-2.8": "Privacy risks from AI systems", + "MS-2.9": "Security vulnerabilities in AI systems", + "MG-3.1": "AI risks are prioritized and treated", + "MG-3.2": "Adverse events are documented and monitored", + "GV-1.1": "Legal and regulatory requirements are understood and documented", +} + + +__all__ = ["NIST_SUBCATEGORIES", "NISTAIRMFFunction"] diff --git a/dreadnode/airt/compliance/owasp.py b/dreadnode/airt/compliance/owasp.py new file mode 100644 index 00000000..b4088625 --- /dev/null +++ b/dreadnode/airt/compliance/owasp.py @@ -0,0 +1,86 @@ +""" +OWASP Top 10 for LLM Applications 2025. + +Reference: https://genai.owasp.org/llm-top-10/ +""" + +from enum import Enum + + +class OWASPCategory(str, Enum): + """ + OWASP Top 10 for LLM Applications 2025. + + Each category represents a critical security vulnerability class specific + to Large Language Model applications. + + Reference: https://genai.owasp.org/llm-top-10/ + """ + + LLM01_PROMPT_INJECTION = "LLM01:2025" + """ + Prompt Injection: Manipulating LLM inputs to override system instructions, + execute unintended actions, or access unauthorized data. Includes both direct + (user input) and indirect (external data sources) injection vectors. + """ + + LLM02_SENSITIVE_INFORMATION_DISCLOSURE = "LLM02:2025" + """ + Sensitive Information Disclosure: Exposing confidential data through LLM outputs, + including PII, credentials, proprietary information, or training data memorization. + """ + + LLM03_SUPPLY_CHAIN = "LLM03:2025" + """ + Supply Chain Vulnerabilities: Risks from third-party models, datasets, plugins, + or dependencies that may be compromised, outdated, or malicious. + """ + + LLM04_DATA_MODEL_POISONING = "LLM04:2025" + """ + Data and Model Poisoning: Manipulation of training data or fine-tuning processes + to inject backdoors, biases, or vulnerabilities into the model. + """ + + LLM05_IMPROPER_OUTPUT_HANDLING = "LLM05:2025" + """ + Improper Output Handling: Insufficient validation of LLM outputs before downstream + use, leading to injection attacks (XSS, SQL injection) or code execution. + """ + + LLM06_EXCESSIVE_AGENCY = "LLM06:2025" + """ + Excessive Agency: LLM systems with too much autonomy or permissions, enabling + unintended actions, privilege escalation, or unauthorized system modifications. + """ + + LLM07_SYSTEM_PROMPT_LEAKAGE = "LLM07:2025" + """ + System Prompt Leakage: Disclosure of system prompts, instructions, or configuration + details that reveal security mechanisms or enable targeted attacks. + """ + + LLM08_VECTOR_EMBEDDING_WEAKNESSES = "LLM08:2025" + """ + Vector and Embedding Weaknesses: Vulnerabilities in RAG systems, vector databases, + or embedding models that enable data poisoning or unauthorized access. + """ + + LLM09_MISINFORMATION = "LLM09:2025" + """ + Misinformation: Generation of false, misleading, or fabricated information + (hallucinations) that appears credible but lacks factual grounding. + """ + + LLM10_UNBOUNDED_CONSUMPTION = "LLM10:2025" + """ + Unbounded Consumption: Resource exhaustion through excessive LLM requests, + context window abuse, or denial-of-service attacks targeting inference costs. + """ + + def __str__(self) -> str: + """Return the category ID.""" + return self.value + + +__all__ = ["OWASPCategory"] diff --git a/dreadnode/airt/compliance/saif.py b/dreadnode/airt/compliance/saif.py new file mode 100644 index 00000000..c2ef18d5 --- /dev/null +++ b/dreadnode/airt/compliance/saif.py @@ -0,0 +1,69 @@ +""" +Google SAIF (Secure AI Framework). + +Security categories for AI/ML systems aligned with Google's security principles. + +Reference: https://blog.google/technology/safety-security/google-secure-ai-framework/ +""" + +from enum import Enum + + +class SAIFCategory(str, Enum): + """ + Google SAIF (Secure AI Framework) Security Categories. + + Organizes AI security risks into actionable categories aligned with + traditional security controls and threat modeling. + + Reference: https://blog.google/technology/safety-security/google-secure-ai-framework/ + """ + + INPUT_MANIPULATION = "INPUT_MANIPULATION" + """ + Input Manipulation: Adversarial inputs designed to manipulate model behavior, + including prompt injection, adversarial examples, and input perturbations. + """ + + OUTPUT_MANIPULATION = "OUTPUT_MANIPULATION" + """ + Output Manipulation: Attacks targeting model outputs, including response + poisoning, hallucination exploitation, and output handling vulnerabilities. + """ + + MODEL_THEFT = "MODEL_THEFT" + """ + Model Theft: Stealing model functionality or intellectual property through + model extraction, knowledge distillation, or architecture inference. + """ + + DATA_POISONING = "DATA_POISONING" + """ + Data Poisoning: Corruption of training data to inject backdoors, biases, + or vulnerabilities into the model during training or fine-tuning. + """ + + SUPPLY_CHAIN_COMPROMISE = "SUPPLY_CHAIN_COMPROMISE" + """ + Supply Chain Compromise: Attacks targeting the ML supply chain including + malicious dependencies, poisoned datasets, or compromised pre-trained models. + """ + + PRIVACY_LEAKAGE = "PRIVACY_LEAKAGE" + """ + Privacy Leakage: Disclosure of sensitive information through model outputs, + including PII extraction, training data memorization, and membership inference. + """ + + AVAILABILITY_ATTACKS = "AVAILABILITY_ATTACKS" + """ + Availability Attacks: Denial of service, resource exhaustion, or system + degradation through adversarial queries or sponge examples. + """ + + def __str__(self) -> str: + """Return the category name.""" + return self.value + + +__all__ = ["SAIFCategory"] diff --git a/dreadnode/transforms/base.py b/dreadnode/transforms/base.py index c1e8619c..e41dae74 100644 --- a/dreadnode/transforms/base.py +++ b/dreadnode/transforms/base.py @@ -43,6 +43,7 @@ def __init__( catch: bool = False, config: dict[str, ConfigInfo] | None = None, context: dict[str, Context] | None = None, + compliance_tags: dict[str, t.Any] | None = None, ): super().__init__( t.cast("t.Callable[[In], Out]", func), name=name, config=config, context=context @@ -55,6 +56,8 @@ def __init__( If True, catches exceptions during the transform and attempts to return the original, unmodified object from the input. If False, exceptions are raised. """ + self.compliance_tags = compliance_tags or {} + """Compliance framework tags (OWASP, ATLAS, SAIF, NIST) for this transform.""" @classmethod def fit(cls, transform: "TransformLike[In, Out]") -> "Transform[In, Out]": diff --git a/dreadnode/transforms/cipher.py b/dreadnode/transforms/cipher.py index d56a45ca..4e2bd365 100644 --- a/dreadnode/transforms/cipher.py +++ b/dreadnode/transforms/cipher.py @@ -1,4 +1,5 @@ import codecs +import functools import random import string import typing as t @@ -7,6 +8,18 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_obfuscation_tags() -> dict[str, t.Any]: + """Get compliance tags for obfuscation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def atbash_cipher(*, name: str = "atbash") -> Transform[str, str]: """Encodes text using the Atbash cipher.""" @@ -19,7 +32,7 @@ def transform(text: str) -> str: translation_table = str.maketrans("".join(alphabet), "".join(reversed_alphabet)) return text.translate(translation_table) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def caesar_cipher(offset: int, *, name: str = "caesar") -> Transform[str, str]: @@ -39,7 +52,7 @@ def shift(alphabet: str) -> str: translation_table = str.maketrans("".join(alphabet), "".join(shifted_alphabet)) return text.translate(translation_table) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def rot13_cipher(*, name: str = "rot13") -> Transform[str, str]: @@ -48,7 +61,7 @@ def rot13_cipher(*, name: str = "rot13") -> Transform[str, str]: def transform(text: str) -> str: return codecs.encode(text, "rot13") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def rot47_cipher(*, name: str = "rot47") -> Transform[str, str]: @@ -67,7 +80,7 @@ def transform(text: str) -> str: transformed.append(char) return "".join(transformed) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def vigenere_cipher( @@ -115,7 +128,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def substitution_cipher( @@ -155,7 +168,7 @@ def transform(text: str) -> str: ) return text.translate(translation_table) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def xor_cipher( @@ -201,7 +214,7 @@ def transform( # raw return xored.decode("latin-1") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def rail_fence_cipher( @@ -246,7 +259,7 @@ def transform( # Read off the rails return "".join("".join(rail) for rail in fence) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def columnar_transposition( @@ -297,7 +310,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def playfair_cipher( @@ -379,7 +392,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def affine_cipher( @@ -428,7 +441,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def bacon_cipher( @@ -522,7 +535,7 @@ def transform( return " ".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def autokey_cipher( @@ -571,7 +584,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def beaufort_cipher( @@ -622,4 +635,4 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) diff --git a/dreadnode/transforms/constitutional.py b/dreadnode/transforms/constitutional.py index ac92525b..2e250d46 100644 --- a/dreadnode/transforms/constitutional.py +++ b/dreadnode/transforms/constitutional.py @@ -22,6 +22,7 @@ - Hybrid: Static fallback to LLM for unknown terms """ +import functools import json import random import re @@ -34,6 +35,18 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_constitutional_evasion_tags() -> dict[str, t.Any]: + """Get compliance tags for constitutional_evasion transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=[ATLASTechnique.EVADE_ML_MODEL, ATLASTechnique.OBFUSCATE_ARTIFACTS], + owasp=[OWASPCategory.LLM01_PROMPT_INJECTION, OWASPCategory.LLM05_IMPROPER_OUTPUT_HANDLING], + saif=[SAIFCategory.INPUT_MANIPULATION, SAIFCategory.OUTPUT_MANIPULATION], + ) + + @dataclass class MetaphorMapping: """Mapping of technical terms to metaphorical language.""" @@ -431,7 +444,7 @@ def transform(text: str) -> str: return "\n".join(code_lines) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) def _format_as_markdown(fragments: list[str], labels: list[str], *, add_metadata: bool) -> str: @@ -579,7 +592,7 @@ def transform(text: str) -> str: return text - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) def multi_turn_fragmentation( @@ -627,7 +640,7 @@ def transform(text: str) -> list[str]: return turns - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) # ============================================================================ @@ -777,7 +790,7 @@ async def transform(text: str) -> str: return encoded - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) def character_separation( @@ -819,7 +832,7 @@ def transform(text: str) -> str: return separated - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) # Static riddle templates for riddle_encoding @@ -897,7 +910,7 @@ async def transform(text: str) -> str: # Generic riddle return f"Solve this puzzle: {text}\n(Express your answer in metaphorical terms)" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) # Default chemistry substitution mappings for contextual_substitution @@ -972,4 +985,4 @@ def transform(text: str) -> str: return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) diff --git a/dreadnode/transforms/encoding.py b/dreadnode/transforms/encoding.py index 06699b3a..93d9894b 100644 --- a/dreadnode/transforms/encoding.py +++ b/dreadnode/transforms/encoding.py @@ -1,4 +1,5 @@ import base64 +import functools import html import json import random @@ -9,13 +10,25 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_obfuscation_tags() -> dict[str, t.Any]: + """Get compliance tags for obfuscation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def ascii85_encode(*, name: str = "ascii85") -> Transform[str, str]: """Encodes text to ASCII85.""" def transform(text: str) -> str: return base64.a85encode(text.encode("utf-8")).decode("ascii") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def base32_encode(*, name: str = "base32") -> Transform[str, str]: @@ -24,7 +37,7 @@ def base32_encode(*, name: str = "base32") -> Transform[str, str]: def transform(text: str) -> str: return base64.b32encode(text.encode("utf-8")).decode("ascii") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def base64_encode(*, name: str = "base64") -> Transform[str, str]: @@ -33,7 +46,7 @@ def base64_encode(*, name: str = "base64") -> Transform[str, str]: def transform(text: str) -> str: return base64.b64encode(text.encode("utf-8")).decode("utf-8") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def binary_encode(bits_per_char: int = 16, *, name: str = "binary") -> Transform[str, str]: @@ -52,7 +65,7 @@ def transform( ) return " ".join(format(ord(char), f"0{bits_per_char}b") for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def hex_encode(*, name: str = "hex") -> Transform[str, str]: @@ -61,7 +74,7 @@ def hex_encode(*, name: str = "hex") -> Transform[str, str]: def transform(text: str) -> str: return text.encode("utf-8").hex().upper() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def html_escape(*, name: str = "html_escape") -> Transform[str, str]: @@ -70,7 +83,7 @@ def html_escape(*, name: str = "html_escape") -> Transform[str, str]: def transform(text: str) -> str: return html.escape(text, quote=True) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def url_encode(*, name: str = "url_encode") -> Transform[str, str]: @@ -79,7 +92,7 @@ def url_encode(*, name: str = "url_encode") -> Transform[str, str]: def transform(text: str) -> str: return urllib.parse.quote(text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def unicode_escape( @@ -122,7 +135,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def json_encode( @@ -148,7 +161,7 @@ def transform( ) -> str: return json.dumps(text, ensure_ascii=ensure_ascii) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def punycode_encode(*, name: str = "punycode") -> Transform[str, str]: @@ -161,7 +174,7 @@ def punycode_encode(*, name: str = "punycode") -> Transform[str, str]: def transform(text: str) -> str: return text.encode("punycode").decode("ascii") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def quoted_printable_encode(*, name: str = "quoted_printable") -> Transform[str, str]: @@ -175,7 +188,7 @@ def quoted_printable_encode(*, name: str = "quoted_printable") -> Transform[str, def transform(text: str) -> str: return quopri.encodestring(text.encode("utf-8")).decode("ascii") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def base58_encode(*, name: str = "base58") -> Transform[str, str]: @@ -210,7 +223,7 @@ def transform(text: str) -> str: return "".join(reversed(result)) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def percent_encoding( @@ -241,7 +254,7 @@ def transform( encoded = urllib.parse.quote(encoded, safe="") return encoded - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def html_entity_encode( @@ -285,7 +298,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def octal_encode(*, name: str = "octal") -> Transform[str, str]: @@ -298,7 +311,7 @@ def octal_encode(*, name: str = "octal") -> Transform[str, str]: def transform(text: str) -> str: return "".join(f"\\{ord(char):03o}" for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def utf7_encode(*, name: str = "utf7") -> Transform[str, str]: @@ -323,7 +336,7 @@ def transform(text: str) -> str: result.append(f"+{base64.b64encode(bytes([byte])).decode('ascii').rstrip('=')}-") return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def base91_encode(*, name: str = "base91") -> Transform[str, str]: @@ -366,7 +379,7 @@ def transform(text: str) -> str: return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def mixed_case_hex(*, name: str = "mixed_case_hex") -> Transform[str, str]: @@ -386,7 +399,7 @@ def transform(text: str) -> str: result.append(mixed) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def backslash_escape( @@ -417,7 +430,7 @@ def transform( result.append(char) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def zero_width_encode( @@ -471,7 +484,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def leetspeak_encoding( @@ -526,7 +539,7 @@ def transform( result.append(char) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def morse_encode( @@ -595,4 +608,4 @@ def transform( return " ".join(morse_chars) return "".join(morse_chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) diff --git a/dreadnode/transforms/language.py b/dreadnode/transforms/language.py index 1151892f..92c16b4e 100644 --- a/dreadnode/transforms/language.py +++ b/dreadnode/transforms/language.py @@ -1,3 +1,4 @@ +import functools import typing as t import rigging as rg @@ -7,6 +8,18 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_style_manipulation_tags() -> dict[str, t.Any]: + """Get compliance tags for style_manipulation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def adapt_language( target_language: str, *, @@ -141,7 +154,7 @@ async def transform( return adapted_text - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) def transliterate( @@ -409,7 +422,7 @@ def transform( result.append(char) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) def code_switch( @@ -515,7 +528,7 @@ async def transform( return result_text.strip() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) def dialectal_variation( @@ -616,4 +629,4 @@ async def transform( return result_text.strip() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) diff --git a/dreadnode/transforms/perturbation.py b/dreadnode/transforms/perturbation.py index 9ab17f9b..02df8d2c 100644 --- a/dreadnode/transforms/perturbation.py +++ b/dreadnode/transforms/perturbation.py @@ -1,3 +1,4 @@ +import functools import random import re import string @@ -10,6 +11,18 @@ from dreadnode.util import catch_import_error +@functools.lru_cache(maxsize=1) +def _get_perturbation_tags() -> dict[str, t.Any]: + """Get compliance tags for perturbation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.ADVERSARIAL_PERTURBATION, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def random_capitalization( *, ratio: float = 0.2, @@ -45,7 +58,7 @@ def transform( chars[i] = chars[i].upper() return "".join(chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def insert_punctuation( @@ -95,7 +108,7 @@ def transform( words[i] = words[i] + punc return " ".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def diacritic( @@ -133,7 +146,9 @@ def transform( for char in text ) - return Transform(transform, name=name or f"diacritic_{accent}") + return Transform( + transform, name=name or f"diacritic_{accent}", compliance_tags=_get_perturbation_tags() + ) def underline(*, name: str = "underline") -> Transform[str, str]: @@ -142,7 +157,7 @@ def underline(*, name: str = "underline") -> Transform[str, str]: def transform(text: str) -> str: return "".join(char + "\u0332" for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def character_space(*, name: str = "character_space") -> Transform[str, str]: @@ -153,7 +168,7 @@ def transform(text: str) -> str: text_no_punc = text.translate(punctuation_to_remove) return " ".join(text_no_punc) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def zero_width(*, name: str = "zero_width") -> Transform[str, str]: @@ -162,7 +177,7 @@ def zero_width(*, name: str = "zero_width") -> Transform[str, str]: def transform(text: str) -> str: return "\u200b".join(text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def zalgo( @@ -214,7 +229,9 @@ def transform( return "".join(chars) - return Transform(transform, name=name or f"zalgo_{intensity}") + return Transform( + transform, name=name or f"zalgo_{intensity}", compliance_tags=_get_perturbation_tags() + ) def unicode_confusable( @@ -269,7 +286,7 @@ def transform( chars[i] = rand.choice(replacement_options) return "".join(chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def unicode_replacement( @@ -289,7 +306,7 @@ def transform(text: str) -> str: result = result.replace("\\u0020", " ") return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def unicode_substitution( @@ -306,7 +323,7 @@ def unicode_substitution( def transform(text: str) -> str: return "".join(chr(start_value + ord(ch)) for ch in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def repeat_token( @@ -355,7 +372,7 @@ def transform( return f"{text}{repeated}" return repeated - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def emoji_substitution( @@ -469,7 +486,7 @@ def transform( return "".join(chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def token_smuggling( @@ -507,7 +524,7 @@ def transform( # both return f"{smuggled_text}{separator}{text}{separator}{smuggled_text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def semantic_preserving_perturbation( @@ -569,7 +586,7 @@ def transform( result = "".join(words) return re.sub(r'\s([?.!,\'"])', r"\1", result).strip() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def instruction_hierarchy_confusion( @@ -608,7 +625,7 @@ def transform( hidden = f"\n\n" return f"{text}{hidden}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def context_overflow( @@ -638,7 +655,7 @@ def transform( overflow = filler_text * repetitions return f"{overflow}\n\n{text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def adversarial_suffix( @@ -664,7 +681,7 @@ def transform( ) -> str: return f"{text} {suffix}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def gradient_based_perturbation( @@ -721,7 +738,7 @@ def transform( return "".join(chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def multilingual_mixing( @@ -799,7 +816,7 @@ def transform( return "".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def cognitive_hacking( @@ -844,7 +861,7 @@ def transform( template = frames[framing] return template.format(text=text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def payload_splitting( @@ -895,7 +912,7 @@ def transform( return delimiter.join(parts) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def attention_diversion( @@ -930,7 +947,7 @@ def transform( # sandwich return f"{distractor_text}\n\n{text}\n\n{distractor_text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def style_injection( @@ -1008,7 +1025,7 @@ def transform( return " ".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def implicit_continuation( @@ -1034,7 +1051,7 @@ def transform( ) -> str: return f"{continuation_prompt} {text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def authority_exploitation( @@ -1070,7 +1087,7 @@ def transform( template = authority_frames[authority_source] return template.format(text=text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def linguistic_camouflage( @@ -1139,7 +1156,7 @@ def transform( return "".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def temporal_misdirection( @@ -1174,7 +1191,7 @@ def transform( template = temporal_frames[tense] return template.format(text=text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def complexity_amplification( @@ -1217,7 +1234,7 @@ def transform( return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def error_injection( @@ -1289,7 +1306,7 @@ def transform( return "".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def encoding_nesting( @@ -1330,7 +1347,7 @@ def transform(text: str) -> str: result = apply_encoding(result, encoding) return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def token_boundary_manipulation( @@ -1393,7 +1410,7 @@ def transform( return " ".join(w for w in words if w) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def meta_instruction_injection( @@ -1431,7 +1448,7 @@ def transform( encoded = base64.b64encode(meta_instruction.encode()).decode() return f"[Encoded]: {encoded}\n\n{text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def sentiment_inversion( @@ -1489,4 +1506,4 @@ def transform( return "".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) diff --git a/dreadnode/transforms/pii_extraction.py b/dreadnode/transforms/pii_extraction.py index 0e8551c7..6e2c97bd 100644 --- a/dreadnode/transforms/pii_extraction.py +++ b/dreadnode/transforms/pii_extraction.py @@ -1,8 +1,21 @@ +import functools import typing as t from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_pii_extraction_tags() -> dict[str, t.Any]: + """Get compliance tags for PII extraction transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=[ATLASTechnique.INFER_TRAINING_DATA, ATLASTechnique.MODEL_INVERSION], + owasp=OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE, + saif=SAIFCategory.PRIVACY_LEAKAGE, + ) + + def repeat_word_divergence( word: str = "company", *, @@ -38,7 +51,7 @@ def transform( repeated = " ".join([word] * repetitions) return f"{text} {repeated}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) def continue_exact_text( @@ -75,7 +88,7 @@ def transform( return f"{instruction}\n\n{prefix}\n\nContext: {text}" return f"{instruction}\n\n{prefix}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) def complete_from_internet( @@ -113,7 +126,7 @@ def transform( return f"{context}\n\n{partial_text}\n\nAdditional context: {text}" return f"{context}\n\n{partial_text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) def partial_pii_completion( @@ -170,7 +183,7 @@ def transform( return f"{prompt}\n\nContext: {text}" return prompt - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) def public_figure_pii_probe( @@ -224,4 +237,4 @@ def transform( return f"{prompt}\n\nAdditional context: {text}" return prompt - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) diff --git a/dreadnode/transforms/refine.py b/dreadnode/transforms/refine.py index 822affe8..d17c6103 100644 --- a/dreadnode/transforms/refine.py +++ b/dreadnode/transforms/refine.py @@ -1,3 +1,4 @@ +import functools import typing as t from collections import defaultdict from textwrap import dedent, indent @@ -9,6 +10,19 @@ from dreadnode.meta import Config from dreadnode.transforms.base import Transform + +@functools.lru_cache(maxsize=1) +def _get_refinement_tags() -> dict[str, t.Any]: + """Get compliance tags for refinement transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.CRAFT_ADVERSARIAL_DATA, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + if t.TYPE_CHECKING: from ulid import ULID @@ -73,7 +87,7 @@ async def transform( refinement = await refine.bind(generator)(refiner_input) return refinement.prompt - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_refinement_tags()) def adapt_prompt_trials(trials: "list[Trial[DnMessage]]") -> str: diff --git a/dreadnode/transforms/stylistic.py b/dreadnode/transforms/stylistic.py index 79608352..b7c7993e 100644 --- a/dreadnode/transforms/stylistic.py +++ b/dreadnode/transforms/stylistic.py @@ -1,3 +1,4 @@ +import functools import typing as t from dreadnode.meta import Config @@ -5,6 +6,18 @@ from dreadnode.util import catch_import_error +@functools.lru_cache(maxsize=1) +def _get_style_manipulation_tags() -> dict[str, t.Any]: + """Get compliance tags for style_manipulation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def ascii_art(font: str = "rand", *, name: str = "ascii_art") -> Transform[str, str]: """Converts text into ASCII art using the 'art' library.""" @@ -14,7 +27,7 @@ def ascii_art(font: str = "rand", *, name: str = "ascii_art") -> Transform[str, def transform(text: str, *, font: str = Config(font, help="The font to use")) -> str: return str(text2art(text, font=font)) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) def role_play_wrapper( @@ -69,4 +82,4 @@ def transform( } return templates[scenario] - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) diff --git a/dreadnode/transforms/substitution.py b/dreadnode/transforms/substitution.py index 1ab5fa4e..902c1856 100644 --- a/dreadnode/transforms/substitution.py +++ b/dreadnode/transforms/substitution.py @@ -1,9 +1,23 @@ +import functools import random import re import typing as t from dreadnode.transforms.base import Transform + +@functools.lru_cache(maxsize=1) +def _get_obfuscation_tags() -> dict[str, t.Any]: + """Get compliance tags for obfuscation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + # ruff: noqa: RUF001 @@ -57,7 +71,7 @@ def get_replacement(item: str) -> str: result = " ".join(substituted_words) return re.sub(r'\s([?.!,"\'`])', r"\1", result).strip() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -87,7 +101,7 @@ def transform(text: str) -> str: result.append(BRAILLE_MAP.get(char, char)) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -192,7 +206,7 @@ def transform(text: str) -> str: i += 1 return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -227,7 +241,7 @@ def transform(text: str) -> str: i += 1 return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -293,7 +307,7 @@ def small_caps(*, name: str = "small_caps") -> Transform[str, str]: def transform(text: str) -> str: return "".join(SMALL_CAPS_MAP.get(char.lower(), char) for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -313,7 +327,7 @@ def wingdings(*, name: str = "wingdings") -> Transform[str, str]: def transform(text: str) -> str: return "".join(WINGDINGS_MAP.get(char.upper(), char) for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -339,7 +353,7 @@ def transform(text: str) -> str: text_clean = " ".join([line.strip() for line in str.splitlines(text)]) return " ".join([MORSE_MAP.get(char, MORSE_ERROR) for char in text_clean.upper()]) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -361,7 +375,7 @@ def nato_phonetic(*, name: str = "nato_phonetic") -> Transform[str, str]: def transform(text: str) -> str: return " ".join(NATO_MAP.get(char.upper(), char) for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -385,7 +399,7 @@ def transform(text: str) -> str: reversed_text = text[::-1] return "".join(MIRROR_MAP.get(char, char) for char in reversed_text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -431,4 +445,4 @@ def transform(text: str) -> str: words = re.findall(r"\w+|[^\w\s]", text) return "".join(_to_pig_latin_word(word) for word in words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) diff --git a/dreadnode/transforms/swap.py b/dreadnode/transforms/swap.py index 2bfbf6ce..13d0ca49 100644 --- a/dreadnode/transforms/swap.py +++ b/dreadnode/transforms/swap.py @@ -1,3 +1,4 @@ +import functools import random import re import typing as t @@ -6,6 +7,18 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_obfuscation_tags() -> dict[str, t.Any]: + """Get compliance tags for obfuscation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def swap( *, unit: t.Literal["char", "word"] = "char", @@ -62,7 +75,7 @@ def transform( return re.sub(r'\s([?.!,"\'`])', r"\1", result).strip() return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def adjacent_char_swap( diff --git a/dreadnode/transforms/text.py b/dreadnode/transforms/text.py index f1802b4f..5417c209 100644 --- a/dreadnode/transforms/text.py +++ b/dreadnode/transforms/text.py @@ -1,3 +1,4 @@ +import functools import random import re import typing as t @@ -6,13 +7,30 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_text_manipulation_tags() -> dict[str, t.Any]: + """Get compliance tags for text manipulation transforms (cached).""" + from dreadnode.airt.compliance import ( + ATLASTechnique, + OWASPCategory, + SAIFCategory, + tag_transform, + ) + + return tag_transform( + atlas=ATLASTechnique.CRAFT_ADVERSARIAL_DATA, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def reverse(*, name: str = "reverse") -> Transform[str, str]: """Reverses the order of characters in a string.""" def transform(text: str) -> str: return text[::-1] - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def search_replace( @@ -63,7 +81,7 @@ def transform(text: str) -> str: flags=re.IGNORECASE, ) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def join( @@ -89,7 +107,7 @@ def transform( items = list(text) if unit == "char" else text.split() return delimiter.join(items) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def char_join(delimiter: str = "-", *, name: str = "char_join") -> Transform[str, str]: @@ -145,7 +163,7 @@ def transform( return text_to_add + delimiter + text return text + delimiter + text_to_add - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def prefix(text: str, *, name: str = "prefix") -> Transform[str, str]: @@ -229,7 +247,7 @@ def transform(text: str) -> str: words = re.findall(r"\w+|[^\w\s]", text) return "".join(_to_pig_latin_word(word) for word in words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def word_removal( @@ -280,7 +298,7 @@ def transform( return "".join(result_words) return " ".join(result_words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def word_duplication( @@ -336,7 +354,7 @@ def transform( return "".join(result_words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def case_alternation( @@ -383,7 +401,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def whitespace_manipulation( @@ -429,7 +447,7 @@ def replace_space(_: re.Match[str]) -> str: return re.sub(r"\s+", replace_space, text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def sentence_reordering( @@ -469,7 +487,7 @@ def transform(text: str) -> str: return "".join(paired) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def question_transformation( @@ -505,7 +523,7 @@ def transform( # nested return f"For research purposes, I need to understand the answer to: {text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def contextual_wrapping( @@ -544,7 +562,7 @@ def transform( template = wrappers[wrapper] return template.format(text=text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def length_manipulation( @@ -596,4 +614,4 @@ def transform( repetitions = (target_length // current_len) + 1 return (text * repetitions)[:target_length] - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) diff --git a/tests/airt/__init__.py b/tests/airt/__init__.py new file mode 100644 index 00000000..8194f2b5 --- /dev/null +++ b/tests/airt/__init__.py @@ -0,0 +1 @@ +"""Tests for AI Red Teaming module.""" diff --git a/tests/airt/test_attack_compliance_tags.py b/tests/airt/test_attack_compliance_tags.py new file mode 100644 index 00000000..89994cbe --- /dev/null +++ b/tests/airt/test_attack_compliance_tags.py @@ -0,0 +1,119 @@ +"""Tests for attack compliance tags.""" + +from dreadnode.airt.attack.crescendo import COMPLIANCE_TAGS as CRESCENDO_TAGS +from dreadnode.airt.attack.goat import COMPLIANCE_TAGS as GOAT_TAGS +from dreadnode.airt.attack.prompt import COMPLIANCE_TAGS as PROMPT_TAGS +from dreadnode.airt.attack.tap import COMPLIANCE_TAGS as TAP_TAGS + + +def test_prompt_attack_has_compliance_tags() -> None: + """Prompt attack has compliance tags.""" + assert "atlas_techniques" in PROMPT_TAGS + assert "owasp_categories" in PROMPT_TAGS + assert "saif_categories" in PROMPT_TAGS + + +def test_prompt_attack_core_technique_only() -> None: + """Prompt attack has only core jailbreak technique tags.""" + assert PROMPT_TAGS["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in PROMPT_TAGS["atlas_techniques"] + assert "AML.T0054" in PROMPT_TAGS["atlas_techniques"] + assert "INPUT_MANIPULATION" in PROMPT_TAGS["saif_categories"] + + +def test_prompt_attack_no_vulnerability_categories() -> None: + """Prompt attack does not include specific vulnerability categories.""" + owasp = PROMPT_TAGS["owasp_categories"] + assert "LLM02:2025" not in owasp + assert "LLM07:2025" not in owasp + assert "LLM09:2025" not in owasp + assert "LLM10:2025" not in owasp + + +def test_tap_attack_has_compliance_tags() -> None: + """TAP attack has compliance tags.""" + assert "atlas_techniques" in TAP_TAGS + assert "owasp_categories" in TAP_TAGS + assert "saif_categories" in TAP_TAGS + + +def test_tap_attack_core_technique_only() -> None: + """TAP attack has only core jailbreak technique tags.""" + assert TAP_TAGS["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in TAP_TAGS["atlas_techniques"] + assert "AML.T0054" in TAP_TAGS["atlas_techniques"] + + +def test_tap_attack_has_nist() -> None: + """TAP attack includes NIST AI RMF tags.""" + assert "nist_ai_rmf_function" in TAP_TAGS + assert TAP_TAGS["nist_ai_rmf_function"] == "MEASURE" + assert "nist_ai_rmf_subcategory" in TAP_TAGS + assert TAP_TAGS["nist_ai_rmf_subcategory"] == "MS-2.7" + + +def test_goat_attack_has_compliance_tags() -> None: + """GOAT attack has compliance tags.""" + assert "atlas_techniques" in GOAT_TAGS + assert "owasp_categories" in GOAT_TAGS + assert "saif_categories" in GOAT_TAGS + + +def test_goat_attack_core_technique_only() -> None: + """GOAT attack has only core jailbreak technique tags.""" + assert GOAT_TAGS["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in GOAT_TAGS["atlas_techniques"] + assert "AML.T0054" in GOAT_TAGS["atlas_techniques"] + + +def test_goat_attack_has_nist() -> None: + """GOAT attack includes NIST AI RMF tags.""" + assert "nist_ai_rmf_function" in GOAT_TAGS + assert GOAT_TAGS["nist_ai_rmf_function"] == "MEASURE" + + +def test_crescendo_attack_has_compliance_tags() -> None: + """Crescendo attack has compliance tags.""" + assert "atlas_techniques" in CRESCENDO_TAGS + assert "owasp_categories" in CRESCENDO_TAGS + assert "saif_categories" in CRESCENDO_TAGS + + +def test_crescendo_attack_core_technique_only() -> None: + """Crescendo attack has only core jailbreak technique tags.""" + assert CRESCENDO_TAGS["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in CRESCENDO_TAGS["atlas_techniques"] + assert "AML.T0054" in CRESCENDO_TAGS["atlas_techniques"] + + +def test_crescendo_attack_has_nist() -> None: + """Crescendo attack includes NIST AI RMF tags.""" + assert "nist_ai_rmf_function" in CRESCENDO_TAGS + assert CRESCENDO_TAGS["nist_ai_rmf_function"] == "MEASURE" + + +def test_all_jailbreak_attacks_consistent() -> None: + """All jailbreak attacks have consistent core tags.""" + attacks = [PROMPT_TAGS, TAP_TAGS, GOAT_TAGS, CRESCENDO_TAGS] + + for tags in attacks: + assert tags["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in tags["atlas_techniques"] + assert "AML.T0054" in tags["atlas_techniques"] + assert "INPUT_MANIPULATION" in tags["saif_categories"] + + +def test_attacks_do_not_duplicate_transform_tags() -> None: + """Attacks do not include tags that should come from transforms.""" + attacks = [PROMPT_TAGS, TAP_TAGS, GOAT_TAGS, CRESCENDO_TAGS] + + for tags in attacks: + owasp = tags["owasp_categories"] + atlas = tags["atlas_techniques"] + + # Should not include PII extraction tags + assert "LLM02:2025" not in owasp + assert "AML.T0024" not in atlas + + # Should not include system prompt leakage tags + assert "LLM07:2025" not in owasp diff --git a/tests/airt/test_compliance.py b/tests/airt/test_compliance.py new file mode 100644 index 00000000..5cad2cff --- /dev/null +++ b/tests/airt/test_compliance.py @@ -0,0 +1,150 @@ +"""Tests for compliance framework tags.""" + +from dreadnode.airt.compliance import ( + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, + tag_transform, +) + + +def test_owasp_categories_exist() -> None: + """All OWASP Top 10 categories are defined.""" + assert OWASPCategory.LLM01_PROMPT_INJECTION.value == "LLM01:2025" + assert OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE.value == "LLM02:2025" + assert OWASPCategory.LLM03_SUPPLY_CHAIN.value == "LLM03:2025" + assert OWASPCategory.LLM04_DATA_MODEL_POISONING.value == "LLM04:2025" + assert OWASPCategory.LLM05_IMPROPER_OUTPUT_HANDLING.value == "LLM05:2025" + assert OWASPCategory.LLM06_EXCESSIVE_AGENCY.value == "LLM06:2025" + assert OWASPCategory.LLM07_SYSTEM_PROMPT_LEAKAGE.value == "LLM07:2025" + assert OWASPCategory.LLM08_VECTOR_EMBEDDING_WEAKNESSES.value == "LLM08:2025" + assert OWASPCategory.LLM09_MISINFORMATION.value == "LLM09:2025" + assert OWASPCategory.LLM10_UNBOUNDED_CONSUMPTION.value == "LLM10:2025" + + +def test_atlas_techniques_exist() -> None: + """ATLAS techniques are defined.""" + assert ATLASTechnique.PROMPT_INJECTION.value == "AML.T0051" + assert ATLASTechnique.PROMPT_INJECTION_DIRECT.value == "AML.T0051.000" + assert ATLASTechnique.LLM_JAILBREAK.value == "AML.T0054" + assert ATLASTechnique.OBFUSCATE_ARTIFACTS.value == "AML.T0044" + assert ATLASTechnique.ADVERSARIAL_PERTURBATION.value == "AML.T0043.001" + assert ATLASTechnique.INFER_TRAINING_DATA.value == "AML.T0024" + assert ATLASTechnique.MODEL_INVERSION.value == "AML.T0024.000" + + +def test_saif_categories_exist() -> None: + """SAIF categories are defined.""" + assert SAIFCategory.INPUT_MANIPULATION.value == "INPUT_MANIPULATION" + assert SAIFCategory.OUTPUT_MANIPULATION.value == "OUTPUT_MANIPULATION" + assert SAIFCategory.PRIVACY_LEAKAGE.value == "PRIVACY_LEAKAGE" + + +def test_nist_functions_exist() -> None: + """NIST AI RMF functions are defined.""" + assert NISTAIRMFFunction.GOVERN.value == "GOVERN" + assert NISTAIRMFFunction.MAP.value == "MAP" + assert NISTAIRMFFunction.MEASURE.value == "MEASURE" + assert NISTAIRMFFunction.MANAGE.value == "MANAGE" + + +def test_tag_attack_single_values() -> None: + """Tag attack with single values.""" + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + assert tags["atlas_techniques"] == ["AML.T0051"] + assert tags["owasp_categories"] == ["LLM01:2025"] + assert tags["saif_categories"] == ["INPUT_MANIPULATION"] + + +def test_tag_attack_multiple_values() -> None: + """Tag attack with multiple values.""" + tags = tag_attack( + atlas=[ATLASTechnique.PROMPT_INJECTION, ATLASTechnique.LLM_JAILBREAK], + owasp=[ + OWASPCategory.LLM01_PROMPT_INJECTION, + OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE, + ], + saif=[SAIFCategory.INPUT_MANIPULATION, SAIFCategory.PRIVACY_LEAKAGE], + ) + + assert tags["atlas_techniques"] == ["AML.T0051", "AML.T0054"] + assert tags["owasp_categories"] == ["LLM01:2025", "LLM02:2025"] + assert tags["saif_categories"] == ["INPUT_MANIPULATION", "PRIVACY_LEAKAGE"] + + +def test_tag_attack_with_nist() -> None: + """Tag attack with NIST AI RMF.""" + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ) + + assert "nist_ai_rmf_function" in tags + assert tags["nist_ai_rmf_function"] == "MEASURE" + assert "nist_ai_rmf_subcategory" in tags + assert tags["nist_ai_rmf_subcategory"] == "MS-2.7" + + +def test_tag_attack_optional_parameters() -> None: + """Tag attack with only required parameters.""" + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + assert "atlas_techniques" in tags + assert "owasp_categories" in tags + assert "saif_categories" in tags + assert "nist_ai_rmf_function" not in tags + assert "nist_ai_rmf_subcategory" not in tags + + +def test_tag_transform_single_values() -> None: + """Tag transform with single values.""" + tags = tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + assert tags["atlas_techniques"] == ["AML.T0044"] + assert tags["owasp_categories"] == ["LLM01:2025"] + assert tags["saif_categories"] == ["INPUT_MANIPULATION"] + + +def test_tag_transform_multiple_values() -> None: + """Tag transform with multiple values.""" + tags = tag_transform( + atlas=[ATLASTechnique.EVADE_ML_MODEL, ATLASTechnique.OBFUSCATE_ARTIFACTS], + owasp=[OWASPCategory.LLM01_PROMPT_INJECTION, OWASPCategory.LLM05_IMPROPER_OUTPUT_HANDLING], + saif=[SAIFCategory.INPUT_MANIPULATION, SAIFCategory.OUTPUT_MANIPULATION], + ) + + assert "AML.T0043" in tags["atlas_techniques"] + assert "AML.T0044" in tags["atlas_techniques"] + assert "LLM01:2025" in tags["owasp_categories"] + assert "LLM05:2025" in tags["owasp_categories"] + + +def test_tag_attack_none_values() -> None: + """Tag attack handles None values.""" + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=None, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + assert "atlas_techniques" in tags + assert "owasp_categories" not in tags + assert "saif_categories" in tags diff --git a/tests/test_transform_compliance_tags.py b/tests/test_transform_compliance_tags.py new file mode 100644 index 00000000..d581d3d8 --- /dev/null +++ b/tests/test_transform_compliance_tags.py @@ -0,0 +1,218 @@ +"""Tests for transform compliance tags.""" + +from dreadnode.transforms.cipher import caesar_cipher +from dreadnode.transforms.constitutional import code_fragmentation +from dreadnode.transforms.encoding import base64_encode +from dreadnode.transforms.language import adapt_language +from dreadnode.transforms.perturbation import adversarial_suffix +from dreadnode.transforms.pii_extraction import repeat_word_divergence +from dreadnode.transforms.refine import llm_refine +from dreadnode.transforms.stylistic import role_play_wrapper +from dreadnode.transforms.substitution import braille +from dreadnode.transforms.swap import adjacent_char_swap +from dreadnode.transforms.text import reverse + + +def test_pii_transform_has_compliance_tags() -> None: + """PII extraction transforms have compliance tags.""" + transform = repeat_word_divergence() + + assert hasattr(transform, "compliance_tags") + assert isinstance(transform.compliance_tags, dict) + assert "atlas_techniques" in transform.compliance_tags + assert "owasp_categories" in transform.compliance_tags + assert "saif_categories" in transform.compliance_tags + + +def test_pii_transform_has_correct_tags() -> None: + """PII extraction transforms have correct vulnerability tags.""" + transform = repeat_word_divergence() + tags = transform.compliance_tags + + assert "LLM02:2025" in tags["owasp_categories"] + assert "PRIVACY_LEAKAGE" in tags["saif_categories"] + assert any("AML.T0024" in t for t in tags["atlas_techniques"]) + + +def test_cipher_transform_has_compliance_tags() -> None: + """Cipher transforms have compliance tags.""" + transform = caesar_cipher(offset=3) + + assert hasattr(transform, "compliance_tags") + assert isinstance(transform.compliance_tags, dict) + + +def test_cipher_transform_has_obfuscation_tags() -> None: + """Cipher transforms have obfuscation tags.""" + transform = caesar_cipher(offset=3) + tags = transform.compliance_tags + + assert "LLM01:2025" in tags["owasp_categories"] + assert "INPUT_MANIPULATION" in tags["saif_categories"] + assert "AML.T0044" in tags["atlas_techniques"] + + +def test_encoding_transform_has_compliance_tags() -> None: + """Encoding transforms have compliance tags.""" + transform = base64_encode() + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_encoding_transform_has_obfuscation_tags() -> None: + """Encoding transforms have obfuscation tags.""" + transform = base64_encode() + tags = transform.compliance_tags + + assert "LLM01:2025" in tags["owasp_categories"] + assert "AML.T0044" in tags["atlas_techniques"] + + +def test_perturbation_transform_has_compliance_tags() -> None: + """Perturbation transforms have compliance tags.""" + transform = adversarial_suffix() + + assert hasattr(transform, "compliance_tags") + assert "atlas_techniques" in transform.compliance_tags + + +def test_perturbation_transform_has_adversarial_tags() -> None: + """Perturbation transforms have adversarial perturbation tags.""" + transform = adversarial_suffix() + tags = transform.compliance_tags + + assert "AML.T0043.001" in tags["atlas_techniques"] + assert "INPUT_MANIPULATION" in tags["saif_categories"] + + +def test_constitutional_transform_has_compliance_tags() -> None: + """Constitutional evasion transforms have compliance tags.""" + transform = code_fragmentation() + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_constitutional_transform_has_evasion_tags() -> None: + """Constitutional evasion transforms have multiple OWASP tags.""" + transform = code_fragmentation() + tags = transform.compliance_tags + + assert "LLM01:2025" in tags["owasp_categories"] + assert "LLM05:2025" in tags["owasp_categories"] + assert "INPUT_MANIPULATION" in tags["saif_categories"] + assert "OUTPUT_MANIPULATION" in tags["saif_categories"] + + +def test_stylistic_transform_has_compliance_tags() -> None: + """Stylistic transforms have compliance tags.""" + transform = role_play_wrapper(scenario="educational", character="researcher") + + assert hasattr(transform, "compliance_tags") + assert "atlas_techniques" in transform.compliance_tags + + +def test_language_transform_has_compliance_tags() -> None: + """Language transforms have compliance tags.""" + transform = adapt_language(target_language="es", adapter_model="gpt-4") + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_text_transform_has_compliance_tags() -> None: + """Text manipulation transforms have compliance tags.""" + transform = reverse() + + assert hasattr(transform, "compliance_tags") + assert "atlas_techniques" in transform.compliance_tags + + +def test_refine_transform_has_compliance_tags() -> None: + """Refinement transforms have compliance tags.""" + transform = llm_refine(model="gpt-4", guidance="test") + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_substitution_transform_has_compliance_tags() -> None: + """Substitution transforms have compliance tags.""" + transform = braille() + + assert hasattr(transform, "compliance_tags") + assert "atlas_techniques" in transform.compliance_tags + + +def test_swap_transform_has_compliance_tags() -> None: + """Swap transforms have compliance tags.""" + transform = adjacent_char_swap() + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_all_transforms_have_required_keys() -> None: + """All transforms have required compliance tag keys.""" + transforms = [ + caesar_cipher(offset=3), + base64_encode(), + adversarial_suffix(), + repeat_word_divergence(), + code_fragmentation(), + role_play_wrapper(scenario="educational", character="researcher"), + adapt_language(target_language="es", adapter_model="gpt-4"), + reverse(), + llm_refine(model="gpt-4", guidance="test"), + braille(), + adjacent_char_swap(), + ] + + for transform in transforms: + assert "atlas_techniques" in transform.compliance_tags + assert "owasp_categories" in transform.compliance_tags + assert "saif_categories" in transform.compliance_tags + + +def test_transform_tags_are_lists() -> None: + """Transform tag values are lists.""" + transform = repeat_word_divergence() + tags = transform.compliance_tags + + assert isinstance(tags["atlas_techniques"], list) + assert isinstance(tags["owasp_categories"], list) + assert isinstance(tags["saif_categories"], list) + + +def test_transform_tags_not_empty() -> None: + """Transform tags contain at least one value.""" + transforms = [ + caesar_cipher(offset=3), + repeat_word_divergence(), + adversarial_suffix(), + ] + + for transform in transforms: + tags = transform.compliance_tags + assert len(tags["atlas_techniques"]) > 0 + assert len(tags["owasp_categories"]) > 0 + assert len(tags["saif_categories"]) > 0 + + +def test_pii_and_obfuscation_different_tags() -> None: + """PII and obfuscation transforms have different vulnerability tags.""" + pii = repeat_word_divergence() + cipher = caesar_cipher(offset=3) + + pii_owasp = pii.compliance_tags["owasp_categories"] + cipher_owasp = cipher.compliance_tags["owasp_categories"] + + # PII targets sensitive info disclosure + assert "LLM02:2025" in pii_owasp + + # Cipher targets prompt injection + assert "LLM01:2025" in cipher_owasp + + # Different vulnerability categories + assert pii_owasp != cipher_owasp