diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index fcb47dbfb23..4b77484a811 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -229,6 +229,17 @@ "type": "", "validation": {} }, + { + "category": "PATHS", + "default": "wildcards", + "description": "Path to directory for dynamic prompt wildcard files.", + "env_var": "INVOKEAI_WILDCARDS_DIR", + "literal_values": [], + "name": "wildcards_dir", + "required": false, + "type": "", + "validation": {} + }, { "category": "LOGGING", "default": [ diff --git a/invokeai/app/api/routers/utilities.py b/invokeai/app/api/routers/utilities.py index 568546603ab..15e6311539a 100644 --- a/invokeai/app/api/routers/utilities.py +++ b/invokeai/app/api/routers/utilities.py @@ -17,6 +17,7 @@ from invokeai.app.api.routers._access import assert_image_read_access from invokeai.app.services.image_files.image_files_common import ImageFileNotFoundException from invokeai.app.services.model_records.model_records_base import UnknownModelException +from invokeai.app.util.dynamicprompts import find_missing_wildcards, get_wildcard_manager from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline from invokeai.backend.model_manager.taxonomy import ModelType from invokeai.backend.text_llm_pipeline import DEFAULT_SYSTEM_PROMPT, TextLLMPipeline @@ -52,13 +53,22 @@ async def parse_dynamicprompts( """Creates a batch process""" max_prompts = min(max_prompts, 10000) generator: Union[RandomPromptGenerator, CombinatorialPromptGenerator] + wildcard_manager = get_wildcard_manager(ApiDependencies.invoker.services.configuration.wildcards_path) + error: Optional[str] = None + + # An unknown wildcard sends the combinatorial generator into an infinite loop, so bail out early + # with a clear message instead of hanging the request (and with it the UI preview). + missing_wildcards = find_missing_wildcards(prompt, wildcard_manager) + if missing_wildcards: + wildcards = ", ".join(missing_wildcards) + return DynamicPromptsResponse(prompts=[prompt], error=f"No values found for wildcard(s): {wildcards}") + try: - error: Optional[str] = None if combinatorial: - generator = CombinatorialPromptGenerator() + generator = CombinatorialPromptGenerator(wildcard_manager=wildcard_manager) prompts = generator.generate(prompt, max_prompts=max_prompts) else: - generator = RandomPromptGenerator(seed=seed) + generator = RandomPromptGenerator(wildcard_manager=wildcard_manager, seed=seed) prompts = generator.generate(prompt, num_images=max_prompts) except ParseException as e: prompts = [prompt] diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 48eec0ac0ef..d46a00b05ac 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -9,6 +9,7 @@ from invokeai.app.invocations.fields import InputField, UIComponent from invokeai.app.invocations.primitives import StringCollectionOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.dynamicprompts import find_missing_wildcards, get_wildcard_manager @invocation( @@ -30,11 +31,19 @@ class DynamicPromptInvocation(BaseInvocation): combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") def invoke(self, context: InvocationContext) -> StringCollectionOutput: + wildcard_manager = get_wildcard_manager(context.config.get().wildcards_path) + + # An unknown wildcard sends the combinatorial generator into an infinite loop, so fail fast + # with a clear message instead of hanging the invocation. + missing_wildcards = find_missing_wildcards(self.prompt, wildcard_manager) + if missing_wildcards: + raise ValueError(f"No values found for wildcard(s): {', '.join(missing_wildcards)}") + if self.combinatorial: - generator = CombinatorialPromptGenerator() + generator = CombinatorialPromptGenerator(wildcard_manager=wildcard_manager) prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) else: - generator = RandomPromptGenerator() + generator = RandomPromptGenerator(wildcard_manager=wildcard_manager) prompts = generator.generate(self.prompt, num_images=self.max_prompts) return StringCollectionOutput(collection=prompts) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index e6cc7c2798c..840898d95e0 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -85,6 +85,7 @@ class InvokeAIAppConfig(BaseSettings): custom_nodes_dir: Path to directory for custom nodes. style_presets_dir: Path to directory for style presets. workflow_thumbnails_dir: Path to directory for workflow thumbnails. + wildcards_dir: Path to directory for dynamic prompt wildcard files. log_handlers: Log handler. Valid options are "console", "file=", "syslog=path|address:host:port", "http=". log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.
Valid values: `plain`, `color`, `syslog`, `legacy` log_level: Emit logging messages at this level or higher.
Valid values: `debug`, `info`, `warning`, `error`, `critical` @@ -171,6 +172,7 @@ class InvokeAIAppConfig(BaseSettings): custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.") style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.") workflow_thumbnails_dir: Path = Field(default=Path("workflow_thumbnails"), description="Path to directory for workflow thumbnails.") + wildcards_dir: Path = Field(default=Path("wildcards"), description="Path to directory for dynamic prompt wildcard files.") # LOGGING log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=", "syslog=path|address:host:port", "http=".') @@ -373,6 +375,11 @@ def workflow_thumbnails_path(self) -> Path: """Path to the workflow thumbnails directory, resolved to an absolute path..""" return self._resolve(self.workflow_thumbnails_dir) + @property + def wildcards_path(self) -> Path: + """Path to the dynamic prompt wildcards directory, resolved to an absolute path..""" + return self._resolve(self.wildcards_dir) + @property def convert_cache_path(self) -> Path: """Path to the converted cache models directory, resolved to an absolute path..""" diff --git a/invokeai/app/util/dynamicprompts.py b/invokeai/app/util/dynamicprompts.py new file mode 100644 index 00000000000..e5bc1c2597a --- /dev/null +++ b/invokeai/app/util/dynamicprompts.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path + +from dynamicprompts.commands import ( + Command, + SequenceCommand, + VariantCommand, + WildcardCommand, + WrapCommand, +) +from dynamicprompts.parser.parse import parse +from dynamicprompts.wildcards import WildcardManager +from pyparsing import ParseException + + +def get_wildcard_manager(wildcards_path: Path) -> WildcardManager: + """Build a `WildcardManager` rooted at the given directory, creating the directory if needed. + + Wildcard collections are plain `.txt` files (one value per line) placed under this directory and + referenced in prompts as `__name__` (subdirectories map to `__subdir/name__`). + """ + wildcards_path.mkdir(parents=True, exist_ok=True) + return WildcardManager(wildcards_path) + + +def _iter_wildcard_names(command: Command) -> Iterator[str]: + """Recursively yield the statically-known wildcard names referenced in a parsed prompt.""" + if isinstance(command, WildcardCommand): + # The wildcard name may itself be a dynamic Command (e.g. `__${var}__`). Only plain string + # names can be validated ahead of time, so the dynamic case is intentionally skipped. + if isinstance(command.wildcard, str): + yield command.wildcard + elif isinstance(command, SequenceCommand): + for token in command.tokens: + yield from _iter_wildcard_names(token) + elif isinstance(command, VariantCommand): + for value in command.values: + yield from _iter_wildcard_names(value) + elif isinstance(command, WrapCommand): + yield from _iter_wildcard_names(command.wrapper) + yield from _iter_wildcard_names(command.inner) + # LiteralCommand and variable commands reference no wildcards we can resolve statically. + + +def find_missing_wildcards(prompt: str, wildcard_manager: WildcardManager) -> list[str]: + """Return the unique wildcard names referenced in `prompt` that resolve to no values. + + Referencing an unknown wildcard makes dynamicprompts' combinatorial generator loop forever: its + not-found fallback (`get_wildcard_not_found_fallback`) yields the wrapped wildcard infinitely, and + the combinatorial variant logic dedupes those duplicates away without ever advancing. Detecting + the missing names up front lets callers report a clear error instead of hanging. + """ + try: + tree = parse(prompt) + except ParseException: + # Malformed prompts are surfaced separately by the generators; nothing to validate here. + return [] + + missing: list[str] = [] + for name in _iter_wildcard_names(tree): + if name not in missing and not wildcard_manager.get_values(name): + missing.append(name) + return missing diff --git a/tests/app/routers/test_utilities.py b/tests/app/routers/test_utilities.py index ce91f2efd24..2403efb1af0 100644 --- a/tests/app/routers/test_utilities.py +++ b/tests/app/routers/test_utilities.py @@ -67,6 +67,23 @@ def test_dynamicprompts_works_for_user(client: TestClient, user1_token: str): assert "prompts" in body +def test_dynamicprompts_unknown_wildcard_returns_error_without_hanging(client: TestClient, user1_token: str): + """An unknown `__wildcard__` would otherwise loop forever in the combinatorial generator. + + The endpoint must instead return promptly with a clear error and the original prompt echoed back. + """ + r = client.post( + "/api/v1/utilities/dynamicprompts", + json={"prompt": "{__random__8chan|fenster|stuff}"}, + headers={"Authorization": f"Bearer {user1_token}"}, + ) + assert r.status_code == status.HTTP_200_OK + body = r.json() + assert body["error"] is not None + assert "random" in body["error"] + assert body["prompts"] == ["{__random__8chan|fenster|stuff}"] + + # ----------------------------- image_to_prompt: ownership / read-access ----------------------------- diff --git a/tests/app/util/test_dynamicprompts.py b/tests/app/util/test_dynamicprompts.py new file mode 100644 index 00000000000..6ac7674693f --- /dev/null +++ b/tests/app/util/test_dynamicprompts.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from invokeai.app.util.dynamicprompts import find_missing_wildcards, get_wildcard_manager + + +@pytest.fixture +def wildcards_dir(tmp_path: Path) -> Path: + """A wildcards directory containing a single `animals` collection.""" + (tmp_path / "animals.txt").write_text("cat\ndog\nbird\n", encoding="utf-8") + return tmp_path + + +def test_get_wildcard_manager_creates_directory(tmp_path: Path) -> None: + target = tmp_path / "does-not-exist-yet" + assert not target.exists() + get_wildcard_manager(target) + assert target.is_dir() + + +def test_find_missing_wildcards_detects_unknown_wildcard_in_variant(wildcards_dir: Path) -> None: + # Regression: `__random__` inside a variant is parsed as a wildcard reference. Left unchecked it + # sends the combinatorial generator into an infinite loop, so it must be reported up front. + wm = get_wildcard_manager(wildcards_dir) + assert find_missing_wildcards("{__random__8chan|fenster|stuff}", wm) == ["random"] + + +def test_find_missing_wildcards_passes_known_wildcard(wildcards_dir: Path) -> None: + wm = get_wildcard_manager(wildcards_dir) + assert find_missing_wildcards("a {__animals__|house}", wm) == [] + + +@pytest.mark.parametrize("prompt", ["plain text", "{a|b|c}", "a {2$$x|y|z}"]) +def test_find_missing_wildcards_ignores_prompts_without_wildcards(wildcards_dir: Path, prompt: str) -> None: + wm = get_wildcard_manager(wildcards_dir) + assert find_missing_wildcards(prompt, wm) == [] + + +def test_find_missing_wildcards_dedupes_repeated_unknown_wildcards(wildcards_dir: Path) -> None: + wm = get_wildcard_manager(wildcards_dir) + assert find_missing_wildcards("__nope__ and __nope__ and __animals__", wm) == ["nope"]