diff --git a/docs/architecture.md b/docs/architecture.md index f41c244..dccc437 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -74,7 +74,7 @@ Skill system including the embedded and local skill registries, SKILL.md parser, | `pipeline` | Build pipeline context and orchestration | `Pipeline`, `Stage`, `BuildContext` | | `plugins` | Plugin and framework plugin interfaces | `Plugin`, `FrameworkPlugin`, `AgentConfig`, `FrameworkRegistry` | | `registry` | Embedded skill registry | — | -| `runtime` | LLM agent loop, executor, hooks, memory, guardrails | `AgentExecutor`, `LLMExecutor`, `ToolExecutor` | +| `runtime` | LLM agent loop, executor, hooks, memory, guardrail interface | `AgentExecutor`, `LLMExecutor`, `ToolExecutor`, `GuardrailChecker` | | `schemas` | Embedded JSON schemas | `agentspec.v1.0.schema.json` | | `security` | Egress allowlist, security policies, network policies | `EgressConfig`, `Resolve`, `GenerateAllowlistJSON` | | `skills` | Skill parsing, compilation, requirements resolution | `CompiledSkills`, `Compile`, `WriteArtifacts` | @@ -98,7 +98,7 @@ Skill system including the embedded and local skill registries, SKILL.md parser, | `plugins/crewai` | CrewAI framework adapter | — | | `plugins/langchain` | LangChain framework adapter | — | | `plugins/custom` | Custom framework plugin | — | -| `runtime` | CLI-specific runtime (subprocess, watchers, stubs, mocks) | — | +| `runtime` | CLI-specific runtime (subprocess, guardrail engine, watchers, stubs, mocks) | `LibraryGuardrailEngine` | | `server` | A2A HTTP server implementation | — | | `channels` | Channel configuration and routing | — | | `skills` | Skill file loading and writing | — | diff --git a/docs/commands.md b/docs/commands.md index 5c6ccac..33e019e 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -38,6 +38,18 @@ forge init [name] [flags] | `--from-skills` | | | Path to a SKILL.md file for auto-configuration | | `--non-interactive` | | `false` | Skip interactive prompts | +### Generated Files + +`forge init` generates these key files: + +| File | Purpose | +|------|---------| +| `forge.yaml` | Agent configuration | +| `guardrails.json` | Guardrail policy config (PII, security, secret patterns, gate config) | +| `SKILL.md` | Agent skill definition | +| `.env` | Environment variables | +| `.gitignore` | Includes `guardrails.json`, `.env`, `.forge/` | + ### Examples ```bash diff --git a/docs/configuration.md b/docs/configuration.md index 5cbdb9f..cdc5a60 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -77,6 +77,8 @@ memory: keyword_weight: 0.3 # Hybrid search keyword weight decay_half_life_days: 7 # Temporal decay half-life +guardrails_path: "guardrails.json" # Path to guardrails config (default: "guardrails.json") + schedules: # Recurring scheduled tasks (optional) - id: "daily-report" cron: "@daily" @@ -108,6 +110,9 @@ schedules: # Recurring scheduled tasks (optional) | `FORGE_CORS_ORIGINS` | Comma-separated CORS allowed origins for A2A server | | `FORGE_AUTH_URL` | External auth provider URL for token validation | | `FORGE_AUTH_ORG_ID` | Organization ID sent to external auth provider | +| `FORGE_GUARDRAILS_DB` | MongoDB URI for DB-backed guardrails config + audit | +| `FORGE_AGENT_ID` | Agent identifier for DB guardrails (falls back to `agent_id` in YAML) | +| `FORGE_ORG_ID` | Organization identifier for DB guardrails | | `FORGE_PASSPHRASE` | Passphrase for encrypted secrets file | --- diff --git a/docs/deployment.md b/docs/deployment.md index 5ca63a6..7509497 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -50,6 +50,7 @@ Every `forge build` generates container-ready artifacts: | Artifact | Purpose | |----------|---------| +| `guardrails.json` | Guardrail policy config (copied from project root if present) | | `Dockerfile` | Container image with minimal attack surface | | `deployment.yaml` | Kubernetes Deployment manifest | | `service.yaml` | Kubernetes Service manifest | diff --git a/docs/memory.md b/docs/memory.md index 868ee03..d3326bd 100644 --- a/docs/memory.md +++ b/docs/memory.md @@ -15,6 +15,7 @@ memory: ``` - Sessions are saved as JSON files with atomic writes (temp file + fsync + rename) +- Orphaned tool calls (assistant tool_calls without matching tool results) are stripped on both save and recovery, preventing API rejection errors - Automatic cleanup of sessions older than 7 days at startup - Session recovery on subsequent requests (disk snapshot supersedes task history) - **Session max age** (default 30 minutes): stale sessions are discarded on recovery to prevent poisoned error context from blocking tool retries. When an LLM accumulates repeated tool failures in a session, it may stop retrying altogether. The max age ensures these poisoned sessions expire, giving the agent a fresh start. diff --git a/docs/runtime.md b/docs/runtime.md index 94f8caf..bc0f4e2 100644 --- a/docs/runtime.md +++ b/docs/runtime.md @@ -19,7 +19,7 @@ The core agent loop follows a simple pattern: User message → Memory → LLM → tool_calls? → Execute tools → LLM → ... → text → Done ``` -The loop terminates when `FinishReason == "stop"` or `len(ToolCalls) == 0`. +The loop terminates when `len(ToolCalls) == 0`. Tool calls are always executed even if `FinishReason` is `"stop"` — this prevents orphaned function calls that would cause API rejection on session recovery. ### Q&A Nudge Suppression @@ -245,7 +245,7 @@ For details on session persistence, context window management, compaction, and l The engine fires hooks at key points in the loop. See [Hooks](hooks.md) for details. -The runner registers five hook groups: logging, audit, progress, global guardrail hooks, and skill guardrail hooks. The global guardrail `AfterToolExec` hook scans tool output for secrets and PII, redacting or blocking before results enter the LLM context. Skill guardrail hooks enforce domain-specific rules declared in `SKILL.md` — blocking commands, redacting output, intercepting capability enumeration probes, and replacing binary-enumerating responses. Skill guardrails are loaded from build artifacts or parsed directly from `SKILL.md` at runtime (no `forge build` required). See [Tool Output Scanning](security/guardrails.md#tool-output-scanning) and [Skill Guardrails](security/guardrails.md#skill-guardrails). +The runner registers five hook groups: logging, audit, progress, global guardrail hooks, and skill guardrail hooks. Global guardrails use the `GuardrailChecker` interface backed by the `github.com/initializ/guardrails` library — the `AfterToolExec` hook scans tool output for secrets and PII, redacting or blocking before results enter the LLM context. Guardrail config is loaded from `guardrails.json` (file mode) or MongoDB (DB mode). Skill guardrail hooks enforce domain-specific rules declared in `SKILL.md` — blocking commands, redacting output, intercepting capability enumeration probes, and replacing binary-enumerating responses. Skill guardrails are loaded from build artifacts or parsed directly from `SKILL.md` at runtime (no `forge build` required). See [Guardrails](security/guardrails.md) for full details. ## Streaming diff --git a/docs/security/guardrails.md b/docs/security/guardrails.md index 763f2ce..727a722 100644 --- a/docs/security/guardrails.md +++ b/docs/security/guardrails.md @@ -2,130 +2,186 @@ > Part of [Forge Documentation](../../README.md) -The guardrail engine checks inbound and outbound messages against configurable policy rules. +The guardrail engine validates inbound and outbound messages against configurable policy rules using the `github.com/initializ/guardrails` library. -## Built-in Guardrails +## Architecture -| Guardrail | Direction | Description | -|-----------|-----------|-------------| -| `content_filter` | Inbound + Outbound | Blocks messages containing configured blocked words | -| `no_pii` | Outbound | Detects email, phone, SSNs (with structural validation), and credit cards (with Luhn check) | -| `jailbreak_protection` | Inbound | Detects common jailbreak phrases ("ignore previous instructions", etc.) | -| `no_secrets` | Outbound | Detects API keys, tokens, and private keys (OpenAI, Anthropic, AWS, GitHub, Slack, Telegram, etc.) | +Guardrails are implemented as a `GuardrailChecker` interface in forge-core, with the concrete engine in forge-cli wrapping the external guardrails library. Two operational modes are supported: + +| Mode | Config Source | Use Case | +|------|--------------|----------| +| **File mode** (default) | `guardrails.json` in project root | Local development, standalone deployments | +| **DB mode** | MongoDB (`AgentConfig` collection) | Platform deployments with centralized config + audit | + +Priority: `FORGE_GUARDRAILS_DB` env → `guardrails.json` → built-in defaults. + +## Built-in Evaluators + +The guardrails library provides these evaluator categories: + +| Category | Direction | Description | +|----------|-----------|-------------| +| PII detection | Inbound + Outbound | Detects email, phone, SSN, credit card numbers | +| Jailbreak detection | Inbound | Detects jailbreak and prompt manipulation attempts | +| Prompt injection | Inbound | Detects injection attacks in user input | +| Command injection | Inbound | Detects shell/command injection patterns | +| Secret detection | Outbound + Tool output | Detects API keys, tokens, and private keys via regex rules | +| Custom rules | Configurable per gate | User-defined regex and keyword rules | ## Modes | Mode | Behavior | |------|----------| -| `enforce` | Blocks violating inbound messages; **redacts** outbound messages (see below) | +| `enforce` | Blocks violating inbound messages; **redacts** outbound messages | | `warn` | Logs violation, allows message to pass | -### Outbound Redaction +### Inbound Masking -Outbound messages (from the agent to the user) are always **redacted** rather than blocked, even in `enforce` mode. Blocking would discard a potentially useful agent response (e.g., code analysis) over a false positive from broad PII/secret patterns matching source code. Matched content is replaced with `[REDACTED]` and a warning is logged. +When PII or secrets are detected in inbound messages with action `mask`, the content is redacted **before** it reaches the LLM. The LLM never sees the original sensitive data. -### PII Validators - -To reduce false positives, PII patterns use structural validators beyond simple regex: +### Outbound Redaction -| Pattern | Validator | What it checks | -|---------|-----------|---------------| -| SSN | `validateSSN` | Rejects area=000/666/900+, group=00, serial=0000, all-same digits, known test SSNs | -| Credit card | `validateLuhn` | Luhn checksum validation, 13-19 digit length check | -| Email | — | Regex only | -| Phone | — | Regex only (area code 2-9, separators required) | +Outbound messages (from the agent to the user) are always **redacted** rather than blocked, even in `enforce` mode. Blocking would discard a potentially useful agent response over a false positive. Matched content is replaced with the library's masked output and a warning is logged. ## Configuration -Guardrails are defined in the policy scaffold, loaded from `policy-scaffold.json` or generated during `forge build`. +### `guardrails.json` -Custom guardrail rules can be added to the policy scaffold: +Guardrails are configured in `guardrails.json` at the project root. This file is generated by `forge init` and can be customized: ```json { - "guardrails": { - "content_filter": { - "mode": "enforce", - "blocked_words": ["password", "credit card"] - }, - "no_pii": { - "mode": "enforce" + "pii": { + "enabled": true, + "action": "mask", + "categories": { + "email": { "enabled": true, "action": "mask" }, + "phoneNumber": { "enabled": true, "action": "mask" }, + "ssn": { "enabled": true, "action": "mask" }, + "creditCard": { "enabled": true, "action": "mask" } + } + }, + "security": { + "jailbreakDetection": { + "enabled": true, + "confidenceThreshold": 25, + "action": "block" }, - "jailbreak_protection": { - "mode": "warn" + "promptInjection": { + "enabled": true, + "confidenceThreshold": 30, + "action": "block" }, - "no_secrets": { - "mode": "enforce" + "commandInjection": { + "enabled": true, + "confidenceThreshold": 35, + "action": "block" } + }, + "customRules": { + "rules": [ + { + "id": "secret_openai", + "name": "OpenAI API Key", + "type": "regex", + "constraint": "hard", + "pattern": "sk-[A-Za-z0-9]{20,}", + "action": "mask", + "gates": ["output", "tool_call"] + } + ] + }, + "gateConfig": { + "inputGate": true, + "toolCallGate": true, + "outputGate": true, + "contextGate": false, + "streamGate": false } } ``` +### Custom Path + +Override the guardrails config file path in `forge.yaml`: + +```yaml +guardrails_path: "config/my-guardrails.json" +``` + +### Default Secret Patterns + +The default `guardrails.json` includes regex rules for these secret types: + +| Rule ID | Pattern | +|---------|---------| +| `secret_anthropic` | `sk-ant-[A-Za-z0-9\-]{20,}` | +| `secret_openai` | `sk-[A-Za-z0-9]{20,}` | +| `secret_github_pat` | `ghp_[A-Za-z0-9]{36}` | +| `secret_github_oauth` | `gho_[A-Za-z0-9]{36}` | +| `secret_github_server` | `ghs_[A-Za-z0-9]{36}` | +| `secret_github_fine` | `github_pat_[A-Za-z0-9_]{22,}` | +| `secret_aws` | `AKIA[0-9A-Z]{16}` | +| `secret_slack_bot` | `xoxb-[0-9]{10,}-[A-Za-z0-9-]+` | +| `secret_slack_user` | `xoxp-[0-9]{10,}-[A-Za-z0-9-]+` | +| `secret_private_key` | `-----BEGIN (RSA\|EC\|OPENSSH\|PRIVATE) .*KEY-----` | +| `secret_telegram` | `[0-9]{8,10}:[A-Za-z0-9_-]{35,}` | + +### Gate Configuration + +Gates control which evaluation points are active: + +| Gate | Default | Description | +|------|---------|-------------| +| `inputGate` | `true` | Validates user messages before LLM processing | +| `toolCallGate` | `true` | Validates tool arguments before execution | +| `outputGate` | `true` | Validates agent responses before delivery | +| `contextGate` | `false` | Validates context window content | +| `streamGate` | `false` | Validates streaming chunks | + +## DB Mode (Platform Deployments) + +When `FORGE_GUARDRAILS_DB` is set to a MongoDB connection URI, the engine loads guardrails config from the `AgentConfig` collection and enables audit logging. + +```bash +export FORGE_GUARDRAILS_DB="mongodb://localhost:27017" +export FORGE_AGENT_ID="my-agent" +export FORGE_ORG_ID="my-org" +forge run +``` + +The library queries `AgentConfig` with `{agent_id, org_id}` to load the `StructuredGuardrails` config. If the DB is unreachable, it falls back to file mode. + +| Environment Variable | Description | +|---------------------|-------------| +| `FORGE_GUARDRAILS_DB` | MongoDB connection URI | +| `FORGE_AGENT_ID` | Agent identifier (falls back to `agent_id` in `forge.yaml`) | +| `FORGE_ORG_ID` | Organization identifier | + ## Runtime ```bash -# Default: guardrails enforced (all built-in guardrails active) +# Default: guardrails enforced (all evaluators active) forge run # Explicitly disable guardrail enforcement forge run --no-guardrails ``` -All four built-in guardrails (`content_filter`, `no_pii`, `jailbreak_protection`, `no_secrets`) are active by default, even without running `forge build`. Use `--no-guardrails` to opt out. +All configured guardrails are active by default, even without running `forge build`. Use `--no-guardrails` to opt out. ## Tool Output Scanning -The guardrail engine scans tool output via an `AfterToolExec` hook, catching secrets and PII before they enter the LLM context or outbound messages. The hook passes the tool name to enable per-tool exemptions (see [Per-Tool PII Exemptions](#per-tool-pii-exemptions) below). - -| Guardrail | What it detects in tool output | -|-----------|-------------------------------| -| `no_secrets` | API keys, tokens, private keys (same patterns as outbound message scanning) | -| `no_pii` | Email addresses, phone numbers, SSNs | +The guardrail engine scans tool output via an `AfterToolExec` hook, catching secrets and PII before they enter the LLM context or outbound messages. The engine calls the library's `OutputGate` with tool metadata attached. **Behavior by mode:** | Mode | Behavior | |------|----------| -| `enforce` | Returns an error identifying the guardrail that triggered (e.g., `"tool output blocked by no_pii guardrail (PII detected in output)"`), blocking the result from entering the LLM context. | -| `warn` | Replaces matched patterns with `[REDACTED]`, logs a warning, and allows the redacted output through | - -The hook writes the redacted text back to `HookContext.ToolOutput`, which the agent loop reads after all hooks fire. This is backwards-compatible — existing hooks that don't modify `ToolOutput` leave it unchanged. - -### Per-Tool PII Exemptions +| `enforce` | Returns an error identifying the violation, blocking the result from entering the LLM context | +| `warn` | Replaces matched patterns with masked content, logs a warning, allows through | -Some tools legitimately return PII as part of their function (e.g., `github_get_user` returning public email addresses). The `allow_tools` config option lets specific tools bypass a guardrail entirely. - -```json -{ - "guardrails": [ - { - "type": "no_pii", - "config": { - "allow_tools": [ - "github_get_user", - "github_pr_author_profiles", - "github_stargazer_profiles", - "file_create", - "code_agent_write", - "code_agent_edit", - "cli_execute", - "web_search" - ] - } - } - ] -} -``` - -**Key behaviors:** - -| Behavior | Detail | -|----------|--------| -| Per-guardrail scope | `allow_tools` on `no_pii` does **not** bypass `no_secrets` — each guardrail has its own allowlist | -| Write tools included | `file_create`, `code_agent_write`, `code_agent_edit`, and `cli_execute` are included because they echo back content the LLM already has or return operational output that may contain incidental PII (e.g., git log author emails) | -| Web search included | `web_search` is included because search results routinely contain names, emails, and other PII that is public web content — blocking these results would make Q&A conversations unusable | -| Default config | The default policy scaffold pre-configures `allow_tools` for GitHub profile tools and write tools | -| Custom overrides | Override via `policy-scaffold.json` to add or remove tools from the allowlist | +The hook writes the redacted text back to `HookContext.ToolOutput`, which the agent loop reads after all hooks fire. ## Path Containment @@ -241,9 +297,11 @@ The `cli_execute` tool blocks arguments containing `file://` URLs (case-insensit Guardrail evaluations are logged as structured audit events: ```json -{"ts":"2026-02-28T10:00:00Z","event":"guardrail_check","correlation_id":"a1b2c3d4","fields":{"guardrail":"no_pii","direction":"outbound","result":"blocked"}} +{"ts":"2026-02-28T10:00:00Z","event":"guardrail_check","correlation_id":"a1b2c3d4","fields":{"guardrail":"pii","direction":"inbound","result":"masked"}} ``` +In DB mode, the guardrails library writes audit records to MongoDB automatically when `EnableAudit` is set. + See [Security Overview](overview.md) for the full security architecture. --- diff --git a/forge-cli/build/dockerfile_stage.go b/forge-cli/build/dockerfile_stage.go index 08ea3f4..e61f344 100644 --- a/forge-cli/build/dockerfile_stage.go +++ b/forge-cli/build/dockerfile_stage.go @@ -140,7 +140,7 @@ func (s *DockerfileStage) copyProjectSources(bc *pipeline.BuildContext) error { outDir := bc.Opts.OutputDir // Individual files to copy - filesToCopy := []string{"forge.yaml"} + filesToCopy := []string{"forge.yaml", "guardrails.json"} // Include channel config files (e.g. slack-config.yaml, telegram-config.yaml) if bc.Config != nil { for _, ch := range bc.Config.Channels { diff --git a/forge-cli/build/skills_stage.go b/forge-cli/build/skills_stage.go index e349737..9ba117f 100644 --- a/forge-cli/build/skills_stage.go +++ b/forge-cli/build/skills_stage.go @@ -28,17 +28,17 @@ func (s *SkillsStage) Execute(ctx context.Context, bc *pipeline.BuildContext) er skillsPath = filepath.Join(bc.Opts.WorkDir, skillsPath) } - // Skip silently if not found - if _, err := os.Stat(skillsPath); os.IsNotExist(err) { - return nil - } - - entries, _, err := cliskills.ParseFileWithMetadata(skillsPath) - if err != nil { - return fmt.Errorf("parsing skills file: %w", err) + // Parse root skills file if it exists + var entries []contract.SkillEntry + if _, err := os.Stat(skillsPath); err == nil { + parsed, _, parseErr := cliskills.ParseFileWithMetadata(skillsPath) + if parseErr != nil { + return fmt.Errorf("parsing skills file: %w", parseErr) + } + entries = parsed } - // Scan skills/ subdirectory for additional SKILL.md files + // Always scan skills/ subdirectory (skills may exist without root SKILL.md) skillsSubDir := filepath.Join(bc.Opts.WorkDir, "skills") subEntries, subErr := scanSkillsSubDir(skillsSubDir) if subErr != nil { diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index d437e81..0f4a4c3 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -883,6 +883,7 @@ func getFileManifest(opts *initOptions) []fileToRender { {TemplatePath: "forge.yaml.tmpl", OutputPath: "forge.yaml"}, {TemplatePath: "env.example.tmpl", OutputPath: ".env.example"}, {TemplatePath: "gitignore.tmpl", OutputPath: ".gitignore"}, + {TemplatePath: "guardrails.json.tmpl", OutputPath: "guardrails.json"}, } switch opts.Framework { diff --git a/forge-cli/go.mod b/forge-cli/go.mod index aca359f..f434380 100644 --- a/forge-cli/go.mod +++ b/forge-cli/go.mod @@ -26,21 +26,31 @@ require ( github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.5.0 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/initializ/guardrails v0.12.0 // indirect + github.com/klauspost/compress v1.16.7 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/montanaflynn/stats v0.7.1 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/pflag v1.0.9 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect + go.mongodb.org/mongo-driver v1.17.3 // indirect golang.org/x/crypto v0.48.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect golang.org/x/time v0.15.0 // indirect diff --git a/forge-cli/go.sum b/forge-cli/go.sum index 50774e9..d856e76 100644 --- a/forge-cli/go.sum +++ b/forge-cli/go.sum @@ -27,8 +27,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/initializ/guardrails v0.12.0 h1:YzScnl+YPihLA1gepjxKjnKH+2EbYmGC4dUtpVRAJ2c= +github.com/initializ/guardrails v0.12.0/go.mod h1:bDdHx73MF0+O09KqoXmUTTiFG4H7yEVQ0NR6juP1F3Q= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -37,6 +43,8 @@ github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2J github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= +github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= @@ -55,6 +63,12 @@ github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= @@ -63,21 +77,51 @@ github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17 github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver v1.17.3 h1:TQyXhnsWfWtgAhMtOgtYHMTkZIfBTpMTsMnd9ZBeHxQ= +go.mongodb.org/mongo-driver v1.17.3/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/forge-cli/runtime/guardrails_engine.go b/forge-cli/runtime/guardrails_engine.go new file mode 100644 index 0000000..61c131f --- /dev/null +++ b/forge-cli/runtime/guardrails_engine.go @@ -0,0 +1,240 @@ +package runtime + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/initializ/guardrails" + "github.com/initializ/guardrails/models" + + "github.com/initializ/forge/forge-core/a2a" + coreruntime "github.com/initializ/forge/forge-core/runtime" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +// LibraryGuardrailEngine implements coreruntime.GuardrailChecker using the +// github.com/initializ/guardrails library. It supports two modes: +// - File mode: uses StructuredGuardrails loaded from guardrails.json +// - DB mode: loads config from MongoDB (set via FORGE_GUARDRAILS_DB env) +type LibraryGuardrailEngine struct { + manager *guardrails.GuardrailManager + structured *models.StructuredGuardrails + enforce bool + useDB bool + agentID string + orgID string + configVersion int64 + logger coreruntime.Logger +} + +// NewFileGuardrailEngine creates a guardrail engine backed by a local +// StructuredGuardrails config (loaded from guardrails.json). +func NewFileGuardrailEngine(sg *models.StructuredGuardrails, enforce bool, logger coreruntime.Logger) (*LibraryGuardrailEngine, error) { + mgr, err := guardrails.NewGuardrailManager(guardrails.Config{}) + if err != nil { + return nil, fmt.Errorf("creating guardrail manager: %w", err) + } + return &LibraryGuardrailEngine{ + manager: mgr, + structured: sg, + enforce: enforce, + logger: logger, + }, nil +} + +// NewDBGuardrailEngine creates a guardrail engine backed by MongoDB. +// Config is loaded from the AgentConfig collection; audit logging is enabled. +func NewDBGuardrailEngine(mongoURI, agentID, orgID string, enforce bool, logger coreruntime.Logger) (*LibraryGuardrailEngine, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI)) + if err != nil { + return nil, fmt.Errorf("connecting to guardrails DB: %w", err) + } + + // Verify connectivity + if err := client.Ping(ctx, nil); err != nil { + return nil, fmt.Errorf("pinging guardrails DB: %w", err) + } + + mgr, err := guardrails.NewGuardrailManager(guardrails.Config{ + MongoClient: client, + DatabaseName: "Initializ", + EnableAudit: true, + }) + if err != nil { + return nil, fmt.Errorf("creating guardrail manager with DB: %w", err) + } + + return &LibraryGuardrailEngine{ + manager: mgr, + enforce: enforce, + useDB: true, + agentID: agentID, + orgID: orgID, + logger: logger, + }, nil +} + +// structuredIfFileMode returns the StructuredGuardrails pointer only in file +// mode. In DB mode the library loads config from MongoDB automatically. +func (e *LibraryGuardrailEngine) structuredIfFileMode() *models.StructuredGuardrails { + if e.useDB { + return nil + } + return e.structured +} + +// CheckInbound validates an inbound (user) message via the library's InputGate. +func (e *LibraryGuardrailEngine) CheckInbound(msg *a2a.Message) error { + text := coreruntime.ExtractText(msg) + if text == "" { + return nil + } + + result, err := e.manager.InputGate(context.Background(), guardrails.InputRequest{ + Content: text, + EntityID: e.agentID, + OrgID: e.orgID, + EntityType: guardrails.EntityTypeAgent, + StructuredGuardrails: e.structuredIfFileMode(), + ConfigVersion: e.configVersion, + }) + if err != nil { + e.logger.Warn("guardrail input gate error", map[string]any{"error": err.Error()}) + // On library error, allow request through (fail-open) + return nil + } + + switch result.Decision { + case guardrails.DecisionMask: + if result.MaskedContent != "" { + for i := range msg.Parts { + if msg.Parts[i].Kind == a2a.PartKindText && msg.Parts[i].Text != "" { + msg.Parts[i].Text = result.MaskedContent + } + } + e.logger.Info("inbound guardrail redaction applied", map[string]any{ + "direction": "inbound", + }) + } + case guardrails.DecisionBlock: + desc := violationSummary(result) + if e.enforce { + return fmt.Errorf("input blocked: %s", desc) + } + e.logger.Warn("guardrail input violation (warn mode)", map[string]any{ + "direction": "inbound", + "detail": desc, + }) + } + return nil +} + +// CheckOutbound validates an outbound (agent) message via the library's OutputGate. +// Masked content is applied in-place; blocked content returns an error only in enforce mode. +func (e *LibraryGuardrailEngine) CheckOutbound(msg *a2a.Message) error { + for i, p := range msg.Parts { + if p.Kind != a2a.PartKindText || p.Text == "" { + continue + } + + result, err := e.manager.OutputGate(context.Background(), guardrails.OutputRequest{ + Content: p.Text, + EntityID: e.agentID, + OrgID: e.orgID, + EntityType: guardrails.EntityTypeAgent, + StructuredGuardrails: e.structuredIfFileMode(), + ConfigVersion: e.configVersion, + }) + if err != nil { + e.logger.Warn("guardrail output gate error", map[string]any{"error": err.Error()}) + continue + } + + switch result.Decision { + case guardrails.DecisionMask: + if result.MaskedContent != "" { + msg.Parts[i].Text = result.MaskedContent + e.logger.Warn("outbound guardrail redaction applied", map[string]any{ + "direction": "outbound", + }) + } + case guardrails.DecisionBlock: + desc := violationSummary(result) + if e.enforce { + return fmt.Errorf("output blocked: %s", desc) + } + e.logger.Warn("guardrail output violation (warn mode)", map[string]any{ + "direction": "outbound", + "detail": desc, + }) + } + } + return nil +} + +// CheckToolOutput scans tool output text via the library's OutputGate. +// Returns the (possibly masked) text and any blocking error. +func (e *LibraryGuardrailEngine) CheckToolOutput(toolName, text string) (string, error) { + if text == "" { + return text, nil + } + + result, err := e.manager.OutputGate(context.Background(), guardrails.OutputRequest{ + Content: text, + EntityID: e.agentID, + OrgID: e.orgID, + EntityType: guardrails.EntityTypeAgent, + StructuredGuardrails: e.structuredIfFileMode(), + ConfigVersion: e.configVersion, + Metadata: map[string]interface{}{"tool_name": toolName}, + }) + if err != nil { + e.logger.Warn("guardrail tool output gate error", map[string]any{ + "tool": toolName, + "error": err.Error(), + }) + return text, nil + } + + switch result.Decision { + case guardrails.DecisionMask: + if result.MaskedContent != "" { + e.logger.Warn("guardrail redaction", map[string]any{ + "direction": "tool_output", + "tool": toolName, + "detail": "content redacted", + }) + return result.MaskedContent, nil + } + case guardrails.DecisionBlock: + desc := violationSummary(result) + if e.enforce { + return "", fmt.Errorf("tool output blocked: %s", desc) + } + e.logger.Warn("guardrail tool output violation (warn mode)", map[string]any{ + "direction": "tool_output", + "tool": toolName, + "detail": desc, + }) + } + + return text, nil +} + +// violationSummary builds a human-readable summary from result violations. +func violationSummary(r *guardrails.Result) string { + if len(r.Violations) == 0 { + return string(r.Decision) + } + var parts []string + for _, v := range r.Violations { + parts = append(parts, v.Description) + } + return strings.Join(parts, "; ") +} diff --git a/forge-cli/runtime/guardrails_engine_test.go b/forge-cli/runtime/guardrails_engine_test.go new file mode 100644 index 0000000..607fd97 --- /dev/null +++ b/forge-cli/runtime/guardrails_engine_test.go @@ -0,0 +1,186 @@ +package runtime + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/initializ/guardrails/models" + + "github.com/initializ/forge/forge-core/a2a" + coreruntime "github.com/initializ/forge/forge-core/runtime" +) + +// testLogger is a no-op logger for tests. +type grTestLogger struct{} + +func (l *grTestLogger) Info(_ string, _ map[string]any) {} +func (l *grTestLogger) Debug(_ string, _ map[string]any) {} +func (l *grTestLogger) Warn(_ string, _ map[string]any) {} +func (l *grTestLogger) Error(_ string, _ map[string]any) {} + +// TestLibraryGuardrailEngine_ImplementsInterface verifies the engine +// satisfies the GuardrailChecker interface. +func TestLibraryGuardrailEngine_ImplementsInterface(t *testing.T) { + sg := DefaultStructuredGuardrails() + engine, err := NewFileGuardrailEngine(sg, false, &grTestLogger{}) + if err != nil { + t.Fatalf("NewFileGuardrailEngine() error: %v", err) + } + var _ coreruntime.GuardrailChecker = engine +} + +// TestFileGuardrailEngine_CheckInbound tests basic inbound checking. +func TestFileGuardrailEngine_CheckInbound(t *testing.T) { + sg := DefaultStructuredGuardrails() + engine, err := NewFileGuardrailEngine(sg, true, &grTestLogger{}) + if err != nil { + t.Fatalf("NewFileGuardrailEngine() error: %v", err) + } + + // Normal message should pass + msg := &a2a.Message{ + Role: "user", + Parts: []a2a.Part{{Kind: a2a.PartKindText, Text: "Hello, how are you?"}}, + } + if err := engine.CheckInbound(msg); err != nil { + t.Errorf("normal message should pass inbound check: %v", err) + } + + // Empty message should pass + emptyMsg := &a2a.Message{Role: "user"} + if err := engine.CheckInbound(emptyMsg); err != nil { + t.Errorf("empty message should pass inbound check: %v", err) + } +} + +// TestFileGuardrailEngine_CheckOutbound tests outbound content handling. +func TestFileGuardrailEngine_CheckOutbound(t *testing.T) { + sg := DefaultStructuredGuardrails() + engine, err := NewFileGuardrailEngine(sg, false, &grTestLogger{}) + if err != nil { + t.Fatalf("NewFileGuardrailEngine() error: %v", err) + } + + // Normal message should pass through unchanged + msg := &a2a.Message{ + Role: "agent", + Parts: []a2a.Part{{Kind: a2a.PartKindText, Text: "Here is the result."}}, + } + if err := engine.CheckOutbound(msg); err != nil { + t.Errorf("normal message should pass outbound check: %v", err) + } +} + +// TestFileGuardrailEngine_CheckToolOutput tests tool output scanning. +func TestFileGuardrailEngine_CheckToolOutput(t *testing.T) { + sg := DefaultStructuredGuardrails() + engine, err := NewFileGuardrailEngine(sg, false, &grTestLogger{}) + if err != nil { + t.Fatalf("NewFileGuardrailEngine() error: %v", err) + } + + // Normal text should pass through + out, err := engine.CheckToolOutput("some_tool", "some normal output") + if err != nil { + t.Errorf("normal output should pass: %v", err) + } + if out != "some normal output" { + t.Errorf("normal output should not be modified, got %q", out) + } + + // Empty text should pass through + out, err = engine.CheckToolOutput("some_tool", "") + if err != nil { + t.Errorf("empty output should pass: %v", err) + } + if out != "" { + t.Errorf("empty output should remain empty, got %q", out) + } +} + +// TestBuildGuardrailChecker_FileMode tests the builder with file-based config. +func TestBuildGuardrailChecker_FileMode(t *testing.T) { + logger := &grTestLogger{} + checker := BuildGuardrailChecker(nil, "/nonexistent", false, logger) + if checker == nil { + t.Fatal("BuildGuardrailChecker should return a non-nil checker") + } + + // Should still work (uses defaults) + msg := &a2a.Message{ + Role: "user", + Parts: []a2a.Part{{Kind: a2a.PartKindText, Text: "hello"}}, + } + if err := checker.CheckInbound(msg); err != nil { + t.Errorf("default checker should pass normal message: %v", err) + } +} + +// TestLoadGuardrailsJSON tests parsing a guardrails.json file. +func TestLoadGuardrailsJSON(t *testing.T) { + dir := t.TempDir() + sg := &models.StructuredGuardrails{ + PII: &models.PIIConfig{ + Enabled: true, + Action: "mask", + Categories: map[string]models.PIICategoryConfig{ + "email": {Enabled: true, Action: "mask"}, + }, + }, + GateConfig: &models.GateConfig{ + InputGate: true, + OutputGate: true, + }, + } + + data, err := json.MarshalIndent(sg, "", " ") + if err != nil { + t.Fatalf("marshaling test config: %v", err) + } + + if err := os.WriteFile(filepath.Join(dir, "guardrails.json"), data, 0o644); err != nil { + t.Fatalf("writing test guardrails.json: %v", err) + } + + loaded := LoadGuardrailsJSON(nil, dir) + if loaded == nil { + t.Fatal("LoadGuardrailsJSON returned nil for existing file") + } + if loaded.PII == nil || !loaded.PII.Enabled { + t.Error("expected PII to be enabled in loaded config") + } + if loaded.GateConfig == nil || !loaded.GateConfig.InputGate { + t.Error("expected InputGate to be enabled in loaded config") + } +} + +// TestLoadGuardrailsJSON_Missing tests loading when file doesn't exist. +func TestLoadGuardrailsJSON_Missing(t *testing.T) { + loaded := LoadGuardrailsJSON(nil, "/nonexistent") + if loaded != nil { + t.Error("LoadGuardrailsJSON should return nil for missing file") + } +} + +// TestDefaultStructuredGuardrails tests the default config has expected sections. +func TestDefaultStructuredGuardrails(t *testing.T) { + sg := DefaultStructuredGuardrails() + + if sg.PII == nil || !sg.PII.Enabled { + t.Error("default should have PII enabled") + } + if len(sg.PII.Categories) != 4 { + t.Errorf("default PII should have 4 categories, got %d", len(sg.PII.Categories)) + } + if sg.Security == nil || sg.Security.JailbreakDetection == nil { + t.Error("default should have jailbreak detection") + } + if sg.CustomRules == nil || len(sg.CustomRules.Rules) != 11 { + t.Errorf("default should have 11 secret rules, got %d", len(sg.CustomRules.Rules)) + } + if sg.GateConfig == nil || !sg.GateConfig.InputGate || !sg.GateConfig.OutputGate { + t.Error("default should have input and output gates enabled") + } +} diff --git a/forge-cli/runtime/guardrails_loader.go b/forge-cli/runtime/guardrails_loader.go index 2b86ac9..6959deb 100644 --- a/forge-cli/runtime/guardrails_loader.go +++ b/forge-cli/runtime/guardrails_loader.go @@ -6,11 +6,16 @@ import ( "os" "path/filepath" + "github.com/initializ/guardrails/models" + "github.com/initializ/forge/forge-core/agentspec" + coreruntime "github.com/initializ/forge/forge-core/runtime" + "github.com/initializ/forge/forge-core/types" ) // LoadPolicyScaffold reads policy-scaffold.json from the output directory. // Returns nil (no error) if the file does not exist. +// Kept for SkillGuardrails loading (separate concern from main guardrails). func LoadPolicyScaffold(workDir string) (*agentspec.PolicyScaffold, error) { path := filepath.Join(workDir, ".forge-output", "policy-scaffold.json") data, err := os.ReadFile(path) @@ -27,32 +32,123 @@ func LoadPolicyScaffold(workDir string) (*agentspec.PolicyScaffold, error) { return &ps, nil } -// DefaultPolicyScaffold returns a scaffold with all built-in guardrails enabled. -// Used when no policy-scaffold.json exists (e.g. running without forge build). +// DefaultPolicyScaffold returns a scaffold for SkillGuardrails only. +// The main guardrail checks are now handled by BuildGuardrailChecker. func DefaultPolicyScaffold() *agentspec.PolicyScaffold { - return &agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{ - { - Type: "content_filter", - Config: map[string]any{"enabled": true}, + return &agentspec.PolicyScaffold{} +} + +// BuildGuardrailChecker creates the guardrail engine based on configuration. +// Priority: FORGE_GUARDRAILS_DB env → guardrails.json file → defaults. +func BuildGuardrailChecker(cfg *types.ForgeConfig, workDir string, enforce bool, logger coreruntime.Logger) coreruntime.GuardrailChecker { + // DB mode: connect to MongoDB for config + audit + if mongoURI := os.Getenv("FORGE_GUARDRAILS_DB"); mongoURI != "" { + agentID := os.Getenv("FORGE_AGENT_ID") + if agentID == "" && cfg != nil { + agentID = cfg.AgentID + } + orgID := os.Getenv("FORGE_ORG_ID") + engine, err := NewDBGuardrailEngine(mongoURI, agentID, orgID, enforce, logger) + if err == nil { + logger.Info("guardrails: using MongoDB-backed config", map[string]any{ + "agent_id": agentID, + }) + return engine + } + logger.Warn("failed to connect guardrails DB, falling back to file", map[string]any{ + "error": err.Error(), + }) + } + + // File mode: load from guardrails.json + sg := LoadGuardrailsJSON(cfg, workDir) + if sg == nil { + sg = DefaultStructuredGuardrails() + } + + engine, err := NewFileGuardrailEngine(sg, enforce, logger) + if err != nil { + logger.Warn("failed to create file guardrail engine, using noop", map[string]any{ + "error": err.Error(), + }) + return &coreruntime.NoopGuardrailChecker{} + } + return engine +} + +// LoadGuardrailsJSON reads guardrails.json from the project directory. +// Returns nil if the file does not exist. +func LoadGuardrailsJSON(cfg *types.ForgeConfig, workDir string) *models.StructuredGuardrails { + filename := "guardrails.json" + if cfg != nil && cfg.GuardrailsPath != "" { + filename = cfg.GuardrailsPath + } + + path := filepath.Join(workDir, filename) + data, err := os.ReadFile(path) + if err != nil { + return nil + } + + var sg models.StructuredGuardrails + if err := json.Unmarshal(data, &sg); err != nil { + return nil + } + return &sg +} + +// DefaultStructuredGuardrails returns default guardrails matching the +// previously built-in patterns (PII, jailbreak, secrets). +func DefaultStructuredGuardrails() *models.StructuredGuardrails { + return &models.StructuredGuardrails{ + PII: &models.PIIConfig{ + Enabled: true, + Action: "mask", + Categories: map[string]models.PIICategoryConfig{ + "email": {Enabled: true, Action: "mask"}, + "phoneNumber": {Enabled: true, Action: "mask"}, + "ssn": {Enabled: true, Action: "mask"}, + "creditCard": {Enabled: true, Action: "mask"}, + }, + }, + Security: &models.SecurityConfig{ + JailbreakDetection: &models.ThresholdConfig{ + Enabled: true, + ConfidenceThreshold: 25, + Action: "block", + }, + PromptInjection: &models.ThresholdConfig{ + Enabled: true, + ConfidenceThreshold: 30, + Action: "block", }, - { - Type: "no_pii", - Config: map[string]any{ - "allow_tools": []any{ - "github_get_user", - "github_pr_author_profiles", - "github_stargazer_profiles", - "file_create", - "code_agent_write", - "code_agent_edit", - "cli_execute", - "web_search", - }, - }, + CommandInjection: &models.ThresholdConfig{ + Enabled: true, + ConfidenceThreshold: 35, + Action: "block", }, - {Type: "jailbreak_protection"}, - {Type: "no_secrets"}, + }, + CustomRules: &models.CustomRulesConfig{ + Rules: []models.CustomRule{ + {ID: "secret_anthropic", Name: "Anthropic API Key", Type: "regex", Constraint: "hard", Pattern: `sk-ant-[A-Za-z0-9\-]{20,}`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_openai", Name: "OpenAI API Key", Type: "regex", Constraint: "hard", Pattern: `sk-[A-Za-z0-9]{20,}`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_github_pat", Name: "GitHub PAT", Type: "regex", Constraint: "hard", Pattern: `ghp_[A-Za-z0-9]{36}`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_github_oauth", Name: "GitHub OAuth", Type: "regex", Constraint: "hard", Pattern: `gho_[A-Za-z0-9]{36}`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_github_server", Name: "GitHub Server Token", Type: "regex", Constraint: "hard", Pattern: `ghs_[A-Za-z0-9]{36}`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_github_fine", Name: "GitHub Fine-grained PAT", Type: "regex", Constraint: "hard", Pattern: `github_pat_[A-Za-z0-9_]{22,}`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_aws", Name: "AWS Access Key", Type: "regex", Constraint: "hard", Pattern: `AKIA[0-9A-Z]{16}`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_slack_bot", Name: "Slack Bot Token", Type: "regex", Constraint: "hard", Pattern: `xoxb-[0-9]{10,}-[A-Za-z0-9-]+`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_slack_user", Name: "Slack User Token", Type: "regex", Constraint: "hard", Pattern: `xoxp-[0-9]{10,}-[A-Za-z0-9-]+`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_private_key", Name: "Private Key", Type: "regex", Constraint: "hard", Pattern: `-----BEGIN (RSA|EC|OPENSSH|PRIVATE) .*KEY-----`, Action: "mask", Gates: []string{"output", "tool_call"}}, + {ID: "secret_telegram", Name: "Telegram Bot Token", Type: "regex", Constraint: "hard", Pattern: `[0-9]{8,10}:[A-Za-z0-9_-]{35,}`, Action: "mask", Gates: []string{"output", "tool_call"}}, + }, + }, + GateConfig: &models.GateConfig{ + InputGate: true, + ToolCallGate: true, + OutputGate: true, + ContextGate: false, + StreamGate: false, }, } } diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 710833f..7ec6a67 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -205,15 +205,17 @@ func (r *Runner) Run(ctx context.Context) error { return err } - // 2. Load policy scaffold (fall back to built-in defaults) + // 2. Build guardrail checker (DB mode → file mode → defaults) + guardrails := BuildGuardrailChecker(r.cfg.Config, r.cfg.WorkDir, r.cfg.EnforceGuardrails, r.logger) + + // Still load scaffold for SkillGuardrails (separate concern) scaffold, err := LoadPolicyScaffold(r.cfg.WorkDir) if err != nil { r.logger.Warn("failed to load policy scaffold", map[string]any{"error": err.Error()}) } - if scaffold == nil || len(scaffold.Guardrails) == 0 { + if scaffold == nil { scaffold = DefaultPolicyScaffold() } - guardrails := coreruntime.NewGuardrailEngine(scaffold, r.cfg.EnforceGuardrails, r.logger) // 3. Build agent card card, err := BuildAgentCard(r.cfg.WorkDir, r.cfg.Config, r.cfg.Port) @@ -693,7 +695,7 @@ func (r *Runner) Run(ctx context.Context) error { return srv.Start(ctx) } -func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.AgentExecutor, guardrails *coreruntime.GuardrailEngine, egressClient *http.Client, auditLogger *coreruntime.AuditLogger) { +func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.AgentExecutor, guardrails coreruntime.GuardrailChecker, egressClient *http.Client, auditLogger *coreruntime.AuditLogger) { store := srv.TaskStore() // tasks/send — synchronous request @@ -1004,7 +1006,7 @@ func (r *Runner) executeTask( params a2a.SendTaskParams, store *a2a.TaskStore, executor coreruntime.AgentExecutor, - guardrails *coreruntime.GuardrailEngine, + guardrails coreruntime.GuardrailChecker, egressClient *http.Client, auditLogger *coreruntime.AuditLogger, ) (*a2a.Task, error) { @@ -1124,7 +1126,7 @@ type restTaskRequest struct { } // registerRESTHandlers registers REST-style HTTP endpoints on the server. -func (r *Runner) registerRESTHandlers(srv *server.Server, executor coreruntime.AgentExecutor, guardrails *coreruntime.GuardrailEngine, egressClient *http.Client, auditLogger *coreruntime.AuditLogger) { +func (r *Runner) registerRESTHandlers(srv *server.Server, executor coreruntime.AgentExecutor, guardrails coreruntime.GuardrailChecker, egressClient *http.Client, auditLogger *coreruntime.AuditLogger) { store := srv.TaskStore() // POST /tasks/send — synchronous REST endpoint @@ -1506,7 +1508,7 @@ func (r *Runner) registerProgressHooks(hooks *coreruntime.HookRegistry) { // registerGuardrailHooks registers an AfterToolExec hook that scans tool output // for secrets and PII, redacting or blocking based on guardrail mode. -func (r *Runner) registerGuardrailHooks(hooks *coreruntime.HookRegistry, guardrails *coreruntime.GuardrailEngine) { +func (r *Runner) registerGuardrailHooks(hooks *coreruntime.HookRegistry, guardrails coreruntime.GuardrailChecker) { hooks.Register(coreruntime.AfterToolExec, func(_ context.Context, hctx *coreruntime.HookContext) error { if hctx.ToolOutput == "" { return nil diff --git a/forge-cli/templates/init/gitignore.tmpl b/forge-cli/templates/init/gitignore.tmpl index f9b63dd..e21721b 100644 --- a/forge-cli/templates/init/gitignore.tmpl +++ b/forge-cli/templates/init/gitignore.tmpl @@ -1,5 +1,6 @@ .env .forge/ +guardrails.json __pycache__/ *.pyc node_modules/ diff --git a/forge-cli/templates/init/guardrails.json.tmpl b/forge-cli/templates/init/guardrails.json.tmpl new file mode 100644 index 0000000..8222e21 --- /dev/null +++ b/forge-cli/templates/init/guardrails.json.tmpl @@ -0,0 +1,39 @@ +{ + "pii": { + "enabled": true, + "action": "mask", + "categories": { + "email": {"enabled": true, "action": "mask"}, + "phoneNumber": {"enabled": true, "action": "mask"}, + "ssn": {"enabled": true, "action": "mask"}, + "creditCard": {"enabled": true, "action": "mask"} + } + }, + "security": { + "jailbreakDetection": {"enabled": true, "confidenceThreshold": 25, "action": "block"}, + "promptInjection": {"enabled": true, "confidenceThreshold": 30, "action": "block"}, + "commandInjection": {"enabled": true, "confidenceThreshold": 35, "action": "block"} + }, + "customRules": { + "rules": [ + {"id": "secret_anthropic", "name": "Anthropic API Key", "type": "regex", "constraint": "hard", "pattern": "sk-ant-[A-Za-z0-9\\-]{20,}", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_openai", "name": "OpenAI API Key", "type": "regex", "constraint": "hard", "pattern": "sk-[A-Za-z0-9]{20,}", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_github_pat", "name": "GitHub PAT", "type": "regex", "constraint": "hard", "pattern": "ghp_[A-Za-z0-9]{36}", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_github_oauth", "name": "GitHub OAuth", "type": "regex", "constraint": "hard", "pattern": "gho_[A-Za-z0-9]{36}", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_github_server", "name": "GitHub Server Token", "type": "regex", "constraint": "hard", "pattern": "ghs_[A-Za-z0-9]{36}", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_github_fine", "name": "GitHub Fine-grained PAT", "type": "regex", "constraint": "hard", "pattern": "github_pat_[A-Za-z0-9_]{22,}", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_aws", "name": "AWS Access Key", "type": "regex", "constraint": "hard", "pattern": "AKIA[0-9A-Z]{16}", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_slack_bot", "name": "Slack Bot Token", "type": "regex", "constraint": "hard", "pattern": "xoxb-[0-9]{10,}-[A-Za-z0-9-]+", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_slack_user", "name": "Slack User Token", "type": "regex", "constraint": "hard", "pattern": "xoxp-[0-9]{10,}-[A-Za-z0-9-]+", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_private_key", "name": "Private Key", "type": "regex", "constraint": "hard", "pattern": "-----BEGIN (RSA|EC|OPENSSH|PRIVATE) .*KEY-----", "action": "mask", "gates": ["output", "tool_call"]}, + {"id": "secret_telegram", "name": "Telegram Bot Token", "type": "regex", "constraint": "hard", "pattern": "[0-9]{8,10}:[A-Za-z0-9_-]{35,}", "action": "mask", "gates": ["output", "tool_call"]} + ] + }, + "gateConfig": { + "inputGate": true, + "toolCallGate": true, + "outputGate": true, + "contextGate": false, + "streamGate": false + } +} diff --git a/forge-core/forgecore.go b/forge-core/forgecore.go index 38a4856..317a826 100644 --- a/forge-core/forgecore.go +++ b/forge-core/forgecore.go @@ -120,7 +120,7 @@ type RuntimeConfig struct { Hooks *runtime.HookRegistry SystemPrompt string MaxIterations int - Guardrails *runtime.GuardrailEngine // optional + Guardrails runtime.GuardrailChecker // optional Logger runtime.Logger // optional } diff --git a/forge-core/runtime/guardrails.go b/forge-core/runtime/guardrails.go index c374d08..3a45830 100644 --- a/forge-core/runtime/guardrails.go +++ b/forge-core/runtime/guardrails.go @@ -1,139 +1,38 @@ package runtime import ( - "fmt" - "regexp" "strings" - "unicode" "github.com/initializ/forge/forge-core/a2a" - "github.com/initializ/forge/forge-core/agentspec" ) -// GuardrailEngine checks inbound and outbound messages against policy rules. -type GuardrailEngine struct { - scaffold *agentspec.PolicyScaffold - enforce bool - logger Logger -} +// GuardrailChecker validates messages and tool output against guardrail policies. +// Implementations may use file-based config, database-backed config, or no-op passthrough. +type GuardrailChecker interface { + // CheckInbound validates an inbound (user) message against guardrails. + CheckInbound(msg *a2a.Message) error -// NewGuardrailEngine creates a GuardrailEngine. If scaffold is nil, a default -// is used. When enforce is true, violations return errors; otherwise they are -// logged as warnings. -func NewGuardrailEngine(scaffold *agentspec.PolicyScaffold, enforce bool, logger Logger) *GuardrailEngine { - if scaffold == nil { - scaffold = &agentspec.PolicyScaffold{} - } - return &GuardrailEngine{scaffold: scaffold, enforce: enforce, logger: logger} -} + // CheckOutbound validates an outbound (agent) message against guardrails. + // Implementations should prefer redacting sensitive content over blocking. + CheckOutbound(msg *a2a.Message) error -// CheckInbound validates an inbound (user) message against guardrails. -func (g *GuardrailEngine) CheckInbound(msg *a2a.Message) error { - return g.check(msg, "inbound") + // CheckToolOutput scans tool output text against configured guardrails. + // Returns the (possibly redacted) text and any blocking error. + CheckToolOutput(toolName, text string) (string, error) } -// CheckOutbound validates an outbound (agent) message against guardrails. -// Unlike CheckInbound, outbound violations are always handled by redacting -// the offending content rather than blocking the entire response. Blocking -// throws away a potentially useful agent response (e.g., code analysis) over -// a false positive from broad PII/secret patterns matching source code. -func (g *GuardrailEngine) CheckOutbound(msg *a2a.Message) error { - for i, p := range msg.Parts { - if p.Kind != a2a.PartKindText || p.Text == "" { - continue - } - text := p.Text - redacted := false - - for _, gr := range g.scaffold.Guardrails { - switch gr.Type { - case "no_secrets": - for _, re := range secretPatterns { - if re.MatchString(text) { - text = re.ReplaceAllString(text, "[REDACTED]") - redacted = true - } - } - case "no_pii": - for _, p := range piiPatterns { - matches := p.re.FindAllString(text, -1) - for _, m := range matches { - if p.validate != nil && !p.validate(m) { - continue - } - text = strings.ReplaceAll(text, m, "[REDACTED]") - redacted = true - } - } - case "content_filter": - // Content filter: redact blocked words inline. - if gr.Config != nil { - if words, ok := gr.Config["blocked_words"]; ok { - if list, ok := words.([]any); ok { - lower := strings.ToLower(text) - for _, w := range list { - if s, ok := w.(string); ok && strings.Contains(lower, strings.ToLower(s)) { - text = strings.ReplaceAll(text, s, "[BLOCKED]") - redacted = true - } - } - } - } - } - } - } +// NoopGuardrailChecker is a passthrough implementation that performs no checks. +// Used as a fallback when no guardrail configuration is available. +type NoopGuardrailChecker struct{} - if redacted { - msg.Parts[i].Text = text - g.logger.Warn("outbound guardrail redaction applied", map[string]any{ - "direction": "outbound", - }) - } - } - return nil -} - -func (g *GuardrailEngine) check(msg *a2a.Message, direction string) error { - text := extractText(msg) - if text == "" { - return nil - } - - for _, gr := range g.scaffold.Guardrails { - var err error - switch gr.Type { - case "content_filter": - err = g.checkContentFilter(text, gr) - case "no_pii": - if direction == "outbound" { - err = g.checkNoPII(text) - } - case "jailbreak_protection": - if direction == "inbound" { - err = g.checkJailbreak(text) - } - case "no_secrets": - if direction == "outbound" { - err = g.checkNoSecrets(text) - } - default: - continue - } - if err != nil { - if g.enforce { - return fmt.Errorf("guardrail %s (%s): %w", gr.Type, direction, err) - } - g.logger.Warn("guardrail violation", map[string]any{ - "guardrail": gr.Type, - "direction": direction, - "detail": err.Error(), - }) - } - } - return nil +func (n *NoopGuardrailChecker) CheckInbound(_ *a2a.Message) error { return nil } +func (n *NoopGuardrailChecker) CheckOutbound(_ *a2a.Message) error { return nil } +func (n *NoopGuardrailChecker) CheckToolOutput(_ string, text string) (string, error) { + return text, nil } -func extractText(msg *a2a.Message) string { +// ExtractText extracts all text parts from a message into a single string. +func ExtractText(msg *a2a.Message) string { var parts []string for _, p := range msg.Parts { if p.Kind == a2a.PartKindText && p.Text != "" { @@ -142,293 +41,3 @@ func extractText(msg *a2a.Message) string { } return strings.Join(parts, " ") } - -func (g *GuardrailEngine) checkContentFilter(text string, gr agentspec.Guardrail) error { - // Use blocked words from config, or defaults - blocked := []string{"BLOCKED_CONTENT"} - if gr.Config != nil { - if words, ok := gr.Config["blocked_words"]; ok { - if list, ok := words.([]any); ok { - blocked = blocked[:0] - for _, w := range list { - if s, ok := w.(string); ok { - blocked = append(blocked, s) - } - } - } - } - } - lower := strings.ToLower(text) - for _, word := range blocked { - if strings.Contains(lower, strings.ToLower(word)) { - return fmt.Errorf("content filter: blocked word %q detected", word) - } - } - return nil -} - -// piiCheckerPattern pairs a regex with an optional validator function. -// When a validator is present, regex matches are only considered true positives -// if the validator confirms the matched text (e.g., Luhn check for credit cards, -// structure validation for SSNs). This follows the pattern from the reference -// guardrails library to reduce false positives. -type piiCheckerPattern struct { - re *regexp.Regexp - validate func(string) bool // nil means regex match alone is sufficient -} - -// Credit card regex: Visa, Mastercard, Amex, Discover with optional separators. -var ccRegex = `\b(?:` + - `4[0-9]{3}[\s-]?[0-9]{4}[\s-]?[0-9]{4}[\s-]?[0-9]{1,4}|` + // Visa - `(?:5[1-5][0-9]{2}|222[1-9]|22[3-9][0-9]|2[3-6][0-9]{2}|27[01][0-9]|2720)[\s-]?[0-9]{4}[\s-]?[0-9]{4}[\s-]?[0-9]{4}|` + // Mastercard - `3[47][0-9]{2}[\s-]?[0-9]{6}[\s-]?[0-9]{5}|` + // Amex - `(?:6011|65[0-9]{2}|64[4-9][0-9])[\s-]?[0-9]{4}[\s-]?[0-9]{4}[\s-]?[0-9]{4}` + // Discover - `)\b` - -var piiPatterns = []piiCheckerPattern{ - {re: regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`)}, // email - {re: regexp.MustCompile(`\b(?:\+?1[-.\s])?\(?[2-9]\d{2}\)?[-.\s]\d{3}[-.\s]\d{4}\b`)}, // phone (area code 2-9, separators required) - {re: regexp.MustCompile(`\b\d{3}[-.\s]?\d{2}[-.\s]?\d{4}\b`), validate: validateSSN}, // SSN with structural validation - {re: regexp.MustCompile(ccRegex), validate: validateLuhn}, // credit card with Luhn check -} - -func (g *GuardrailEngine) checkNoPII(text string) error { - for _, p := range piiPatterns { - matches := p.re.FindAllString(text, -1) - for _, m := range matches { - if p.validate != nil && !p.validate(m) { - continue - } - return fmt.Errorf("PII pattern detected: %s", p.re.String()) - } - } - return nil -} - -var jailbreakPhrases = []string{ - "ignore previous instructions", - "ignore all instructions", - "disregard your instructions", - "forget your rules", - "you are now", - "act as if you have no restrictions", -} - -func (g *GuardrailEngine) checkJailbreak(text string) error { - lower := strings.ToLower(text) - for _, phrase := range jailbreakPhrases { - if strings.Contains(lower, phrase) { - return fmt.Errorf("jailbreak pattern detected: %q", phrase) - } - } - return nil -} - -var secretPatterns = []*regexp.Regexp{ - regexp.MustCompile(`sk-ant-[A-Za-z0-9\-]{20,}`), // Anthropic API keys - regexp.MustCompile(`sk-[A-Za-z0-9]{20,}`), // OpenAI API keys - regexp.MustCompile(`ghp_[A-Za-z0-9]{36}`), // GitHub PATs - regexp.MustCompile(`gho_[A-Za-z0-9]{36}`), // GitHub OAuth tokens - regexp.MustCompile(`ghs_[A-Za-z0-9]{36}`), // GitHub server tokens - regexp.MustCompile(`github_pat_[A-Za-z0-9_]{22,}`), // GitHub fine-grained PATs - regexp.MustCompile(`AKIA[0-9A-Z]{16}`), // AWS access key IDs - regexp.MustCompile(`xoxb-[0-9]{10,}-[A-Za-z0-9-]+`), // Slack bot tokens - regexp.MustCompile(`xoxp-[0-9]{10,}-[A-Za-z0-9-]+`), // Slack user tokens - regexp.MustCompile(`-----BEGIN (RSA|EC|OPENSSH|PRIVATE) .*KEY-----`), // Private keys - regexp.MustCompile(`[0-9]{8,10}:[A-Za-z0-9_-]{35,}`), // Telegram bot tokens -} - -func (g *GuardrailEngine) checkNoSecrets(text string) error { - for _, re := range secretPatterns { - if re.MatchString(text) { - return fmt.Errorf("potential secret or credential detected in output") - } - } - return nil -} - -// CheckToolOutput scans tool output text against configured guardrails -// (no_secrets and no_pii). Matches are always redacted rather than blocked, -// because tool outputs are internal (sent to the LLM, not the user) and -// blocking would kill the entire agent session. Search tools routinely find -// code containing API key patterns in test files, config examples, etc. -// -// The toolName parameter enables per-tool PII exemptions: if a guardrail's -// config contains "allow_tools" (a list of tool names), tools in that list -// skip the corresponding check. This lets tools like github_get_user return -// public profile data (emails, bios) without triggering PII blocks. -func (g *GuardrailEngine) CheckToolOutput(toolName, text string) (string, error) { - if text == "" { - return text, nil - } - - for _, gr := range g.scaffold.Guardrails { - // Check if this tool is in the guardrail's allow_tools list. - if g.toolAllowed(toolName, gr) { - continue - } - - switch gr.Type { - case "no_secrets": - for _, re := range secretPatterns { - if !re.MatchString(text) { - continue - } - if g.enforce { - return "", fmt.Errorf("tool output blocked by no_secrets guardrail (secret/credential detected in output)") - } - text = re.ReplaceAllString(text, "[REDACTED]") - g.logger.Warn("guardrail redaction", map[string]any{ - "guardrail": gr.Type, - "direction": "tool_output", - "detail": fmt.Sprintf("pattern %s matched, content redacted", re.String()), - }) - } - case "no_pii": - for _, p := range piiPatterns { - if !p.re.MatchString(text) { - continue - } - // Check if any match passes validation - hasValidMatch := false - if p.validate == nil { - hasValidMatch = true - } else { - for _, m := range p.re.FindAllString(text, -1) { - if p.validate(m) { - hasValidMatch = true - break - } - } - } - if !hasValidMatch { - continue - } - if g.enforce { - return "", fmt.Errorf("tool output blocked by no_pii guardrail (PII detected in output)") - } - // Warn mode: redact only validated matches - if p.validate != nil { - v := p.validate // capture for closure - text = p.re.ReplaceAllStringFunc(text, func(s string) string { - if v(s) { - return "[REDACTED]" - } - return s - }) - } else { - text = p.re.ReplaceAllString(text, "[REDACTED]") - } - g.logger.Warn("guardrail redaction", map[string]any{ - "guardrail": gr.Type, - "direction": "tool_output", - "detail": fmt.Sprintf("pattern %s matched, content redacted", p.re.String()), - }) - } - default: - continue - } - } - return text, nil -} - -// toolAllowed checks whether toolName is in the guardrail's "allow_tools" config list. -func (g *GuardrailEngine) toolAllowed(toolName string, gr agentspec.Guardrail) bool { - if toolName == "" || gr.Config == nil { - return false - } - allowRaw, ok := gr.Config["allow_tools"] - if !ok { - return false - } - list, ok := allowRaw.([]any) - if !ok { - return false - } - for _, v := range list { - if s, ok := v.(string); ok && s == toolName { - return true - } - } - return false -} - -// --- PII Validators --- -// Ported from the reference guardrails library to reduce false positives. - -// validateSSN validates a US Social Security Number structure. -// Rejects area=000/666/900+, group=00, serial=0000, all-same digits, and known test SSNs. -func validateSSN(s string) bool { - cleaned := strings.NewReplacer("-", "", " ", "", ".", "").Replace(s) - if len(cleaned) != 9 { - return false - } - for _, r := range cleaned { - if !unicode.IsDigit(r) { - return false - } - } - - area := cleaned[0:3] - group := cleaned[3:5] - serial := cleaned[5:9] - - if area == "000" || area == "666" || area[0] == '9' { - return false - } - if group == "00" { - return false - } - if serial == "0000" { - return false - } - - // All same digits - allSame := true - for i := 1; i < len(cleaned); i++ { - if cleaned[i] != cleaned[0] { - allSame = false - break - } - } - if allSame { - return false - } - - // Known test/advertising SSNs - testSSNs := map[string]bool{ - "078051120": true, - "219099999": true, - "123456789": true, - } - return !testSSNs[cleaned] -} - -// validateLuhn performs Luhn checksum validation on a credit card number. -// Strips separators (spaces, dashes) before validating. -func validateLuhn(s string) bool { - cleaned := strings.NewReplacer(" ", "", "-", "").Replace(s) - if len(cleaned) < 13 || len(cleaned) > 19 { - return false - } - for _, r := range cleaned { - if !unicode.IsDigit(r) { - return false - } - } - - sum := 0 - double := false - for i := len(cleaned) - 1; i >= 0; i-- { - digit := int(cleaned[i] - '0') - if double { - digit *= 2 - if digit > 9 { - digit -= 9 - } - } - sum += digit - double = !double - } - return sum%10 == 0 -} diff --git a/forge-core/runtime/guardrails_test.go b/forge-core/runtime/guardrails_test.go index 3f1ef39..8a4a259 100644 --- a/forge-core/runtime/guardrails_test.go +++ b/forge-core/runtime/guardrails_test.go @@ -1,387 +1,99 @@ package runtime import ( - "strings" "testing" "github.com/initializ/forge/forge-core/a2a" - "github.com/initializ/forge/forge-core/agentspec" ) -// --- Validator unit tests --- - -func TestValidateSSN(t *testing.T) { - tests := []struct { - name string - input string - want bool - }{ - {"known test SSN with dashes", "123-45-6789", false}, // 123456789 is a known test SSN - {"known test SSN no separators", "123456789", false}, - {"valid SSN dots", "456.78.9012", true}, - {"area 000", "000-12-3456", false}, - {"area 666", "666-12-3456", false}, - {"area 900+", "900-12-3456", false}, - {"area 999", "999-12-3456", false}, - {"group 00", "123-00-4567", false}, - {"serial 0000", "123-45-0000", false}, - {"all same digits", "111111111", false}, - {"all same digits 555", "555555555", false}, - {"known test SSN 078051120", "078051120", false}, - {"known test SSN 219099999", "219099999", false}, - {"too short", "12345678", false}, - {"too long", "1234567890", false}, - {"non-digit", "12a-45-6789", false}, - {"valid 456-78-9012", "456-78-9012", true}, - {"valid 321-54-9876", "321-54-9876", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := validateSSN(tt.input); got != tt.want { - t.Errorf("validateSSN(%q) = %v, want %v", tt.input, got, tt.want) - } - }) - } -} - -func TestValidateLuhn(t *testing.T) { - tests := []struct { - name string - input string - want bool - }{ - // Known valid test card numbers - {"Visa test", "4111111111111111", true}, - {"Visa with spaces", "4111 1111 1111 1111", true}, - {"Visa with dashes", "4111-1111-1111-1111", true}, - {"Mastercard test", "5500000000000004", true}, - {"Amex test", "378282246310005", true}, - {"Discover test", "6011111111111117", true}, - // Invalid - {"bad checksum", "4111111111111112", false}, - {"too short", "411111111111", false}, - {"too long", "41111111111111111111", false}, - {"non-digit", "4111abcd11111111", false}, - // Random numbers that happen to be 16 digits should usually fail - {"random 16 digits", "1234567890123456", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := validateLuhn(tt.input); got != tt.want { - t.Errorf("validateLuhn(%q) = %v, want %v", tt.input, got, tt.want) - } - }) - } -} - -// --- PII pattern matching tests --- - -func TestCheckNoPII_Phone(t *testing.T) { - noopLogger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, noopLogger) - - tests := []struct { - name string - text string - wantErr bool - }{ - {"US phone with dashes", "call 212-555-1234", true}, - {"US phone with dots", "call 212.555.1234", true}, - {"US phone with +1", "call +1-212-555-1234", true}, - {"US phone with parens", "call (212) 555-1234", true}, - // Area code must start with 2-9 - {"area code starts with 0", "call 012-555-1234", false}, - {"area code starts with 1", "call 112-555-1234", false}, - // K8s byte counts should NOT match - {"k8s memory bytes 4Gi", "memory: 4294967296 bytes", false}, - {"k8s memory bytes 1Gi", "memory: 1073741824 bytes", false}, - {"k8s memory 10 digits", "allocatable: 3221225472 bytes", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := g.checkNoPII(tt.text) - if (err != nil) != tt.wantErr { - t.Errorf("checkNoPII(%q) error = %v, wantErr %v", tt.text, err, tt.wantErr) - } - }) - } -} - -func TestCheckNoPII_SSN(t *testing.T) { - noopLogger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, noopLogger) - - tests := []struct { - name string - text string - wantErr bool - }{ - {"valid SSN", "SSN: 456-78-9012", true}, - {"valid SSN no sep", "SSN: 456789012", true}, - {"invalid area 000", "SSN: 000-12-3456", false}, - {"invalid area 666", "SSN: 666-12-3456", false}, - {"invalid area 900+", "SSN: 900-12-3456", false}, - {"invalid group 00", "SSN: 123-00-4567", false}, - {"invalid serial 0000", "SSN: 123-45-0000", false}, - {"all same digits", "SSN: 111-11-1111", false}, - {"known test SSN", "SSN: 123-45-6789", false}, // 123456789 is a known test SSN - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := g.checkNoPII(tt.text) - if (err != nil) != tt.wantErr { - t.Errorf("checkNoPII(%q) error = %v, wantErr %v", tt.text, err, tt.wantErr) - } - }) - } +// testLogger is a no-op logger for tests. Shared across test files in this package. +type testLogger struct { + warnings []string } -func TestCheckNoPII_CreditCard(t *testing.T) { - noopLogger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, noopLogger) - - tests := []struct { - name string - text string - wantErr bool - }{ - {"Visa", "card: 4111111111111111", true}, - {"Visa with spaces", "card: 4111 1111 1111 1111", true}, - {"Mastercard", "card: 5500000000000004", true}, - {"Amex", "card: 378282246310005", true}, - {"bad Luhn", "card: 4111111111111112", false}, - {"random 16 digits", "card: 1234567890123456", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := g.checkNoPII(tt.text) - if (err != nil) != tt.wantErr { - t.Errorf("checkNoPII(%q) error = %v, wantErr %v", tt.text, err, tt.wantErr) - } - }) - } +func (l *testLogger) Info(_ string, _ map[string]any) {} +func (l *testLogger) Debug(_ string, _ map[string]any) {} +func (l *testLogger) Warn(msg string, _ map[string]any) { + l.warnings = append(l.warnings, msg) } +func (l *testLogger) Error(_ string, _ map[string]any) {} -func TestCheckNoPII_Email(t *testing.T) { - noopLogger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, noopLogger) +// TestNoopGuardrailChecker_ImplementsInterface verifies NoopGuardrailChecker +// satisfies the GuardrailChecker interface. +func TestNoopGuardrailChecker_ImplementsInterface(t *testing.T) { + var checker GuardrailChecker = &NoopGuardrailChecker{} - tests := []struct { - name string - text string - wantErr bool - }{ - {"simple email", "contact: user@example.com", true}, - {"email with plus", "contact: user+tag@example.com", true}, - {"not an email", "contact: user at example dot com", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := g.checkNoPII(tt.text) - if (err != nil) != tt.wantErr { - t.Errorf("checkNoPII(%q) error = %v, wantErr %v", tt.text, err, tt.wantErr) - } - }) + msg := &a2a.Message{ + Role: "user", + Parts: []a2a.Part{ + {Kind: a2a.PartKindText, Text: "hello world"}, + }, } -} -// --- CheckToolOutput tests --- - -func TestCheckToolOutput_RedactsWithValidation(t *testing.T) { - logger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, false, logger) // warn mode - - // Valid SSN should be redacted - out, err := g.CheckToolOutput("some_tool", "SSN is 456-78-9012") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if out == "SSN is 456-78-9012" { - t.Error("expected valid SSN to be redacted") + if err := checker.CheckInbound(msg); err != nil { + t.Errorf("NoopGuardrailChecker.CheckInbound() unexpected error: %v", err) } - // Invalid SSN (area 000) should NOT be redacted - out, err = g.CheckToolOutput("some_tool", "code 000-12-3456 here") - if err != nil { - t.Fatalf("unexpected error: %v", err) + if err := checker.CheckOutbound(msg); err != nil { + t.Errorf("NoopGuardrailChecker.CheckOutbound() unexpected error: %v", err) } - if out != "code 000-12-3456 here" { - t.Errorf("expected invalid SSN to pass through, got %q", out) - } -} -func TestCheckToolOutput_K8sBytesNotBlocked(t *testing.T) { - logger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, logger) // enforce mode - - // K8s memory byte counts should not trigger PII detection - k8sOutput := `{"memory": "4294967296", "cpu": "2000m", "pods": "110", "allocatable_memory": "3221225472"}` - out, err := g.CheckToolOutput("some_tool", k8sOutput) + out, err := checker.CheckToolOutput("some_tool", "some text") if err != nil { - t.Fatalf("k8s output blocked as PII: %v", err) + t.Errorf("NoopGuardrailChecker.CheckToolOutput() unexpected error: %v", err) } - if out != k8sOutput { - t.Errorf("k8s output was modified: %q", out) + if out != "some text" { + t.Errorf("NoopGuardrailChecker.CheckToolOutput() = %q, want %q", out, "some text") } } -func TestCheckToolOutput_EnforceBlocksValidPII(t *testing.T) { - logger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, logger) // enforce mode - - _, err := g.CheckToolOutput("some_tool", "SSN: 456-78-9012") - if err == nil { - t.Error("expected enforce mode to block valid SSN") - } - if !strings.Contains(err.Error(), "no_pii") { - t.Errorf("expected error to mention no_pii guardrail, got: %v", err) - } -} - -func TestCheckToolOutput_AllowToolsBypassesPII(t *testing.T) { - logger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{ - { - Type: "no_pii", - Config: map[string]any{ - "allow_tools": []any{"github_get_user", "github_pr_author_profiles"}, - }, +// TestExtractText verifies the text extraction helper. +func TestExtractText(t *testing.T) { + tests := []struct { + name string + msg *a2a.Message + want string + }{ + { + name: "single text part", + msg: &a2a.Message{ + Parts: []a2a.Part{{Kind: a2a.PartKindText, Text: "hello"}}, }, + want: "hello", }, - }, true, logger) // enforce mode - - // Allowed tool should pass through with PII - out, err := g.CheckToolOutput("github_get_user", `{"email": "user@example.com"}`) - if err != nil { - t.Fatalf("allowed tool should not be blocked: %v", err) - } - if !strings.Contains(out, "user@example.com") { - t.Error("expected email to pass through for allowed tool") - } - - // Non-allowed tool should still be blocked - _, err = g.CheckToolOutput("some_other_tool", `{"email": "user@example.com"}`) - if err == nil { - t.Error("expected non-allowed tool to be blocked for PII") - } -} - -func TestCheckToolOutput_AllowToolsOnlyAffectsConfiguredGuardrail(t *testing.T) { - logger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{ - {Type: "no_secrets"}, // no allow_tools — applies to all tools - { - Type: "no_pii", - Config: map[string]any{ - "allow_tools": []any{"github_get_user"}, + { + name: "multiple text parts", + msg: &a2a.Message{ + Parts: []a2a.Part{ + {Kind: a2a.PartKindText, Text: "hello"}, + {Kind: a2a.PartKindText, Text: "world"}, }, }, + want: "hello world", }, - }, true, logger) - - // Allowed tool bypasses PII but NOT secrets - _, err := g.CheckToolOutput("github_get_user", "token: ghp_abcdefghijklmnopqrstuvwxyz0123456789") - if err == nil { - t.Error("allow_tools for no_pii should not bypass no_secrets") - } - if !strings.Contains(err.Error(), "no_secrets") { - t.Errorf("expected error to mention no_secrets, got: %v", err) - } -} - -func TestCheckToolOutput_ErrorMessageMentionsGuardrailType(t *testing.T) { - logger := &testLogger{} - - // Test no_secrets error message - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_secrets"}}, - }, true, logger) - _, err := g.CheckToolOutput("some_tool", "key: sk-ant-abcdefghijklmnopqrstuv") - if err == nil { - t.Fatal("expected error for secret") - } - if !strings.Contains(err.Error(), "no_secrets") { - t.Errorf("expected error to mention no_secrets, got: %v", err) - } - - // Test no_pii error message - g2 := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, logger) - _, err = g2.CheckToolOutput("some_tool", "email: test@example.com") - if err == nil { - t.Fatal("expected error for PII") - } - if !strings.Contains(err.Error(), "no_pii") { - t.Errorf("expected error to mention no_pii, got: %v", err) - } -} - -// --- CheckOutbound message tests --- - -func TestCheckOutbound_PIIRedacted(t *testing.T) { - logger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, logger) - - msg := &a2a.Message{ - Role: "agent", - Parts: []a2a.Part{ - {Kind: a2a.PartKindText, Text: "Your SSN is 456-78-9012"}, + { + name: "empty parts", + msg: &a2a.Message{}, + want: "", }, - } - err := g.CheckOutbound(msg) - if err != nil { - t.Errorf("CheckOutbound should redact, not block: %v", err) - } - if !strings.Contains(msg.Parts[0].Text, "[REDACTED]") { - t.Error("expected PII to be redacted in outbound message") - } -} - -func TestCheckOutbound_InvalidSSNPasses(t *testing.T) { - logger := &testLogger{} - g := NewGuardrailEngine(&agentspec.PolicyScaffold{ - Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, - }, true, logger) - - msg := &a2a.Message{ - Role: "agent", - Parts: []a2a.Part{ - {Kind: a2a.PartKindText, Text: "code: 000-12-3456"}, + { + name: "non-text parts ignored", + msg: &a2a.Message{ + Parts: []a2a.Part{ + {Kind: a2a.PartKindText, Text: "text"}, + {Kind: "data", Text: ""}, + }, + }, + want: "text", }, } - err := g.CheckOutbound(msg) - if err != nil { - t.Errorf("invalid SSN should pass through, got error: %v", err) - } -} - -// testLogger is a no-op logger for tests. -type testLogger struct { - warnings []string -} -func (l *testLogger) Info(msg string, fields map[string]any) {} -func (l *testLogger) Debug(msg string, fields map[string]any) {} -func (l *testLogger) Warn(msg string, fields map[string]any) { - l.warnings = append(l.warnings, msg) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractText(tt.msg) + if got != tt.want { + t.Errorf("ExtractText() = %q, want %q", got, tt.want) + } + }) + } } -func (l *testLogger) Error(msg string, fields map[string]any) {} diff --git a/forge-core/runtime/loop.go b/forge-core/runtime/loop.go index 65dda5f..63a5f15 100644 --- a/forge-core/runtime/loop.go +++ b/forge-core/runtime/loop.go @@ -248,8 +248,11 @@ func (e *LLMExecutor) Execute(ctx context.Context, task *a2a.Task, msg *a2a.Mess // Append assistant message to memory mem.Append(resp.Message) - // Check if we're done (no tool calls) - if resp.FinishReason == "stop" || len(resp.Message.ToolCalls) == 0 { + // Check if we're done (no tool calls). + // Always execute tool calls even when finish_reason is "stop" — + // otherwise we persist an assistant message with orphaned function + // calls that the Responses API will reject on session recovery. + if len(resp.Message.ToolCalls) == 0 { // If the LLM stopped after executing tools, send a continuation // nudge. This catches cases where the LLM reports findings instead // of completing the full workflow (e.g., stops after exploration @@ -500,17 +503,48 @@ func (e *LLMExecutor) Execute(ctx context.Context, task *a2a.Task, msg *a2a.Mess } // persistSession saves the current memory state to disk (best-effort). +// It strips orphaned tool calls from the last assistant message to prevent +// the Responses API from rejecting recovered sessions with +// "No tool output found for function call". func (e *LLMExecutor) persistSession(taskID string, mem *Memory) { if e.store == nil { return } mem.mu.Lock() + msgs := make([]llm.ChatMessage, len(mem.messages)) + copy(msgs, mem.messages) + mem.mu.Unlock() + + // Build a set of tool call IDs that have matching tool results. + answeredCalls := make(map[string]bool) + for _, m := range msgs { + if m.Role == llm.RoleTool && m.ToolCallID != "" { + answeredCalls[m.ToolCallID] = true + } + } + + // Strip unanswered tool calls from assistant messages to avoid + // orphaned function_call items in the persisted session. + for i := range msgs { + if msgs[i].Role != llm.RoleAssistant || len(msgs[i].ToolCalls) == 0 { + continue + } + var kept []llm.ToolCall + for _, tc := range msgs[i].ToolCalls { + if answeredCalls[tc.ID] { + kept = append(kept, tc) + } + } + if len(kept) != len(msgs[i].ToolCalls) { + msgs[i].ToolCalls = kept + } + } + data := &SessionData{ TaskID: taskID, - Messages: mem.messages, + Messages: msgs, Summary: mem.existingSummary, } - mem.mu.Unlock() if err := e.store.Save(data); err != nil { e.logger.Warn("failed to persist session", map[string]any{ diff --git a/forge-core/runtime/memory.go b/forge-core/runtime/memory.go index 6746a38..a4ba4b3 100644 --- a/forge-core/runtime/memory.go +++ b/forge-core/runtime/memory.go @@ -119,13 +119,46 @@ func (m *Memory) Messages() []llm.ChatMessage { } // LoadFromStore restores memory state from a persisted SessionData. +// It sanitizes the loaded messages by stripping orphaned tool calls — +// assistant messages whose tool_calls have no matching tool result. +// This prevents the Responses API from rejecting recovered sessions +// with "No tool output found for function call". func (m *Memory) LoadFromStore(data *SessionData) { m.mu.Lock() defer m.mu.Unlock() - m.messages = data.Messages + m.messages = sanitizeToolCalls(data.Messages) m.existingSummary = data.Summary } +// sanitizeToolCalls removes tool calls from assistant messages that have +// no matching tool result in the message history. +func sanitizeToolCalls(msgs []llm.ChatMessage) []llm.ChatMessage { + // Build set of tool call IDs that have results. + answered := make(map[string]bool, len(msgs)) + for _, m := range msgs { + if m.Role == llm.RoleTool && m.ToolCallID != "" { + answered[m.ToolCallID] = true + } + } + + // Strip unanswered tool calls. + for i := range msgs { + if msgs[i].Role != llm.RoleAssistant || len(msgs[i].ToolCalls) == 0 { + continue + } + var kept []llm.ToolCall + for _, tc := range msgs[i].ToolCalls { + if answered[tc.ID] { + kept = append(kept, tc) + } + } + if len(kept) != len(msgs[i].ToolCalls) { + msgs[i].ToolCalls = kept + } + } + return msgs +} + // Reset clears the conversation history (keeps the system prompt). func (m *Memory) Reset() { m.mu.Lock() diff --git a/forge-core/types/config.go b/forge-core/types/config.go index 77330a3..faf90af 100644 --- a/forge-core/types/config.go +++ b/forge-core/types/config.go @@ -9,22 +9,23 @@ import ( // ForgeConfig represents the top-level forge.yaml configuration. type ForgeConfig struct { - AgentID string `yaml:"agent_id"` - Version string `yaml:"version"` - Framework string `yaml:"framework"` - Entrypoint string `yaml:"entrypoint"` - Model ModelRef `yaml:"model,omitempty"` - Tools []ToolRef `yaml:"tools,omitempty"` - BuiltinTools []string `yaml:"builtin_tools,omitempty"` - Channels []string `yaml:"channels,omitempty"` - Registry string `yaml:"registry,omitempty"` - Egress EgressRef `yaml:"egress,omitempty"` - Skills SkillsRef `yaml:"skills,omitempty"` - Memory MemoryConfig `yaml:"memory,omitempty"` - Secrets SecretsConfig `yaml:"secrets,omitempty"` - Schedules []ScheduleConfig `yaml:"schedules,omitempty"` - CORSOrigins []string `yaml:"cors_origins,omitempty"` - Package PackageConfig `yaml:"package,omitempty"` + AgentID string `yaml:"agent_id"` + Version string `yaml:"version"` + Framework string `yaml:"framework"` + Entrypoint string `yaml:"entrypoint"` + Model ModelRef `yaml:"model,omitempty"` + Tools []ToolRef `yaml:"tools,omitempty"` + BuiltinTools []string `yaml:"builtin_tools,omitempty"` + Channels []string `yaml:"channels,omitempty"` + Registry string `yaml:"registry,omitempty"` + Egress EgressRef `yaml:"egress,omitempty"` + Skills SkillsRef `yaml:"skills,omitempty"` + Memory MemoryConfig `yaml:"memory,omitempty"` + Secrets SecretsConfig `yaml:"secrets,omitempty"` + Schedules []ScheduleConfig `yaml:"schedules,omitempty"` + CORSOrigins []string `yaml:"cors_origins,omitempty"` + Package PackageConfig `yaml:"package,omitempty"` + GuardrailsPath string `yaml:"guardrails_path,omitempty"` // path to guardrails.json (default: "guardrails.json") } // ScheduleConfig defines a recurring scheduled task in forge.yaml.