diff --git a/docs/contributing/call_saved_workflow.md b/docs/contributing/call_saved_workflow.md new file mode 100644 index 00000000000..3e598ef8918 --- /dev/null +++ b/docs/contributing/call_saved_workflow.md @@ -0,0 +1,795 @@ +# Call Saved Workflow Architecture + +## Goal + +`CallSavedWorkflowInvocation` should become an engine-native workflow call boundary, not a frontend-only dynamic node +and not a compile-time graph inliner. + +The long-term feature goal is: + +- A parent workflow can call a saved workflow selected by ID. +- The call node redraws in the editor based on the selected workflow's exposed form fields. +- Parent values and inbound connections bind to those exposed fields as call arguments. +- Execution suspends at the call node, runs the selected workflow as a dependent workflow execution, captures explicit + return values, and then resumes the parent workflow. +- The architecture must work for Invoke frontend graphs and for externally submitted graphs that use the same node type. + +This document records the current state, the target architecture, and the execution contract needed to continue +development later. + +## Implementation Priority + +Favor the architecturally correct design over the fastest implementation path. + +The work may still proceed incrementally, but each increment should satisfy all of the following: + +- testable in isolation +- compatible with the long-term architecture described here +- non-breaking to existing code and existing workflow execution behavior + +Speed is not the primary goal for this phase. The primary goal is to move toward the durable design without introducing +throwaway execution semantics that would need to be unwound later. + +## Current State + +Implemented already in the branch: + +- A real invocation exists: `call_saved_workflow`. +- A real return node exists: `workflow_return`. +- Named returns exist through `workflow_return_value`, `workflow_return`, and caller-side `workflow_return_get`. +- `workflow_return` accepts one key/value return member directly or a collected list of return members, then emits a + named `values: dict[str, Any]` map. +- Only one `workflow_return` node is allowed per workflow, enforced in both frontend validation and Python validation. +- The frontend provides a saved-workflow picker using a reusable `SavedWorkflowField` UI type. +- The node redraws dynamically based on the selected saved workflow's exposed form fields. +- Dynamic field values persist with the parent workflow. +- Compatible inbound edges are preserved when switching between workflows with matching exposed field identities and + compatible types. +- Incompatible or no-longer-exposed inbound edges are removed in the editor. +- Backend validation exists for `workflow_id` existence and access rights. + +Implemented runtime scaffolding: + +- `GraphExecutionState` now persists workflow-call runtime state: + - `workflow_call_stack` + - `workflow_call_history` + - `workflow_call_parent` + - `waiting_workflow_call` + - `waiting_workflow_call_execution` + - `waiting_workflow_call_child_session` + - `max_workflow_call_depth` +- Nested and recursive calls are represented by the stack, with a runtime depth cap of 4. +- Parent/child workflow-call identity is now explicit in runtime state: + - the parent tracks an active `WorkflowCallExecution` record while waiting + - completed and failed calls are preserved in `workflow_call_history` + - child sessions carry a `workflow_call_parent` reference back to the parent call relationship +- `GraphExecutionState.next()` returns no runnable node while the parent session is waiting on a child workflow call. +- `GraphExecutionState.is_complete()` stays false while waiting. +- `DefaultSessionRunner.run_node()` now treats `call_saved_workflow` as a call boundary instead of a normal executable + node. +- On boundary entry, the runner: + - validates the selected workflow + - builds a workflow call frame + - converts the saved workflow JSON into a backend `Graph` + - validates and applies parent call arguments to the child graph + - creates a child `GraphExecutionState` + - attaches that child session to the waiting parent session +- Workflow-call runtime responsibilities are now split: + - `WorkflowCallCoordinator` handles call-specific setup: + - build the child graph + - apply parent call arguments + - create the child `GraphExecutionState` + - suspend the parent and enqueue the child queue item + - `WorkflowCallQueueLifecycle` handles queue-visible parent/child lifecycle: + - run child queue items + - resume waiting parents after child success + - complete the parent call node with the child `workflow_return` values + - fail suspended parents after child failure and cascade that failure upward through parent call chains +- Child `SessionQueueItem` rows now carry explicit relationship metadata: + - `workflow_call_id` + - `parent_item_id` + - `parent_session_id` + - `root_item_id` + - `workflow_call_depth` + - this metadata is now used directly by queue-visible child execution and parent resume/failure handling +- The `session_queue` table now has matching durable columns for that relationship metadata: + - `workflow_call_id` + - `parent_item_id` + - `parent_session_id` + - `root_item_id` + - `workflow_call_depth` + - child workflow executions are now inserted as their own pending queue rows using those columns +- Parent queue items now enter a real `waiting` status while suspended on a child workflow execution. +- `_on_after_run_session()` no longer completes queue items whose sessions are incomplete but waiting. +- Dynamic call arguments now execute end-to-end in the current runner path: + - literal dynamic values are serialized into a hidden `workflow_inputs` payload on the parent node at graph-build time + - stale hidden `workflow_inputs` values from recalled graphs are ignored unless a matching current dynamic field + exists + - existing dynamic input values are preserved across refresh only while the exposed field type remains compatible; if + the selected child workflow changes the exposed field type at the same node/field path, the caller input resets to + the child workflow's current initial value + - connected dynamic values are accepted as special call-boundary edges and are resolved from parent results at runtime + - both are validated against the child workflow's exposed form interface before being applied to the child graph +- Queue lifecycle semantics now exist for workflow-call chains: + - parent queue items are suspended in `waiting` while a child queue row runs + - child success resumes the suspended parent and completes the parent call node with the child `workflow_return` + values + - child failure fails the suspended parent and cascades upward through any waiting parent chain + - canceling a parent cancels its descendant child chain + - canceling a child cancels the waiting parent chain upward + - canceling remaining siblings after a batched child failure also cancels descendants of those sibling rows + - deleting any queue row in a workflow-call chain deletes the full chain to avoid leaving orphaned parent or child + rows behind + - `cancel_all_except_current` and `delete_all_except_current` preserve the active queue item plus its workflow-call + ancestors and descendants; unrelated waiting chains are still canceled or deleted + - retry is root-oriented rather than child-oriented; child queue rows should not be directly retried from the UI + - the current UI policy is: + - child queue rows keep `Cancel` + - child queue rows hide `Retry` + - child queue-row creation is now fail-clean: + - if call-boundary setup fails after some child rows have already been inserted, those child rows are deleted before + the parent invocation is failed + - child queue-row fan-out is bounded by remaining queue capacity, not just the global queue-size setting: + - a workflow call that would exceed the remaining pending capacity now fails instead of silently truncating or + over-enqueuing child rows + +Implemented conversion helper: + +- `workflow_graph_builder.py` converts saved workflow JSON into an executable backend `Graph`. +- It currently supports the invocation-node subset needed for this feature. +- It flattens connector nodes and omits explicit destination field values when a connection exists, matching frontend + graph-build semantics. +- It now serves as the first explicit callable-workflow compatibility gate: + - the selected workflow must contain exactly one `workflow_return` node + - connected batch child inputs produced by ordinary non-generator upstream nodes still fail early with a clear + unsupported-feature error + - malformed batch input wiring, including multiple connected inputs to one batch field, is reported as + `unsupported_batch_input` compatibility rather than a generic unsupported-node failure + - child workflows that mix supported batch nodes with unrelated generator nodes are currently rejected with a clear + unsupported-feature error + - unsupported callees are rejected before any child queue row is created +- Compatibility metadata is now exposed through workflow library API responses: + - workflow list items and workflow detail responses include `call_saved_workflow_compatibility` + - workflow list items use structural generator-backed batch checks so list/picker rendering does not enumerate every + image in board-backed generators; workflow detail and runtime execution still resolve real generator values + - the saved-workflow picker uses that metadata to disable unsupported workflows before execution + - the picker still allows an already-selected unsupported workflow to render, with an explicit unsupported state and + backend-provided reason message + - workflow library list items now surface an explicit unsupported badge and backend-provided reason message without + blocking normal workflow viewing or editing + +What is still not implemented: + +- connected batch child inputs whose batch values are produced by ordinary non-generator upstream nodes are still not + supported and must fail with a clear domain error +- child workflows that mix supported batch nodes with unrelated generator nodes are still not supported and must fail + with a clear domain error +- broader child-workflow compatibility coverage still needs to be expanded from real unsupported shapes rather than + trying to interpret every frontend-only workflow representation through the current graph-builder path +- the current workflow-call queue lifecycle is still implemented through dedicated workflow-call runtime classes rather + than a fully generalized parent/child scheduler model + +Conclusion: + +- the editor contract is largely in place +- the parent-side runtime call boundary is in place +- child execution, argument forwarding, explicit child return capture, suspended parent status, queue-visible child + rows, and upward failure cascade now work +- the remaining major runtime work is to harden and generalize the parent/child scheduler model rather than prove the + basic call boundary + +## Architectural Direction + +Use the architecture that is more likely to be kept long-term: + +- `call_saved_workflow` is a call boundary. +- The parent graph does not inline the full child workflow into itself at queue time. +- Runtime execution pauses at the call node and creates a dependent child workflow execution. +- The child workflow receives arguments from the parent. +- The child workflow returns explicit outputs to the parent. +- The parent resumes once the child returns successfully. + +This is preferred over full graph expansion because it: + +- avoids execution-graph blowup +- preserves workflow boundaries +- matches the conceptual model of workflow reuse +- supports explicit return values +- keeps externally submitted graphs viable as long as they use the same node type and contract + +## Non-Goals For The Next Phase + +These should not be the first implementation target: + +- full inline graph expansion of called workflows +- unlimited nested workflow call support +- automatic exposure of arbitrary internal child workflow state +- implicit output inference from arbitrary child nodes + +## Execution Contract + +### 1. Callable Interface + +The callable interface of a saved workflow is defined by its saved workflow JSON. + +Primary source: + +- `workflow.form` + +Fallback source for older workflows: + +- `workflow.exposedFields` + +Only fields exposed by the child workflow form are callable inputs. Internal child inputs that exist in the workflow +graph but are not exposed by the form are not part of the public call interface. + +### 2. Input Arguments + +`CallSavedWorkflowInvocation` exposes dynamic inputs in the editor based on the selected workflow's callable interface. + +The saved-workflow picker sends typed search text to the workflow-list endpoint. This keeps large workflow libraries +discoverable even when the desired workflow has not already been loaded into the combobox pages. + +Each dynamic input must have: + +- a stable external handle name +- a type +- a default value if defined by the child workflow +- a user-facing label and description when available + +Current fast-path identity is based on child `nodeId + fieldName`. That is acceptable short-term in the editor, but a +longer-term stable interface ID would be better if child workflows are frequently duplicated or refactored. + +### 3. Input Binding At Runtime + +At runtime, when the parent reaches `call_saved_workflow`: + +- the engine resolves `workflow_id` +- the engine loads the selected child workflow record +- the engine reconstructs the callable interface from the saved workflow JSON +- the engine collects argument values from the parent node's dynamic inputs +- the engine starts a dependent child workflow execution using those arguments + +Argument values may come from: + +- parent literal field values +- resolved inbound connections into the call node's dynamic inputs + +For batch-aware child workflows, the parent call boundary should still pass normal exposed form inputs. Batching should +emerge from the child workflow's own internal batch nodes or generators, not from a separate caller-side batch protocol. + +### 4. Child Workflow Execution + +The child workflow runs as its own dependent execution context, not as an inlined copy of the parent graph. + +Desired semantics: + +- parent execution pauses at the call node +- child execution runs with inherited context where appropriate +- child workflow finishes or fails +- parent resumes only if child execution succeeds + +This implies the queue/session/runtime layer needs an explicit parent-child execution relationship. + +Current limitation: + +- the temporary `workflow_graph_builder.py` path still reconstructs only the ordinary invocation subset of child + workflows +- direct batch-special child workflows now bypass that path and use queue batch expansion instead +- generator-backed batch child workflows now bypass that path too when the batch is fed directly by a supported + generator node +- connected batch child inputs produced by ordinary non-generator upstream nodes are still not supported and should fail + early with a clear unsupported-feature error +- the current queue-visible child execution path still relies on `WorkflowCallCoordinator` to resume or fail parents + directly rather than a more general queue scheduler abstraction +- the current implementation is still an intermediate architecture step, but it is now materially closer to the intended + durable parent/child model than the earlier inline-runner path + +### 4a. Queue Lifecycle Contract + +The current queue-visible implementation uses the following lifecycle contract: + +- root or parent queue items may enter `waiting` while suspended on a child workflow call +- child workflow executions are represented as real queue rows with explicit parent/child relationship metadata +- child completion resumes the suspended parent and returns control to normal queue execution +- child failure fails the suspended parent call node and cascades upward through any ancestor chain +- cancel operations are chain-aware: + - canceling a waiting parent cancels descendants + - canceling a child cancels waiting ancestors + - canceling batched siblings after one child fails includes nested descendants of those siblings + - bulk "all except current" actions preserve the active queue item and its parent/child chain, not just the single + `in_progress` row +- retry operations are root-aware: + - retrying a root queue item creates a new root execution + - retrying a child queue item should be normalized to the root by backend code + - child queue rows should not expose direct retry affordances in the UI + - retry websocket delivery is owner-scoped; when an admin retries roots owned by multiple users, each non-admin user + must receive only the retry item ids for their own roots, while admins can still observe the full retried set +- workflow live-update sockets join workflow event rooms in both authenticated multiuser mode and unauthenticated + single-user mode; the frontend relies on those events to invalidate workflow library data and clear deleted saved + workflow selections; in single-user mode, workflow CRUD events emit only to the admin room to avoid duplicate delivery + to sockets that are also joined to `user:system` +- the saved-workflow node picker queries owned/default workflows and public shared workflows separately, merges them by + workflow id, and fetches additional pages as the combobox menu reaches the end + +This is now part of the intended user-facing contract, even though the orchestration still lives in +`WorkflowCallCoordinator`. + +### 4b. Batch Child Workflows + +The current implementation now supports direct batch-special child workflows for: + +- `image_batch` +- `string_batch` +- `integer_batch` +- `float_batch` + +It also supports generator-backed batch child workflows when those batch nodes are fed directly by: + +- `integer_generator` +- `float_generator` +- `string_generator` +- `image_generator` using `image_generator_images_from_board` + +Current semantics: + +- batch-special nodes are removed from the executable child graph before ordinary graph validation +- supported generator nodes that feed those batch-special nodes are removed from the executable child graph as well +- their outgoing edges are converted into queue batch substitutions +- ungrouped batch nodes expand as a cartesian product +- grouped batch nodes zip by `batch_group_id` +- the workflow call creates one child queue row per expanded batch session +- supported generator value shapes are resolved into concrete batch items before queue batch expansion +- batch outputs may feed a named `workflow_return_value.value` directly; each expanded child returns one value for that + key +- parent resume waits for all child rows tied to that workflow call +- parent return aggregation produces `values: dict[str, list[Any]]`, where each key maps to one value per child row +- all child rows in one batch call must return the same key set; mismatched keys fail the parent call clearly +- if any child row fails, remaining sibling child rows are canceled and the parent call fails +- generator-backed image batches must respect board access: + - the caller may expand images from a board they own + - admins may expand any board + - shared/public boards may be expanded by other users + - inaccessible private boards must fail before image expansion rather than leaking board contents across users + +Current generator coverage: + +- integer generators: + - arithmetic sequence + - linear distribution + - parse string + - seeded uniform random distribution +- float generators: + - arithmetic sequence + - linear distribution + - parse string + - seeded uniform random distribution +- string generators: + - parse string + - dynamic prompts combinatorial + - dynamic prompts random +- image generators: + - images from board + +Still unsupported: + +- connected batch inputs whose batch values are produced by non-generator upstream nodes + +Plain-English summary: + +1. The parent workflow reaches `call_saved_workflow`. +1. The parent pauses and enters `waiting`. +1. The child workflow is inspected before execution. +1. If the child contains supported batch inputs, that one call expands into multiple child executions instead of one. +1. Each expanded child execution becomes its own queue row. +1. Each child queue row keeps the substituted batch `field_values`, matching ordinary batch queue rows. +1. Those child queue rows run independently. +1. The parent does not resume until all child queue rows for that call have finished. +1. Each child execution produces its own named `workflow_return.values` map. +1. The parent aggregates those maps into `values: dict[str, list[Any]]`. +1. The `call_saved_workflow` node completes with that named values map, and the parent workflow continues. + +Expansion rules: + +- ungrouped batch inputs expand as a cartesian product +- batch inputs that share the same `batch_group_id` zip together by position + +Example: + +- ungrouped inputs `[1, 2]` and `[10, 20]` produce 4 child executions: + - `(1, 10)` + - `(1, 20)` + - `(2, 10)` + - `(2, 20)` +- grouped inputs `[1, 2, 3]` and `[10, 20, 30]` with the same `batch_group_id` produce 3 child executions: + - `(1, 10)` + - `(2, 20)` + - `(3, 30)` + +### 4c. Tricky Areas + +The following parts of the runtime contract are easy to misread and should stay explicit in both code and tests. + +Waiting and resume: + +- a parent queue row in `waiting` is suspended, not completed +- a parent resumes only after every child queue row tied to that workflow call has reached a terminal state + +Return aggregation: + +- each child queue row returns its own named `workflow_return.values` +- for batched calls, the parent call node output is `values: dict[str, list[Any]]` +- all child rows in one batched call must return the same key set so each returned list is row-aligned +- if one key should contain multiple images for a non-batch child, the child must collect those images into a single + list value before returning that key + +Sibling failure behavior: + +- if one child queue row in a batched workflow call fails, remaining sibling child rows for that same workflow call are + canceled +- if parent return aggregation rejects a completed child row, remaining sibling child rows for that same workflow call + are canceled +- after sibling cancelation, the parent call fails +- if that parent is itself a child of another workflow call, failure continues upward through the ancestor chain + +Cancel behavior: + +- canceling a waiting parent cancels descendant child rows +- canceling a child row cancels waiting ancestors +- cancelation should stay cancelation; it should not be rewritten into ordinary failure semantics +- startup recovery cancels any interrupted `in_progress` or `waiting` workflow-call chain, including pending + descendants, so a restart cannot leave a suspended parent waiting on a child row that will never report back + +Retry behavior: + +- retry is root-oriented +- child queue rows should not be directly retried from the UI +- backend retry of a child id should normalize to the root workflow call chain rather than create an isolated child-only + rerun + +### 5. Return Values + +Return values should be explicit. + +Recommended model: + +- introduce a workflow return node analogous in concept to Canvas Output +- the child workflow declares named return values through explicit key/value return members +- each return member has a stable string key and a connected value +- when the workflow is run independently, the return node has no caller-visible effect +- when the workflow is run via `call_saved_workflow`, the named return map becomes the return value of the call +- `call_saved_workflow` exposes that named return map to the parent workflow + +Only one workflow return node may exist per workflow. That rule should be enforced in both the frontend editor and in +Python validation/runtime code. + +Do not infer child outputs from arbitrary terminal nodes. That is too ambiguous and too brittle. + +Named return contract: + +- the called workflow builds return members with a dedicated key/value node +- `workflow_return` accepts one return member directly, or a collected list of return members when the workflow returns + multiple named values +- non-batch execution rejects duplicate return keys +- if a non-batch workflow needs to return multiple images under one key, the child workflow should collect those images + into one list value and return that list under the key +- the caller extracts a named return value with a companion caller-side extraction node +- missing keys should fail clearly unless the extraction node explicitly supports and receives a default value + +Batch return aggregation: + +- when a called workflow expands into multiple child queue rows, each child row produces its own named return map +- the parent aggregates those child maps as `dict[str, list[Any]]` +- each key maps to the list of values returned by completed child rows for that key +- child rows are still aggregated in child-completion order unless a later contract explicitly requires stable input + order +- duplicate keys within a single child return map are still invalid; repeated keys across batch children are the normal + aggregation path + +### 6. Error Propagation + +If child execution fails: + +- the call node fails +- the parent workflow fails unless a later design adds explicit error-handling semantics + +For the first implementation, failure propagation should be simple and strict. + +### 7. Access Control + +Runtime must enforce the same access rules used elsewhere for saved workflows. + +The caller may execute a child workflow only if it is allowed to access that saved workflow at runtime. + +This matters even if the parent workflow was authored in a context where the child was once visible. + +### 8. Recursion And Nesting + +Nested and recursive `call_saved_workflow` execution should be allowed, but bounded. + +Initial implementation should enforce: + +- nested workflow calls are allowed +- recursive workflow calls are allowed +- maximum workflow call depth is capped at 4 call frames +- the depth cap is enforced at runtime, based on the active call stack, not by static validation alone + +This allows legitimate recursive or conditionally terminating workflow structures while still preventing unbounded call +growth. + +## Where The Runtime Work Belongs + +The goal is to support externally submitted graphs, not only frontend-authored graphs. Therefore the authoritative +execution logic must live in Python. + +Recommended high-level design: + +- a backend `GraphExpander` or broader graph-preparation service may still exist as an abstraction point +- but for this feature, the preferred long-term runtime model is not full graph expansion +- instead, the runtime needs a call-execution mechanism in the Python execution stack + +Relevant existing path: + +- frontend builds and submits a graph and workflow payload +- backend receives the batch via session queue APIs +- session queue stores session state +- runtime executes through `GraphExecutionState` + +Current insertion points already used: + +- `DefaultSessionRunner.run_node()` detects `call_saved_workflow` and enters boundary state +- `GraphExecutionState` stores the waiting/call-stack state and attached child session +- `WorkflowCallCoordinator` currently establishes the call boundary and enqueues child workflow executions as real queue + rows +- `WorkflowCallQueueLifecycle` currently resumes or fails parents when those child rows complete +- child queue items already carry stable parent/child identifiers in both runtime objects and durable queue columns + +Next runtime work still needed: + +- keep `WorkflowCallQueueLifecycle` as the bounded workflow-call lifecycle component for this PR + - the current workflow-call feature is the only caller of parent/child queue semantics + - replacing it with a generalized queue dependency scheduler now would add regression risk without unlocking a + concrete user workflow + - revisit only if another feature needs dependent queue items, richer retry/cancel policies, or resumable waits +- if support expands beyond the currently supported direct and generator-backed batch shapes, route those new child + workflow execution shapes through machinery that can honor ordinary Invoke batch semantics + +## Suggested Runtime Components + +### CallSavedWorkflowRuntime + +A dedicated runtime helper for this node type should be introduced. Responsibilities: + +- load and validate the selected child workflow record +- validate runtime access rights +- extract callable inputs from the child workflow definition +- build child execution arguments from the parent node state +- launch dependent execution +- collect declared returns +- map returned values back to the parent node outputs + +### Workflow Return Node + +A dedicated child-workflow return node should be introduced. Responsibilities: + +- define the return interface of the called workflow +- accept collected named key/value return members representing the workflow result +- provide that named values map back to the parent call site when invoked through `call_saved_workflow` +- remain inert from a caller perspective when the workflow is run independently +- guarantee that only one such node exists per workflow +- behave as a normal node in the editor, with singularity enforced by both frontend and Python validation/runtime code + +This should likely become the canonical reusable return mechanism for any future subworkflow call behavior. + +### Execution Relationship Tracking + +Session/runtime state will likely need to record: + +- parent execution waiting on child execution +- child execution belonging to a parent node call site +- result propagation back to the parent +- strict failure propagation rules + +### Workflow Return Value Flow + +The workflow return value should not be persisted back into the saved workflow record and should not be derived from +frontend state. + +The intended runtime flow is: + +1. The child workflow computes named return members like ordinary node outputs. +1. The child workflow connects one return member directly to `workflow_return.values`, or collects multiple return + members and connects that list to `workflow_return.values`. +1. When the child reaches `workflow_return`, runtime captures the resolved named return map as the child workflow + result. +1. The child workflow result is stored in child execution state. +1. That result is handed back to the suspended parent call frame. +1. The parent `call_saved_workflow` node is completed with that returned named value map. +1. The parent graph resumes. + +## Named Return Implementation Plan + +This is the next planned feature slice. Development should proceed test-first and keep documentation updated as each +stage lands. + +### Stage 1: Backend Return Contract + +Status: implemented in backend invocation tests. + +Goal: + +- establish the named return data model and invocation primitives + +Contract: + +- `WorkflowReturnValueField` stores one `key: str` and one `value: Any` +- `workflow_return_value` creates a single `WorkflowReturnValueField` from a key and connected value +- `workflow_return` accepts either one `WorkflowReturnValueField` member or a list of `WorkflowReturnValueField` members +- `WorkflowReturnOutput` exposes `values: dict[str, Any]` +- duplicate keys in one non-batch `workflow_return` execution are invalid and must fail clearly + +Tests first: + +- `workflow_return_value` emits the requested key/value pair +- `workflow_return` emits a named value map from one or more return members +- duplicate keys in one `workflow_return` execution are rejected +- empty returns are valid only if that remains an intentional callable-workflow contract + +### Stage 2: Caller-Side Extraction Primitive + +Status: implemented in backend invocation tests. + +Goal: + +- let the calling workflow extract a named return value without relying on collection position + +Contract: + +- `workflow_return_get` accepts the named return map and a key +- `workflow_return_get` outputs the selected value as `Any` +- missing keys fail clearly unless a later version intentionally adds default-value support + +Tests first: + +- extracting an existing key returns the stored value +- extracting a missing key fails with a useful message +- extracted `Any` values can feed typed downstream nodes through the existing connection compatibility rules + +### Stage 3: Runtime Propagation + +Status: implemented in backend runtime tests. + +Goal: + +- carry named return maps through queue-visible child execution and parent resume + +Contract: + +- non-batch child execution returns `values: dict[str, Any]` +- `call_saved_workflow` exposes that map on its output +- failed child execution behavior is unchanged +- cancel/retry lifecycle behavior is unchanged + +Tests first: + +- a called workflow returning `{image: image_value}` completes the parent `call_saved_workflow` output with that key +- a caller-side extraction node can consume that output after parent resume +- missing or invalid `workflow_return` nodes still fail with the existing clear errors + +### Stage 4: Batch Return Aggregation + +Status: implemented in backend runtime tests. + +Goal: + +- define named returns for child workflows that expand into multiple queue rows + +Contract: + +- each child queue row produces one named return map +- the parent aggregates child maps as `dict[str, list[Any]]` +- each key maps to values returned by completed child rows for that key +- all child rows in one batch call must return the same key set +- repeated keys across child rows are expected +- duplicate keys within one child row remain invalid +- if a non-batch workflow wants multiple images under one key, it must collect those images into a single list value + before returning that key + +Tests first: + +- a batched child returning `{image: image_value}` from each child row produces `{image: [image_1, image_2, ...]}` +- sibling failure still cancels remaining siblings and fails the parent +- duplicate keys inside one child row are rejected rather than silently aggregated + +### Stage 5: Frontend Schema, UI, And Docs + +Status: partially implemented. Schema/type generation includes the backend nodes and fields; editor-specific UX cleanup +is still pending. + +Goal: + +- make named returns usable and visible in the editor + +Contract: + +- generated schema/types include the new return field, return-value node, and extraction node +- visible UI strings are localized through `en.json` +- `call_saved_workflow` exposes the named return map output +- users can wire that output to `workflow_return_get` + +Tests first: + +- frontend connection/type tests cover return-value collection wiring +- frontend connection/type tests cover wiring one `workflow_return_value.value` directly to `workflow_return.values` +- frontend connection/type tests cover `call_saved_workflow.values -> workflow_return_get.values` +- docs describe how a called workflow creates named returns and how a caller extracts them + +## Frontend Responsibilities In The Long-Term Design + +The frontend remains responsible for editor-time behavior: + +- choosing the saved workflow +- redrawing dynamic inputs based on the child workflow callable interface +- persisting those dynamic fields and their values +- preserving compatible inbound edges when workflow selection changes +- clearing incompatible edges and invalid selections in a predictable way +- using backend compatibility metadata so unsupported saved workflows are not presented as callable choices + - compatibility analysis now tolerates required exposed caller inputs by synthesizing placeholder values for those + inputs during backend compatibility evaluation, so workflows that are valid once the caller supplies exposed values + are not disabled prematurely + +Potential future optimization: + +- add a backend endpoint that returns a normalized callable workflow interface +- this would let the frontend avoid re-parsing full saved workflow payloads to redraw the node +- it would also give the frontend a backend-authoritative interface hash for drift detection + +## Tests Needed Going Forward + +Already covered: + +- workflow-call stack and waiting state on `GraphExecutionState` +- depth-limit enforcement +- waiting blocks scheduling +- parent sessions are not completed while waiting +- runner boundary entry for `call_saved_workflow` +- validation failures and depth-limit failures still follow normal node-error behavior +- child workflow JSON conversion to backend `Graph` +- child graph build failure does not leave the parent in a partial waiting state +- child `GraphExecutionState` is attached to the waiting parent session +- coordinator-owned child execution completes the parent queue item instead of leaving it stuck in `in_progress` +- literal and connected dynamic call arguments are applied to the child graph at runtime +- non-exposed dynamic call arguments are rejected at runtime +- child `workflow_return` output is captured and becomes the parent `call_saved_workflow` output +- named `workflow_return` values can be constructed, propagated to the parent, extracted by key, and batch-aggregated as + `dict[str, list[Any]]` +- child workflows without a `workflow_return` node fail cleanly when called +- child execution events now include stable workflow-call relationship metadata on the child `SessionQueueItem` +- parent-child resume and failure propagation through queue-visible child rows +- nested runtime execution with bounded stack depth +- direct and generator-backed batch-special child workflows through queue child-row expansion +- compatibility metadata for required exposed inputs, missing/multiple returns, supported named batch-return shapes, and + unsupported batch input wiring + +Still needed in later increments: + +- focused coverage for any newly supported batch or generator shape when its contract changes +- possible migration from dedicated workflow-call queue lifecycle handling to a more general scheduler or + queue-lifecycle model only if another feature needs reusable dependent queue items + +## Recommended Immediate Next Step + +The next incremental step should be: + +- stop adding feature slices unless they close a concrete correctness gap or unlock a realistic user workflow +- stabilize the current branch with review, targeted test runs, and cleanup of stale design-doc language +- treat migration from `WorkflowCallQueueLifecycle` to a generalized parent/child queue lifecycle as a larger + architecture slice, not as small follow-on busywork + +The current branch is at the point where: + +- parent call-boundary state exists +- child execution state can be created from the selected saved workflow +- child execution, argument forwarding, explicit return propagation, suspended parent status, queue-visible child rows, + and upward failure cascade work through the current coordinator + queue path +- but long-term generalized parent/child scheduling semantics are still missing diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 41a5a411c7a..d63d28444dc 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -72,6 +72,11 @@ def sanitize_queue_item_for_user( sanitized_item.priority = 0 sanitized_item.field_values = None sanitized_item.retried_from_item_id = None + sanitized_item.workflow_call_id = None + sanitized_item.parent_item_id = None + sanitized_item.parent_session_id = None + sanitized_item.root_item_id = None + sanitized_item.workflow_call_depth = None sanitized_item.workflow = None sanitized_item.error_type = None sanitized_item.error_message = None @@ -312,20 +317,27 @@ async def retry_items_by_id( ) -> RetryItemsResult: """Retries the given queue items. Users can only retry their own items unless they are an admin.""" try: - # Check authorization: user must own all items or be an admin - if not current_user.is_admin: - for item_id in item_ids: - try: - queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id) - if queue_item.user_id != current_user.user_id: - raise HTTPException( - status_code=403, detail=f"You do not have permission to retry queue item {item_id}" - ) - except SessionQueueItemNotFoundError: - # Skip items that don't exist - they will be handled by retry_items_by_id - continue + # Check queue membership for all items and ownership for non-admins. + valid_item_ids: list[int] = [] + for item_id in item_ids: + try: + queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id) + if queue_item.queue_id != queue_id: + raise HTTPException( + status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}" + ) + if not current_user.is_admin and queue_item.user_id != current_user.user_id: + raise HTTPException( + status_code=403, detail=f"You do not have permission to retry queue item {item_id}" + ) + valid_item_ids.append(item_id) + except SessionQueueItemNotFoundError: + # Skip items that don't exist - they will be handled by retry_items_by_id + continue - return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids) + return ApiDependencies.invoker.services.session_queue.retry_items_by_id( + queue_id=queue_id, item_ids=valid_item_ids + ) except HTTPException: raise except Exception as e: @@ -507,6 +519,8 @@ async def delete_queue_item( try: # Get the queue item to check ownership queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id) + if queue_item.queue_id != queue_id: + raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}") # Check authorization: user must own the item or be an admin if queue_item.user_id != current_user.user_id and not current_user.is_admin: @@ -537,6 +551,8 @@ async def cancel_queue_item( try: # Get the queue item to check ownership queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id) + if queue_item.queue_id != queue_id: + raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}") # Check authorization: user must own the item or be an admin if queue_item.user_id != current_user.user_id and not current_user.is_admin: diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py index eb893251953..7daf93fb5ce 100644 --- a/invokeai/app/api/routers/workflows.py +++ b/invokeai/app/api/routers/workflows.py @@ -10,6 +10,7 @@ from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.shared.workflow_call_compatibility import get_workflow_call_compatibility from invokeai.app.services.workflow_records.workflow_records_common import ( Workflow, WorkflowCategory, @@ -51,7 +52,18 @@ async def get_workflow( raise HTTPException(status_code=403, detail="Not authorized to access this workflow") thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) - return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) + compatibility = get_workflow_call_compatibility( + workflow=workflow.workflow.model_dump(), + workflow_id=workflow.workflow_id, + services=ApiDependencies.invoker.services, + user_id=current_user.user_id, + maximum_children=ApiDependencies.invoker.services.configuration.max_queue_size, + ) + return WorkflowRecordWithThumbnailDTO( + thumbnail_url=thumbnail_url, + call_saved_workflow_compatibility=compatibility, + **workflow.model_dump(), + ) @workflows_router.patch( @@ -66,17 +78,24 @@ async def update_workflow( workflow: Workflow = Body(description="The updated workflow", embed=True), ) -> WorkflowRecordDTO: """Updates a workflow""" + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration if config.multiuser: - try: - existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id) - except WorkflowNotFoundError: - raise HTTPException(status_code=404, detail="Workflow not found") if not current_user.is_admin and existing.user_id != current_user.user_id: raise HTTPException(status_code=403, detail="Not authorized to update this workflow") - # Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any. user_id = None if current_user.is_admin else current_user.user_id - return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id) + updated = ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id) + ApiDependencies.invoker.services.events.emit_workflow_updated( + workflow_id=updated.workflow_id, + user_id=updated.user_id, + old_is_public=existing.is_public, + new_is_public=updated.is_public, + ) + return updated @workflows_router.delete( @@ -88,12 +107,13 @@ async def delete_workflow( workflow_id: str = Path(description="The workflow to delete"), ) -> None: """Deletes a workflow""" + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration if config.multiuser: - try: - existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) - except WorkflowNotFoundError: - raise HTTPException(status_code=404, detail="Workflow not found") if not current_user.is_admin and existing.user_id != current_user.user_id: raise HTTPException(status_code=403, detail="Not authorized to delete this workflow") try: @@ -103,6 +123,11 @@ async def delete_workflow( pass user_id = None if current_user.is_admin else current_user.user_id ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id) + ApiDependencies.invoker.services.events.emit_workflow_deleted( + workflow_id=existing.workflow_id, + user_id=existing.user_id, + is_public=existing.is_public, + ) @workflows_router.post( @@ -121,9 +146,15 @@ async def create_workflow( # workflows remain visible. In multiuser mode, workflows are private to the creator by default. config = ApiDependencies.invoker.services.configuration is_public = not config.multiuser - return ApiDependencies.invoker.services.workflow_records.create( + created = ApiDependencies.invoker.services.workflow_records.create( workflow=workflow, user_id=current_user.user_id, is_public=is_public ) + ApiDependencies.invoker.services.events.emit_workflow_created( + workflow_id=created.workflow_id, + user_id=created.user_id, + is_public=created.is_public, + ) + return created @workflows_router.get( @@ -171,16 +202,31 @@ async def list_workflows( user_id=user_id_filter, is_public=is_public, ) + skipped_missing_workflows = 0 for workflow in workflows.items: + try: + full_workflow = ApiDependencies.invoker.services.workflow_records.get(workflow.workflow_id) + except WorkflowNotFoundError: + skipped_missing_workflows += 1 + continue + compatibility = get_workflow_call_compatibility( + workflow=full_workflow.workflow.model_dump(), + workflow_id=full_workflow.workflow_id, + services=ApiDependencies.invoker.services, + user_id=current_user.user_id, + maximum_children=ApiDependencies.invoker.services.configuration.max_queue_size, + resolve_generator_items=False, + ) workflows_with_thumbnails.append( WorkflowRecordListItemWithThumbnailDTO( thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id), + call_saved_workflow_compatibility=compatibility, **workflow.model_dump(), ) ) return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]( items=workflows_with_thumbnails, - total=workflows.total, + total=max(len(workflows_with_thumbnails), workflows.total - skipped_missing_workflows), page=workflows.page, pages=workflows.pages, per_page=workflows.per_page, @@ -312,9 +358,16 @@ async def update_workflow_is_public( raise HTTPException(status_code=403, detail="Not authorized to update this workflow") user_id = None if current_user.is_admin else current_user.user_id - return ApiDependencies.invoker.services.workflow_records.update_is_public( + updated = ApiDependencies.invoker.services.workflow_records.update_is_public( workflow_id=workflow_id, is_public=is_public, user_id=user_id ) + ApiDependencies.invoker.services.events.emit_workflow_updated( + workflow_id=updated.workflow_id, + user_id=updated.user_id, + old_is_public=existing.is_public, + new_is_public=updated.is_public, + ) + return updated @workflows_router.get("/tags", operation_id="get_all_tags") diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 5783b804c0b..c281cc62ab1 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -35,8 +35,13 @@ ModelLoadStartedEvent, QueueClearedEvent, QueueEventBase, + QueueItemsRetriedEvent, QueueItemStatusChangedEvent, RecallParametersUpdatedEvent, + WorkflowCreatedEvent, + WorkflowDeletedEvent, + WorkflowEventBase, + WorkflowUpdatedEvent, register_events, ) from invokeai.backend.util.logging import InvokeAILogger @@ -65,6 +70,7 @@ class BulkDownloadSubscriptionEvent(BaseModel): InvocationErrorEvent, QueueItemStatusChangedEvent, BatchEnqueuedEvent, + QueueItemsRetriedEvent, QueueClearedEvent, RecallParametersUpdatedEvent, } @@ -86,6 +92,7 @@ class BulkDownloadSubscriptionEvent(BaseModel): } BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent} +WORKFLOW_EVENTS = {WorkflowCreatedEvent, WorkflowUpdatedEvent, WorkflowDeletedEvent} class SocketIO: @@ -115,6 +122,7 @@ def __init__(self, app: FastAPI): register_events(QUEUE_EVENTS, self._handle_queue_event) register_events(MODEL_EVENTS, self._handle_model_event) register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event) + register_events(WORKFLOW_EVENTS, self._handle_workflow_event) async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> bool: """Handle socket connection and authenticate the user. @@ -167,6 +175,10 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b logger.info( f"Socket {sid} connected with user_id: {token_data.user_id}, is_admin: {token_data.is_admin}" ) + await self._sio.enter_room(sid, f"user:{token_data.user_id}") + await self._sio.enter_room(sid, "workflows:shared") + if token_data.is_admin: + await self._sio.enter_room(sid, "admin") return True # No valid token provided. In multiuser mode this is not allowed — reject @@ -183,6 +195,9 @@ async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> b "is_admin": True, } logger.debug(f"Socket {sid} connected as system admin (single-user mode)") + await self._sio.enter_room(sid, "user:system") + await self._sio.enter_room(sid, "workflows:shared") + await self._sio.enter_room(sid, "admin") return True @staticmethod @@ -329,6 +344,26 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room") + # QueueItemsRetriedEvent carries queue item ids that should only be visible + # to the affected owners + admins. + elif isinstance(event_data, QueueItemsRetriedEvent): + for user_id in event_data.user_ids: + user_room = f"user:{user_id}" + owner_event_data = event_data.model_copy( + update={ + "retried_item_ids": event_data.retried_item_ids_by_user.get(user_id, []), + "user_ids": [user_id], + "retried_item_ids_by_user": {user_id: event_data.retried_item_ids_by_user.get(user_id, [])}, + } + ) + await self._sio.emit( + event=event_name, data=owner_event_data.model_dump(mode="json"), room=user_room + ) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + logger.debug( + f"Emitted private queue_items_retried event to user rooms {event_data.user_ids} and admin room" + ) + else: # For remaining queue events (e.g. QueueClearedEvent) that do not # carry user identity, emit to all subscribers in the queue room. @@ -360,3 +395,32 @@ async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownlo await self._sio.emit( event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id ) + + async def _handle_workflow_event(self, event: FastAPIEvent[WorkflowEventBase]) -> None: + event_name, event_data = event + payload = event_data.model_dump(mode="json") + + if not self._is_multiuser_enabled(): + await self._sio.emit(event=event_name, data=payload, room="admin") + return + + await self._sio.emit(event=event_name, data=payload, room=f"user:{event_data.user_id}") + await self._sio.emit(event=event_name, data=payload, room="admin") + + if event_name == "workflow_created": + if getattr(event_data, "is_public", False): + await self._sio.emit(event=event_name, data=payload, room="workflows:shared") + return + + if event_name == "workflow_deleted": + if getattr(event_data, "is_public", False): + await self._sio.emit(event=event_name, data=payload, room="workflows:shared") + return + + if event_name == "workflow_updated": + if getattr(event_data, "new_is_public", False): + await self._sio.emit(event=event_name, data=payload, room="workflows:shared") + elif getattr(event_data, "old_is_public", False): + await self._sio.emit( + event="workflow_deleted", data={"workflow_id": event_data.workflow_id}, room="workflows:shared" + ) diff --git a/invokeai/app/invocations/call_saved_workflow.py b/invokeai/app/invocations/call_saved_workflow.py new file mode 100644 index 00000000000..36f03021b1c --- /dev/null +++ b/invokeai/app/invocations/call_saved_workflow.py @@ -0,0 +1,75 @@ +from typing import Any + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import InputField, UIType +from invokeai.app.invocations.workflow_return import WorkflowReturnOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowCategory, WorkflowNotFoundError + +CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX = "saved_workflow_input::" + + +def is_call_saved_workflow_dynamic_input(field_name: str) -> bool: + return field_name.startswith(CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX) + + +def parse_call_saved_workflow_dynamic_input(field_name: str) -> tuple[str, str]: + if not is_call_saved_workflow_dynamic_input(field_name): + raise ValueError(f"'{field_name}' is not a call_saved_workflow dynamic input field") + + raw_identifier = field_name.removeprefix(CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX) + node_id, separator, input_field_name = raw_identifier.rpartition("::") + if not separator or not node_id or not input_field_name: + raise ValueError(f"Invalid call_saved_workflow dynamic input field '{field_name}'") + + return node_id, input_field_name + + +@invocation( + "call_saved_workflow", + title="Call Saved Workflow", + tags=["workflow", "saved", "library"], + category="workflow", + version="1.0.0", + use_cache=False, + classification=Classification.Beta, +) +class CallSavedWorkflowInvocation(BaseInvocation): + """Displays and later executes against a selected saved workflow.""" + + workflow_id: str = InputField( + default="", + description="The selected saved workflow ID, managed by the workflow editor UI.", + ui_type=UIType.SavedWorkflow, + ) + workflow_inputs: dict[str, Any] = InputField( + default={}, + description="Literal values for the selected workflow's exposed inputs, managed by the workflow editor UI.", + ui_hidden=True, + ) + + def validate_selected_workflow(self, context: InvocationContext): + if not self.workflow_id: + raise ValueError("A saved workflow must be selected before executing call_saved_workflow.") + + try: + workflow_record = context._services.workflow_records.get(self.workflow_id) + except WorkflowNotFoundError as e: + raise ValueError(f"The selected saved workflow '{self.workflow_id}' could not be found.") from e + + config = context._services.configuration + if config.multiuser: + queue_user_id = context._data.queue_item.user_id + user = context._services.users.get(queue_user_id) + is_admin = bool(user and user.is_admin) + is_owner = workflow_record.user_id == queue_user_id + is_default = workflow_record.workflow.meta.category is WorkflowCategory.Default + if not (is_default or is_owner or workflow_record.is_public or is_admin): + raise ValueError(f"The selected saved workflow '{self.workflow_id}' is not accessible to this user.") + + return workflow_record + + def invoke(self, context: InvocationContext) -> WorkflowReturnOutput: + self.validate_selected_workflow(context) + + return WorkflowReturnOutput(values={}) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index e53aeb417b2..4418c86371a 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -49,6 +49,7 @@ class UIType(str, Enum, metaclass=MetaEnum): # region Misc Field Types Scheduler = "SchedulerField" Any = "AnyField" + SavedWorkflow = "SavedWorkflowField" # endregion # region Internal Field Types diff --git a/invokeai/app/invocations/workflow_return.py b/invokeai/app/invocations/workflow_return.py new file mode 100644 index 00000000000..9f517a26482 --- /dev/null +++ b/invokeai/app/invocations/workflow_return.py @@ -0,0 +1,139 @@ +from typing import Any + +from pydantic import BaseModel, Field + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext + + +@invocation_output("workflow_return_output") +class WorkflowReturnOutput(BaseInvocationOutput): + """The explicit named values returned from a callable workflow.""" + + values: dict[str, Any] = OutputField( + default={}, + description="The workflow return values, keyed by return name.", + title="Values", + ui_type=UIType.Any, + ) + + +class WorkflowReturnValueField(BaseModel): + """One named workflow return value.""" + + key: str = Field(description="The workflow return key.") + value: Any = Field(default=None, description="The workflow return value.") + + +@invocation_output("workflow_return_value_output") +class WorkflowReturnValueOutput(BaseInvocationOutput): + """A named workflow return value.""" + + value: WorkflowReturnValueField = OutputField( + description="The named workflow return value.", + title="Return Value", + ui_type=UIType._CollectionItem, + ) + + +@invocation( + "workflow_return_value", + title="Workflow Return Value", + tags=["workflow", "return", "output"], + category="workflow", + version="1.0.0", + classification=Classification.Beta, + use_cache=False, +) +class WorkflowReturnValueInvocation(BaseInvocation): + """Creates one named value for a callable workflow return.""" + + key: str = InputField(default="", description="The return key.", title="Key") + value: Any = InputField( + default=None, + description="The value returned under this key.", + title="Value", + ui_type=UIType.Any, + ) + + def invoke(self, context: InvocationContext) -> WorkflowReturnValueOutput: + key = self.key.strip() + if not key: + raise ValueError("Workflow return key must not be empty.") + return WorkflowReturnValueOutput(value=WorkflowReturnValueField(key=key, value=self.value)) + + +@invocation( + "workflow_return", + title="Workflow Return", + tags=["workflow", "return", "output"], + category="workflow", + version="1.0.0", + classification=Classification.Beta, + use_cache=False, +) +class WorkflowReturnInvocation(BaseInvocation): + """Defines the explicit named result returned by a callable workflow.""" + + values: WorkflowReturnValueField | list[WorkflowReturnValueField] = InputField( + default=[], + description="The named values returned to a calling workflow.", + title="Values", + input=Input.Connection, + ) + + def invoke(self, context: InvocationContext) -> WorkflowReturnOutput: + named_values: dict[str, Any] = {} + return_values = self.values if isinstance(self.values, list) else [self.values] + for value in return_values: + key = value.key.strip() + if not key: + raise ValueError("Workflow return key must not be empty.") + if key in named_values: + raise ValueError(f"Duplicate workflow return key '{key}'.") + named_values[key] = value.value + return WorkflowReturnOutput(values=named_values) + + +@invocation_output("workflow_return_get_output") +class WorkflowReturnGetOutput(BaseInvocationOutput): + """A value extracted from named workflow return values.""" + + value: Any = OutputField(description="The extracted workflow return value.", title="Value", ui_type=UIType.Any) + + +@invocation( + "workflow_return_get", + title="Get Workflow Return Value", + tags=["workflow", "return", "input"], + category="workflow", + version="1.0.0", + classification=Classification.Beta, + use_cache=False, +) +class WorkflowReturnGetInvocation(BaseInvocation): + """Extracts one named value from a callable workflow return.""" + + values: dict[str, Any] = InputField( + default={}, + description="The named workflow return values.", + title="Values", + ui_type=UIType.Any, + input=Input.Connection, + ) + key: str = InputField(default="", description="The return key to extract.", title="Key") + + def invoke(self, context: InvocationContext) -> WorkflowReturnGetOutput: + key = self.key.strip() + if not key: + raise ValueError("Workflow return key must not be empty.") + if key not in self.values: + raise ValueError(f"Workflow return key '{key}' was not found.") + return WorkflowReturnGetOutput(value=self.values[key]) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 935b422a732..0e1f71c2bc7 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -32,6 +32,9 @@ QueueItemsRetriedEvent, QueueItemStatusChangedEvent, RecallParametersUpdatedEvent, + WorkflowCreatedEvent, + WorkflowDeletedEvent, + WorkflowUpdatedEvent, ) if TYPE_CHECKING: @@ -104,9 +107,11 @@ def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", user_id: str """Emitted when a batch is enqueued""" self.dispatch(BatchEnqueuedEvent.build(enqueue_result, user_id)) - def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None: + def emit_queue_items_retried( + self, retry_result: "RetryItemsResult", user_ids: list[str], retried_item_ids_by_user: dict[str, list[int]] + ) -> None: """Emitted when a list of queue items are retried""" - self.dispatch(QueueItemsRetriedEvent.build(retry_result)) + self.dispatch(QueueItemsRetriedEvent.build(retry_result, user_ids, retried_item_ids_by_user)) def emit_queue_cleared(self, queue_id: str) -> None: """Emitted when a queue is cleared""" @@ -118,6 +123,29 @@ def emit_recall_parameters_updated(self, queue_id: str, user_id: str, parameters # endregion + # region Workflow library + + def emit_workflow_created(self, workflow_id: str, user_id: str, is_public: bool) -> None: + """Emitted when a workflow is created.""" + self.dispatch(WorkflowCreatedEvent.build(workflow_id=workflow_id, user_id=user_id, is_public=is_public)) + + def emit_workflow_updated(self, workflow_id: str, user_id: str, old_is_public: bool, new_is_public: bool) -> None: + """Emitted when a workflow is updated.""" + self.dispatch( + WorkflowUpdatedEvent.build( + workflow_id=workflow_id, + user_id=user_id, + old_is_public=old_is_public, + new_is_public=new_is_public, + ) + ) + + def emit_workflow_deleted(self, workflow_id: str, user_id: str, is_public: bool) -> None: + """Emitted when a workflow is deleted.""" + self.dispatch(WorkflowDeletedEvent.build(workflow_id=workflow_id, user_id=user_id, is_public=is_public)) + + # endregion + # region Download def emit_download_started(self, job: "DownloadJob") -> None: diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 0c530f9a2f7..8297f1f42ca 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -308,12 +308,20 @@ class QueueItemsRetriedEvent(QueueEventBase): __event_name__ = "queue_items_retried" retried_item_ids: list[int] = Field(description="The IDs of the queue items that were retried") + user_ids: list[str] = Field(description="The IDs of the users who own the retried root queue items") + retried_item_ids_by_user: dict[str, list[int]] = Field( + description="The retried root queue item IDs keyed by owner user ID." + ) @classmethod - def build(cls, retry_result: RetryItemsResult) -> "QueueItemsRetriedEvent": + def build( + cls, retry_result: RetryItemsResult, user_ids: list[str], retried_item_ids_by_user: dict[str, list[int]] + ) -> "QueueItemsRetriedEvent": return cls( queue_id=retry_result.queue_id, retried_item_ids=retry_result.retried_item_ids, + user_ids=user_ids, + retried_item_ids_by_user=retried_item_ids_by_user, ) @@ -328,6 +336,58 @@ def build(cls, queue_id: str) -> "QueueClearedEvent": return cls(queue_id=queue_id) +class WorkflowEventBase(EventBase): + """Base class for workflow library CRUD events.""" + + workflow_id: str = Field(description="The ID of the workflow") + user_id: str = Field(description="The owner of the workflow") + + +@payload_schema.register +class WorkflowCreatedEvent(WorkflowEventBase): + """Event model for workflow_created""" + + __event_name__ = "workflow_created" + + is_public: bool = Field(description="Whether the workflow is shared with all users") + + @classmethod + def build(cls, workflow_id: str, user_id: str, is_public: bool) -> "WorkflowCreatedEvent": + return cls(workflow_id=workflow_id, user_id=user_id, is_public=is_public) + + +@payload_schema.register +class WorkflowUpdatedEvent(WorkflowEventBase): + """Event model for workflow_updated""" + + __event_name__ = "workflow_updated" + + old_is_public: bool = Field(description="Whether the workflow was shared before the update") + new_is_public: bool = Field(description="Whether the workflow is shared after the update") + + @classmethod + def build(cls, workflow_id: str, user_id: str, old_is_public: bool, new_is_public: bool) -> "WorkflowUpdatedEvent": + return cls( + workflow_id=workflow_id, + user_id=user_id, + old_is_public=old_is_public, + new_is_public=new_is_public, + ) + + +@payload_schema.register +class WorkflowDeletedEvent(WorkflowEventBase): + """Event model for workflow_deleted""" + + __event_name__ = "workflow_deleted" + + is_public: bool = Field(description="Whether the workflow was shared when it was deleted") + + @classmethod + def build(cls, workflow_id: str, user_id: str, is_public: bool) -> "WorkflowDeletedEvent": + return cls(workflow_id=workflow_id, user_id=user_id, is_public=is_public) + + class DownloadEventBase(EventBase): """Base class for events associated with a download""" diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7159c19e746..bbfa8d4f40b 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -6,6 +6,7 @@ from typing import Optional from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.invocations.call_saved_workflow import CallSavedWorkflowInvocation from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, FastAPIEvent, @@ -27,6 +28,10 @@ SessionRunnerBase, ) from invokeai.app.services.session_processor.session_processor_common import CanceledException, SessionProcessorStatus +from invokeai.app.services.session_processor.workflow_call_runtime import ( + WorkflowCallCoordinator, + WorkflowCallQueueLifecycle, +) from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context @@ -58,6 +63,8 @@ def __init__( self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] self._on_node_error_callbacks = on_node_error_callbacks or [] self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] + self.workflow_call_coordinator = WorkflowCallCoordinator(self) + self.workflow_call_queue_lifecycle = WorkflowCallQueueLifecycle(self) def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): self._services = services @@ -69,11 +76,7 @@ def _is_canceled(self) -> bool: denoising to check if the session has been canceled.""" return self._cancel_event.is_set() - def run(self, queue_item: SessionQueueItem): - # Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here. - - self._on_before_run_session(queue_item=queue_item) - + def _run_session_loop(self, queue_item: SessionQueueItem) -> None: # Loop over invocations until the session is complete or canceled while True: try: @@ -107,6 +110,11 @@ def run(self, queue_item: SessionQueueItem): ): break + def run(self, queue_item: SessionQueueItem): + # Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here. + + self._on_before_run_session(queue_item=queue_item) + self._run_session_loop(queue_item) self._on_after_run_session(queue_item=queue_item) def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): @@ -126,6 +134,11 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): is_canceled=self._is_canceled, ) + if isinstance(invocation, CallSavedWorkflowInvocation): + workflow_record = invocation.validate_selected_workflow(context) + self.workflow_call_coordinator.begin_workflow_call_boundary(invocation, queue_item, workflow_record) + return + # Invoke the node output = invocation.invoke_internal(context=context, services=self._services) # Save output and history @@ -201,7 +214,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # The queue item may have been canceled or failed while the session was running. We should only complete it # if it is not already canceled or failed. - if queue_item.status not in ["canceled", "failed"]: + if queue_item.status not in ["canceled", "failed"] and queue_item.session.is_complete(): queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id) # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor @@ -316,6 +329,7 @@ def __init__( super().__init__() self.session_runner = session_runner if session_runner else DefaultSessionRunner() + self.workflow_call_queue_lifecycle = self.session_runner.workflow_call_queue_lifecycle self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] self._thread_limit = thread_limit self._polling_interval = polling_interval @@ -458,7 +472,7 @@ def _process( cancel_event.clear() # Run the graph - self.session_runner.run(queue_item=self._queue_item) + self.workflow_call_queue_lifecycle.run_queue_item(self._queue_item) except Exception as e: error_type = e.__class__.__name__ diff --git a/invokeai/app/services/session_processor/workflow_call_batch.py b/invokeai/app/services/session_processor/workflow_call_batch.py new file mode 100644 index 00000000000..a37b9d5100b --- /dev/null +++ b/invokeai/app/services/session_processor/workflow_call_batch.py @@ -0,0 +1,701 @@ +from __future__ import annotations + +import copy +import json +import random +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import Any + +from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator + +from invokeai.app.invocations.fields import ImageField +from invokeai.app.services.board_records.board_records_common import BoardRecordOrderBy, BoardVisibility +from invokeai.app.services.image_records.image_records_common import ASSETS_CATEGORIES, IMAGE_CATEGORIES +from invokeai.app.services.session_queue.session_queue_common import ( + Batch, + BatchDatum, + NodeFieldValue, + TooManySessionsError, + calc_session_count, + create_session_nfv_tuples, +) +from invokeai.app.services.shared.graph import GraphExecutionState, WorkflowCallFrame +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.shared.workflow_graph_builder import ( + UnsupportedWorkflowNodeError, + apply_workflow_inputs_to_workflow, + build_graph_from_workflow, +) + +BATCH_FIELD_NAMES = { + "image_batch": "images", + "string_batch": "strings", + "integer_batch": "integers", + "float_batch": "floats", +} +SUPPORTED_BATCH_TYPES = set(BATCH_FIELD_NAMES) +SUPPORTED_BATCH_GROUP_IDS = { + "None", + "Group 1", + "Group 2", + "Group 3", + "Group 4", + "Group 5", +} +CONNECTOR_INPUT_HANDLE = "in" + + +@dataclass(frozen=True) +class WorkflowCallChildSessionResult: + session: GraphExecutionState + field_values: list[NodeFieldValue] | None = None + + +def _is_mapping(value: Any) -> bool: + return isinstance(value, Mapping) + + +def _is_invocation_node(node: Any) -> bool: + return _is_mapping(node) and node.get("type") == "invocation" and _is_mapping(node.get("data")) + + +def _is_connector_node(node: Any) -> bool: + return _is_mapping(node) and node.get("type") == "connector" + + +def workflow_contains_supported_batch_nodes(workflow: Mapping[str, Any]) -> bool: + workflow_nodes = workflow.get("nodes", []) + if not isinstance(workflow_nodes, Sequence): + return False + return any( + _is_invocation_node(node) and node["data"].get("type") in SUPPORTED_BATCH_TYPES for node in workflow_nodes + ) + + +def _get_workflow_nodes(workflow: Mapping[str, Any]) -> dict[str, Mapping[str, Any]]: + workflow_nodes = workflow.get("nodes", []) + if not isinstance(workflow_nodes, Sequence): + return {} + return {node["id"]: node for node in workflow_nodes if _is_mapping(node) and isinstance(node.get("id"), str)} + + +def _get_default_edges(workflow: Mapping[str, Any]) -> list[Mapping[str, Any]]: + workflow_edges = workflow.get("edges", []) + if not isinstance(workflow_edges, Sequence): + return [] + return [edge for edge in workflow_edges if _is_mapping(edge) and edge.get("type") == "default"] + + +def _get_connector_input_edge( + connector_id: str, workflow_edges: Sequence[Mapping[str, Any]] +) -> Mapping[str, Any] | None: + return next( + ( + edge + for edge in workflow_edges + if edge.get("target") == connector_id and edge.get("targetHandle") == CONNECTOR_INPUT_HANDLE + ), + None, + ) + + +def _resolve_connector_source( + connector_id: str, workflow_nodes: Mapping[str, Mapping[str, Any]], workflow_edges: Sequence[Mapping[str, Any]] +) -> tuple[str, str] | None: + visited: set[str] = set() + + def resolve(node_id: str) -> tuple[str, str] | None: + if node_id in visited: + return None + visited.add(node_id) + + incoming_edge = _get_connector_input_edge(node_id, workflow_edges) + if incoming_edge is None: + return None + + source_id = incoming_edge.get("source") + source_handle = incoming_edge.get("sourceHandle") + if not isinstance(source_id, str) or not isinstance(source_handle, str): + return None + + source_node = workflow_nodes.get(source_id) + if source_node is None: + return None + + if _is_invocation_node(source_node): + return (source_id, source_handle) + + if _is_connector_node(source_node): + return resolve(source_id) + + return None + + return resolve(connector_id) + + +def _build_child_graph_workflow(workflow: Mapping[str, Any], used_generator_node_ids: set[str]) -> dict[str, Any]: + workflow_nodes = workflow.get("nodes", []) + workflow_edges = workflow.get("edges", []) + if not isinstance(workflow_nodes, list) or not isinstance(workflow_edges, list): + raise UnsupportedWorkflowNodeError("call_saved_workflow child workflow is malformed") + + filtered_nodes = [ + node + for node in workflow_nodes + if not ( + _is_invocation_node(node) + and ( + node["data"].get("type") in SUPPORTED_BATCH_TYPES + or (isinstance(node.get("id"), str) and node["id"] in used_generator_node_ids) + ) + ) + ] + filtered_node_ids = {node["id"] for node in filtered_nodes if _is_mapping(node) and isinstance(node.get("id"), str)} + filtered_edges = [ + edge + for edge in workflow_edges + if _is_mapping(edge) + and edge.get("type") == "default" + and edge.get("source") in filtered_node_ids + and edge.get("target") in filtered_node_ids + ] + return {**workflow, "nodes": filtered_nodes, "edges": filtered_edges} + + +def _reject_unrelated_generator_nodes(workflow: Mapping[str, Any], used_generator_node_ids: set[str]) -> None: + workflow_nodes = workflow.get("nodes", []) + if not isinstance(workflow_nodes, list): + raise UnsupportedWorkflowNodeError("call_saved_workflow child workflow is malformed") + + unrelated_generator_nodes: list[tuple[str, str]] = [] + for node in workflow_nodes: + if not _is_invocation_node(node): + continue + + node_data = node["data"] + node_id = node_data.get("id") + node_type = node_data.get("type") + if not isinstance(node_id, str) or not isinstance(node_type, str): + continue + if node_type.endswith("_generator") and node_id not in used_generator_node_ids: + unrelated_generator_nodes.append((node_type, node_id)) + + if unrelated_generator_nodes: + unsupported_nodes = ", ".join( + f"'{node_type}' (node '{node_id}')" for node_type, node_id in unrelated_generator_nodes + ) + raise UnsupportedWorkflowNodeError( + "call_saved_workflow does not yet support child workflows that mix supported batch nodes with " + f"unrelated generator nodes: {unsupported_nodes}" + ) + + +def _get_batch_group_id(node_data: Mapping[str, Any]) -> str: + inputs = node_data.get("inputs") + if not _is_mapping(inputs): + return "None" + batch_group_input = inputs.get("batch_group_id") + if not _is_mapping(batch_group_input): + return "None" + batch_group_id = batch_group_input.get("value") + if not isinstance(batch_group_id, str): + return "None" + if batch_group_id not in SUPPORTED_BATCH_GROUP_IDS: + raise UnsupportedWorkflowNodeError(f"Unsupported batch group id '{batch_group_id}' in called workflow") + return batch_group_id + + +def _get_batch_items(node_data: Mapping[str, Any], field_name: str) -> list[Any]: + inputs = node_data.get("inputs") + if not _is_mapping(inputs): + raise UnsupportedWorkflowNodeError("call_saved_workflow batch child workflow node inputs are malformed") + batch_input = inputs.get(field_name) + if not _is_mapping(batch_input): + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow batch child workflow node is missing required '{field_name}' input" + ) + batch_items = batch_input.get("value") + if not isinstance(batch_items, list): + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow batch child workflow node '{node_data.get('id')}' must provide a direct list for '{field_name}'" + ) + return batch_items + + +def _parse_split_values(input_value: str, split_on: str) -> list[str]: + if split_on == "": + return [input_value] + try: + return input_value.split(json.loads(f'"{split_on}"')) + except Exception: + return input_value.split(split_on) + + +def _resolve_float_generator(value: Mapping[str, Any]) -> list[float]: + generator_type = value.get("type") + if generator_type == "float_generator_arithmetic_sequence": + start = float(value.get("start", 0)) + step = float(value.get("step", 1)) + count = int(value.get("count", 10)) + if step == 0: + return [start] + return [start + i * step for i in range(count)] + if generator_type == "float_generator_linear_distribution": + start = float(value.get("start", 0)) + end = float(value.get("end", 1)) + count = int(value.get("count", 10)) + if count == 1: + return [start] + return [start + (end - start) * (i / (count - 1)) for i in range(count)] + if generator_type == "float_generator_random_distribution_uniform": + minimum = float(value.get("min", 0)) + maximum = float(value.get("max", 1)) + count = int(value.get("count", 10)) + if "values" in value and isinstance(value["values"], list): + return [float(v) for v in value["values"]] + rng = random.Random(value.get("seed")) + return [rng.random() * (maximum - minimum) + minimum for _ in range(count)] + if generator_type == "float_generator_parse_string": + if "values" in value and isinstance(value["values"], list): + return [float(v) for v in value["values"]] + split_values = _parse_split_values(str(value.get("input", "")), str(value.get("splitOn", ","))) + return [float(v.strip()) for v in split_values if v.strip()] + raise UnsupportedWorkflowNodeError(f"Unsupported float generator type '{generator_type}'") + + +def _resolve_integer_generator(value: Mapping[str, Any]) -> list[int]: + generator_type = value.get("type") + if generator_type == "integer_generator_arithmetic_sequence": + start = int(value.get("start", 0)) + step = int(value.get("step", 1)) + count = int(value.get("count", 10)) + if step == 0: + return [start] + return [start + i * step for i in range(count)] + if generator_type == "integer_generator_linear_distribution": + start = int(value.get("start", 0)) + end = int(value.get("end", 10)) + count = int(value.get("count", 10)) + if count == 1: + return [start] + return [start + round((end - start) * (i / (count - 1))) for i in range(count)] + if generator_type == "integer_generator_random_distribution_uniform": + minimum = int(value.get("min", 0)) + maximum = int(value.get("max", 10)) + count = int(value.get("count", 10)) + rng = random.Random(value.get("seed")) + return [int(rng.random() * (maximum - minimum + 1)) + minimum for _ in range(count)] + if generator_type == "integer_generator_parse_string": + split_values = _parse_split_values(str(value.get("input", "")), str(value.get("splitOn", ","))) + return [int(v.strip()) for v in split_values if v.strip()] + raise UnsupportedWorkflowNodeError(f"Unsupported integer generator type '{generator_type}'") + + +def _resolve_string_generator(value: Mapping[str, Any]) -> list[str]: + generator_type = value.get("type") + if generator_type == "string_generator_parse_string": + return [v for v in _parse_split_values(str(value.get("input", "")), str(value.get("splitOn", ","))) if v] + if generator_type == "string_generator_dynamic_prompts_combinatorial": + generator = CombinatorialPromptGenerator() + return list(generator.generate(str(value.get("input", "")), max_prompts=int(value.get("maxPrompts", 10)))) + if generator_type == "string_generator_dynamic_prompts_random": + seed = value.get("seed") + if seed is None: + seed = random.randint(0, 2**31 - 1) + generator = RandomPromptGenerator(seed=int(seed)) + return list(generator.generate(str(value.get("input", "")), num_images=int(value.get("count", 10)))) + raise UnsupportedWorkflowNodeError(f"Unsupported string generator type '{generator_type}'") + + +def _assert_user_can_access_board(board_id: str, services: Any, user_id: str | None) -> None: + if not user_id: + return + + board_records = getattr(services, "board_records", None) + if board_records is None or not hasattr(board_records, "get"): + return + + users = getattr(services, "users", None) + user = users.get(user_id) if users is not None and hasattr(users, "get") else None + is_admin = bool(user and getattr(user, "is_admin", False)) + if is_admin: + return + + try: + board_record = board_records.get(board_id) + except Exception as e: + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow could not access board '{board_id}' for image generator expansion" + ) from e + + if getattr(board_record, "user_id", None) == user_id: + return + + board_visibility = getattr(board_record, "board_visibility", BoardVisibility.Private) + if isinstance(board_visibility, str): + try: + board_visibility = BoardVisibility(board_visibility) + except ValueError: + board_visibility = BoardVisibility.Private + if board_visibility in {BoardVisibility.Shared, BoardVisibility.Public}: + return + + if hasattr(board_records, "get_all"): + try: + accessible_boards = board_records.get_all( + user_id=user_id, + is_admin=False, + order_by=BoardRecordOrderBy.Name, + direction=SQLiteDirection.Ascending, + include_archived=True, + ) + except Exception: + accessible_boards = [] + if any(getattr(board, "board_id", None) == board_id for board in accessible_boards): + return + + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow caller does not have access to board '{board_id}' for image generator expansion" + ) + + +def _resolve_image_generator(value: Mapping[str, Any], services: Any, user_id: str | None) -> list[ImageField]: + generator_type = value.get("type") + if generator_type != "image_generator_images_from_board": + raise UnsupportedWorkflowNodeError(f"Unsupported image generator type '{generator_type}'") + board_id = value.get("board_id") + if not isinstance(board_id, str) or not board_id: + return [] + _assert_user_can_access_board(board_id, services, user_id) + category = value.get("category", "images") + categories = IMAGE_CATEGORIES if category == "images" else ASSETS_CATEGORIES + image_names = services.board_images.get_all_board_image_names_for_board( + board_id=board_id, + categories=categories, + is_intermediate=False, + ) + return [ImageField(image_name=image_name) for image_name in image_names] + + +def _resolve_generator_items(generator_node: Mapping[str, Any], services: Any, user_id: str | None) -> list[Any]: + generator_node_data = generator_node["data"] + node_type = generator_node_data.get("type") + inputs = generator_node_data.get("inputs") + if not isinstance(node_type, str) or not _is_mapping(inputs): + raise UnsupportedWorkflowNodeError("call_saved_workflow generator node is malformed") + generator_input = inputs.get("generator") + if not _is_mapping(generator_input): + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow generator node '{generator_node_data.get('id')}' is missing generator input" + ) + generator_value = generator_input.get("value") + if not _is_mapping(generator_value): + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow generator node '{generator_node_data.get('id')}' has invalid generator value" + ) + if node_type == "integer_generator": + return _resolve_integer_generator(generator_value) + if node_type == "float_generator": + return _resolve_float_generator(generator_value) + if node_type == "string_generator": + return _resolve_string_generator(generator_value) + if node_type == "image_generator": + return _resolve_image_generator(generator_value, services, user_id) + raise UnsupportedWorkflowNodeError(f"Unsupported generator node type '{node_type}'") + + +def _get_generator_placeholder_items(generator_node: Mapping[str, Any]) -> list[Any]: + generator_node_data = generator_node["data"] + node_type = generator_node_data.get("type") + if node_type == "integer_generator": + return [0] + if node_type == "float_generator": + return [0.0] + if node_type == "string_generator": + return [""] + if node_type == "image_generator": + return [ImageField(image_name="compatibility-placeholder")] + raise UnsupportedWorkflowNodeError(f"Unsupported generator node type '{node_type}'") + + +def _get_outgoing_default_edges( + node_id: str, source_handle: str, workflow_edges: Sequence[Mapping[str, Any]] +) -> list[Mapping[str, Any]]: + return [ + edge + for edge in workflow_edges + if edge.get("source") == node_id and edge.get("sourceHandle") == source_handle and edge.get("type") == "default" + ] + + +def _resolve_connector_destinations( + connector_id: str, workflow_nodes: Mapping[str, Mapping[str, Any]], workflow_edges: Sequence[Mapping[str, Any]] +) -> list[tuple[str, str]]: + visited: set[str] = set() + destinations: list[tuple[str, str]] = [] + stack = [connector_id] + while stack: + current_id = stack.pop() + if current_id in visited: + continue + visited.add(current_id) + outgoing_edges = _get_outgoing_default_edges(current_id, "out", workflow_edges) + for edge in outgoing_edges: + target_id = edge.get("target") + target_handle = edge.get("targetHandle") + if not isinstance(target_id, str) or not isinstance(target_handle, str): + continue + target_node = workflow_nodes.get(target_id) + if target_node is None: + continue + if _is_invocation_node(target_node): + destinations.append((target_id, target_handle)) + elif _is_connector_node(target_node): + stack.append(target_id) + return destinations + + +def _resolve_batch_destinations( + node_id: str, + source_handle: str, + workflow_nodes: Mapping[str, Mapping[str, Any]], + workflow_edges: Sequence[Mapping[str, Any]], +) -> list[tuple[str, str]]: + destinations: list[tuple[str, str]] = [] + for edge in _get_outgoing_default_edges(node_id, source_handle, workflow_edges): + target_id = edge.get("target") + target_handle = edge.get("targetHandle") + if not isinstance(target_id, str) or not isinstance(target_handle, str): + continue + target_node = workflow_nodes.get(target_id) + if target_node is None: + continue + if _is_invocation_node(target_node): + destinations.append((target_id, target_handle)) + elif _is_connector_node(target_node): + destinations.extend(_resolve_connector_destinations(target_id, workflow_nodes, workflow_edges)) + return destinations + + +def _normalize_batch_item_for_destination(destination_field: str, batch_items: list[Any]) -> list[Any]: + if destination_field == "collection": + return [[item] for item in batch_items] + return batch_items + + +def _resolve_batch_items_from_inputs( + node_id: str, + field_name: str, + workflow_edges: Sequence[Mapping[str, Any]], + workflow_nodes: Mapping[str, Mapping[str, Any]], +) -> list[Any] | None: + incoming_edges = [ + edge + for edge in workflow_edges + if edge.get("target") == node_id and edge.get("targetHandle") == field_name and edge.get("type") == "default" + ] + if not incoming_edges: + return None + incoming_source_ids = [edge.get("source") for edge in incoming_edges if isinstance(edge.get("source"), str)] + if len(incoming_source_ids) != 1: + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow does not yet support multiple connected batch inputs on node '{node_id}'" + ) + source_id = incoming_source_ids[0] + source_node = workflow_nodes.get(source_id) + if _is_invocation_node(source_node) and source_node["data"].get("type", "").endswith("_generator"): + return source_id + if _is_connector_node(source_node): + resolved_source = _resolve_connector_source(source_id, workflow_nodes, workflow_edges) + if resolved_source is not None: + resolved_source_id, _resolved_source_handle = resolved_source + resolved_source_node = workflow_nodes.get(resolved_source_id) + if _is_invocation_node(resolved_source_node) and resolved_source_node["data"].get("type", "").endswith( + "_generator" + ): + return resolved_source_id + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow does not yet support connected batch child workflow inputs on node '{node_id}'" + ) + + +def build_batch_child_workflow_session_results( + *, + parent_session: GraphExecutionState, + workflow: Mapping[str, Any], + workflow_inputs: Mapping[str, Any], + call_frame: WorkflowCallFrame, + maximum_children: int, + services: Any = None, + user_id: str | None = None, + resolve_generator_items: bool = True, +) -> list[GraphExecutionState]: + mutable_workflow = copy.deepcopy(workflow) + apply_workflow_inputs_to_workflow(mutable_workflow, workflow_inputs) + + workflow_nodes = _get_workflow_nodes(mutable_workflow) + workflow_edges = _get_default_edges(mutable_workflow) + + batch_data_by_group: dict[str, list[BatchDatum]] = {} + used_generator_node_ids: set[str] = set() + for node in workflow_nodes.values(): + if not _is_invocation_node(node): + continue + node_data = node["data"] + node_id = node_data.get("id") + node_type = node_data.get("type") + if not isinstance(node_id, str) or not isinstance(node_type, str): + continue + if node_type.endswith("_generator"): + continue + if node_type not in SUPPORTED_BATCH_TYPES: + continue + + field_name = BATCH_FIELD_NAMES[node_type] + generator_source_id = _resolve_batch_items_from_inputs(node_id, field_name, workflow_edges, workflow_nodes) + if generator_source_id is not None: + generator_node = workflow_nodes.get(generator_source_id) + if generator_node is None: + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow generator-backed batch child workflow is missing generator node '{generator_source_id}'" + ) + generator_node_type = generator_node["data"].get("type") if _is_invocation_node(generator_node) else None + if generator_node_type == "image_generator" and services is None and resolve_generator_items: + raise UnsupportedWorkflowNodeError( + "call_saved_workflow image-generator-backed batch child workflows require runtime services" + ) + batch_items = ( + _resolve_generator_items(generator_node, services, user_id) + if resolve_generator_items + else _get_generator_placeholder_items(generator_node) + ) + used_generator_node_ids.add(generator_source_id) + if not batch_items: + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow generator-backed batch child workflow node '{generator_source_id}' produced no batch items" + ) + else: + batch_items = _get_batch_items(node_data, field_name) + if not batch_items: + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow batch child workflow node '{node_id}' must provide at least one batch item" + ) + batch_group_id = _get_batch_group_id(node_data) + destinations = _resolve_batch_destinations(node_id, field_name, workflow_nodes, workflow_edges) + if not destinations: + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow batch child workflow node '{node_id}' is not connected to any invocation input" + ) + group_batch_data = batch_data_by_group.setdefault(batch_group_id, []) + for destination_node_id, destination_field in destinations: + group_batch_data.append( + BatchDatum( + node_path=destination_node_id, + field_name=destination_field, + items=_normalize_batch_item_for_destination(destination_field, batch_items), + ) + ) + + if not batch_data_by_group: + raise UnsupportedWorkflowNodeError("call_saved_workflow batch child workflow contains no supported batch nodes") + + _reject_unrelated_generator_nodes(mutable_workflow, used_generator_node_ids) + sanitized_workflow = _build_child_graph_workflow(mutable_workflow, used_generator_node_ids) + child_graph = build_graph_from_workflow(sanitized_workflow) + batch_data = [[datum] for datum in batch_data_by_group.pop("None", [])] + batch_data.extend(batch_data_by_group.values()) + batch = Batch(graph=child_graph, data=batch_data) + if calc_session_count(batch) > maximum_children: + raise TooManySessionsError("call_saved_workflow exceeds remaining queue capacity for child workflow executions") + + child_session_results: list[WorkflowCallChildSessionResult] = [] + for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, maximum_children): + generated_session = GraphExecutionState.model_validate_json(session_json) + child_session = parent_session.create_child_workflow_execution_state(generated_session.graph, call_frame) + child_session.id = session_id + field_values = [NodeFieldValue.model_validate(field_value) for field_value in json.loads(field_values_json)] + child_session_results.append(WorkflowCallChildSessionResult(session=child_session, field_values=field_values)) + return child_session_results + + +def build_batch_child_workflow_sessions( + *, + parent_session: GraphExecutionState, + workflow: Mapping[str, Any], + workflow_inputs: Mapping[str, Any], + call_frame: WorkflowCallFrame, + maximum_children: int, + services: Any = None, + user_id: str | None = None, + resolve_generator_items: bool = True, +) -> list[GraphExecutionState]: + return [ + child_result.session + for child_result in build_batch_child_workflow_session_results( + parent_session=parent_session, + workflow=workflow, + workflow_inputs=workflow_inputs, + call_frame=call_frame, + maximum_children=maximum_children, + services=services, + user_id=user_id, + resolve_generator_items=resolve_generator_items, + ) + ] + + +def build_child_workflow_session_results( + *, + parent_session: GraphExecutionState, + workflow: Mapping[str, Any], + workflow_inputs: Mapping[str, Any], + call_frame: WorkflowCallFrame, + maximum_children: int, + services: Any = None, + user_id: str | None = None, + resolve_generator_items: bool = True, +) -> list[WorkflowCallChildSessionResult]: + if workflow_contains_supported_batch_nodes(workflow): + return build_batch_child_workflow_session_results( + parent_session=parent_session, + workflow=workflow, + workflow_inputs=workflow_inputs, + call_frame=call_frame, + maximum_children=maximum_children, + services=services, + user_id=user_id, + resolve_generator_items=resolve_generator_items, + ) + + mutable_workflow = copy.deepcopy(workflow) + apply_workflow_inputs_to_workflow(mutable_workflow, workflow_inputs) + child_graph = build_graph_from_workflow(mutable_workflow) + child_session = parent_session.create_child_workflow_execution_state(child_graph, call_frame) + return [WorkflowCallChildSessionResult(session=child_session)] + + +def build_child_workflow_sessions( + *, + parent_session: GraphExecutionState, + workflow: Mapping[str, Any], + workflow_inputs: Mapping[str, Any], + call_frame: WorkflowCallFrame, + maximum_children: int, + services: Any = None, + user_id: str | None = None, + resolve_generator_items: bool = True, +) -> list[GraphExecutionState]: + return [ + child_result.session + for child_result in build_child_workflow_session_results( + parent_session=parent_session, + workflow=workflow, + workflow_inputs=workflow_inputs, + call_frame=call_frame, + maximum_children=maximum_children, + services=services, + user_id=user_id, + resolve_generator_items=resolve_generator_items, + ) + ] diff --git a/invokeai/app/services/session_processor/workflow_call_runtime.py b/invokeai/app/services/session_processor/workflow_call_runtime.py new file mode 100644 index 00000000000..f7ff3a729de --- /dev/null +++ b/invokeai/app/services/session_processor/workflow_call_runtime.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from invokeai.app.invocations.call_saved_workflow import ( + CallSavedWorkflowInvocation, + is_call_saved_workflow_dynamic_input, +) +from invokeai.app.invocations.workflow_return import WorkflowReturnOutput +from invokeai.app.services.session_processor.workflow_call_batch import build_child_workflow_session_results +from invokeai.app.services.session_queue.session_queue_common import NodeFieldValue, SessionQueueItem +from invokeai.app.services.shared.graph import GraphExecutionState + +if TYPE_CHECKING: + from invokeai.app.services.session_processor.session_processor_default import DefaultSessionRunner + + +class WorkflowCallCoordinator: + """Coordinates call-specific workflow setup.""" + + def __init__(self, session_runner: DefaultSessionRunner) -> None: + self._session_runner = session_runner + + def _collect_call_saved_workflow_inputs( + self, invocation: CallSavedWorkflowInvocation, queue_item: SessionQueueItem + ) -> dict[str, Any]: + workflow_inputs = dict(invocation.workflow_inputs) + for edge in queue_item.session.execution_graph._get_input_edges(invocation.id): + if not is_call_saved_workflow_dynamic_input(edge.destination.field): + continue + if edge.source.node_id not in queue_item.session.results: + continue + workflow_inputs[edge.destination.field] = getattr( + queue_item.session.results[edge.source.node_id], edge.source.field + ) + return workflow_inputs + + @staticmethod + def build_child_queue_item( + queue_item: SessionQueueItem, + child_session: GraphExecutionState, + field_values: list[NodeFieldValue] | None = None, + ) -> SessionQueueItem: + workflow_call_execution = queue_item.session.waiting_workflow_call_execution + if workflow_call_execution is None: + raise ValueError("Parent queue item is missing active workflow call execution metadata.") + root_item_id = getattr(queue_item, "root_item_id", None) or queue_item.item_id + child_updates = { + "session": child_session, + "session_id": child_session.id, + "workflow_call_id": workflow_call_execution.id, + "parent_item_id": queue_item.item_id, + "parent_session_id": queue_item.session_id, + "root_item_id": root_item_id, + "workflow_call_depth": workflow_call_execution.depth, + "field_values": field_values, + } + if hasattr(queue_item, "model_copy"): + return queue_item.model_copy(update=child_updates) + + child_queue_item = type(queue_item).__new__(type(queue_item)) + child_queue_item.__dict__ = {**queue_item.__dict__, **child_updates} + return child_queue_item + + def begin_workflow_call_boundary( + self, + invocation: CallSavedWorkflowInvocation, + queue_item: SessionQueueItem, + workflow_record, + ) -> SessionQueueItem: + queue_status = self._session_runner._services.session_queue.get_queue_status(queue_item.queue_id) + remaining_queue_capacity = self._session_runner._services.configuration.max_queue_size - queue_status.pending + if remaining_queue_capacity <= 0: + raise ValueError("call_saved_workflow exceeds remaining queue capacity for child workflow executions") + + call_frame = queue_item.session.build_workflow_call_frame(invocation.id, invocation.workflow_id) + workflow_inputs = self._collect_call_saved_workflow_inputs(invocation, queue_item) + child_session_results = build_child_workflow_session_results( + parent_session=queue_item.session, + workflow=workflow_record.workflow.model_dump(), + workflow_inputs=workflow_inputs, + call_frame=call_frame, + maximum_children=remaining_queue_capacity, + services=self._session_runner._services, + user_id=getattr(queue_item, "user_id", None), + ) + child_sessions = [child_result.session for child_result in child_session_results] + if len(child_sessions) > remaining_queue_capacity: + raise ValueError("call_saved_workflow exceeds remaining queue capacity for child workflow executions") + queue_item.session.begin_waiting_on_workflow_call(call_frame) + queue_item.session.attach_waiting_workflow_call_child_sessions(child_sessions) + child_queue_item = None + enqueued_child_item_ids: list[int] = [] + try: + self._session_runner._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + for child_result in child_session_results: + child_queue_item = self._session_runner._services.session_queue.enqueue_workflow_call_child( + parent_queue_item=queue_item, + child_session=child_result.session, + field_values=child_result.field_values, + ) + enqueued_child_item_ids.append(child_queue_item.item_id) + self._session_runner._services.session_queue.suspend_queue_item(queue_item.item_id) + except Exception as e: + if enqueued_child_item_ids: + self._session_runner._services.session_queue.delete_queue_items_by_id(enqueued_child_item_ids) + queue_item.session.end_waiting_on_workflow_call(status="failed", error_message=str(e)) + raise + queue_item.status = "waiting" + if child_queue_item is None: + raise ValueError("Workflow call did not produce any child executions.") + return child_queue_item + + +class WorkflowCallQueueLifecycle: + """Coordinates queue-visible child workflow execution and parent lifecycle transitions.""" + + def __init__(self, session_runner: DefaultSessionRunner) -> None: + self._session_runner = session_runner + + @staticmethod + def get_waiting_workflow_call_invocation(queue_item: SessionQueueItem) -> CallSavedWorkflowInvocation: + waiting_frame = queue_item.session.waiting_workflow_call + if waiting_frame is None: + raise ValueError("Execution state is not waiting on a workflow call.") + invocation = queue_item.session.execution_graph.nodes.get(waiting_frame.prepared_call_node_id) + if not isinstance(invocation, CallSavedWorkflowInvocation): + raise ValueError("Waiting workflow call frame does not point to a call_saved_workflow invocation.") + return invocation + + @staticmethod + def get_child_workflow_return_output(child_session: GraphExecutionState) -> WorkflowReturnOutput: + workflow_return_node_ids = [ + node_id for node_id, node in child_session.graph.nodes.items() if node.get_type() == "workflow_return" + ] + if not workflow_return_node_ids: + raise ValueError("The selected saved workflow must contain exactly one workflow_return node.") + if len(workflow_return_node_ids) > 1: + raise ValueError("The selected saved workflow must not contain more than one workflow_return node.") + + workflow_return_node_id = workflow_return_node_ids[0] + prepared_return_node_ids = child_session.source_prepared_mapping.get(workflow_return_node_id, set()) + if len(prepared_return_node_ids) != 1: + raise ValueError( + "The selected saved workflow produced an unsupported number of workflow_return executions." + ) + + prepared_return_node_id = next(iter(prepared_return_node_ids)) + output = child_session.results.get(prepared_return_node_id) + if not isinstance(output, WorkflowReturnOutput): + raise ValueError("The selected saved workflow did not produce a valid workflow_return output.") + + return output + + def resume_waiting_workflow_call(self, queue_item: SessionQueueItem) -> None: + invocation = self.get_waiting_workflow_call_invocation(queue_item) + child_session = queue_item.session.waiting_workflow_call_child_session + if child_session is None: + raise ValueError("Execution state is waiting on a workflow call but has no attached child session.") + output = self.get_child_workflow_return_output(child_session) + queue_item.session.end_waiting_on_workflow_call(status="completed") + queue_item.session.complete(invocation.id, output) + self._session_runner._on_after_run_node(invocation, queue_item, output) + + def fail_waiting_workflow_call(self, queue_item: SessionQueueItem, error_message: str) -> None: + invocation = self.get_waiting_workflow_call_invocation(queue_item) + queue_item.session.end_waiting_on_workflow_call(status="failed", error_message=error_message) + self._session_runner._on_node_error( + invocation=invocation, + queue_item=queue_item, + error_type="ValueError", + error_message=error_message, + error_traceback=error_message, + ) + + def _get_parent_queue_item(self, child_queue_item: SessionQueueItem) -> SessionQueueItem: + parent_item_id = child_queue_item.parent_item_id + if parent_item_id is None: + raise ValueError("Child workflow queue item is missing parent_item_id metadata.") + return self._session_runner._services.session_queue.get_queue_item(parent_item_id) + + def _resume_parent_from_completed_child(self, child_queue_item: SessionQueueItem) -> None: + parent_queue_item = self._get_parent_queue_item(child_queue_item) + if parent_queue_item.status in ("completed", "failed", "canceled"): + return + try: + output = self.get_child_workflow_return_output(child_queue_item.session) + should_resume_parent, aggregated_values = ( + parent_queue_item.session.record_waiting_workflow_call_child_completion( + child_queue_item.item_id, output.values + ) + ) + except Exception as e: + workflow_call_execution = parent_queue_item.session.waiting_workflow_call_execution + if workflow_call_execution is not None: + self._session_runner._services.session_queue.cancel_workflow_call_children( + workflow_call_execution.id, + exclude_item_ids={child_queue_item.item_id}, + ) + self.fail_waiting_workflow_call(parent_queue_item, str(e)) + parent_queue_item = self._session_runner._services.session_queue.get_queue_item(parent_queue_item.item_id) + if getattr(parent_queue_item, "parent_item_id", None) is not None: + self._fail_parent_from_failed_child(parent_queue_item) + return + if not should_resume_parent: + self._session_runner._services.session_queue.set_queue_item_session( + parent_queue_item.item_id, parent_queue_item.session + ) + return + parent_queue_item.session.waiting_workflow_call_child_session = child_queue_item.session + waiting_invocation = self.get_waiting_workflow_call_invocation(parent_queue_item) + parent_queue_item.session.end_waiting_on_workflow_call(status="completed") + parent_output = WorkflowReturnOutput(values=aggregated_values) + parent_queue_item.session.complete(waiting_invocation.id, parent_output) + self._session_runner._on_after_run_node(waiting_invocation, parent_queue_item, parent_output) + parent_queue_item = self._session_runner._services.session_queue.set_queue_item_session( + parent_queue_item.item_id, parent_queue_item.session + ) + if parent_queue_item.session.is_complete(): + parent_queue_item = self._session_runner._services.session_queue.complete_queue_item( + parent_queue_item.item_id + ) + if getattr(parent_queue_item, "parent_item_id", None) is not None: + self._resume_parent_from_completed_child(parent_queue_item) + return + self._session_runner._services.session_queue.resume_queue_item(parent_queue_item.item_id) + + def _fail_parent_from_failed_child(self, child_queue_item: SessionQueueItem) -> None: + parent_queue_item = self._get_parent_queue_item(child_queue_item) + if parent_queue_item.status in ("completed", "failed", "canceled"): + return + waiting_frame = parent_queue_item.session.waiting_workflow_call + if waiting_frame is None: + raise ValueError("Parent queue item is missing workflow call waiting state.") + workflow_call_execution = parent_queue_item.session.waiting_workflow_call_execution + if workflow_call_execution is not None: + self._session_runner._services.session_queue.cancel_workflow_call_children( + workflow_call_execution.id, + exclude_item_ids={child_queue_item.item_id}, + ) + child_error_message = getattr(child_queue_item, "error_message", None) or ( + f"The selected saved workflow '{waiting_frame.workflow_id}' failed during child execution." + ) + self.fail_waiting_workflow_call(parent_queue_item, child_error_message) + parent_queue_item = self._session_runner._services.session_queue.get_queue_item(parent_queue_item.item_id) + if getattr(parent_queue_item, "parent_item_id", None) is not None: + self._fail_parent_from_failed_child(parent_queue_item) + + def _cancel_parent_from_canceled_child(self, child_queue_item: SessionQueueItem) -> None: + parent_queue_item = self._get_parent_queue_item(child_queue_item) + if parent_queue_item.status == "canceled": + return + self._session_runner._services.session_queue.cancel_queue_item(parent_queue_item.item_id) + + def run_queue_item(self, queue_item: SessionQueueItem) -> None: + self._session_runner.run(queue_item) + updated_queue_item = self._session_runner._services.session_queue.get_queue_item(queue_item.item_id) + if getattr(updated_queue_item, "parent_item_id", None) is None: + return + if updated_queue_item.status == "completed": + self._resume_parent_from_completed_child(updated_queue_item) + elif updated_queue_item.status == "failed": + self._fail_parent_from_failed_child(updated_queue_item) + elif updated_queue_item.status == "canceled": + self._cancel_parent_from_canceled_child(updated_queue_item) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 73acf9c31aa..6e7b2c0caff 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -16,6 +16,7 @@ IsEmptyResult, IsFullResult, ItemIdsResult, + NodeFieldValue, PruneResult, RetryItemsResult, SessionQueueCountsByDestination, @@ -106,6 +107,16 @@ def complete_queue_item(self, item_id: int) -> SessionQueueItem: """Completes a session queue item""" pass + @abstractmethod + def suspend_queue_item(self, item_id: int) -> SessionQueueItem: + """Suspends a session queue item while waiting on a child workflow execution.""" + pass + + @abstractmethod + def resume_queue_item(self, item_id: int) -> SessionQueueItem: + """Resumes a suspended session queue item by returning it to pending state.""" + pass + @abstractmethod def cancel_queue_item(self, item_id: int) -> SessionQueueItem: """Cancels a session queue item""" @@ -116,6 +127,11 @@ def delete_queue_item(self, item_id: int) -> None: """Deletes a session queue item""" pass + @abstractmethod + def delete_queue_items_by_id(self, item_ids: list[int]) -> None: + """Deletes session queue items by ID.""" + pass + @abstractmethod def fail_queue_item( self, item_id: int, error_type: str, error_message: str, error_traceback: str @@ -201,6 +217,23 @@ def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> """Sets the session for a session queue item. Use this to update the session state.""" pass + @abstractmethod + def enqueue_workflow_call_child( + self, + parent_queue_item: SessionQueueItem, + child_session: GraphExecutionState, + field_values: list[NodeFieldValue] | None = None, + ) -> SessionQueueItem: + """Enqueues a child workflow execution linked to a suspended parent queue item.""" + pass + + @abstractmethod + def cancel_workflow_call_children( + self, workflow_call_id: str, exclude_item_ids: set[int] | None = None + ) -> list[int]: + """Cancels child workflow queue items for a workflow call without canceling the waiting parent chain.""" + pass + @abstractmethod def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult: """Retries the given queue items""" diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index d87221fbbae..37f60653046 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -51,7 +51,8 @@ class SessionQueueItemNotFoundError(ValueError): # region Batch -BatchDataType = Union[StrictStr, float, int, ImageField] +BatchScalarDataType = Union[StrictStr, float, int, ImageField] +BatchDataType = Union[BatchScalarDataType, list[BatchScalarDataType]] class NodeFieldValue(BaseModel): @@ -172,7 +173,7 @@ def validate_graph(cls, v: Graph): DEFAULT_QUEUE_ID = "default" SYSTEM_USER_ID = "system" # Default user_id for system-generated queue items -QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"] +QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "waiting", "completed", "failed", "canceled"] class ItemIdsResult(BaseModel): @@ -262,6 +263,21 @@ class SessionQueueItem(BaseModel): retried_from_item_id: Optional[int] = Field( default=None, description="The item_id of the queue item that this item was retried from" ) + workflow_call_id: Optional[str] = Field( + default=None, description="The active workflow-call relationship id when this queue item is a child execution." + ) + parent_item_id: Optional[int] = Field( + default=None, description="The parent queue item id when this queue item is a child workflow execution." + ) + parent_session_id: Optional[str] = Field( + default=None, description="The parent session id when this queue item is a child workflow execution." + ) + root_item_id: Optional[int] = Field( + default=None, description="The root queue item id for this workflow call chain, if any." + ) + workflow_call_depth: Optional[int] = Field( + default=None, description="The 1-based workflow-call depth for this queue item when it is a child execution." + ) session: GraphExecutionState = Field(description="The fully-populated session to be executed") workflow: Optional[WorkflowWithoutID] = Field( default=None, description="The workflow associated with this queue item" @@ -305,6 +321,7 @@ class SessionQueueStatus(BaseModel): session_id: Optional[str] = Field(description="The current queue item's session id") pending: int = Field(..., description="Number of queue items with status 'pending'") in_progress: int = Field(..., description="Number of queue items with status 'in_progress'") + waiting: int = Field(..., description="Number of queue items with status 'waiting'") completed: int = Field(..., description="Number of queue items with status 'complete'") failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") @@ -316,6 +333,7 @@ class SessionQueueCountsByDestination(BaseModel): destination: str = Field(..., description="The destination of queue items included in this status") pending: int = Field(..., description="Number of queue items with status 'pending' for the destination") in_progress: int = Field(..., description="Number of queue items with status 'in_progress' for the destination") + waiting: int = Field(..., description="Number of queue items with status 'waiting' for the destination") completed: int = Field(..., description="Number of queue items with status 'complete' for the destination") failed: int = Field(..., description="Number of queue items with status 'error' for the destination") canceled: int = Field(..., description="Number of queue items with status 'canceled' for the destination") @@ -329,6 +347,7 @@ class BatchStatus(BaseModel): destination: str | None = Field(..., description="The destination of the batch") pending: int = Field(..., description="Number of queue items with status 'pending'") in_progress: int = Field(..., description="Number of queue items with status 'in_progress'") + waiting: int = Field(..., description="Number of queue items with status 'waiting'") completed: int = Field(..., description="Number of queue items with status 'complete'") failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index a05ed468857..d4f64de56ba 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -1,7 +1,7 @@ import asyncio import json import sqlite3 -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from pydantic_core import to_jsonable_python @@ -23,6 +23,7 @@ IsEmptyResult, IsFullResult, ItemIdsResult, + NodeFieldValue, PruneResult, RetryItemsResult, SessionQueueCountsByDestination, @@ -65,17 +66,40 @@ def __init__(self, db: SqliteDatabase) -> None: def _set_in_progress_to_canceled(self) -> None: """ - Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. - This is necessary because the invoker may have been killed while processing a queue item. + Sets all in_progress or waiting queue items to canceled. Run on app startup, not associated with any queue. + This is necessary because the invoker may have been killed while processing a queue item or while a parent + queue item was suspended waiting on a child workflow execution. """ with self._db.transaction() as cursor: cursor.execute( """--sql + SELECT item_id + FROM session_queue + WHERE status = 'in_progress' + OR status = 'waiting'; + """ + ) + interrupted_item_ids = [row[0] for row in cast(list[sqlite3.Row], cursor.fetchall())] + item_ids_to_cancel: set[int] = set() + for item_id in interrupted_item_ids: + item_ids_to_cancel.update(self._get_workflow_call_chain_item_ids(item_id)) + if not item_ids_to_cancel: + return + with self._db.transaction() as cursor: + placeholders = ",".join("?" for _ in item_ids_to_cancel) + cursor.execute( + f"""--sql UPDATE session_queue SET status = 'canceled', status_sequence = COALESCE(status_sequence, 0) + 1 - WHERE status = 'in_progress'; - """ + WHERE item_id IN ({placeholders}) + AND ( + status = 'pending' + OR status = 'in_progress' + OR status = 'waiting' + ); + """, + tuple(item_ids_to_cancel), ) def _prune_terminal_to_limit(self, queue_id: str, keep: int) -> int: @@ -327,6 +351,53 @@ def _set_queue_item_status( self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status) return queue_item + def _get_workflow_call_child_ids(self, item_id: int) -> list[int]: + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT item_id + FROM session_queue + WHERE parent_item_id = ? + ORDER BY item_id ASC + """, + (item_id,), + ) + rows = cast(list[sqlite3.Row], cursor.fetchall()) + return [row[0] for row in rows] + + def _get_workflow_call_descendant_ids(self, item_id: int) -> list[int]: + descendant_ids: list[int] = [] + queue: list[int] = [item_id] + while queue: + current_item_id = queue.pop(0) + child_ids = self._get_workflow_call_child_ids(current_item_id) + descendant_ids.extend(child_ids) + queue.extend(child_ids) + return descendant_ids + + def _get_workflow_call_ancestor_ids(self, item_id: int) -> list[int]: + ancestor_ids: list[int] = [] + current_queue_item = self.get_queue_item(item_id) + while current_queue_item.parent_item_id is not None: + parent_item_id = current_queue_item.parent_item_id + ancestor_ids.append(parent_item_id) + current_queue_item = self.get_queue_item(parent_item_id) + return ancestor_ids + + def _get_workflow_call_chain_item_ids(self, item_id: int) -> list[int]: + ancestor_ids = self._get_workflow_call_ancestor_ids(item_id) + root_item_id = ancestor_ids[-1] if ancestor_ids else item_id + descendant_ids = self._get_workflow_call_descendant_ids(root_item_id) + chain_item_ids = ancestor_ids + [item_id] + descendant_ids + deduped_chain_item_ids = list(dict.fromkeys(chain_item_ids)) + return deduped_chain_item_ids + + def _get_current_workflow_call_chain_item_ids(self, queue_id: str) -> set[int]: + current_queue_item = self.get_current(queue_id) + if current_queue_item is None: + return set() + return set(self._get_workflow_call_chain_item_ids(current_queue_item.item_id)) + def is_empty(self, queue_id: str) -> IsEmptyResult: with self._db.transaction() as cursor: cursor.execute( @@ -384,6 +455,20 @@ def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult: self.__invoker.services.events.emit_queue_cleared(queue_id) return ClearResult(deleted=count) + def delete_queue_items_by_id(self, item_ids: list[int]) -> None: + if not item_ids: + return + placeholders = ", ".join(["?" for _ in item_ids]) + with self._db.transaction() as cursor: + cursor.execute( + f"""--sql + DELETE + FROM session_queue + WHERE item_id IN ({placeholders}) + """, + tuple(item_ids), + ) + def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult: with self._db.transaction() as cursor: # Build WHERE clause with optional user_id filter @@ -398,7 +483,7 @@ def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult: ) {user_filter} """ - params = [queue_id] + params: list[Any] = [queue_id] if user_id is not None: params.append(user_id) @@ -422,29 +507,33 @@ def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult: return PruneResult(deleted=count) def cancel_queue_item(self, item_id: int) -> SessionQueueItem: - queue_item = self._set_queue_item_status(item_id=item_id, status="canceled") - return queue_item + chain_item_ids = self._get_workflow_call_chain_item_ids(item_id) + for chain_item_id in chain_item_ids: + self._set_queue_item_status(item_id=chain_item_id, status="canceled") + return self.get_queue_item(item_id) def delete_queue_item(self, item_id: int) -> None: """Deletes a session queue item""" - try: + chain_item_ids = self._get_workflow_call_chain_item_ids(item_id) + if any( + self.get_queue_item(chain_item_id).status not in {"completed", "failed", "canceled"} + for chain_item_id in chain_item_ids + ): self.cancel_queue_item(item_id) - except SessionQueueItemNotFoundError: - pass - with self._db.transaction() as cursor: - cursor.execute( - """--sql - DELETE - FROM session_queue - WHERE item_id = ? - """, - (item_id,), - ) + self.delete_queue_items_by_id(chain_item_ids) def complete_queue_item(self, item_id: int) -> SessionQueueItem: queue_item = self._set_queue_item_status(item_id=item_id, status="completed") return queue_item + def suspend_queue_item(self, item_id: int) -> SessionQueueItem: + queue_item = self._set_queue_item_status(item_id=item_id, status="waiting") + return queue_item + + def resume_queue_item(self, item_id: int) -> SessionQueueItem: + queue_item = self._set_queue_item_status(item_id=item_id, status="pending") + return queue_item + def fail_queue_item( self, item_id: int, @@ -602,18 +691,25 @@ def delete_by_destination( return DeleteByDestinationResult(deleted=count) def delete_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> DeleteAllExceptCurrentResult: + current_chain_item_ids = self._get_current_workflow_call_chain_item_ids(queue_id) with self._db.transaction() as cursor: # Build WHERE clause with optional user_id filter user_filter = "AND user_id = ?" if user_id is not None else "" + current_chain_filter = "" + if current_chain_item_ids: + placeholders = ", ".join(["?" for _ in current_chain_item_ids]) + current_chain_filter = f"AND item_id NOT IN ({placeholders})" where = f"""--sql WHERE queue_id == ? - AND status == 'pending' + AND status IN ('pending', 'waiting') {user_filter} + {current_chain_filter} """ - params = [queue_id] + params: list[Any] = [queue_id] if user_id is not None: params.append(user_id) + params.extend(current_chain_item_ids) cursor.execute( f"""--sql @@ -671,18 +767,25 @@ def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: return CancelByQueueIDResult(canceled=count) def cancel_all_except_current(self, queue_id: str, user_id: Optional[str] = None) -> CancelAllExceptCurrentResult: + current_chain_item_ids = self._get_current_workflow_call_chain_item_ids(queue_id) with self._db.transaction() as cursor: # Build WHERE clause with optional user_id filter user_filter = "AND user_id = ?" if user_id is not None else "" + current_chain_filter = "" + if current_chain_item_ids: + placeholders = ", ".join(["?" for _ in current_chain_item_ids]) + current_chain_filter = f"AND item_id NOT IN ({placeholders})" where = f"""--sql WHERE queue_id == ? - AND status == 'pending' + AND status IN ('pending', 'waiting') {user_filter} + {current_chain_filter} """ params = [queue_id] if user_id is not None: params.append(user_id) + params.extend(current_chain_item_ids) cursor.execute( f"""--sql @@ -739,6 +842,102 @@ def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> ) return self.get_queue_item(item_id) + def enqueue_workflow_call_child( + self, + parent_queue_item: SessionQueueItem, + child_session: GraphExecutionState, + field_values: list[NodeFieldValue] | None = None, + ) -> SessionQueueItem: + workflow_call_execution = parent_queue_item.session.waiting_workflow_call_execution + if workflow_call_execution is None: + raise ValueError("Parent queue item is missing active workflow call execution metadata.") + + session_json = child_session.model_dump_json(warnings=False, exclude_none=True) + field_values_json = json.dumps(field_values, default=to_jsonable_python) if field_values is not None else None + root_item_id = parent_queue_item.root_item_id or parent_queue_item.item_id + + with self._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue ( + queue_id, + session, + session_id, + batch_id, + field_values, + priority, + workflow, + origin, + destination, + retried_from_item_id, + user_id, + workflow_call_id, + parent_item_id, + parent_session_id, + root_item_id, + workflow_call_depth, + status + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending') + """, + ( + parent_queue_item.queue_id, + session_json, + child_session.id, + parent_queue_item.batch_id, + field_values_json, + parent_queue_item.priority, + None, + parent_queue_item.origin, + parent_queue_item.destination, + None, + parent_queue_item.user_id, + workflow_call_execution.id, + parent_queue_item.item_id, + parent_queue_item.session_id, + root_item_id, + workflow_call_execution.depth, + ), + ) + item_id = cursor.lastrowid + + queue_item = self.get_queue_item(item_id) + batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) + queue_status = self.get_queue_status(queue_id=queue_item.queue_id, acting_user_id=queue_item.user_id) + self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status) + return queue_item + + def cancel_workflow_call_children( + self, workflow_call_id: str, exclude_item_ids: set[int] | None = None + ) -> list[int]: + exclude_item_ids = exclude_item_ids or set() + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT item_id + FROM session_queue + WHERE workflow_call_id = ? + ORDER BY item_id ASC + """, + (workflow_call_id,), + ) + item_ids = [row[0] for row in cast(list[sqlite3.Row], cursor.fetchall())] + item_ids_with_descendants: list[int] = [] + for item_id in item_ids: + item_ids_with_descendants.append(item_id) + item_ids_with_descendants.extend(self._get_workflow_call_descendant_ids(item_id)) + item_ids = list(dict.fromkeys(item_ids_with_descendants)) + canceled_item_ids: list[int] = [] + for item_id in item_ids: + if item_id in exclude_item_ids: + continue + queue_item = self.get_queue_item(item_id) + if queue_item.status in {"completed", "failed", "canceled"}: + continue + self._set_queue_item_status(item_id=item_id, status="canceled") + canceled_item_ids.append(item_id) + return canceled_item_ids + def list_queue_items( self, queue_id: str, @@ -905,6 +1104,7 @@ def get_queue_status( batch_id=current_item.batch_id if show_current_item else None, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), + waiting=counts.get("waiting", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), @@ -937,6 +1137,7 @@ def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] queue_id=queue_id, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), + waiting=counts.get("waiting", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), @@ -968,6 +1169,7 @@ def get_counts_by_destination( destination=destination, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), + waiting=counts.get("waiting", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), @@ -978,47 +1180,77 @@ def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsRes """Retries the given queue items""" with self._db.transaction() as cursor: values_to_insert: list[ValueToInsertTuple] = [] - retried_item_ids: list[int] = [] + retried_root_item_ids: list[int] = [] + retried_user_ids: list[str] = [] + retried_item_ids_by_user: dict[str, list[int]] = {} + seen_root_item_ids: set[int] = set() + max_new_queue_items = self.__invoker.services.configuration.max_queue_size - self._get_current_queue_size( + queue_id + ) + + if max_new_queue_items <= 0: + return RetryItemsResult(queue_id=queue_id, retried_item_ids=[]) for item_id in item_ids: - queue_item = self.get_queue_item(item_id) + try: + queue_item = self.get_queue_item(item_id) + except SessionQueueItemNotFoundError: + continue + if queue_item.queue_id != queue_id: + continue if queue_item.status not in ("failed", "canceled"): continue - retried_item_ids.append(item_id) + root_item_id = queue_item.root_item_id or queue_item.item_id + if root_item_id in seen_root_item_ids: + continue + seen_root_item_ids.add(root_item_id) + + root_queue_item = self.get_queue_item(root_item_id) + if root_queue_item.status not in ("failed", "canceled"): + continue + + retried_root_item_ids.append(root_item_id) + retried_user_ids.append(root_queue_item.user_id) + retried_item_ids_by_user.setdefault(root_queue_item.user_id, []).append(root_item_id) field_values_json = ( - json.dumps(queue_item.field_values, default=to_jsonable_python) if queue_item.field_values else None + json.dumps(root_queue_item.field_values, default=to_jsonable_python) + if root_queue_item.field_values + else None ) workflow_json = ( - json.dumps(queue_item.workflow, default=to_jsonable_python) if queue_item.workflow else None + json.dumps(root_queue_item.workflow, default=to_jsonable_python) + if root_queue_item.workflow + else None ) - cloned_session = GraphExecutionState(graph=queue_item.session.graph) + cloned_session = GraphExecutionState(graph=root_queue_item.session.graph) cloned_session_json = cloned_session.model_dump_json(warnings=False, exclude_none=True) retried_from_item_id = ( - queue_item.retried_from_item_id - if queue_item.retried_from_item_id is not None - else queue_item.item_id + root_queue_item.retried_from_item_id + if root_queue_item.retried_from_item_id is not None + else root_queue_item.item_id ) value_to_insert: ValueToInsertTuple = ( - queue_item.queue_id, + root_queue_item.queue_id, cloned_session_json, cloned_session.id, - queue_item.batch_id, + root_queue_item.batch_id, field_values_json, - queue_item.priority, + root_queue_item.priority, workflow_json, - queue_item.origin, - queue_item.destination, + root_queue_item.origin, + root_queue_item.destination, retried_from_item_id, - queue_item.user_id, + root_queue_item.user_id, ) values_to_insert.append(value_to_insert) - # TODO(psyche): Handle max queue size? + if len(values_to_insert) >= max_new_queue_items: + break cursor.executemany( """--sql @@ -1030,7 +1262,11 @@ def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsRes retry_result = RetryItemsResult( queue_id=queue_id, - retried_item_ids=retried_item_ids, + retried_item_ids=retried_root_item_ids, + ) + self.__invoker.services.events.emit_queue_items_retried( + retry_result, + user_ids=list(dict.fromkeys(retried_user_ids)), + retried_item_ids_by_user=retried_item_ids_by_user, ) - self.__invoker.services.events.emit_queue_items_retried(retry_result) return retry_result diff --git a/invokeai/app/services/shared/README.md b/invokeai/app/services/shared/README.md index f92b1f1ea2e..c6603f2c5d2 100644 --- a/invokeai/app/services/shared/README.md +++ b/invokeai/app/services/shared/README.md @@ -58,6 +58,19 @@ Runs a sequence of checks: 1. **Type compatibility** `get_output_field_type` vs `get_input_field_type` and `are_connection_types_compatible`. + Special case: + + - `call_saved_workflow` currently accepts dynamic destination handles of the form + `saved_workflow_input::{childNodeId}::{childFieldName}` as part of its temporary call-boundary contract. + - Those handles are allowed through graph validation even though they are not static Python model fields on the + invocation class. + - Runtime later validates them against the selected child workflow's exposed callable interface before applying + values to the child graph. + - The editor preserves dynamic caller values only while the exposed field type remains compatible; type drift at the + same child node/field path resets to the selected workflow's current initial value. + - Saved-workflow picker search is server-backed so large workflow libraries do not require scrolling every page + before selecting a workflow by name. + 1. **Iterator / collector structure** Enforce special rules: - Iterator's input must be `collection`; its outgoing edges use `item`. @@ -105,6 +118,14 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre - `prepared_source_mapping: dict[str, str]` - exec id -> source id. - `source_prepared_mapping: dict[str, set[str]]` - source id -> exec ids. - `indegree: dict[str, int]` - unmet inputs per exec node. +- Workflow-call runtime state: + - `workflow_call_stack` - active parent call frames. + - `workflow_call_history` - completed or failed workflow-call relationships observed by this execution state. + - `workflow_call_parent` - parent workflow-call relationship metadata when this execution state is a child session. + - `waiting_workflow_call` - the call frame currently suspending this execution state, if any. + - `waiting_workflow_call_execution` - the active parent/child workflow-call relationship record for the waiting call. + - `waiting_workflow_call_child_session` - attached child execution state for the waiting workflow call, if any. + - `max_workflow_call_depth` - runtime guardrail for nested or recursive workflow calls. - Prepared exec metadata caches: - source node id - iteration path @@ -115,10 +136,42 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre ### 4.2 Core methods - `next()` Returns the next ready exec node. If none are ready, it asks the materializer to expand more source nodes and - then retries. Before returning a node, the runtime helper deep-copies inbound values into the node fields. + then retries. If the execution state is paused on a workflow call boundary, it returns `None` without scheduling more + work. Before returning a node, the runtime helper deep-copies inbound values into the node fields. - `complete(node_id, output)` Records the result, marks the exec node executed, marks the source node executed once all of its prepared exec copies are done, then decrements downstream indegrees and enqueues newly ready nodes. +Workflow-call note: + +- `GraphExecutionState` can represent a paused parent execution plus an attached child execution state, but it does not + itself orchestrate child execution. +- In the current implementation, `DefaultSessionRunner.run_node()` establishes the workflow call boundary and attaches + the child execution state, while `WorkflowCallCoordinator` handles call-specific setup and + `WorkflowCallQueueLifecycle` later resumes or fails the parent based on that child queue row's outcome. +- Child `SessionQueueItem` rows created by the coordinator now carry explicit relationship metadata such as + `workflow_call_id`, `parent_item_id`, `parent_session_id`, `root_item_id`, and `workflow_call_depth`, even though the + higher-level scheduler semantics are still evolving. +- The `session_queue` schema now has matching columns for those relationship fields, and parent queue items can enter a + `waiting` status while suspended on a child workflow execution. +- Queue lifecycle semantics are now partially defined for workflow-call chains: + - child success resumes the waiting parent + - multiple child queue rows may complete under one waiting parent when the called workflow contains direct batch + nodes; the parent resumes only after all expected child rows complete + - child failure fails the waiting parent and can cascade upward through ancestors + - failing child rows cancel their remaining workflow-call siblings before the parent is failed + - cancelation is chain-aware across parents and children, including nested descendants of batched siblings + - "all except current" queue actions preserve the active current item plus its workflow-call chain, while still + canceling or deleting unrelated waiting chains + - startup recovery cancels interrupted `in_progress` or `waiting` workflow-call chains, including pending descendants + - deleting a workflow-call queue row currently deletes the whole parent/child chain rather than leaving orphaned rows + behind + - retry is root-oriented and should not be exposed directly on child queue rows in the UI + - child queue-row creation is cleaned up on boundary-setup failure and child fan-out is bounded by remaining queue + capacity + - child workflows that mix supported batch nodes with unrelated generator nodes are rejected for now +- This is still an intermediate architecture step and should eventually be replaced by a more general parent/child + execution mechanism rather than workflow-call-specific queue lifecycle handling. + ### 4.3 Runtime helper classes `GraphExecutionState` now delegates most runtime behavior to internal helpers: @@ -218,7 +271,7 @@ This behavior is implemented in the runtime scheduler, not in the invocation bod - Execute node externally -> `output`. - `state.complete(node.id, output)` -> updates indegrees, `If` state, and ready queues. -1. Finish when `next()` returns `None`. +1. Finish when `next()` returns `None` and the execution state is not paused waiting on a workflow call boundary. In normal execution, all runtime expansion occurs in `execution_graph` with traceability back to source nodes. @@ -239,6 +292,35 @@ In normal execution, all runtime expansion occurs in `execution_graph` with trac complexity. - **Dynamic behaviors** (future): can be added in `GraphExecutionState` by creating exec nodes and edges at `complete()` time, as long as the DAG invariant holds. +- **Workflow call boundaries**: `GraphExecutionState` can suspend a parent execution state on a workflow call, attach a + child execution state, and later resume the parent without mutating the source graph. + +Current limitation: + +- Child workflow executions are now represented as first-class queue items. Parent resume/failure is intentionally + handled by a dedicated workflow-call queue lifecycle component for this PR because no other feature currently needs a + generalized dependent-queue scheduler. +- Called workflows currently require exactly one valid `workflow_return` node to be callable at all. +- A single `workflow_return_value.value` may connect directly to `workflow_return.values`; multiple named return members + should be collected and then connected to `workflow_return.values`. +- Direct batch-special child workflows are now supported by expanding them into multiple child queue rows. +- Batch outputs may feed a named `workflow_return_value.value` directly. Parent resume aggregates named return maps as + `values: dict[str, list[Any]]`, and all rows in one batch call must return the same key set. +- Generator-backed batch child workflows are now supported when the batch node is fed directly by a supported integer, + float, string, or image generator. +- Connected batch child inputs produced by ordinary non-generator upstream nodes are still rejected before any child + queue row is created. +- Workflow library API responses now include compatibility metadata so the frontend can disable unsupported callees + before execution rather than failing only at runtime. +- Workflow library list compatibility uses structural generator-backed batch validation so list and picker rendering do + not enumerate every image in board-backed generators; workflow detail and runtime execution still resolve real + generator values. +- Batch-specific compatibility failures, including multiple connected inputs to one batch field, are reported as + `unsupported_batch_input` rather than generic unsupported-node failures. +- The workflow library list also surfaces that metadata as an informational unsupported state; workflows remain + viewable/editable even when they are not currently callable by `call_saved_workflow`. +- Single-user workflow CRUD socket events emit only to the admin room because every single-user socket already joins + that room, avoiding duplicate delivery through both `user:system` and `admin`. ## 8) Error Model (selected) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index aa47c3b4bb5..41554a15af1 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -29,6 +29,10 @@ invocation, invocation_output, ) +from invokeai.app.invocations.call_saved_workflow import ( + CallSavedWorkflowInvocation, + is_call_saved_workflow_dynamic_input, +) from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType from invokeai.app.invocations.logic import IfInvocation from invokeai.app.services.shared.invocation_context import InvocationContext @@ -66,6 +70,51 @@ def __str__(self): PreparedExecState = Literal["pending", "ready", "executed", "skipped"] +WorkflowCallStatus = Literal["waiting_for_child", "running_child", "completed", "failed"] + + +class WorkflowCallFrame(BaseModel): + """Represents one workflow-call frame in a nested call chain.""" + + prepared_call_node_id: str = Field(description="The prepared exec node id for the call site.") + source_call_node_id: str = Field(description="The source graph node id for the call site.") + workflow_id: str = Field(description="The saved workflow being called.") + depth: int = Field(description="The 1-based depth of this call frame.", ge=1) + + +class WorkflowCallExecution(BaseModel): + """Tracks one parent/child workflow-call relationship and its lifecycle.""" + + id: str = Field(description="The workflow-call execution id.", default_factory=uuid_string) + parent_session_id: str = Field(description="The parent graph execution state id.") + child_session_id: Optional[str] = Field(default=None, description="The child graph execution state id, if any.") + prepared_call_node_id: str = Field(description="The prepared exec node id for the parent call site.") + source_call_node_id: str = Field(description="The source graph node id for the parent call site.") + workflow_id: str = Field(description="The saved workflow being called.") + depth: int = Field(description="The 1-based depth of this call frame.", ge=1) + status: WorkflowCallStatus = Field(description="The current workflow-call lifecycle state.") + error_message: Optional[str] = Field(default=None, description="Failure reason, if the call failed.") + child_session_ids: list[str] = Field(default_factory=list, description="All child graph execution state ids.") + expected_child_count: int = Field(default=1, ge=1, description="The number of child executions for this call.") + completed_child_item_ids: list[int] = Field( + default_factory=list, + description="The child queue item ids whose workflow_return outputs have been aggregated.", + ) + aggregated_values: dict[str, list[Any]] = Field( + default_factory=dict, + description="The aggregated workflow_return values accumulated from child executions.", + ) + + +class WorkflowCallParentRef(BaseModel): + """Reference from a child execution state back to its parent workflow-call relationship.""" + + workflow_call_id: str = Field(description="The workflow-call execution id.") + parent_session_id: str = Field(description="The parent graph execution state id.") + prepared_call_node_id: str = Field(description="The prepared exec node id for the parent call site.") + source_call_node_id: str = Field(description="The source graph node id for the parent call site.") + workflow_id: str = Field(description="The saved workflow being called.") + depth: int = Field(description="The 1-based depth of this call frame.", ge=1) @dataclass @@ -772,6 +821,10 @@ def _set_node_inputs( for edge in input_edges: if allowed_fields is not None and edge.destination.field not in allowed_fields: continue + if isinstance(node, CallSavedWorkflowInvocation) and is_call_saved_workflow_dynamic_input( + edge.destination.field + ): + continue setattr(node, edge.destination.field, self._get_copied_result_value(edge)) def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[Edge]) -> None: @@ -1201,6 +1254,10 @@ def _validate_edge_nodes_and_fields(self) -> None: ) if edge.destination.field not in type(destination_node).model_fields: + if isinstance(destination_node, CallSavedWorkflowInvocation) and is_call_saved_workflow_dynamic_input( + edge.destination.field + ): + continue raise NodeFieldNotFoundError( f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}" ) @@ -1212,10 +1269,15 @@ def _validate_graph_is_acyclic(self) -> None: def _validate_edge_type_compatibility(self) -> None: for edge in self.edges: + destination_node = self.get_node(edge.destination.node_id) + if isinstance(destination_node, CallSavedWorkflowInvocation) and is_call_saved_workflow_dynamic_input( + edge.destination.field + ): + continue if not are_connections_compatible( self.get_node(edge.source.node_id), edge.source.field, - self.get_node(edge.destination.node_id), + destination_node, edge.destination.field, ): raise InvalidEdgeError(f"Edge source and target types do not match ({edge})") @@ -1305,6 +1367,10 @@ def _validate_edge_would_not_create_cycle(self, edge: Edge) -> None: def _validate_edge_field_compatibility( self, edge: Edge, source_node: BaseInvocation, destination_node: BaseInvocation ) -> None: + if isinstance(destination_node, CallSavedWorkflowInvocation) and is_call_saved_workflow_dynamic_input( + edge.destination.field + ): + return if not are_connections_compatible(source_node, edge.source.field, destination_node, edge.destination.field): raise InvalidEdgeError(f"Field types are incompatible ({edge})") @@ -1736,6 +1802,36 @@ class GraphExecutionState(BaseModel): # Errors raised when executing nodes errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) + workflow_call_stack: list[WorkflowCallFrame] = Field( + description="The nested workflow call stack inherited by this execution state.", + default_factory=list, + ) + workflow_call_history: list[WorkflowCallExecution] = Field( + description="Completed or failed workflow-call relationships observed by this execution state.", + default_factory=list, + ) + workflow_call_parent: Optional[WorkflowCallParentRef] = Field( + default=None, + description="Parent workflow-call relationship metadata when this execution state is a child workflow session.", + ) + waiting_workflow_call: Optional[WorkflowCallFrame] = Field( + default=None, + description="The child workflow call this execution state is currently waiting on, if any.", + ) + waiting_workflow_call_execution: Optional[WorkflowCallExecution] = Field( + default=None, + description="The active workflow-call relationship metadata for the current waiting child workflow, if any.", + ) + waiting_workflow_call_child_session: Optional["GraphExecutionState"] = Field( + default=None, + description="The child workflow execution state spawned by the current waiting workflow call, if any.", + ) + max_workflow_call_depth: int = Field( + default=4, + ge=1, + description="The maximum permitted workflow call depth for nested workflow execution.", + ) + # Map of prepared/executed nodes to their original nodes prepared_source_mapping: dict[str, str] = Field( description="The map of prepared nodes to original graph nodes", @@ -1918,6 +2014,8 @@ def model_post_init(self, __context: Any) -> None: "executed_history", "results", "errors", + "workflow_call_stack", + "workflow_call_history", "prepared_source_mapping", "source_prepared_mapping", ] @@ -1936,6 +2034,9 @@ def next(self) -> Optional[BaseInvocation]: # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # possibly with a timeout? + if self.is_waiting_on_workflow_call(): + return None + # If there are no prepared nodes, prepare some nodes next_node = self._get_next_node() if next_node is None: @@ -1961,6 +2062,8 @@ def set_node_error(self, node_id: str, error: str): def is_complete(self) -> bool: """Returns true if the graph is complete""" + if self.is_waiting_on_workflow_call(): + return False node_ids = set(self.graph.nx_graph_flat().nodes) return self.has_error() or all((k in self.executed for k in node_ids)) @@ -1968,6 +2071,132 @@ def has_error(self) -> bool: """Returns true if the graph has any errors""" return len(self.errors) > 0 + def get_workflow_call_depth(self) -> int: + return len(self.workflow_call_stack) + + def is_waiting_on_workflow_call(self) -> bool: + return self.waiting_workflow_call is not None + + def build_workflow_call_frame(self, exec_node_id: str, workflow_id: str) -> WorkflowCallFrame: + if exec_node_id not in self.execution_graph.nodes: + raise NodeNotFoundError(f"Node {exec_node_id} not found in execution graph") + if exec_node_id not in self.prepared_source_mapping: + raise ValueError(f"Node {exec_node_id} is not a prepared execution node") + + next_depth = self.get_workflow_call_depth() + 1 + if next_depth > self.max_workflow_call_depth: + raise ValueError( + f"Maximum workflow call depth exceeded ({self.max_workflow_call_depth}) for workflow '{workflow_id}'" + ) + + return WorkflowCallFrame( + prepared_call_node_id=exec_node_id, + source_call_node_id=self.prepared_source_mapping[exec_node_id], + workflow_id=workflow_id, + depth=next_depth, + ) + + def begin_waiting_on_workflow_call(self, frame: WorkflowCallFrame) -> None: + if self.waiting_workflow_call is not None: + raise ValueError("Execution state is already waiting on a workflow call") + self.waiting_workflow_call = frame + self.waiting_workflow_call_execution = WorkflowCallExecution( + parent_session_id=self.id, + prepared_call_node_id=frame.prepared_call_node_id, + source_call_node_id=frame.source_call_node_id, + workflow_id=frame.workflow_id, + depth=frame.depth, + status="waiting_for_child", + ) + + def attach_waiting_workflow_call_child_session(self, child_session: "GraphExecutionState") -> None: + if self.waiting_workflow_call is None: + raise ValueError("Execution state must be waiting on a workflow call before attaching a child session") + if self.waiting_workflow_call_execution is None: + raise ValueError("Execution state is waiting on a workflow call but has no workflow call execution") + self.waiting_workflow_call_child_session = child_session + self.waiting_workflow_call_execution.child_session_id = child_session.id + self.waiting_workflow_call_execution.child_session_ids = [child_session.id] + self.waiting_workflow_call_execution.expected_child_count = 1 + self.waiting_workflow_call_execution.status = "running_child" + child_session.workflow_call_parent = WorkflowCallParentRef( + workflow_call_id=self.waiting_workflow_call_execution.id, + parent_session_id=self.waiting_workflow_call_execution.parent_session_id, + prepared_call_node_id=self.waiting_workflow_call_execution.prepared_call_node_id, + source_call_node_id=self.waiting_workflow_call_execution.source_call_node_id, + workflow_id=self.waiting_workflow_call_execution.workflow_id, + depth=self.waiting_workflow_call_execution.depth, + ) + + def attach_waiting_workflow_call_child_sessions(self, child_sessions: list["GraphExecutionState"]) -> None: + if not child_sessions: + raise ValueError("Workflow call must attach at least one child session") + if self.waiting_workflow_call_execution is None: + raise ValueError("Execution state is waiting on a workflow call but has no workflow call execution") + self.waiting_workflow_call_child_session = child_sessions[0] if len(child_sessions) == 1 else None + self.waiting_workflow_call_execution.child_session_id = child_sessions[0].id + self.waiting_workflow_call_execution.child_session_ids = [child_session.id for child_session in child_sessions] + self.waiting_workflow_call_execution.expected_child_count = len(child_sessions) + self.waiting_workflow_call_execution.status = "running_child" + for child_session in child_sessions: + child_session.workflow_call_parent = WorkflowCallParentRef( + workflow_call_id=self.waiting_workflow_call_execution.id, + parent_session_id=self.waiting_workflow_call_execution.parent_session_id, + prepared_call_node_id=self.waiting_workflow_call_execution.prepared_call_node_id, + source_call_node_id=self.waiting_workflow_call_execution.source_call_node_id, + workflow_id=self.waiting_workflow_call_execution.workflow_id, + depth=self.waiting_workflow_call_execution.depth, + ) + + def record_waiting_workflow_call_child_completion( + self, child_item_id: int, output_values: dict[str, Any] + ) -> tuple[bool, dict[str, Any]]: + if self.waiting_workflow_call_execution is None: + raise ValueError("Execution state is not waiting on a workflow call.") + if child_item_id not in self.waiting_workflow_call_execution.completed_child_item_ids: + if ( + self.waiting_workflow_call_execution.expected_child_count > 1 + and self.waiting_workflow_call_execution.completed_child_item_ids + and set(output_values.keys()) != set(self.waiting_workflow_call_execution.aggregated_values.keys()) + ): + raise ValueError("Batched child workflows returned different workflow return keys.") + self.waiting_workflow_call_execution.completed_child_item_ids.append(child_item_id) + for key, value in output_values.items(): + self.waiting_workflow_call_execution.aggregated_values.setdefault(key, []).append(value) + is_complete = ( + len(self.waiting_workflow_call_execution.completed_child_item_ids) + >= self.waiting_workflow_call_execution.expected_child_count + ) + if self.waiting_workflow_call_execution.expected_child_count == 1: + return ( + is_complete, + {key: values[0] for key, values in self.waiting_workflow_call_execution.aggregated_values.items()}, + ) + return ( + is_complete, + {key: list(values) for key, values in self.waiting_workflow_call_execution.aggregated_values.items()}, + ) + + def end_waiting_on_workflow_call( + self, + status: Literal["completed", "failed"] = "completed", + error_message: Optional[str] = None, + ) -> None: + if self.waiting_workflow_call_execution is not None: + self.waiting_workflow_call_execution.status = status + self.waiting_workflow_call_execution.error_message = error_message + self.workflow_call_history.append(self.waiting_workflow_call_execution.model_copy(deep=True)) + self.waiting_workflow_call = None + self.waiting_workflow_call_execution = None + self.waiting_workflow_call_child_session = None + + def create_child_workflow_execution_state(self, graph: Graph, frame: WorkflowCallFrame) -> "GraphExecutionState": + return GraphExecutionState( + graph=graph, + workflow_call_stack=[*self.workflow_call_stack, frame], + max_workflow_call_depth=self.max_workflow_call_depth, + ) + def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: return self._materializer().create_execution_node(node_id, iteration_node_map) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 19e3b897202..12642610c8c 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -33,6 +33,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_28 import build_migration_28 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_30 import build_migration_30 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_31 import build_migration_31 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -83,6 +84,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_28()) migrator.register_migration(build_migration_29()) migrator.register_migration(build_migration_30()) + migrator.register_migration(build_migration_31()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_31.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_31.py new file mode 100644 index 00000000000..9dba8832924 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_31.py @@ -0,0 +1,53 @@ +"""Migration 31: Add workflow-call relationship columns to session_queue.""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration31Callback: + """Add durable parent/child workflow-call relationship columns to session_queue.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='session_queue';") + if cursor.fetchone() is None: + return + + cursor.execute("PRAGMA table_info(session_queue);") + columns = [row[1] for row in cursor.fetchall()] + + if "workflow_call_id" not in columns: + cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow_call_id TEXT;") + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_session_queue_workflow_call_id ON session_queue(workflow_call_id);" + ) + + if "parent_item_id" not in columns: + cursor.execute("ALTER TABLE session_queue ADD COLUMN parent_item_id INTEGER;") + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_session_queue_parent_item_id ON session_queue(parent_item_id);" + ) + + if "parent_session_id" not in columns: + cursor.execute("ALTER TABLE session_queue ADD COLUMN parent_session_id TEXT;") + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_session_queue_parent_session_id ON session_queue(parent_session_id);" + ) + + if "root_item_id" not in columns: + cursor.execute("ALTER TABLE session_queue ADD COLUMN root_item_id INTEGER;") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_session_queue_root_item_id ON session_queue(root_item_id);") + + if "workflow_call_depth" not in columns: + cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow_call_depth INTEGER;") + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_session_queue_workflow_call_depth ON session_queue(workflow_call_depth);" + ) + + +def build_migration_31() -> Migration: + return Migration( + from_version=30, + to_version=31, + callback=Migration31Callback(), + ) diff --git a/invokeai/app/services/shared/workflow_call_compatibility.py b/invokeai/app/services/shared/workflow_call_compatibility.py new file mode 100644 index 00000000000..741aba5a0b0 --- /dev/null +++ b/invokeai/app/services/shared/workflow_call_compatibility.py @@ -0,0 +1,222 @@ +from collections.abc import Mapping +from copy import deepcopy +from typing import Any, get_args, get_origin + +from pydantic import BaseModel +from pydantic_core import PydanticUndefined + +from invokeai.app.invocations.baseinvocation import InvocationRegistry +from invokeai.app.invocations.call_saved_workflow import parse_call_saved_workflow_dynamic_input +from invokeai.app.invocations.fields import ImageField +from invokeai.app.services.session_processor.workflow_call_batch import build_child_workflow_sessions +from invokeai.app.services.shared.graph import Graph, GraphExecutionState, WorkflowCallFrame +from invokeai.app.services.shared.workflow_call_compatibility_common import ( + WorkflowCallCompatibility, + WorkflowCallCompatibilityReason, +) +from invokeai.app.services.shared.workflow_graph_builder import ( + InvalidWorkflowInputError, + UnsupportedWorkflowNodeError, + get_exposed_workflow_input_names, +) + + +def _count_workflow_return_nodes(workflow: dict[str, Any]) -> int: + workflow_return_count = 0 + for node in workflow.get("nodes", []): + if not isinstance(node, dict) or node.get("type") != "invocation": + continue + data = node.get("data") + if isinstance(data, dict) and data.get("type") == "workflow_return": + workflow_return_count += 1 + return workflow_return_count + + +def _is_mapping(value: Any) -> bool: + return isinstance(value, Mapping) + + +def _build_placeholder_model(annotation: type[BaseModel]) -> Any: + values: dict[str, Any] = {} + for field_name, field_info in annotation.model_fields.items(): + if field_info.default is not PydanticUndefined: + values[field_name] = deepcopy(field_info.default) + continue + if field_info.default_factory is not None: + values[field_name] = field_info.default_factory() + continue + placeholder = _get_placeholder_for_annotation(field_info.annotation) + if placeholder is None: + return None + values[field_name] = placeholder + return annotation.model_construct(**values) + + +def _get_placeholder_for_annotation(annotation: Any) -> Any: + origin = get_origin(annotation) + if origin is not None: + if origin is list: + return [] + if origin is dict: + return {} + if origin is tuple: + return [] + if origin is set: + return [] + args = [arg for arg in get_args(annotation) if arg is not type(None)] + if args: + return _get_placeholder_for_annotation(args[0]) + return None + + if annotation is Any: + return {} + if annotation is str: + return "" + if annotation is int: + return 0 + if annotation is float: + return 0.0 + if annotation is bool: + return False + if annotation is ImageField: + return ImageField(image_name="compatibility-placeholder") + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + return _build_placeholder_model(annotation) + return None + + +def _build_compatibility_workflow_inputs(workflow: dict[str, Any]) -> dict[str, Any]: + workflow_inputs: dict[str, Any] = {} + workflow_nodes = workflow.get("nodes", []) + if not isinstance(workflow_nodes, list): + return workflow_inputs + + nodes_by_id = { + node.get("id"): node + for node in workflow_nodes + if _is_mapping(node) and isinstance(node.get("id"), str) and _is_mapping(node.get("data")) + } + + for input_name in get_exposed_workflow_input_names(workflow): + node_id, field_name = parse_call_saved_workflow_dynamic_input(input_name) + node = nodes_by_id.get(node_id) + if not _is_mapping(node): + continue + node_data = node.get("data") + if not _is_mapping(node_data): + continue + node_type = node_data.get("type") + if not isinstance(node_type, str): + continue + invocation_class = InvocationRegistry.get_invocation_for_type(node_type) + if invocation_class is None: + continue + field_info = invocation_class.model_fields.get(field_name) + if field_info is None: + continue + if field_info.default is not PydanticUndefined: + workflow_inputs[input_name] = deepcopy(field_info.default) + continue + if field_info.default_factory is not None: + workflow_inputs[input_name] = field_info.default_factory() + continue + placeholder = _get_placeholder_for_annotation(field_info.annotation) + if placeholder is not None: + workflow_inputs[input_name] = placeholder + + return workflow_inputs + + +def _is_unsupported_batch_input_message(message: str) -> bool: + return any( + marker in message + for marker in ( + "batch child workflow", + "batch group", + "batch input", + "batch inputs", + "batch node", + "batch-special child workflow nodes", + "connected batch", + "generator-backed batch", + ) + ) + + +def get_workflow_call_compatibility( + *, + workflow: dict[str, Any], + workflow_id: str, + services: Any, + user_id: str | None, + maximum_children: int, + resolve_generator_items: bool = True, +) -> WorkflowCallCompatibility: + workflow_return_count = _count_workflow_return_nodes(workflow) + if workflow_return_count == 0: + return WorkflowCallCompatibility( + is_callable=False, + reason=WorkflowCallCompatibilityReason.MissingWorkflowReturn, + message="The workflow must contain exactly one workflow_return node.", + ) + if workflow_return_count > 1: + return WorkflowCallCompatibility( + is_callable=False, + reason=WorkflowCallCompatibilityReason.MultipleWorkflowReturn, + message="The workflow must not contain more than one workflow_return node.", + ) + + try: + workflow_inputs = _build_compatibility_workflow_inputs(workflow) + build_child_workflow_sessions( + parent_session=GraphExecutionState(graph=Graph()), + workflow=workflow, + workflow_inputs=workflow_inputs, + call_frame=WorkflowCallFrame( + prepared_call_node_id="compatibility-call", + source_call_node_id="compatibility-call", + workflow_id=workflow_id, + depth=1, + ), + maximum_children=maximum_children, + services=services, + user_id=user_id, + resolve_generator_items=resolve_generator_items, + ) + except InvalidWorkflowInputError as e: + return WorkflowCallCompatibility( + is_callable=False, + reason=WorkflowCallCompatibilityReason.InvalidInputs, + message=str(e), + ) + except UnsupportedWorkflowNodeError as e: + message = str(e) + reason = WorkflowCallCompatibilityReason.UnsupportedNode + if _is_unsupported_batch_input_message(message): + reason = WorkflowCallCompatibilityReason.UnsupportedBatchInput + elif "exactly one workflow_return" in message and workflow_return_count == 0: + reason = WorkflowCallCompatibilityReason.MissingWorkflowReturn + elif "exactly one workflow_return" in message: + reason = WorkflowCallCompatibilityReason.InvalidGraph + return WorkflowCallCompatibility( + is_callable=False, + reason=reason, + message=message, + ) + except ValueError as e: + return WorkflowCallCompatibility( + is_callable=False, + reason=WorkflowCallCompatibilityReason.InvalidGraph, + message=str(e), + ) + except Exception as e: + return WorkflowCallCompatibility( + is_callable=False, + reason=WorkflowCallCompatibilityReason.Unknown, + message=str(e), + ) + + return WorkflowCallCompatibility( + is_callable=True, + reason=WorkflowCallCompatibilityReason.Ok, + ) diff --git a/invokeai/app/services/shared/workflow_call_compatibility_common.py b/invokeai/app/services/shared/workflow_call_compatibility_common.py new file mode 100644 index 00000000000..a28dc316727 --- /dev/null +++ b/invokeai/app/services/shared/workflow_call_compatibility_common.py @@ -0,0 +1,20 @@ +from enum import Enum + +from pydantic import BaseModel, Field + + +class WorkflowCallCompatibilityReason(str, Enum): + Ok = "ok" + MissingWorkflowReturn = "missing_workflow_return" + MultipleWorkflowReturn = "multiple_workflow_return" + UnsupportedNode = "unsupported_node" + UnsupportedBatchInput = "unsupported_batch_input" + InvalidGraph = "invalid_graph" + InvalidInputs = "invalid_inputs" + Unknown = "unknown" + + +class WorkflowCallCompatibility(BaseModel): + is_callable: bool = Field(description="Whether the workflow can currently be executed by call_saved_workflow.") + reason: WorkflowCallCompatibilityReason = Field(description="Structured compatibility result.") + message: str | None = Field(default=None, description="Human-readable compatibility detail when unavailable.") diff --git a/invokeai/app/services/shared/workflow_graph_builder.py b/invokeai/app/services/shared/workflow_graph_builder.py new file mode 100644 index 00000000000..a32d9d01f68 --- /dev/null +++ b/invokeai/app/services/shared/workflow_graph_builder.py @@ -0,0 +1,368 @@ +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any + +from invokeai.app.invocations.baseinvocation import Classification, InvocationRegistry +from invokeai.app.invocations.call_saved_workflow import ( + CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX, + parse_call_saved_workflow_dynamic_input, +) +from invokeai.app.services.shared.graph import Edge, EdgeConnection, Graph + +CONNECTOR_INPUT_HANDLE = "in" +CONNECTOR_OUTPUT_HANDLE = "out" + + +class UnsupportedWorkflowNodeError(ValueError): + pass + + +class InvalidWorkflowInputError(ValueError): + pass + + +def _is_mapping(value: Any) -> bool: + return isinstance(value, Mapping) + + +def _is_invocation_node(node: Any) -> bool: + return _is_mapping(node) and node.get("type") == "invocation" and _is_mapping(node.get("data")) + + +def _is_connector_node(node: Any) -> bool: + return _is_mapping(node) and node.get("type") == "connector" + + +def _build_dynamic_input_name(node_id: str, field_name: str) -> str: + return f"{CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX}{node_id}::{field_name}" + + +def _get_form_elements(workflow: Mapping[str, Any]) -> tuple[Mapping[str, Any], str | None]: + form = workflow.get("form") + if not _is_mapping(form): + return {}, None + + elements = form.get("elements") + root_element_id = form.get("rootElementId") + if not _is_mapping(elements) or not isinstance(root_element_id, str): + return {}, None + + return elements, root_element_id + + +def _collect_exposed_inputs_from_form(workflow: Mapping[str, Any]) -> set[str]: + elements, root_element_id = _get_form_elements(workflow) + if not elements or root_element_id is None: + return set() + + exposed_inputs: set[str] = set() + stack = [root_element_id] + visited: set[str] = set() + + while stack: + element_id = stack.pop() + if element_id in visited: + continue + visited.add(element_id) + + element = elements.get(element_id) + if not _is_mapping(element): + continue + + if element.get("type") == "node-field": + data = element.get("data") + if _is_mapping(data): + field_identifier = data.get("fieldIdentifier") + if _is_mapping(field_identifier): + node_id = field_identifier.get("nodeId") + field_name = field_identifier.get("fieldName") + if isinstance(node_id, str) and isinstance(field_name, str): + exposed_inputs.add(_build_dynamic_input_name(node_id, field_name)) + + data = element.get("data") + if _is_mapping(data): + children = data.get("children") + if isinstance(children, Sequence): + for child_id in reversed(children): + if isinstance(child_id, str): + stack.append(child_id) + + return exposed_inputs + + +def get_exposed_workflow_input_names(workflow: Mapping[str, Any]) -> set[str]: + exposed_inputs = _collect_exposed_inputs_from_form(workflow) + if exposed_inputs: + return exposed_inputs + + workflow_exposed_fields = workflow.get("exposedFields", []) + if not isinstance(workflow_exposed_fields, Sequence): + return set() + + fallback_inputs: set[str] = set() + for field in workflow_exposed_fields: + if not _is_mapping(field): + continue + node_id = field.get("nodeId") + field_name = field.get("fieldName") + if isinstance(node_id, str) and isinstance(field_name, str): + fallback_inputs.add(_build_dynamic_input_name(node_id, field_name)) + + return fallback_inputs + + +def apply_workflow_inputs_to_workflow(workflow: MutableMapping[str, Any], workflow_inputs: Mapping[str, Any]) -> None: + if not workflow_inputs: + return + + allowed_inputs = get_exposed_workflow_input_names(workflow) + for input_name, value in workflow_inputs.items(): + if input_name not in allowed_inputs: + raise InvalidWorkflowInputError( + f"call_saved_workflow input '{input_name}' is not exposed by the selected workflow" + ) + + node_id, field_name = parse_call_saved_workflow_dynamic_input(input_name) + workflow_nodes = workflow.get("nodes", []) + if not isinstance(workflow_nodes, list): + raise InvalidWorkflowInputError( + f"call_saved_workflow input '{input_name}' targets missing child workflow node '{node_id}'" + ) + matching_node = next( + ( + node + for node in workflow_nodes + if _is_mapping(node) + and _is_mapping(node.get("data")) + and node.get("id") == node_id + and node["data"].get("id") == node_id + ), + None, + ) + if matching_node is None: + raise InvalidWorkflowInputError( + f"call_saved_workflow input '{input_name}' targets missing child workflow node '{node_id}'" + ) + matching_node_data = matching_node["data"] + node_type = matching_node_data.get("type") + if not isinstance(node_type, str): + raise InvalidWorkflowInputError( + f"call_saved_workflow input '{input_name}' targets missing child workflow node '{node_id}'" + ) + invocation_class = InvocationRegistry.get_invocation_for_type(node_type) + if invocation_class is None or field_name not in invocation_class.model_fields: + raise InvalidWorkflowInputError( + f"call_saved_workflow input '{input_name}' targets missing child workflow field '{field_name}'" + ) + inputs = matching_node_data.setdefault("inputs", {}) + if not _is_mapping(inputs): + raise InvalidWorkflowInputError( + f"call_saved_workflow input '{input_name}' targets invalid child workflow inputs on '{node_id}'" + ) + inputs[field_name] = {"value": value} + + +def apply_workflow_inputs_to_graph( + graph: Graph, workflow: Mapping[str, Any], workflow_inputs: Mapping[str, Any] +) -> None: + if not workflow_inputs: + return + + mutable_workflow = dict(workflow) + apply_workflow_inputs_to_workflow(mutable_workflow, workflow_inputs) + for input_name, value in workflow_inputs.items(): + node_id, field_name = parse_call_saved_workflow_dynamic_input(input_name) + node = graph.nodes.get(node_id) + if node is None: + continue + setattr(node, field_name, value) + + +def _raise_if_unsupported_invocation_type(node_type: str, node_id: str) -> None: + invocation_class = InvocationRegistry.get_invocation_for_type(node_type) + if invocation_class is None: + return + + if ( + invocation_class.UIConfig.category == "batch" + and invocation_class.UIConfig.classification == Classification.Special + and not node_type.endswith("_generator") + ): + raise UnsupportedWorkflowNodeError( + f"call_saved_workflow does not yet support batch-special child workflow nodes such as " + f"'{node_type}' (node '{node_id}')" + ) + + +def _validate_callable_workflow_nodes(workflow_nodes: Sequence[Any]) -> None: + workflow_return_node_ids: list[str] = [] + + for node in workflow_nodes: + if not _is_invocation_node(node): + continue + + data = node["data"] + node_id = data.get("id") + node_type = data.get("type") + if not isinstance(node_id, str) or not isinstance(node_type, str): + continue + + _raise_if_unsupported_invocation_type(node_type, node_id) + + if node_type == "workflow_return": + workflow_return_node_ids.append(node_id) + + if len(workflow_return_node_ids) != 1: + raise UnsupportedWorkflowNodeError( + "call_saved_workflow requires the selected workflow to contain exactly one workflow_return node" + ) + + +def _get_default_edges(workflow_edges: Sequence[Any]) -> list[Mapping[str, Any]]: + return [edge for edge in workflow_edges if _is_mapping(edge) and edge.get("type") == "default"] + + +def _get_connector_input_edge( + connector_id: str, workflow_edges: Sequence[Mapping[str, Any]] +) -> Mapping[str, Any] | None: + return next( + ( + edge + for edge in workflow_edges + if edge.get("target") == connector_id and edge.get("targetHandle") == CONNECTOR_INPUT_HANDLE + ), + None, + ) + + +def _resolve_connector_source( + connector_id: str, workflow_nodes: dict[str, Mapping[str, Any]], workflow_edges: Sequence[Mapping[str, Any]] +) -> tuple[str, str] | None: + visited: set[str] = set() + + def resolve(node_id: str) -> tuple[str, str] | None: + if node_id in visited: + return None + visited.add(node_id) + + incoming_edge = _get_connector_input_edge(node_id, workflow_edges) + if incoming_edge is None: + return None + + source_id = incoming_edge.get("source") + source_handle = incoming_edge.get("sourceHandle") + if not isinstance(source_id, str) or not isinstance(source_handle, str): + return None + + source_node = workflow_nodes.get(source_id) + if source_node is None: + return None + + if _is_invocation_node(source_node): + return (source_id, source_handle) + + if _is_connector_node(source_node): + return resolve(source_id) + + return None + + return resolve(connector_id) + + +def build_graph_from_workflow(workflow: Mapping[str, Any]) -> Graph: + workflow_nodes_raw = workflow.get("nodes", []) + workflow_edges_raw = workflow.get("edges", []) + _validate_callable_workflow_nodes(workflow_nodes_raw if isinstance(workflow_nodes_raw, Sequence) else []) + + workflow_nodes = { + node["id"]: node for node in workflow_nodes_raw if _is_mapping(node) and isinstance(node.get("id"), str) + } + default_edges = _get_default_edges(workflow_edges_raw if isinstance(workflow_edges_raw, Sequence) else []) + + parsed_nodes: dict[str, dict[str, Any]] = {} + for node in workflow_nodes.values(): + if not _is_invocation_node(node): + continue + + data = node["data"] + node_id = data.get("id") + node_type = data.get("type") + if not isinstance(node_id, str) or not isinstance(node_type, str): + continue + + graph_node: dict[str, Any] = { + "id": node_id, + "type": node_type, + "use_cache": data.get("useCache", False), + "is_intermediate": data.get("isIntermediate", False), + } + + inputs = data.get("inputs", {}) + if _is_mapping(inputs): + for field_name, field_value in inputs.items(): + if not isinstance(field_name, str) or not _is_mapping(field_value): + continue + graph_node[field_name] = field_value.get("value") + + parsed_nodes[node_id] = graph_node + + parsed_edges: list[dict[str, dict[str, str]]] = [] + seen_edges: set[tuple[str, str, str, str]] = set() + + for edge in default_edges: + source_id = edge.get("source") + target_id = edge.get("target") + source_handle = edge.get("sourceHandle") + target_handle = edge.get("targetHandle") + if not all(isinstance(v, str) for v in (source_id, target_id, source_handle, target_handle)): + continue + + target_node = workflow_nodes.get(target_id) + if not _is_invocation_node(target_node): + continue + + source_node = workflow_nodes.get(source_id) + resolved_source: tuple[str, str] | None = None + if _is_invocation_node(source_node): + resolved_source = (source_id, source_handle) + elif _is_connector_node(source_node): + resolved_source = _resolve_connector_source(source_id, workflow_nodes, default_edges) + + if resolved_source is None: + continue + + resolved_source_id, resolved_source_handle = resolved_source + edge_key = (resolved_source_id, resolved_source_handle, target_id, target_handle) + if edge_key in seen_edges: + continue + seen_edges.add(edge_key) + + parsed_edges.append( + { + "source": { + "node_id": resolved_source_id, + "field": resolved_source_handle, + }, + "destination": { + "node_id": target_id, + "field": target_handle, + }, + } + ) + + for edge in parsed_edges: + destination_node_id = edge["destination"]["node_id"] + destination_field = edge["destination"]["field"] + parsed_nodes[destination_node_id].pop(destination_field, None) + + return Graph.model_validate( + { + "nodes": parsed_nodes, + "edges": [ + Edge( + source=EdgeConnection(**edge["source"]), + destination=EdgeConnection(**edge["destination"]), + ) + for edge in parsed_edges + ], + } + ) diff --git a/invokeai/app/services/workflow_records/workflow_records_common.py b/invokeai/app/services/workflow_records/workflow_records_common.py index 9c505530c90..178dd5a8cff 100644 --- a/invokeai/app/services/workflow_records/workflow_records_common.py +++ b/invokeai/app/services/workflow_records/workflow_records_common.py @@ -5,6 +5,7 @@ import semver from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator +from invokeai.app.services.shared.workflow_call_compatibility_common import WorkflowCallCompatibility from invokeai.app.util.metaenum import MetaEnum __workflow_meta_version__ = semver.Version.parse("1.0.0") @@ -73,6 +74,23 @@ class WorkflowWithoutID(BaseModel): model_config = ConfigDict(extra="ignore") + @field_validator("nodes") + @classmethod + def validate_workflow_return_node_uniqueness(cls, nodes: list[dict[str, JsonValue]]): + workflow_return_count = 0 + + for node in nodes: + if not isinstance(node, dict) or node.get("type") != "invocation": + continue + data = node.get("data") + if isinstance(data, dict) and data.get("type") == "workflow_return": + workflow_return_count += 1 + + if workflow_return_count > 1: + raise ValueError("A workflow may not contain more than one workflow_return node.") + + return nodes + WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID) @@ -131,7 +149,13 @@ class WorkflowRecordListItemDTO(WorkflowRecordDTOBase): class WorkflowRecordWithThumbnailDTO(WorkflowRecordDTO): thumbnail_url: str | None = Field(default=None, description="The URL of the workflow thumbnail.") + call_saved_workflow_compatibility: WorkflowCallCompatibility | None = Field( + default=None, description="Whether this workflow is currently callable by call_saved_workflow." + ) class WorkflowRecordListItemWithThumbnailDTO(WorkflowRecordListItemDTO): thumbnail_url: str | None = Field(default=None, description="The URL of the workflow thumbnail.") + call_saved_workflow_compatibility: WorkflowCallCompatibility | None = Field( + default=None, description="Whether this workflow is currently callable by call_saved_workflow." + ) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index c62122db222..2b7104a0b6b 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -446,6 +446,7 @@ "credits": "Credits", "pending": "Pending", "in_progress": "In Progress", + "waiting": "Waiting On Child Workflow", "paused": "Paused", "completed": "Completed", "failed": "Failed", @@ -1469,6 +1470,12 @@ "loadWorkflow": "Load Workflow", "noWorkflows": "No Workflows", "noMatchingWorkflows": "No Matching Workflows", + "savedWorkflowChoose": "Choose a workflow", + "savedWorkflowDefaultBadge": "Default", + "savedWorkflowMissing": "Missing or inaccessible workflow", + "savedWorkflowUnsupported": "Unsupported", + "savedWorkflowSelectExposedFields": "Select a workflow with exposed form fields", + "savedWorkflowUpdating": "Updating...", "noWorkflow": "No Workflow", "noWorkflowToSave": "No workflow to save", "unableToUpdateNode": "Node update failed: node {{node}} of type {{type}} (may require deleting and recreating)", @@ -2399,6 +2406,7 @@ "noRecentWorkflows": "No Recent Workflows", "private": "Private", "shared": "Shared", + "savedWorkflowUnsupportedDescription": "This workflow cannot currently be used by Call Saved Workflow.", "published": "Published", "browseWorkflows": "Browse Workflows", "deselectAll": "Deselect All", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/state.ts b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/state.ts index 6c16a8fdf22..4ee71acdff3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/state.ts +++ b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/state.ts @@ -70,7 +70,9 @@ const getQueueItemStatusRank = (status: S['SessionQueueItem']['status']): number switch (status) { case 'pending': return 0; + // Waiting items are suspended on child workflow execution, but they are still nonterminal. case 'in_progress': + case 'waiting': return 1; case 'completed': case 'failed': diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/CallSavedWorkflowNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/CallSavedWorkflowNode.tsx new file mode 100644 index 00000000000..8db547a40d3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/CallSavedWorkflowNode.tsx @@ -0,0 +1,198 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; +import { Badge, Flex, Grid, GridItem } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { InputFieldEditModeNodes } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes'; +import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate'; +import { OutputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldGate'; +import { OutputFieldNodesEditorView } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldNodesEditorView'; +import InvocationNodeFooter from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter'; +import InvocationNodeHeader from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader'; +import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance'; +import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames'; +import { useWithFooter } from 'features/nodes/hooks/useWithFooter'; +import { $templates, callSavedWorkflowDynamicFieldsChanged } from 'features/nodes/store/nodesSlice'; +import { selectNodesSlice } from 'features/nodes/store/selectors'; +import type { SavedWorkflowFieldInputInstance } from 'features/nodes/types/field'; +import { memo, useEffect, useMemo, useRef } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetWorkflowQuery } from 'services/api/endpoints/workflows'; + +import { + getSavedWorkflowDynamicEdgeIdsToRemove, + getSavedWorkflowDynamicFields, + shouldSyncSavedWorkflowDynamicFields, +} from './callSavedWorkflowFormUtils'; + +const bodySx: SystemStyleObject = { + flexDirection: 'column', + w: 'full', + h: 'full', + py: 2, + gap: 1, + borderBottomRadius: 'base', + '&[data-with-footer="true"]': { + borderBottomRadius: 0, + }, + '&[data-with-footer="false"]': { + pb: 4, + }, +}; + +const dynamicFieldSx: SystemStyleObject = { + w: 'full', +}; + +type Props = { + nodeId: string; + isOpen: boolean; +}; + +const CallSavedWorkflowNode = ({ nodeId, isOpen }: Props) => { + const withFooter = useWithFooter(); + const { t } = useTranslation(); + const workflowIdField = useInputFieldInstance('workflow_id'); + const templates = useStore($templates); + const dispatch = useAppDispatch(); + const nodesState = useAppSelector(selectNodesSlice); + + const { data: workflow } = useGetWorkflowQuery(workflowIdField.value, { + skip: !workflowIdField.value, + }); + + const shouldSyncDynamicFields = shouldSyncSavedWorkflowDynamicFields({ workflowId: workflowIdField.value, workflow }); + const dynamicFields = useMemo(() => getSavedWorkflowDynamicFields(workflow, templates), [templates, workflow]); + const edgeIdsToRemove = useMemo( + () => + getSavedWorkflowDynamicEdgeIdsToRemove({ + nodeId, + fields: dynamicFields, + nodes: nodesState.nodes, + edges: nodesState.edges, + templates, + }), + [dynamicFields, nodeId, nodesState.edges, nodesState.nodes, templates] + ); + const syncKey = useMemo( + () => JSON.stringify({ fields: dynamicFields, edgeIdsToRemove }), + [dynamicFields, edgeIdsToRemove] + ); + const lastSyncKeyRef = useRef(null); + + useEffect(() => { + if (!shouldSyncDynamicFields) { + return; + } + if (lastSyncKeyRef.current === syncKey) { + return; + } + lastSyncKeyRef.current = syncKey; + dispatch(callSavedWorkflowDynamicFieldsChanged({ nodeId, fields: dynamicFields, edgeIdsToRemove })); + }, [dispatch, dynamicFields, edgeIdsToRemove, nodeId, shouldSyncDynamicFields, syncKey]); + + return ( + <> + + {isOpen && ( + <> + + + + + + + + + + + + + + {withFooter && } + + )} + + ); +}; + +export default memo(CallSavedWorkflowNode); + +const DynamicFieldsSection = memo( + ({ + nodeId, + fields, + emptyMessage, + }: { + nodeId: string; + fields: ReturnType; + emptyMessage: string; + }) => { + if (fields.length === 0) { + return ( + + {emptyMessage} + + ); + } + + return ( + <> + {fields.map((field) => ( + + ))} + + ); + } +); +DynamicFieldsSection.displayName = 'DynamicFieldsSection'; + +const DynamicFieldRow = memo( + ({ + nodeId, + fieldName, + settings, + }: { + nodeId: string; + fieldName: string; + settings: ReturnType[number]['settings']; + }) => { + return ( + + + + + + ); + } +); +DynamicFieldRow.displayName = 'DynamicFieldRow'; + +const OutputFields = memo(({ nodeId }: { nodeId: string }) => { + const fieldNames = useOutputFieldNames(); + + if (fieldNames.length === 0) { + return null; + } + + return ( + <> + {fieldNames.map((fieldName, i) => ( + + + + + + ))} + + ); +}); +OutputFields.displayName = 'OutputFields'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx index 7a4ea9ca65c..1961e4dc751 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx @@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit'; import type { Node, NodeProps } from '@xyflow/react'; import { useAppSelector } from 'app/store/storeHooks'; import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper'; +import CallSavedWorkflowNode from 'features/nodes/components/flow/nodes/Invocation/CallSavedWorkflowNode'; import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode'; import { $templates } from 'features/nodes/store/nodesSlice'; import { selectNodes } from 'features/nodes/store/selectors'; @@ -40,7 +41,11 @@ const InvocationNodeWrapper = (props: NodeProps>) => { return ( - + {type === 'call_saved_workflow' ? ( + + ) : ( + + )} ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/callSavedWorkflowFormUtils.test.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/callSavedWorkflowFormUtils.test.ts new file mode 100644 index 00000000000..68290298bb5 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/callSavedWorkflowFormUtils.test.ts @@ -0,0 +1,292 @@ +import { buildEdge, buildNode, call_saved_workflow, templates } from 'features/nodes/store/util/testUtils'; +import type { BooleanFieldInputTemplate } from 'features/nodes/types/field'; +import { + type BuilderForm, + buildHeading, + buildNodeFieldElement, + getDefaultForm, + isContainerElement, +} from 'features/nodes/types/workflow'; +import type { paths } from 'services/api/schema'; +import { describe, expect, it } from 'vitest'; + +import { + getRenderableWorkflowForm, + getSavedWorkflowDynamicEdgeIdsToRemove, + getSavedWorkflowDynamicFields, + getSavedWorkflowFormFieldData, + shouldSyncSavedWorkflowDynamicFields, +} from './callSavedWorkflowFormUtils'; + +type WorkflowResponse = + paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json']; + +const addTemplate = templates.add; +if (!addTemplate) { + throw new Error('Expected add template'); +} +const addInputA = addTemplate.inputs.a; +const addInputB = addTemplate.inputs.b; +if (!addInputA || !addInputB) { + throw new Error('Expected add template inputs'); +} + +const getRootChildren = (form: BuilderForm): string[] => { + const root = form.elements[form.rootElementId]; + + if (!root || !isContainerElement(root)) { + throw new Error('Expected root container'); + } + + return root.data.children; +}; + +const buildWorkflowResponse = (overrides?: { + exposedFields?: Array<{ nodeId: string; fieldName: string }>; + form?: BuilderForm | null; + inputs?: Record; +}): WorkflowResponse => + ({ + workflow_id: 'workflow-1', + name: 'Workflow 1', + created_at: '2026-04-08T00:00:00Z', + updated_at: '2026-04-08T00:00:00Z', + opened_at: null, + user_id: 'user-1', + is_public: false, + thumbnail_url: null, + workflow: { + id: 'workflow-1', + name: 'Workflow 1', + author: 'InvokeAI', + description: 'A workflow', + version: '1.0.0', + contact: '', + tags: '', + notes: '', + exposedFields: overrides?.exposedFields ?? [], + meta: { + category: 'user', + version: '3.0.0', + }, + nodes: [ + { + id: 'node-1', + type: 'invocation', + data: { + id: 'node-1', + type: 'add', + inputs: overrides?.inputs ?? { + a: { name: 'a', label: 'Left Addend', value: 1 }, + b: { name: 'b', label: '', value: 2 }, + }, + }, + }, + ], + edges: [], + form: overrides?.form ?? getDefaultForm(), + }, + }) as WorkflowResponse; + +describe('callSavedWorkflowFormUtils', () => { + it('does not sync dynamic fields while a selected saved workflow is still loading', () => { + expect( + shouldSyncSavedWorkflowDynamicFields({ + workflowId: 'workflow-1', + workflow: undefined, + }) + ).toBe(false); + }); + + it('syncs dynamic fields when no workflow is selected or when the selected workflow is loaded', () => { + expect( + shouldSyncSavedWorkflowDynamicFields({ + workflowId: '', + workflow: undefined, + }) + ).toBe(true); + expect( + shouldSyncSavedWorkflowDynamicFields({ + workflowId: 'workflow-1', + workflow: buildWorkflowResponse(), + }) + ).toBe(true); + }); + + it('returns the stored form when it is non-empty and valid', () => { + const form = getDefaultForm(); + const heading = buildHeading('Workflow Inputs'); + form.elements[heading.id] = { ...heading, parentId: form.rootElementId }; + getRootChildren(form).push(heading.id); + + const workflow = buildWorkflowResponse({ form }); + + expect(getRenderableWorkflowForm(workflow, templates)).toBe(form); + }); + + it('builds a fallback form from exposed fields when the stored form is empty', () => { + const workflow = buildWorkflowResponse({ + exposedFields: [{ nodeId: 'node-1', fieldName: 'a' }], + }); + + const form = getRenderableWorkflowForm(workflow, templates); + + expect(form).not.toBeNull(); + expect(form ? getRootChildren(form) : []).toHaveLength(1); + const childId = form ? getRootChildren(form)[0] : undefined; + expect(childId).toBeDefined(); + expect(childId ? form?.elements[childId]?.type : undefined).toBe('node-field'); + }); + + it('skips exposed fields that do not resolve to a known node field', () => { + const workflow = buildWorkflowResponse({ + exposedFields: [{ nodeId: 'missing-node', fieldName: 'a' }], + }); + + const form = getRenderableWorkflowForm(workflow, templates); + + expect(form).not.toBeNull(); + expect(form ? getRootChildren(form) : []).toHaveLength(0); + }); + + it('uses the stored field label when available', () => { + const element = buildNodeFieldElement('node-1', 'a', addInputA.type); + const workflow = buildWorkflowResponse(); + + expect(getSavedWorkflowFormFieldData(workflow, templates, element)).toEqual( + expect.objectContaining({ + label: 'Left Addend', + description: 'The first number', + typeName: 'IntegerField', + isMissing: false, + }) + ); + }); + + it('falls back to the template title when the stored field label is empty', () => { + const element = buildNodeFieldElement('node-1', 'b', addInputB.type); + const workflow = buildWorkflowResponse(); + + expect(getSavedWorkflowFormFieldData(workflow, templates, element)).toEqual( + expect.objectContaining({ + label: 'B', + description: 'The second number', + typeName: 'IntegerField', + isMissing: false, + }) + ); + }); + + it('marks missing node field references as missing', () => { + const element = buildNodeFieldElement('missing-node', 'a', addInputA.type); + const workflow = buildWorkflowResponse(); + + expect(getSavedWorkflowFormFieldData(workflow, templates, element)).toEqual( + expect.objectContaining({ + label: 'a', + description: '', + typeName: null, + isMissing: true, + }) + ); + }); + + it('builds ordered dynamic fields from the workflow form', () => { + const workflow = buildWorkflowResponse({ + exposedFields: [ + { nodeId: 'node-1', fieldName: 'a' }, + { nodeId: 'node-1', fieldName: 'b' }, + ], + }); + + const dynamicFields = getSavedWorkflowDynamicFields(workflow, templates); + + expect(dynamicFields).toHaveLength(2); + expect(dynamicFields[0]?.fieldName).toBe('saved_workflow_input::node-1::a'); + expect(dynamicFields[1]?.fieldName).toBe('saved_workflow_input::node-1::b'); + expect(dynamicFields[0]?.fieldTemplate.title).toBe('Left Addend'); + expect(dynamicFields[1]?.fieldTemplate.title).toBe('B'); + expect(dynamicFields[0]?.fieldTemplate.name).toBe(dynamicFields[0]?.fieldName); + expect(dynamicFields[1]?.fieldTemplate.ui_order).toBe(1); + expect(dynamicFields[0]?.initialValue).toBe(1); + expect(dynamicFields[1]?.initialValue).toBe(2); + }); + + it('preserves inbound edges when the same exposed dynamic field remains compatible', () => { + const sourceNode = buildNode(addTemplate); + const targetNode = buildNode(call_saved_workflow); + const fields = getSavedWorkflowDynamicFields( + buildWorkflowResponse({ + exposedFields: [{ nodeId: 'node-1', fieldName: 'a' }], + }), + templates + ); + + const edgeIdsToRemove = getSavedWorkflowDynamicEdgeIdsToRemove({ + nodeId: targetNode.id, + fields, + nodes: [sourceNode, targetNode], + edges: [buildEdge(sourceNode.id, 'value', targetNode.id, 'saved_workflow_input::node-1::a')], + templates, + }); + + expect(edgeIdsToRemove).toEqual([]); + }); + + it('removes inbound edges when a previously connected field is no longer exposed', () => { + const sourceNode = buildNode(addTemplate); + const targetNode = buildNode(call_saved_workflow); + const edge = buildEdge(sourceNode.id, 'value', targetNode.id, 'saved_workflow_input::node-1::a'); + + const edgeIdsToRemove = getSavedWorkflowDynamicEdgeIdsToRemove({ + nodeId: targetNode.id, + fields: [], + nodes: [sourceNode, targetNode], + edges: [edge], + templates, + }); + + expect(edgeIdsToRemove).toEqual([edge.id]); + }); + + it('removes inbound edges when an exposed field changes to an incompatible type', () => { + const sourceNode = buildNode(addTemplate); + const targetNode = buildNode(call_saved_workflow); + const edge = buildEdge(sourceNode.id, 'value', targetNode.id, 'saved_workflow_input::node-1::a'); + + const booleanTemplate: BooleanFieldInputTemplate = { + name: 'saved_workflow_input::node-1::a', + title: 'Enabled', + required: false, + description: 'Whether the workflow is enabled', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + default: false, + type: { + name: 'BooleanField', + cardinality: 'SINGLE', + batch: false, + }, + }; + + const edgeIdsToRemove = getSavedWorkflowDynamicEdgeIdsToRemove({ + nodeId: targetNode.id, + fields: [ + { + fieldName: 'saved_workflow_input::node-1::a', + fieldTemplate: booleanTemplate, + label: 'Enabled', + description: 'Whether the workflow is enabled', + initialValue: false, + settings: undefined, + }, + ], + nodes: [sourceNode, targetNode], + edges: [edge], + templates, + }); + + expect(edgeIdsToRemove).toEqual([edge.id]); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/callSavedWorkflowFormUtils.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/callSavedWorkflowFormUtils.ts new file mode 100644 index 00000000000..fbc600f85c2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/callSavedWorkflowFormUtils.ts @@ -0,0 +1,307 @@ +import { addElement, getIsFormEmpty } from 'features/nodes/components/sidePanel/builder/form-manipulation'; +import { CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX } from 'features/nodes/store/nodesSlice'; +import type { Templates } from 'features/nodes/store/types'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; +import type { FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field'; +import { isStatefulFieldType } from 'features/nodes/types/field'; +import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; +import { + type BuilderForm, + buildNodeFieldElement, + type FormElement, + getDefaultForm, + isContainerElement, + type NodeFieldElement, + validateFormStructure, +} from 'features/nodes/types/workflow'; +import type { paths } from 'services/api/schema'; + +type WorkflowResponse = + paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json']; + +type WorkflowNodeLike = { + data?: { + id?: string; + type?: string; + inputs?: Record; + }; +}; + +type ExposedFieldLike = { + nodeId?: string; + fieldName?: string; +}; + +type SavedWorkflowFormFieldData = { + label: string; + description: string; + typeName: string | null; + isMissing: boolean; +}; + +type SavedWorkflowDynamicField = { + fieldName: string; + fieldTemplate: FieldInputTemplate; + label: string; + description: string; + initialValue: StatefulFieldValue; + settings: NodeFieldElement['data']['settings']; +}; + +const getStoredForm = (workflow: WorkflowResponse | undefined): BuilderForm | null => { + const form = workflow?.workflow.form; + + if (!form || typeof form !== 'object' || !('elements' in form) || !('rootElementId' in form)) { + return null; + } + + return form as unknown as BuilderForm; +}; + +const getWorkflowNodes = (workflow: WorkflowResponse | undefined): WorkflowNodeLike[] => { + return (workflow?.workflow.nodes ?? []) as WorkflowNodeLike[]; +}; + +const buildFormFromExposedFields = ( + workflow: WorkflowResponse | undefined, + templates: Templates +): BuilderForm | null => { + const exposedFields = (workflow?.workflow.exposedFields ?? []) as ExposedFieldLike[]; + + if (exposedFields.length === 0) { + return null; + } + + const nodes = getWorkflowNodes(workflow); + const form = getDefaultForm(); + + for (const { nodeId, fieldName } of [...exposedFields].reverse()) { + if (!nodeId || !fieldName) { + continue; + } + + const node = nodes.find((candidate) => candidate.data?.id === nodeId); + const nodeType = node?.data?.type; + if (!nodeType) { + continue; + } + + const fieldTemplate = templates[nodeType]?.inputs[fieldName]; + if (!fieldTemplate) { + continue; + } + + const element = buildNodeFieldElement(nodeId, fieldName, fieldTemplate.type); + element.data.showDescription = false; + addElement({ + form, + element, + parentId: form.rootElementId, + index: 0, + }); + } + + return form; +}; + +export const getRenderableWorkflowForm = ( + workflow: WorkflowResponse | undefined, + templates: Templates +): BuilderForm | null => { + const storedForm = getStoredForm(workflow); + + if (storedForm && validateFormStructure(storedForm) && !getIsFormEmpty(storedForm)) { + return storedForm; + } + + const fallbackForm = buildFormFromExposedFields(workflow, templates); + if (fallbackForm && !getIsFormEmpty(fallbackForm)) { + return fallbackForm; + } + + if (storedForm && validateFormStructure(storedForm)) { + return storedForm; + } + + return null; +}; + +export const getSavedWorkflowFormFieldData = ( + workflow: WorkflowResponse | undefined, + templates: Templates, + element: NodeFieldElement +): SavedWorkflowFormFieldData => { + const { nodeId, fieldName } = element.data.fieldIdentifier; + const node = getWorkflowNodes(workflow).find((candidate) => candidate.data?.id === nodeId); + const nodeType = node?.data?.type; + const field = node?.data?.inputs?.[fieldName]; + const fieldTemplate = nodeType ? templates[nodeType]?.inputs[fieldName] : undefined; + + return { + label: field?.label || fieldTemplate?.title || fieldName, + description: fieldTemplate?.description || '', + typeName: fieldTemplate?.type.name ?? null, + isMissing: !node || !field || !fieldTemplate, + }; +}; + +const getElementsInOrder = (form: BuilderForm): FormElement[] => { + const orderedElements: FormElement[] = []; + + const visit = (elementId: string) => { + const element = form.elements[elementId]; + if (!element) { + return; + } + + orderedElements.push(element); + if (isContainerElement(element)) { + for (const childId of element.data.children) { + visit(childId); + } + } + }; + + visit(form.rootElementId); + return orderedElements; +}; + +const buildDynamicFieldName = (nodeId: string, fieldName: string): string => { + return `${CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX}${nodeId}::${fieldName}`; +}; + +const cloneDynamicFieldTemplate = ({ + fieldName, + fieldTemplate, + label, + description, + uiOrder, +}: { + fieldName: string; + fieldTemplate: FieldInputTemplate; + label: string; + description: string; + uiOrder: number; +}): FieldInputTemplate => { + return { + ...fieldTemplate, + name: fieldName, + title: label, + description, + ui_order: uiOrder, + input: 'any', + ui_hidden: false, + } as FieldInputTemplate; +}; + +export const getSavedWorkflowDynamicFields = ( + workflow: WorkflowResponse | undefined, + templates: Templates +): SavedWorkflowDynamicField[] => { + const form = getRenderableWorkflowForm(workflow, templates); + if (!form) { + return []; + } + + const nodes = getWorkflowNodes(workflow); + const dynamicFields: SavedWorkflowDynamicField[] = []; + + for (const element of getElementsInOrder(form)) { + if (!('data' in element) || !('fieldIdentifier' in (element.data ?? {}))) { + continue; + } + if (!('type' in element) || element.type !== 'node-field') { + continue; + } + + const { nodeId, fieldName } = element.data.fieldIdentifier; + const node = nodes.find((candidate) => candidate.data?.id === nodeId); + const nodeType = node?.data?.type; + const field = node?.data?.inputs?.[fieldName]; + const fieldTemplate = nodeType ? templates[nodeType]?.inputs[fieldName] : undefined; + + if (!field || !fieldTemplate || !isStatefulFieldType(fieldTemplate.type)) { + continue; + } + + const dynamicFieldName = buildDynamicFieldName(nodeId, fieldName); + const label = field.label || fieldTemplate.title || fieldName; + const description = field.description || fieldTemplate.description || ''; + + dynamicFields.push({ + fieldName: dynamicFieldName, + fieldTemplate: cloneDynamicFieldTemplate({ + fieldName: dynamicFieldName, + fieldTemplate, + label, + description, + uiOrder: dynamicFields.length, + }), + label, + description, + initialValue: field.value, + settings: element.data.settings, + }); + } + + return dynamicFields; +}; + +export const shouldSyncSavedWorkflowDynamicFields = ({ + workflowId, + workflow, +}: { + workflowId: string | null | undefined; + workflow: WorkflowResponse | undefined; +}): boolean => { + return !workflowId || Boolean(workflow); +}; + +export const getSavedWorkflowDynamicEdgeIdsToRemove = ({ + nodeId, + fields, + nodes, + edges, + templates, +}: { + nodeId: string; + fields: SavedWorkflowDynamicField[]; + nodes: AnyNode[]; + edges: AnyEdge[]; + templates: Templates; +}): string[] => { + const nextFieldTemplates = new Map(fields.map((field) => [field.fieldName, field.fieldTemplate])); + + return edges.flatMap((edge) => { + if (edge.type !== 'default' || edge.target !== nodeId || !edge.targetHandle) { + return []; + } + + if (!edge.targetHandle.startsWith(CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX)) { + return []; + } + + const targetFieldTemplate = nextFieldTemplates.get(edge.targetHandle); + if (!targetFieldTemplate || targetFieldTemplate.input === 'direct' || !edge.sourceHandle) { + return [edge.id]; + } + + const sourceNode = nodes.find((node) => node.id === edge.source); + if (!sourceNode || !isInvocationNode(sourceNode)) { + return [edge.id]; + } + + const sourceTemplate = templates[sourceNode.data.type]; + const sourceFieldTemplate = sourceTemplate?.outputs[edge.sourceHandle]; + if (!sourceFieldTemplate) { + return [edge.id]; + } + + if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { + return [edge.id]; + } + + return []; + }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/context.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/context.tsx index 8618511d77d..af0a9fe4e95 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/context.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/context.tsx @@ -4,7 +4,12 @@ import { createSelector } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import { $templates } from 'features/nodes/store/nodesSlice'; import { selectEdges, selectNodeFieldElements, selectNodes } from 'features/nodes/store/selectors'; -import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation'; +import { + getInvocationNodeInputTemplate, + getInvocationNodeTemplateWithDynamicInputs, + type InvocationNode, + type InvocationTemplate, +} from 'features/nodes/types/invocation'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import type { PropsWithChildren } from 'react'; import { createContext, memo, useContext, useMemo } from 'react'; @@ -81,8 +86,12 @@ export const InvocationNodeContextProvider = memo(({ nodeId, children }: PropsWi }) ); const selectNodeTemplateSafe = getSelectorFromCache(cache, 'selectNodeTemplateSafe', () => - createSelector(selectNodeTypeSafe, (type) => { - return type ? (templates[type] ?? null) : null; + createSelector(selectNodeDataSafe, (data) => { + if (!data) { + return null; + } + const template = templates[data.type]; + return template ? getInvocationNodeTemplateWithDynamicInputs(data, template) : null; }) ); const selectNodeInputsSafe = getSelectorFromCache(cache, 'selectNodeInputsSafe', () => @@ -129,12 +138,12 @@ export const InvocationNodeContextProvider = memo(({ nodeId, children }: PropsWi }) ); const selectNodeTemplateOrThrow = getSelectorFromCache(cache, 'selectNodeTemplateOrThrow', () => - createSelector(selectNodeTypeOrThrow, (type) => { - const template = templates[type]; + createSelector(selectNodeDataOrThrow, (data) => { + const template = templates[data.type]; if (template === undefined) { - throw new Error(`Cannot find template for node with id ${nodeId} with type ${type}`); + throw new Error(`Cannot find template for node with id ${nodeId} with type ${data.type}`); } - return template; + return getInvocationNodeTemplateWithDynamicInputs(data, template); }) ); const selectNodeInputsOrThrow = getSelectorFromCache(cache, 'selectNodeInputsOrThrow', () => @@ -154,8 +163,8 @@ export const InvocationNodeContextProvider = memo(({ nodeId, children }: PropsWi ); const buildSelectInputFieldTemplateOrThrow = (fieldName: string) => getSelectorFromCache(cache, `buildSelectInputFieldTemplateOrThrow-${fieldName}`, () => - createSelector(selectNodeTemplateOrThrow, (template) => { - const fieldTemplate = template.inputs[fieldName]; + createSelector(selectNodeDataOrThrow, selectNodeTemplateOrThrow, (data, template) => { + const fieldTemplate = getInvocationNodeInputTemplate(data, template, fieldName); if (fieldTemplate === undefined) { throw new Error(`Cannot find input field template with name ${fieldName} in node ${nodeId}`); } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx index 1597f4ede16..60ce93fa619 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx @@ -9,6 +9,7 @@ import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInva import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow'; import { NO_DRAG_CLASS } from 'features/nodes/types/constants'; import type { FieldInputTemplate } from 'features/nodes/types/field'; +import type { NodeFieldElement } from 'features/nodes/types/workflow'; import { memo, useRef } from 'react'; import { InputFieldAddRemoveFormRoot } from './InputFieldAddRemoveFormRoot'; @@ -19,9 +20,10 @@ import { InputFieldWrapper } from './InputFieldWrapper'; interface Props { nodeId: string; fieldName: string; + settings?: NodeFieldElement['data']['settings']; } -export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => { +export const InputFieldEditModeNodes = memo(({ nodeId, fieldName, settings }: Props) => { const fieldTemplate = useInputFieldTemplateOrThrow(fieldName); const isInvalid = useInputFieldIsInvalid(fieldName); const isConnected = useInputFieldIsConnected(fieldName); @@ -45,6 +47,7 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => { isInvalid={isInvalid} isConnected={isConnected} fieldTemplate={fieldTemplate} + settings={settings} /> ); }); @@ -57,6 +60,7 @@ type CommonProps = { isInvalid: boolean; isConnected: boolean; fieldTemplate: FieldInputTemplate; + settings?: NodeFieldElement['data']['settings']; }; const ConnectedOrConnectionField = memo(({ nodeId, fieldName, isInvalid }: CommonProps) => { @@ -96,7 +100,7 @@ const directFieldSx: SystemStyleObject = { }, }; -const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemplate }: CommonProps) => { +const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemplate, settings }: CommonProps) => { const draggableRef = useRef(null); const dragHandleRef = useRef(null); @@ -116,7 +120,7 @@ const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemp - + {fieldTemplate.input !== 'direct' && } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 60a3f8e472a..4fa020212ee 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -8,6 +8,7 @@ import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flo import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent'; import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent'; import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent'; +import SavedWorkflowFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/SavedWorkflowFieldInputComponent'; import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent'; import { StringGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent'; import { IntegerFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/IntegerField/IntegerFieldInput'; @@ -47,6 +48,8 @@ import { isIntegerGeneratorFieldInputTemplate, isModelIdentifierFieldInputInstance, isModelIdentifierFieldInputTemplate, + isSavedWorkflowFieldInputInstance, + isSavedWorkflowFieldInputTemplate, isSchedulerFieldInputInstance, isSchedulerFieldInputTemplate, isStringFieldCollectionInputInstance, @@ -223,6 +226,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props) return ; } + if (isSavedWorkflowFieldInputTemplate(template)) { + if (!isSavedWorkflowFieldInputInstance(field)) { + return null; + } + return ; + } + if (isColorFieldInputTemplate(template)) { if (!isColorFieldInputInstance(field)) { return null; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SavedWorkflowFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SavedWorkflowFieldInputComponent.tsx new file mode 100644 index 00000000000..538eb4aac0c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SavedWorkflowFieldInputComponent.tsx @@ -0,0 +1,175 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Badge, Combobox, Flex, Text } from '@invoke-ai/ui-library'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { fieldStringValueChanged } from 'features/nodes/store/nodesSlice'; +import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants'; +import type { SavedWorkflowFieldInputInstance, SavedWorkflowFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback, useDeferredValue, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetWorkflowQuery, useListWorkflowsInfiniteInfiniteQuery } from 'services/api/endpoints/workflows'; + +import { + buildSavedWorkflowOptions, + getSavedWorkflowDisplayState, + getSavedWorkflowListItemFromRecord, + getSavedWorkflowPickerOwnedQueryArg, + getSavedWorkflowPickerSharedQueryArg, + getSavedWorkflowSelectionOption, + getSavedWorkflowSelectionState, + mergeSavedWorkflowPickerItems, + MISSING_WORKFLOW_OPTION_VALUE, + shouldFetchNextSavedWorkflowPickerPage, +} from './savedWorkflowFieldUtils'; +import type { FieldComponentProps } from './types'; + +const queryOptions = { + selectFromResult: ({ data, ...rest }) => ({ + items: data?.pages.flatMap(({ items }) => items) ?? EMPTY_ARRAY, + ...rest, + }), +} satisfies Parameters[1]; + +const SavedWorkflowFieldInputComponent = ( + props: FieldComponentProps +) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const [workflowSearchQuery, setWorkflowSearchQuery] = useState(''); + const deferredWorkflowSearchQuery = useDeferredValue(workflowSearchQuery); + const ownedQueryArg = useMemo( + () => + getSavedWorkflowPickerOwnedQueryArg(deferredWorkflowSearchQuery) satisfies Parameters< + typeof useListWorkflowsInfiniteInfiniteQuery + >[0], + [deferredWorkflowSearchQuery] + ); + const sharedQueryArg = useMemo( + () => + getSavedWorkflowPickerSharedQueryArg(deferredWorkflowSearchQuery) satisfies Parameters< + typeof useListWorkflowsInfiniteInfiniteQuery + >[0], + [deferredWorkflowSearchQuery] + ); + const { + items: ownedItems, + isLoading: isOwnedLoading, + isFetching: isOwnedFetching, + hasNextPage: hasNextOwnedPage, + fetchNextPage: fetchNextOwnedPage, + } = useListWorkflowsInfiniteInfiniteQuery(ownedQueryArg, queryOptions); + const { + items: sharedItems, + isLoading: isSharedLoading, + isFetching: isSharedFetching, + hasNextPage: hasNextSharedPage, + fetchNextPage: fetchNextSharedPage, + } = useListWorkflowsInfiniteInfiniteQuery(sharedQueryArg, queryOptions); + const items = useMemo(() => mergeSavedWorkflowPickerItems(ownedItems, sharedItems), [ownedItems, sharedItems]); + const isSelectedWorkflowInList = useMemo( + () => items.some((workflow) => workflow.workflow_id === field.value), + [field.value, items] + ); + const { data: selectedWorkflowRecord } = useGetWorkflowQuery(field.value, { + skip: !field.value || isSelectedWorkflowInList, + }); + const selectedWorkflow = useMemo( + () => (selectedWorkflowRecord ? getSavedWorkflowListItemFromRecord(selectedWorkflowRecord) : undefined), + [selectedWorkflowRecord] + ); + + const options = useMemo(() => buildSavedWorkflowOptions(items), [items]); + const selectionState = useMemo( + () => getSavedWorkflowSelectionState(items, field.value, selectedWorkflow), + [field.value, items, selectedWorkflow] + ); + const value = useMemo(() => { + const option = getSavedWorkflowSelectionOption(selectionState); + if (option?.value === MISSING_WORKFLOW_OPTION_VALUE) { + return { + ...option, + label: t('nodes.savedWorkflowMissing'), + }; + } + return option; + }, [selectionState, t]); + const statusLabel = useMemo(() => { + const displayState = getSavedWorkflowDisplayState(selectionState); + return displayState.statusLabelKey ? t(displayState.statusLabelKey) : null; + }, [selectionState, t]); + const displayState = useMemo(() => getSavedWorkflowDisplayState(selectionState), [selectionState]); + + const onChange = useCallback( + (v) => { + dispatch( + fieldStringValueChanged({ + nodeId, + fieldName: field.name, + value: v?.value ?? '', + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const onMenuScrollToBottom = useCallback(() => { + if (shouldFetchNextSavedWorkflowPickerPage({ hasNextPage: hasNextOwnedPage, isFetching: isOwnedFetching })) { + fetchNextOwnedPage(); + } + if (shouldFetchNextSavedWorkflowPickerPage({ hasNextPage: hasNextSharedPage, isFetching: isSharedFetching })) { + fetchNextSharedPage(); + } + }, [fetchNextOwnedPage, fetchNextSharedPage, hasNextOwnedPage, hasNextSharedPage, isOwnedFetching, isSharedFetching]); + const onInputChange = useCallback((inputValue: string) => { + setWorkflowSearchQuery(inputValue); + return inputValue; + }, []); + + const noOptionsMessage = useCallback(() => t('nodes.noMatchingWorkflows'), [t]); + const isLoading = isOwnedLoading || isSharedLoading; + const isFetching = isOwnedFetching || isSharedFetching; + + return ( + + + {selectionState.status === 'selected' ? ( + + + {selectionState.workflow.name} + + {displayState.badges.includes('unsupported') && ( + {t('nodes.savedWorkflowUnsupported')} + )} + {displayState.badges.includes('default') && ( + {t('nodes.savedWorkflowDefaultBadge')} + )} + {displayState.badges.includes('shared') && {t('workflows.shared')}} + + ) : ( + {statusLabel ?? t('nodes.savedWorkflowChoose')} + )} + {selectionState.status === 'selected' && displayState.compatibilityMessage && ( + + {displayState.compatibilityMessage} + + )} + {isFetching && ( + + {t('nodes.savedWorkflowUpdating')} + + )} + + ); +}; + +export default memo(SavedWorkflowFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/savedWorkflowFieldUtils.test.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/savedWorkflowFieldUtils.test.ts new file mode 100644 index 00000000000..3439cbf3747 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/savedWorkflowFieldUtils.test.ts @@ -0,0 +1,203 @@ +import type { WorkflowRecordListItemWithThumbnailDTO } from 'services/api/types'; +import { describe, expect, it } from 'vitest'; + +import { + buildSavedWorkflowOptions, + getSavedWorkflowDisplayState, + getSavedWorkflowListItemFromRecord, + getSavedWorkflowPickerOwnedQueryArg, + getSavedWorkflowPickerSharedQueryArg, + getSavedWorkflowSelectionOption, + getSavedWorkflowSelectionState, + mergeSavedWorkflowPickerItems, + MISSING_WORKFLOW_OPTION_VALUE, + shouldFetchNextSavedWorkflowPickerPage, +} from './savedWorkflowFieldUtils'; + +const workflows: WorkflowRecordListItemWithThumbnailDTO[] = [ + { + workflow_id: 'workflow-a', + name: 'Alpha Workflow', + created_at: '', + updated_at: '', + opened_at: null, + description: '', + tags: '', + is_public: false, + thumbnail_url: null, + category: 'user', + user_id: 'user-a', + call_saved_workflow_compatibility: { + is_callable: true, + reason: 'ok', + message: null, + }, + }, + { + workflow_id: 'workflow-b', + name: 'Beta Workflow', + created_at: '', + updated_at: '', + opened_at: null, + description: '', + tags: '', + is_public: true, + thumbnail_url: null, + category: 'default', + user_id: 'system', + call_saved_workflow_compatibility: { + is_callable: false, + reason: 'missing_workflow_return', + message: 'The workflow must contain exactly one workflow_return node.', + }, + }, +]; + +describe('savedWorkflowFieldUtils', () => { + it('builds combobox options from visible workflows', () => { + expect(buildSavedWorkflowOptions(workflows)).toEqual([ + { label: 'Alpha Workflow', value: 'workflow-a', isDisabled: false }, + { label: 'Beta Workflow', value: 'workflow-b', isDisabled: true }, + ]); + }); + + it('returns an unselected state for the default empty value', () => { + const selectionState = getSavedWorkflowSelectionState(workflows, ''); + expect(selectionState).toEqual({ status: 'unselected' }); + expect(getSavedWorkflowSelectionOption(selectionState)).toBeNull(); + }); + + it('returns a selected state for a valid workflow id', () => { + const selectionState = getSavedWorkflowSelectionState(workflows, 'workflow-b'); + expect(selectionState).toEqual({ status: 'selected', workflow: workflows[1] }); + expect(getSavedWorkflowSelectionOption(selectionState)).toEqual({ + label: 'Beta Workflow', + value: 'workflow-b', + }); + }); + + it('returns a missing state for a stale or inaccessible workflow id', () => { + const selectionState = getSavedWorkflowSelectionState(workflows, 'missing-workflow'); + expect(selectionState).toEqual({ status: 'missing', workflowId: 'missing-workflow' }); + expect(getSavedWorkflowSelectionOption(selectionState)).toEqual({ + label: MISSING_WORKFLOW_OPTION_VALUE, + value: MISSING_WORKFLOW_OPTION_VALUE, + }); + }); + + it('builds display state for an already-selected unsupported workflow', () => { + const selectionState = getSavedWorkflowSelectionState(workflows, 'workflow-b'); + + expect(getSavedWorkflowDisplayState(selectionState)).toEqual({ + selection: 'selected', + statusLabelKey: null, + badges: ['unsupported', 'default'], + compatibilityMessage: 'The workflow must contain exactly one workflow_return node.', + }); + }); + + it('builds display state for missing and unselected workflows', () => { + expect(getSavedWorkflowDisplayState({ status: 'unselected' })).toEqual({ + selection: 'unselected', + statusLabelKey: 'nodes.savedWorkflowChoose', + badges: [], + compatibilityMessage: null, + }); + + expect(getSavedWorkflowDisplayState({ status: 'missing', workflowId: 'missing-workflow' })).toEqual({ + selection: 'missing', + statusLabelKey: 'nodes.savedWorkflowMissing', + badges: [], + compatibilityMessage: null, + }); + }); + + it('uses a fetched selected workflow when it is not present in the first list page', () => { + const selectedWorkflow = getSavedWorkflowListItemFromRecord({ + workflow_id: 'workflow-z', + name: 'Zeta Workflow', + created_at: '', + updated_at: '', + opened_at: null, + user_id: 'user-z', + is_public: true, + thumbnail_url: null, + call_saved_workflow_compatibility: { + is_callable: true, + reason: 'ok', + message: null, + }, + workflow: { + id: 'workflow-z', + name: 'Zeta Workflow', + author: '', + description: 'A workflow outside the first page', + version: '', + contact: '', + tags: 'paged', + notes: '', + exposedFields: [], + meta: { + category: 'user', + version: '1.0.0', + }, + nodes: [], + edges: [], + form: null, + }, + }); + + const selectionState = getSavedWorkflowSelectionState(workflows, 'workflow-z', selectedWorkflow); + + expect(selectionState).toEqual({ status: 'selected', workflow: selectedWorkflow }); + expect(getSavedWorkflowDisplayState(selectionState)).toEqual({ + selection: 'selected', + statusLabelKey: null, + badges: ['shared'], + compatibilityMessage: null, + }); + }); + + it('queries owned/default workflows and shared public workflows separately', () => { + expect(getSavedWorkflowPickerOwnedQueryArg('landscape')).toMatchObject({ + page: 0, + per_page: 50, + query: 'landscape', + categories: ['user', 'default'], + is_public: undefined, + }); + expect(getSavedWorkflowPickerSharedQueryArg('landscape')).toMatchObject({ + page: 0, + per_page: 50, + query: 'landscape', + categories: ['user'], + is_public: true, + }); + }); + + it('merges paged owned and shared workflow picker results without duplicates', () => { + const ownedWorkflow = workflows[0]; + const defaultWorkflow = workflows[1]; + if (!ownedWorkflow || !defaultWorkflow) { + throw new Error('Expected workflow fixtures'); + } + const sharedWorkflow = { + ...ownedWorkflow, + workflow_id: 'workflow-shared', + name: 'Shared Workflow', + is_public: true, + }; + + expect(mergeSavedWorkflowPickerItems([ownedWorkflow], [defaultWorkflow, sharedWorkflow], [ownedWorkflow])).toEqual([ + ownedWorkflow, + defaultWorkflow, + sharedWorkflow, + ]); + }); + + it('fetches more workflow picker pages only when a query has another idle page', () => { + expect(shouldFetchNextSavedWorkflowPickerPage({ hasNextPage: true, isFetching: false })).toBe(true); + expect(shouldFetchNextSavedWorkflowPickerPage({ hasNextPage: false, isFetching: false })).toBe(false); + expect(shouldFetchNextSavedWorkflowPickerPage({ hasNextPage: true, isFetching: true })).toBe(false); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/savedWorkflowFieldUtils.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/savedWorkflowFieldUtils.ts new file mode 100644 index 00000000000..ca9ac0271a5 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/savedWorkflowFieldUtils.ts @@ -0,0 +1,175 @@ +import type { ComboboxOption } from '@invoke-ai/ui-library'; +import { getWorkflowCallCompatibilityState } from 'features/workflowLibrary/util/workflowCallCompatibility'; +import type { S, WorkflowRecordListItemWithThumbnailDTO } from 'services/api/types'; + +export const MISSING_WORKFLOW_OPTION_VALUE = '__missing_workflow__'; +const SAVED_WORKFLOW_PICKER_PAGE_SIZE = 50; +type SavedWorkflowBadge = 'unsupported' | 'default' | 'shared'; + +type SavedWorkflowSelectionState = + | { status: 'unselected' } + | { status: 'selected'; workflow: WorkflowRecordListItemWithThumbnailDTO } + | { status: 'missing'; workflowId: string }; + +export const buildSavedWorkflowOptions = (workflows: WorkflowRecordListItemWithThumbnailDTO[]): ComboboxOption[] => { + return workflows.map((workflow) => ({ + label: workflow.name, + value: workflow.workflow_id, + isDisabled: workflow.call_saved_workflow_compatibility?.is_callable === false, + })); +}; + +const baseSavedWorkflowPickerQueryArg = { + page: 0, + per_page: SAVED_WORKFLOW_PICKER_PAGE_SIZE, + order_by: 'name', + direction: 'ASC', + tags: [] as string[], + has_been_opened: undefined, +} as const; + +export const getSavedWorkflowPickerOwnedQueryArg = (query = '') => ({ + ...baseSavedWorkflowPickerQueryArg, + query, + categories: ['user', 'default'] as ('user' | 'default')[], + is_public: undefined, +}); + +export const getSavedWorkflowPickerSharedQueryArg = (query = '') => ({ + ...baseSavedWorkflowPickerQueryArg, + query, + categories: ['user'] as ('user' | 'default')[], + is_public: true, +}); + +export const mergeSavedWorkflowPickerItems = ( + ...workflowLists: WorkflowRecordListItemWithThumbnailDTO[][] +): WorkflowRecordListItemWithThumbnailDTO[] => { + const workflowsById = new Map(); + for (const workflows of workflowLists) { + for (const workflow of workflows) { + if (!workflowsById.has(workflow.workflow_id)) { + workflowsById.set(workflow.workflow_id, workflow); + } + } + } + return Array.from(workflowsById.values()); +}; + +export const shouldFetchNextSavedWorkflowPickerPage = ({ + hasNextPage, + isFetching, +}: { + hasNextPage: boolean; + isFetching: boolean; +}) => hasNextPage && !isFetching; + +export const getSavedWorkflowSelectionState = ( + workflows: WorkflowRecordListItemWithThumbnailDTO[], + workflowId: string, + selectedWorkflow?: WorkflowRecordListItemWithThumbnailDTO +): SavedWorkflowSelectionState => { + if (!workflowId) { + return { status: 'unselected' }; + } + + const workflow = workflows.find((workflow) => workflow.workflow_id === workflowId); + if (workflow) { + return { status: 'selected', workflow }; + } + + if (selectedWorkflow?.workflow_id === workflowId) { + return { status: 'selected', workflow: selectedWorkflow }; + } + + return { status: 'missing', workflowId }; +}; + +export const getSavedWorkflowListItemFromRecord = ( + workflow: S['WorkflowRecordWithThumbnailDTO'] +): WorkflowRecordListItemWithThumbnailDTO => ({ + workflow_id: workflow.workflow_id, + name: workflow.name, + created_at: workflow.created_at, + updated_at: workflow.updated_at, + opened_at: workflow.opened_at, + user_id: workflow.user_id, + is_public: workflow.is_public, + description: workflow.workflow.description, + category: workflow.workflow.meta.category, + tags: workflow.workflow.tags, + thumbnail_url: workflow.thumbnail_url, + call_saved_workflow_compatibility: workflow.call_saved_workflow_compatibility, +}); + +export const getSavedWorkflowSelectionOption = (selectionState: SavedWorkflowSelectionState): ComboboxOption | null => { + if (selectionState.status === 'unselected') { + return null; + } + + if (selectionState.status === 'selected') { + return { + label: selectionState.workflow.name, + value: selectionState.workflow.workflow_id, + }; + } + + return { + label: MISSING_WORKFLOW_OPTION_VALUE, + value: MISSING_WORKFLOW_OPTION_VALUE, + }; +}; + +type SavedWorkflowDisplayState = + | { + selection: 'unselected' | 'missing'; + statusLabelKey: 'nodes.savedWorkflowChoose' | 'nodes.savedWorkflowMissing'; + badges: SavedWorkflowBadge[]; + compatibilityMessage: null; + } + | { + selection: 'selected'; + statusLabelKey: null; + badges: SavedWorkflowBadge[]; + compatibilityMessage: string | null; + }; + +export const getSavedWorkflowDisplayState = ( + selectionState: SavedWorkflowSelectionState +): SavedWorkflowDisplayState => { + if (selectionState.status === 'unselected') { + return { + selection: 'unselected', + statusLabelKey: 'nodes.savedWorkflowChoose', + badges: [], + compatibilityMessage: null, + }; + } + + if (selectionState.status === 'missing') { + return { + selection: 'missing', + statusLabelKey: 'nodes.savedWorkflowMissing', + badges: [], + compatibilityMessage: null, + }; + } + + const compatibilityState = getWorkflowCallCompatibilityState(selectionState.workflow); + const badges: SavedWorkflowBadge[] = []; + if (compatibilityState.isUnsupported) { + badges.push('unsupported'); + } + if (selectionState.workflow.category === 'default') { + badges.push('default'); + } else if (selectionState.workflow.is_public) { + badges.push('shared'); + } + + return { + selection: 'selected', + statusLabelKey: null, + badges, + compatibilityMessage: compatibilityState.message, + }; +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx index 3291d75f59e..5c8dfd5dd45 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowListItem.tsx @@ -5,6 +5,7 @@ import { selectCurrentUser } from 'features/auth/store/authSlice'; import { selectWorkflowId } from 'features/nodes/store/selectors'; import { workflowModeChanged } from 'features/nodes/store/workflowLibrarySlice'; import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog'; +import { getWorkflowLibraryListItemState } from 'features/workflowLibrary/util/workflowLibraryListItemState'; import InvokeLogo from 'public/assets/images/invoke-symbol-wht-lrg.svg'; import { type ChangeEvent, memo, type MouseEvent, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -65,6 +66,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi .map((tag) => tag.trim()) .filter((tag) => tag.length > 0); }, [workflow.tags]); + const listItemState = useMemo(() => getWorkflowLibraryListItemState(workflow), [workflow]); const handleClickLoad = useCallback(() => { loadWorkflowWithDialog({ @@ -119,7 +121,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi {t('workflows.opened')} )} - {setupStatus?.multiuser_enabled && workflow.is_public && workflow.category !== 'default' && ( + {setupStatus?.multiuser_enabled && listItemState.showSharedBadge && ( )} - {workflow.category === 'default' && ( + {listItemState.showUnsupportedBadge && ( + + + {t('nodes.savedWorkflowUnsupported')} + + + )} + {listItemState.showDefaultIcon && ( invoke-logo {workflow.description} + {listItemState.showUnsupportedBadge && ( + + {listItemState.unsupportedMessage ?? t('workflows.savedWorkflowUnsupportedDescription')} + + )} {tags.length > 0 && ( diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts index f331692b43a..6c1d8ea9709 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts @@ -1,10 +1,48 @@ import { deepClone } from 'common/util/deepClone'; +import type { IntegerFieldInputTemplate, StringFieldInputTemplate } from 'features/nodes/types/field'; import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; import { describe, expect, it } from 'vitest'; -import { connectorInserted, nodesChanged, nodesSliceConfig } from './nodesSlice'; +import { + callSavedWorkflowDynamicFieldsChanged, + connectorInserted, + fieldIntegerValueChanged, + nodesChanged, + nodesSliceConfig, +} from './nodesSlice'; import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from './util/connectorTopology'; -import { add, buildEdge, buildNode, sub } from './util/testUtils'; +import { add, buildEdge, buildNode, sub, templates } from './util/testUtils'; + +const callSavedWorkflowTemplate = templates.call_saved_workflow; +const addTemplate = templates.add; + +if (!callSavedWorkflowTemplate || !addTemplate || !addTemplate.inputs.a) { + throw new Error('Expected saved workflow and add templates'); +} +const addIntegerInputTemplate = addTemplate.inputs.a as IntegerFieldInputTemplate; + +const buildDynamicIntegerTemplate = (fieldName: string): IntegerFieldInputTemplate => ({ + ...addIntegerInputTemplate, + name: fieldName, + title: 'Left Addend', + input: 'any' as const, +}); + +const buildDynamicStringTemplate = (fieldName: string): StringFieldInputTemplate => ({ + name: fieldName, + title: 'Prompt', + required: false, + description: 'Prompt text', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + default: 'new default', + type: { + name: 'StringField', + cardinality: 'SINGLE', + batch: false, + }, +}); const buildFixedConnectorNode = (id: string) => { const connectorNode = buildConnectorNode({ x: 0, y: 0 }); @@ -18,6 +56,226 @@ const buildFixedConnectorNode = (id: string) => { }; }; +describe('callSavedWorkflowDynamicFieldsChanged', () => { + it('seeds new dynamic fields with the source workflow values', () => { + const state = nodesSliceConfig.getInitialState(); + const node = buildNode(callSavedWorkflowTemplate); + state.nodes.push(node); + + const nextState = nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [ + { + fieldName: 'saved_workflow_input::node-1::a', + fieldTemplate: buildDynamicIntegerTemplate('saved_workflow_input::node-1::a'), + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ); + + const dynamicField = nextState.nodes[0]; + if (!dynamicField || dynamicField.type !== 'invocation') { + throw new Error('Expected invocation node'); + } + + expect(dynamicField.data.inputs['saved_workflow_input::node-1::a']?.value).toBe(23); + expect(dynamicField.data.inputs['saved_workflow_input::node-1::a']?.label).toBe('Left Addend'); + expect(dynamicField.data.dynamicInputTemplates['saved_workflow_input::node-1::a']?.title).toBe('Left Addend'); + }); + + it('preserves existing dynamic field values on resync', () => { + const state = nodesSliceConfig.getInitialState(); + const node = buildNode(callSavedWorkflowTemplate); + state.nodes.push(node); + + const fieldName = 'saved_workflow_input::node-1::a'; + + let nextState = nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [ + { + fieldName, + fieldTemplate: buildDynamicIntegerTemplate(fieldName), + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ); + + nextState = nodesSliceConfig.slice.reducer( + nextState, + fieldIntegerValueChanged({ + nodeId: node.id, + fieldName, + value: 99, + }) + ); + + nextState = nodesSliceConfig.slice.reducer( + nextState, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [ + { + fieldName, + fieldTemplate: buildDynamicIntegerTemplate(fieldName), + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ); + + const resyncedNode = nextState.nodes[0]; + if (!resyncedNode || resyncedNode.type !== 'invocation') { + throw new Error('Expected invocation node'); + } + + expect(resyncedNode.data.inputs[fieldName]?.value).toBe(99); + expect(resyncedNode.data.dynamicInputTemplates[fieldName]?.name).toBe(fieldName); + }); + + it('resets an existing dynamic field value when the exposed field type changes', () => { + const state = nodesSliceConfig.getInitialState(); + const node = buildNode(callSavedWorkflowTemplate); + state.nodes.push(node); + + const fieldName = 'saved_workflow_input::node-1::a'; + + let nextState = nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [ + { + fieldName, + fieldTemplate: buildDynamicIntegerTemplate(fieldName), + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ); + + nextState = nodesSliceConfig.slice.reducer( + nextState, + fieldIntegerValueChanged({ + nodeId: node.id, + fieldName, + value: 99, + }) + ); + + nextState = nodesSliceConfig.slice.reducer( + nextState, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [ + { + fieldName, + fieldTemplate: buildDynamicStringTemplate(fieldName), + label: 'Prompt', + description: 'Prompt text', + initialValue: 'new default', + }, + ], + edgeIdsToRemove: [], + }) + ); + + const resyncedNode = nextState.nodes[0]; + if (!resyncedNode || resyncedNode.type !== 'invocation') { + throw new Error('Expected invocation node'); + } + + expect(resyncedNode.data.inputs[fieldName]?.value).toBe('new default'); + expect(resyncedNode.data.dynamicInputTemplates[fieldName]?.type.name).toBe('StringField'); + }); + + it('removes stale dynamic field templates when the selected workflow fields change', () => { + const state = nodesSliceConfig.getInitialState(); + const node = buildNode(callSavedWorkflowTemplate); + state.nodes.push(node); + + const fieldName = 'saved_workflow_input::node-1::a'; + + let nextState = nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [ + { + fieldName, + fieldTemplate: buildDynamicIntegerTemplate(fieldName), + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ); + + nextState = nodesSliceConfig.slice.reducer( + nextState, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [], + edgeIdsToRemove: [], + }) + ); + + const resyncedNode = nextState.nodes[0]; + if (!resyncedNode || resyncedNode.type !== 'invocation') { + throw new Error('Expected invocation node'); + } + + expect(resyncedNode.data.inputs[fieldName]).toBeUndefined(); + expect(resyncedNode.data.dynamicInputTemplates[fieldName]).toBeUndefined(); + }); + + it('removes specified inbound edges during dynamic field resync', () => { + const state = nodesSliceConfig.getInitialState(); + const sourceNode = buildNode(addTemplate); + const targetNode = buildNode(callSavedWorkflowTemplate); + state.nodes.push(sourceNode, targetNode); + state.edges.push({ + id: 'edge-1', + type: 'default', + source: sourceNode.id, + sourceHandle: 'value', + target: targetNode.id, + targetHandle: 'saved_workflow_input::node-1::a', + }); + + const nextState = nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: targetNode.id, + fields: [], + edgeIdsToRemove: ['edge-1'], + }) + ); + + expect(nextState.edges).toHaveLength(0); + }); +}); + describe('nodesSlice connector actions', () => { it('removes an unconnected connector', () => { const connector = buildFixedConnectorNode('connector-1'); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 6713ee8fb42..9c6a9eea9f0 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -92,6 +92,7 @@ import { isNodeFieldElement, isTextElement, } from 'features/nodes/types/workflow'; +import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance'; import { atom, computed } from 'nanostores'; import type { MouseEvent } from 'react'; import type { UndoableOptions } from 'redux-undo'; @@ -100,6 +101,8 @@ import type { z } from 'zod'; import type { PendingConnection, Templates } from './types'; +const CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX = 'saved_workflow_input::'; + export const getInitialWorkflow = (): Omit => { return { name: '', @@ -550,6 +553,69 @@ const slice = createSlice({ } field.description = val || ''; }, + callSavedWorkflowDynamicFieldsChanged: ( + state, + action: PayloadAction<{ + nodeId: string; + fields: Array<{ + fieldName: string; + fieldTemplate: Parameters[1]; + label: string; + description: string; + initialValue: StatefulFieldValue; + }>; + edgeIdsToRemove: string[]; + }> + ) => { + const { nodeId, fields, edgeIdsToRemove } = action.payload; + const node = state.nodes.find((n) => n.id === nodeId); + if (!isInvocationNode(node) || node.data.type !== 'call_saved_workflow') { + return; + } + + const nextFieldNames = new Set(fields.map((field) => field.fieldName)); + + for (const fieldName of Object.keys(node.data.inputs)) { + if (fieldName.startsWith(CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX) && !nextFieldNames.has(fieldName)) { + delete node.data.inputs[fieldName]; + delete node.data.dynamicInputTemplates[fieldName]; + } + } + + for (const { fieldName, fieldTemplate, label, description, initialValue } of fields) { + const existingTemplate = node.data.dynamicInputTemplates[fieldName]; + node.data.dynamicInputTemplates[fieldName] = fieldTemplate; + const existing = node.data.inputs[fieldName]; + if (existing) { + if ( + existingTemplate?.type.name !== fieldTemplate.type.name || + existingTemplate?.type.cardinality !== fieldTemplate.type.cardinality || + existingTemplate?.type.batch !== fieldTemplate.type.batch + ) { + const instance = buildFieldInputInstance(fieldName, fieldTemplate); + instance.label = label; + instance.description = description; + instance.value = initialValue; + node.data.inputs[fieldName] = instance; + continue; + } + existing.label = label; + existing.description = description; + continue; + } + + const instance = buildFieldInputInstance(fieldName, fieldTemplate); + instance.label = label; + instance.description = description; + instance.value = initialValue; + node.data.inputs[fieldName] = instance; + } + + if (edgeIdsToRemove.length > 0) { + const edgeIdsToRemoveSet = new Set(edgeIdsToRemove); + state.edges = state.edges.filter((edge) => !edgeIdsToRemoveSet.has(edge.id)); + } + }, notesNodeValueChanged: (state, action: PayloadAction<{ nodeId: string; value: string }>) => { const { nodeId, value } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); @@ -680,6 +746,7 @@ export const { fieldStringGeneratorValueChanged, fieldImageGeneratorValueChanged, fieldDescriptionChanged, + callSavedWorkflowDynamicFieldsChanged, nodeEditorReset, nodeIsIntermediateChanged, nodeIsOpenChanged, @@ -925,3 +992,5 @@ export const getFormFieldInitialValues = (form: BuilderForm, nodes: NodesState[' return formFieldInitialValues; }; + +export { CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.test.ts new file mode 100644 index 00000000000..dacc1434466 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.test.ts @@ -0,0 +1,53 @@ +import { callSavedWorkflowDynamicFieldsChanged, nodesSliceConfig } from 'features/nodes/store/nodesSlice'; +import { buildNode, templates } from 'features/nodes/store/util/testUtils'; +import type { IntegerFieldInputTemplate } from 'features/nodes/types/field'; +import { describe, expect, it } from 'vitest'; + +import { getInvocationNodeErrors } from './fieldValidators'; + +const callSavedWorkflowTemplate = templates.call_saved_workflow; +const addTemplate = templates.add; + +if (!callSavedWorkflowTemplate || !addTemplate || !addTemplate.inputs.a) { + throw new Error('Expected saved workflow and add templates'); +} + +const addIntegerInputTemplate = addTemplate.inputs.a as IntegerFieldInputTemplate; + +const buildDynamicIntegerTemplate = (fieldName: string): IntegerFieldInputTemplate => ({ + ...addIntegerInputTemplate, + name: fieldName, + title: 'Left Addend', + input: 'any', +}); + +describe('getInvocationNodeErrors', () => { + it('does not report missing field templates for dynamic saved workflow inputs', () => { + const state = nodesSliceConfig.getInitialState(); + const node = buildNode(callSavedWorkflowTemplate); + state.nodes.push(node); + + const nextState = nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [ + { + fieldName: 'saved_workflow_input::node-1::a', + fieldTemplate: buildDynamicIntegerTemplate('saved_workflow_input::node-1::a'), + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ); + + const errors = getInvocationNodeErrors(node.id, templates, nextState); + + expect( + errors.find((error) => error.type === 'node-error' && error.issue === 'parameters.invoke.missingFieldTemplate') + ).toBeUndefined(); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts b/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts index 85738b357c9..b3ac3424fdf 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts @@ -36,7 +36,12 @@ import { isStringFieldCollectionInputInstance, isStringFieldCollectionInputTemplate, } from 'features/nodes/types/field'; -import { type InvocationNode, type InvocationTemplate, isInvocationNode } from 'features/nodes/types/invocation'; +import { + getInvocationNodeInputTemplate, + type InvocationNode, + type InvocationTemplate, + isInvocationNode, +} from 'features/nodes/types/invocation'; import { t } from 'i18next'; import { map } from 'nanostores'; import { useEffect } from 'react'; @@ -272,7 +277,7 @@ export const getInvocationNodeErrors = ( } for (const [fieldName, field] of Object.entries(node.data.inputs)) { - const fieldTemplate = nodeTemplate.inputs[fieldName]; + const fieldTemplate = getInvocationNodeInputTemplate(node.data, nodeTemplate, fieldName); if (!fieldTemplate) { errors.push({ type: 'node-error', nodeId, issue: t('parameters.invoke.missingFieldTemplate') }); @@ -307,7 +312,7 @@ const syncNodeErrors = (nodesState: NodesState, templates: Templates) => { } for (const [fieldName, field] of Object.entries(node.data.inputs)) { - const fieldTemplate = nodeTemplate.inputs[fieldName]; + const fieldTemplate = getInvocationNodeInputTemplate(node.data, nodeTemplate, fieldName); if (!fieldTemplate) { errors.push({ type: 'node-error', nodeId: node.id, issue: t('parameters.invoke.missingFieldTemplate') }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 2a466aae444..17b068ad0c5 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -9,7 +9,7 @@ import { import { validateConnection } from 'features/nodes/store/util/validateConnection'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; -import { isConnectorNode } from 'features/nodes/types/invocation'; +import { getInvocationNodeInputTemplate, isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; /** * @@ -118,6 +118,10 @@ export const getTargetCandidateFields = ( return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : []; } + if (!isInvocationNode(targetNode)) { + return []; + } + const targetTemplate = templates[targetNode.data.type]; if (!targetTemplate) { return []; @@ -136,10 +140,14 @@ export const getTargetCandidateFields = ( } } - const targetCandidateFields = map(targetTemplate.inputs).filter((field) => { + const targetCandidateFields = Object.entries(targetNode.data.inputs).flatMap(([fieldName, input]) => { + const field = getInvocationNodeInputTemplate(targetNode.data, targetTemplate, fieldName); + if (!field || field.name !== input.name) { + return []; + } const c = { source, sourceHandle, target, targetHandle: field.name }; const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); - return connectionErrorTKey === null; + return connectionErrorTKey === null ? [field] : []; }); return targetCandidateFields; @@ -162,8 +170,11 @@ export const getSourceCandidateFields = ( if (isConnectorNode(sourceNode)) { const sourceFieldType = resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates); - const targetTemplate = !isConnectorNode(targetNode) ? templates[targetNode.data.type] : null; - const targetFieldType = targetTemplate?.inputs[targetHandle]?.type; + const targetTemplate = isInvocationNode(targetNode) ? templates[targetNode.data.type] : null; + const targetFieldType = + isInvocationNode(targetNode) && targetTemplate + ? getInvocationNodeInputTemplate(targetNode.data, targetTemplate, targetHandle)?.type + : undefined; const candidateType = sourceFieldType ?? targetFieldType; if (!candidateType) { return []; @@ -188,13 +199,16 @@ export const getSourceCandidateFields = ( } if (!isConnectorNode(targetNode)) { + if (!isInvocationNode(targetNode)) { + return []; + } + const targetTemplate = templates[targetNode.data.type]; if (!targetTemplate) { return []; } - const targetField = targetTemplate.inputs[targetHandle]; - + const targetField = getInvocationNodeInputTemplate(targetNode.data, targetTemplate, targetHandle); if (!targetField) { return []; } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index 8706e199bbe..67c477408f3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -73,6 +73,102 @@ export const add: InvocationTemplate = { category: 'math', }; +export const call_saved_workflow: InvocationTemplate = { + title: 'Call Saved Workflow', + type: 'call_saved_workflow', + version: '1.0.0', + tags: ['workflow', 'saved', 'library'], + description: 'Displays and later executes against a selected saved workflow.', + category: 'workflow', + outputType: 'integer_output', + inputs: { + workflow_id: { + name: 'workflow_id', + title: 'Workflow Id', + required: false, + description: 'The selected saved workflow ID, managed by the workflow editor UI.', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + ui_type: 'SavedWorkflowField', + type: { + name: 'SavedWorkflowField', + cardinality: 'SINGLE', + batch: false, + originalType: { + name: 'StringField', + cardinality: 'SINGLE', + batch: false, + }, + }, + default: '', + }, + }, + outputs: { + value: { + fieldKind: 'output', + name: 'value', + title: 'Value', + description: 'The output integer', + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + batch: false, + }, + ui_hidden: false, + }, + }, + useCache: false, + nodePack: 'invokeai', + classification: 'beta', +}; + +export const workflow_return: InvocationTemplate = { + title: 'Workflow Return', + type: 'workflow_return', + version: '1.0.0', + tags: ['workflow', 'return', 'output'], + description: 'Defines the explicit collection result returned by a callable workflow.', + category: 'workflow', + outputType: 'workflow_return_output', + inputs: { + collection: { + name: 'collection', + title: 'Collection', + required: false, + description: 'The collection returned to a calling workflow.', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + ui_type: 'CollectionField', + type: { + name: 'CollectionField', + cardinality: 'COLLECTION', + batch: false, + }, + default: undefined, + }, + }, + outputs: { + collection: { + fieldKind: 'output', + name: 'collection', + title: 'Collection', + description: 'The workflow return collection', + type: { + name: 'CollectionField', + cardinality: 'COLLECTION', + batch: false, + }, + ui_hidden: false, + ui_type: 'CollectionField', + }, + }, + useCache: false, + nodePack: 'invokeai', + classification: 'beta', +}; + export const sub: InvocationTemplate = { title: 'Subtract Integers', type: 'sub', @@ -537,6 +633,8 @@ const iterate: InvocationTemplate = { export const templates: Templates = { add, + call_saved_workflow, + workflow_return, sub, collect, iterate, @@ -554,6 +652,146 @@ export const schema = { }, components: { schemas: { + CallSavedWorkflowInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: false, + field_kind: 'node_attribute', + }, + workflow_id: { + type: 'string', + title: 'Workflow Id', + description: 'The selected saved workflow ID, managed by the workflow editor UI.', + default: '', + field_kind: 'input', + input: 'any', + orig_default: '', + orig_required: false, + ui_hidden: false, + ui_type: 'SavedWorkflowField', + }, + type: { + type: 'string', + enum: ['call_saved_workflow'], + const: 'call_saved_workflow', + title: 'type', + default: 'call_saved_workflow', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Call Saved Workflow', + description: 'Displays and later executes against a selected saved workflow.', + category: 'workflow', + classification: 'beta', + node_pack: 'invokeai', + tags: ['workflow', 'saved', 'library'], + version: '1.0.0', + output: { + $ref: '#/components/schemas/IntegerOutput', + }, + class: 'invocation', + }, + WorkflowReturnInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: false, + field_kind: 'node_attribute', + }, + collection: { + type: 'array', + items: {}, + title: 'Collection', + description: 'The collection returned to a calling workflow.', + field_kind: 'input', + input: 'connection', + orig_required: false, + ui_hidden: false, + ui_type: 'CollectionField', + }, + type: { + type: 'string', + enum: ['workflow_return'], + const: 'workflow_return', + title: 'type', + default: 'workflow_return', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Workflow Return', + description: 'Defines the explicit collection result returned by a callable workflow.', + category: 'workflow', + classification: 'beta', + node_pack: 'invokeai', + tags: ['workflow', 'return', 'output'], + version: '1.0.0', + output: { + $ref: '#/components/schemas/WorkflowReturnOutput', + }, + class: 'invocation', + }, + WorkflowReturnOutput: { + properties: { + type: { + type: 'string', + enum: ['workflow_return_output'], + const: 'workflow_return_output', + title: 'type', + default: 'workflow_return_output', + field_kind: 'node_attribute', + }, + collection: { + type: 'array', + items: {}, + title: 'Collection', + description: 'The workflow return collection', + field_kind: 'output', + ui_hidden: false, + ui_type: 'CollectionField', + }, + }, + type: 'object', + required: ['type', 'collection'], + title: 'Workflow Return Output', + class: 'output', + }, AddInvocation: { properties: { id: { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 730dced1d35..1eef0794436 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -1,5 +1,7 @@ import { deepClone } from 'common/util/deepClone'; import { set } from 'es-toolkit/compat'; +import { callSavedWorkflowDynamicFieldsChanged, nodesSliceConfig } from 'features/nodes/store/nodesSlice'; +import type { IntegerFieldInputTemplate } from 'features/nodes/types/field'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { describe, expect, it } from 'vitest'; @@ -8,7 +10,17 @@ import { CONNECTOR_OUTPUT_HANDLE, getConnectorDeletionSpliceConnections, } from './connectorTopology'; -import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils'; +import { + add, + buildEdge, + buildNode, + call_saved_workflow, + collect, + img_resize, + main_model_loader, + sub, + templates, +} from './testUtils'; import { validateConnection } from './validateConnection'; const ifTemplate: InvocationTemplate = { @@ -147,6 +159,66 @@ const integerCollectionOutputTemplate: InvocationTemplate = { classification: 'stable', }; +const workflowReturnValueTemplate: InvocationTemplate = { + title: 'Workflow Return Value', + type: 'workflow_return_value', + version: '1.0.0', + tags: ['workflow', 'return', 'output'], + category: 'workflow', + description: 'Creates one named value for a callable workflow return.', + outputType: 'workflow_return_value_output', + inputs: {}, + outputs: { + value: { + fieldKind: 'output', + name: 'value', + title: 'Return Value', + description: 'The named workflow return value.', + type: { + name: 'CollectionItemField', + cardinality: 'SINGLE', + batch: false, + }, + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + }, + useCache: false, + nodePack: 'invokeai', + classification: 'beta', +}; + +const workflowReturnTemplate: InvocationTemplate = { + title: 'Workflow Return', + type: 'workflow_return', + version: '1.0.0', + tags: ['workflow', 'return', 'output'], + category: 'workflow', + description: 'Defines the explicit named result returned by a callable workflow.', + outputType: 'workflow_return_output', + inputs: { + values: { + name: 'values', + title: 'Values', + required: false, + description: 'The named values returned to a calling workflow.', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + type: { + name: 'WorkflowReturnValueField', + cardinality: 'SINGLE_OR_COLLECTION', + batch: false, + }, + default: undefined, + }, + }, + outputs: {}, + useCache: false, + nodePack: 'invokeai', + classification: 'beta', +}; + const buildConnectorNode = (id: string) => ({ id, type: 'connector' as const, @@ -217,6 +289,64 @@ describe(validateConnection.name, () => { }); }); + it('accepts connections to dynamic saved workflow input fields', () => { + const addIntegerInputTemplate = add.inputs.a as IntegerFieldInputTemplate; + const state = nodesSliceConfig.getInitialState(); + const sourceNode = buildNode(add); + const targetNode = buildNode(call_saved_workflow); + state.nodes.push(sourceNode, targetNode); + + const nextState = nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: targetNode.id, + fields: [ + { + fieldName: 'saved_workflow_input::node-1::a', + fieldTemplate: { + ...addIntegerInputTemplate, + name: 'saved_workflow_input::node-1::a', + title: 'Left Addend', + input: 'any', + }, + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ); + + const c = { + source: sourceNode.id, + sourceHandle: 'value', + target: targetNode.id, + targetHandle: 'saved_workflow_input::node-1::a', + }; + const r = validateConnection(c, nextState.nodes, [], templates, null); + expect(r).toEqual(null); + }); + + it('accepts a single workflow return value connected directly to workflow_return values', () => { + const sourceNode = buildNode(workflowReturnValueTemplate); + const targetNode = buildNode(workflowReturnTemplate); + const c = { source: sourceNode.id, sourceHandle: 'value', target: targetNode.id, targetHandle: 'values' }; + + const r = validateConnection( + c, + [sourceNode, targetNode], + [], + { + workflow_return_value: workflowReturnValueTemplate, + workflow_return: workflowReturnTemplate, + }, + null + ); + + expect(r).toEqual(null); + }); + describe('duplicate connections', () => { const n1 = buildNode(add); const n2 = buildNode(sub); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index bb98d472d31..710e49de7cf 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -11,7 +11,7 @@ import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { FieldType } from 'features/nodes/types/field'; import type { AnyEdge, AnyNode, InvocationNode } from 'features/nodes/types/invocation'; -import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; +import { getInvocationNodeInputTemplate, isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; import type { SetNonNullable } from 'type-fest'; type Connection = SetNonNullable; @@ -291,7 +291,7 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.missingInvocationTemplate'; } - const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; + const targetFieldTemplate = getInvocationNodeInputTemplate(targetNode.data, targetTemplate, c.targetHandle); if (!targetFieldTemplate) { return 'nodes.missingFieldTemplate'; } diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index ffd87ae3984..df2f909dc7c 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -188,6 +188,10 @@ const zSchedulerFieldType = zFieldTypeBase.extend({ name: z.literal('SchedulerField'), originalType: zStatelessFieldType.optional(), }); +const zSavedWorkflowFieldType = zFieldTypeBase.extend({ + name: z.literal('SavedWorkflowField'), + originalType: zStatelessFieldType.optional(), +}); const zFloatGeneratorFieldType = zFieldTypeBase.extend({ name: z.literal('FloatGeneratorField'), originalType: zStatelessFieldType.optional(), @@ -216,6 +220,7 @@ const zStatefulFieldType = z.union([ zModelIdentifierFieldType, zColorFieldType, zSchedulerFieldType, + zSavedWorkflowFieldType, zFloatGeneratorFieldType, zIntegerGeneratorFieldType, zStringGeneratorFieldType, @@ -697,6 +702,28 @@ export const isSchedulerFieldInputInstance = buildInstanceTypeGuard(zSchedulerFi export const isSchedulerFieldInputTemplate = buildTemplateTypeGuard('SchedulerField'); // #endregion +// #region SavedWorkflowField +const zSavedWorkflowFieldValue = z.string(); +const zSavedWorkflowFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zSavedWorkflowFieldValue, +}); +const zSavedWorkflowFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSavedWorkflowFieldType, + originalType: zFieldType.optional(), + default: zSavedWorkflowFieldValue, + maxLength: z.number().int().gte(0).optional(), + minLength: z.number().int().gte(0).optional(), +}); +const zSavedWorkflowFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSavedWorkflowFieldType, +}); +export type SavedWorkflowFieldInputInstance = z.infer; +export type SavedWorkflowFieldInputTemplate = z.infer; +export const isSavedWorkflowFieldInputInstance = buildInstanceTypeGuard(zSavedWorkflowFieldInputInstance); +export const isSavedWorkflowFieldInputTemplate = + buildTemplateTypeGuard('SavedWorkflowField'); +// #endregion + // #region FloatGeneratorField export const FloatGeneratorArithmeticSequenceType = 'float_generator_arithmetic_sequence'; const zFloatGeneratorArithmeticSequence = z.object({ @@ -1290,6 +1317,7 @@ export const zStatefulFieldValue = z.union([ zModelIdentifierFieldValue, zColorFieldValue, zSchedulerFieldValue, + zSavedWorkflowFieldValue, zFloatGeneratorFieldValue, zIntegerGeneratorFieldValue, zStringGeneratorFieldValue, @@ -1318,6 +1346,7 @@ const zStatefulFieldInputInstance = z.union([ zModelIdentifierFieldInputInstance, zColorFieldInputInstance, zSchedulerFieldInputInstance, + zSavedWorkflowFieldInputInstance, zFloatGeneratorFieldInputInstance, zIntegerGeneratorFieldInputInstance, zStringGeneratorFieldInputInstance, @@ -1345,7 +1374,7 @@ const zStatefulFieldInputTemplate = z.union([ zModelIdentifierFieldInputTemplate, zColorFieldInputTemplate, zSchedulerFieldInputTemplate, - zStatelessFieldInputTemplate, + zSavedWorkflowFieldInputTemplate, zFloatGeneratorFieldInputTemplate, zIntegerGeneratorFieldInputTemplate, zStringGeneratorFieldInputTemplate, @@ -1373,6 +1402,7 @@ const zStatefulFieldOutputTemplate = z.union([ zModelIdentifierFieldOutputTemplate, zColorFieldOutputTemplate, zSchedulerFieldOutputTemplate, + zSavedWorkflowFieldOutputTemplate, zFloatGeneratorFieldOutputTemplate, zIntegerGeneratorFieldOutputTemplate, zStringGeneratorFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 5d8d85dd87f..5ceb30f7794 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -32,6 +32,7 @@ export const zInvocationNodeData = z.object({ notes: z.string(), type: z.string().trim().min(1), inputs: z.record(z.string(), zFieldInputInstance), + dynamicInputTemplates: z.record(z.string(), zFieldInputTemplate).default({}), isOpen: z.boolean(), isIntermediate: z.boolean(), useCache: z.boolean(), @@ -162,3 +163,28 @@ const isGeneratorNode = (node: InvocationNode) => isGeneratorNodeType(node.data. export const isExecutableNode = (node: InvocationNode) => { return !isBatchNode(node) && !isGeneratorNode(node); }; + +export const getInvocationNodeInputTemplate = ( + nodeData: Pick & Partial>, + template: InvocationTemplate, + fieldName: string +) => { + return nodeData.dynamicInputTemplates?.[fieldName] ?? template.inputs[fieldName]; +}; + +export const getInvocationNodeTemplateWithDynamicInputs = ( + nodeData: Pick & Partial>, + template: InvocationTemplate +): InvocationTemplate => { + if (!nodeData.dynamicInputTemplates || Object.keys(nodeData.dynamicInputTemplates).length === 0) { + return template; + } + + return { + ...template, + inputs: { + ...template.inputs, + ...nodeData.dynamicInputTemplates, + }, + }; +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 34f98eb289f..bfb7b92b18f 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -369,24 +369,38 @@ const zValidatedBuilderForm = zBuilderForm //# endregion // #region Workflow -export const zWorkflowV3 = z.object({ - id: z.string().min(1).optional(), - name: z.string(), - author: z.string(), - description: z.string(), - version: z.string(), - contact: z.string(), - tags: z.string(), - notes: z.string(), - nodes: z.array(zWorkflowNode), - edges: z.array(zWorkflowEdge), - exposedFields: z.array(zFieldIdentifier), - meta: z.object({ - category: zWorkflowCategory.default('user'), - version: z.literal('4.0.0'), - }), - // Use the validated form schema! - form: zValidatedBuilderForm, -}); +export const zWorkflowV3 = z + .object({ + id: z.string().min(1).optional(), + name: z.string(), + author: z.string(), + description: z.string(), + version: z.string(), + contact: z.string(), + tags: z.string(), + notes: z.string(), + nodes: z.array(zWorkflowNode), + edges: z.array(zWorkflowEdge), + exposedFields: z.array(zFieldIdentifier), + meta: z.object({ + category: zWorkflowCategory.default('user'), + version: z.literal('4.0.0'), + }), + // Use the validated form schema! + form: zValidatedBuilderForm, + }) + .superRefine((workflow, ctx) => { + const workflowReturnCount = workflow.nodes.filter( + (node) => node.type === 'invocation' && node.data.type === 'workflow_return' + ).length; + + if (workflowReturnCount > 1) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: 'A workflow may not contain more than one workflow_return node.', + path: ['nodes'], + }); + } + }); export type WorkflowV3 = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts index b44c6b38cd8..046487e40e7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts @@ -1,10 +1,28 @@ import { deepClone } from 'common/util/deepClone'; +import { callSavedWorkflowDynamicFieldsChanged, nodesSliceConfig } from 'features/nodes/store/nodesSlice'; import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils'; +import type { IntegerFieldInputTemplate } from 'features/nodes/types/field'; import { describe, expect, it } from 'vitest'; import { buildNodesGraph } from './buildNodesGraph'; +const callSavedWorkflowTemplate = templates.call_saved_workflow; +const addTemplate = templates.add; + +if (!callSavedWorkflowTemplate || !addTemplate || !addTemplate.inputs.a) { + throw new Error('Expected saved workflow and add templates'); +} + +const addIntegerInputTemplate = addTemplate.inputs.a as IntegerFieldInputTemplate; + +const buildDynamicIntegerTemplate = (fieldName: string): IntegerFieldInputTemplate => ({ + ...addIntegerInputTemplate, + name: fieldName, + title: 'Left Addend', + input: 'any', +}); + const buildConnectorNode = (id: string) => ({ id, type: 'connector' as const, @@ -54,6 +72,139 @@ const buildState = (nodes: unknown[], edges: unknown[]) => }) as unknown as Parameters[0]; describe('buildNodesGraph', () => { + it('serializes dynamic saved workflow inputs into workflow_inputs', () => { + const state = nodesSliceConfig.getInitialState(); + const node = buildNode(callSavedWorkflowTemplate); + state.nodes.push(node); + + const nextState = nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: node.id, + fields: [ + { + fieldName: 'saved_workflow_input::node-1::a', + fieldTemplate: buildDynamicIntegerTemplate('saved_workflow_input::node-1::a'), + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ); + + const rootState = { + nodes: { + past: [], + future: [], + present: nextState, + }, + gallery: { + autoAddBoardId: 'none', + }, + } as never; + + const graph = buildNodesGraph(rootState, templates); + + expect(graph.nodes[node.id]).toMatchObject({ + workflow_id: '', + workflow_inputs: { + ['saved_workflow_input::node-1::a']: 23, + }, + }); + }); + + it('omits connected dynamic saved workflow literal values from workflow_inputs while preserving the edge', () => { + const state = nodesSliceConfig.getInitialState(); + const sourceNode = buildNode(add); + const callNode = buildNode(callSavedWorkflowTemplate); + state.nodes.push(sourceNode, callNode); + + const nextState = deepClone( + nodesSliceConfig.slice.reducer( + state, + callSavedWorkflowDynamicFieldsChanged({ + nodeId: callNode.id, + fields: [ + { + fieldName: 'saved_workflow_input::node-1::a', + fieldTemplate: buildDynamicIntegerTemplate('saved_workflow_input::node-1::a'), + label: 'Left Addend', + description: 'The first number', + initialValue: 23, + }, + ], + edgeIdsToRemove: [], + }) + ) + ); + + nextState.edges.push(buildEdge(sourceNode.id, 'value', callNode.id, 'saved_workflow_input::node-1::a')); + + const rootState = { + nodes: { + past: [], + future: [], + present: nextState, + }, + gallery: { + autoAddBoardId: 'none', + }, + } as never; + + const graph = buildNodesGraph(rootState, templates); + + expect(graph.nodes[callNode.id]).toMatchObject({ + workflow_id: '', + workflow_inputs: {}, + }); + expect(graph.edges).toContainEqual({ + source: { node_id: sourceNode.id, field: 'value' }, + destination: { node_id: callNode.id, field: 'saved_workflow_input::node-1::a' }, + }); + }); + + it('does not serialize stale hidden saved workflow input values without matching dynamic fields', () => { + const state = nodesSliceConfig.getInitialState(); + const node = buildNode(callSavedWorkflowTemplate); + node.data.inputs.workflow_inputs = { + name: 'workflow_inputs', + type: 'workflow_inputs', + value: { + ['saved_workflow_input::old-node::a']: 23, + }, + } as never; + state.nodes.push(node); + const templatesWithWorkflowInputs = { + ...templates, + call_saved_workflow: { + ...callSavedWorkflowTemplate, + inputs: { + ...callSavedWorkflowTemplate.inputs, + workflow_inputs: buildDynamicIntegerTemplate('workflow_inputs'), + }, + }, + }; + + const rootState = { + nodes: { + past: [], + future: [], + present: state, + }, + gallery: { + autoAddBoardId: 'none', + }, + } as never; + + const graph = buildNodesGraph(rootState, templatesWithWorkflowInputs); + const graphNode = graph.nodes[node.id] as { workflow_id: string; workflow_inputs: Record }; + + expect(graphNode.workflow_id).toBe(''); + expect(graphNode.workflow_inputs).toEqual({}); + }); + it('flattens a single connector to one direct execution edge', () => { const source = buildNode(add); const target = buildNode(sub); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts index 50052c806ce..eb2ce06d152 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts @@ -8,11 +8,17 @@ import { resolveConnectorSource } from 'features/nodes/store/util/connectorTopol import type { BoardField } from 'features/nodes/types/common'; import type { BoardFieldInputInstance } from 'features/nodes/types/field'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate } from 'features/nodes/types/field'; -import { isConnectorNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation'; +import { + getInvocationNodeInputTemplate, + isConnectorNode, + isExecutableNode, + isInvocationNode, +} from 'features/nodes/types/invocation'; import type { AnyInvocation, Graph } from 'services/api/types'; import { v4 as uuidv4 } from 'uuid'; const log = logger('workflows'); +const CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX = 'saved_workflow_input::'; const getBoardField = (field: BoardFieldInputInstance, state: RootState): BoardField | undefined => { // Translate the UI value to the graph value. See note in BoardFieldInputComponent for more info. @@ -59,7 +65,20 @@ export const buildNodesGraph = (state: RootState, templates: Templates): Require const transformedInputs = reduce( inputs, (inputsAccumulator, input, name) => { - const fieldTemplate = nodeTemplate.inputs[name]; + if (type === 'call_saved_workflow' && name === 'workflow_inputs') { + return inputsAccumulator; + } + + if (type === 'call_saved_workflow' && name.startsWith(CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX)) { + const workflowInputs = { + ...((inputsAccumulator['workflow_inputs'] as Record | undefined) ?? {}), + }; + workflowInputs[name] = input.value; + inputsAccumulator['workflow_inputs'] = workflowInputs; + return inputsAccumulator; + } + + const fieldTemplate = getInvocationNodeInputTemplate(data, nodeTemplate, name); if (!fieldTemplate) { log.warn({ id, name }, 'Field template not found!'); return inputsAccumulator; @@ -75,6 +94,10 @@ export const buildNodesGraph = (state: RootState, templates: Templates): Require {} as Record, unknown> ); + if (type === 'call_saved_workflow' && transformedInputs['workflow_inputs'] === undefined) { + transformedInputs['workflow_inputs'] = {}; + } + // add reserved use_cache transformedInputs['use_cache'] = node.data.useCache; @@ -181,7 +204,20 @@ export const buildNodesGraph = (state: RootState, templates: Templates): Require */ parsedEdges.forEach((edge) => { const destination_node = parsedNodes[edge.destination.node_id]; + if (!destination_node) { + return; + } const field = edge.destination.field; + const destinationNodeRecord = destination_node as Record; + if ( + destination_node.type === 'call_saved_workflow' && + field.startsWith(CALL_SAVED_WORKFLOW_DYNAMIC_FIELD_PREFIX) && + typeof destinationNodeRecord['workflow_inputs'] === 'object' && + destinationNodeRecord['workflow_inputs'] !== null + ) { + delete (destinationNodeRecord['workflow_inputs'] as Record)[field]; + return; + } parsedNodes[edge.destination.node_id] = omit(destination_node, field) as AnyInvocation; }); diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts index 859e978b0fa..2af152e6bb8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts @@ -40,6 +40,7 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe useCache: template.useCache, nodePack: template.nodePack, inputs, + dynamicInputTemplates: {}, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index ef7b92efdd8..6bb3155f476 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = IntegerField: 0, ModelIdentifierField: undefined, SchedulerField: 'dpmpp_3m_k', + SavedWorkflowField: '', StringField: '', StylePresetField: undefined, FloatGeneratorField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index adaa3f413ce..ca31b988ff7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -17,6 +17,7 @@ import type { IntegerFieldInputTemplate, IntegerGeneratorFieldInputTemplate, ModelIdentifierFieldInputTemplate, + SavedWorkflowFieldInputTemplate, SchedulerFieldInputTemplate, StatefulFieldType, StatelessFieldInputTemplate, @@ -408,6 +409,28 @@ const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: SavedWorkflowFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? '', + }; + + if (schemaObject.minLength !== undefined) { + template.minLength = schemaObject.minLength; + } + + if (schemaObject.maxLength !== undefined) { + template.maxLength = schemaObject.maxLength; + } + + return template; +}; + const buildFloatGeneratorFieldInputTemplate: FieldInputTemplateBuilder = ({ // schemaObject, baseField, @@ -474,6 +497,7 @@ const TEMPLATE_BUILDER_MAP: Record { const parsed = parseSchema(schema, ['add']); expect(stripUndefinedDeep(parsed)).toEqual(stripUndefinedDeep(pick(templates, 'add'))); }); + it('should parse the call_saved_workflow node template', () => { + const parsed = parseSchema(schema); + expect(stripUndefinedDeep(parsed.call_saved_workflow)).toEqual(stripUndefinedDeep(call_saved_workflow)); + const template = parsed.call_saved_workflow; + if (!template) { + throw new Error('Expected call_saved_workflow template'); + } + const workflowIdInput = template.inputs.workflow_id; + if (!workflowIdInput) { + throw new Error('Expected workflow_id input'); + } + expect(workflowIdInput.type.name).toBe('SavedWorkflowField'); + expect(workflowIdInput.ui_type).toBe('SavedWorkflowField'); + }); + it('should parse the workflow_return node template', () => { + const parsed = parseSchema(schema); + expect(stripUndefinedDeep(parsed.workflow_return)).toEqual(stripUndefinedDeep(workflow_return)); + const template = parsed.workflow_return; + if (!template) { + throw new Error('Expected workflow_return template'); + } + const collectionInput = template.inputs.collection; + if (!collectionInput) { + throw new Error('Expected collection input'); + } + expect(collectionInput.type.name).toBe('CollectionField'); + expect(collectionInput.ui_type).toBe('CollectionField'); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.test.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.test.ts new file mode 100644 index 00000000000..c7bdd80acea --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.test.ts @@ -0,0 +1,43 @@ +import { getInitialWorkflow } from 'features/nodes/store/nodesSlice'; +import { buildNode, call_saved_workflow } from 'features/nodes/store/util/testUtils'; +import { describe, expect, it } from 'vitest'; + +describe('buildWorkflowFast', () => { + it('persists the selected workflow id for call_saved_workflow nodes', async () => { + Object.assign(globalThis, { + window: { + location: { + origin: 'http://localhost', + }, + }, + }); + + const { buildWorkflowFast } = await import('features/nodes/util/workflow/buildWorkflow'); + const node = buildNode(call_saved_workflow); + const workflowIdInput = node.data.inputs.workflow_id; + if (!workflowIdInput) { + throw new Error('Expected workflow_id input'); + } + workflowIdInput.value = 'workflow-123'; + + const workflow = buildWorkflowFast({ + _version: 1, + formFieldInitialValues: {}, + ...getInitialWorkflow(), + nodes: [node], + edges: [], + }); + + expect(workflow.nodes).toHaveLength(1); + expect(workflow.nodes[0]?.type).toBe('invocation'); + if (workflow.nodes[0]?.type !== 'invocation') { + throw new Error('Expected invocation node'); + } + expect(workflow.nodes[0].data.type).toBe('call_saved_workflow'); + const serializedWorkflowIdInput = workflow.nodes[0].data.inputs.workflow_id; + if (!serializedWorkflowIdInput) { + throw new Error('Expected serialized workflow_id input'); + } + expect(serializedWorkflowIdInput.value).toBe('workflow-123'); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts index c09f4e1729d..28d07e9e170 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts @@ -84,6 +84,7 @@ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): Wor version: template.version, label: '', notes: '', + dynamicInputTemplates: {}, isOpen: true, isIntermediate: node.is_intermediate ?? false, useCache: node.use_cache ?? true, diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts index c1d08588315..28bc5915fcf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts @@ -1,10 +1,21 @@ import { get } from 'es-toolkit/compat'; import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; -import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils'; +import { img_resize, main_model_loader, workflow_return } from 'features/nodes/store/util/testUtils'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { getDefaultForm } from 'features/nodes/types/workflow'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'; -import { describe, expect, it } from 'vitest'; +import { describe, expect, it, vi } from 'vitest'; + +vi.mock('app/logging/logger', () => ({ + logger: () => ({ + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }), +})); //TODO(psyche): Test workflow validation for form builder fields describe('validateWorkflow', () => { @@ -40,6 +51,7 @@ describe('validateWorkflow', () => { version: '1.0.2', label: '', notes: '', + dynamicInputTemplates: {}, isOpen: true, isIntermediate: true, useCache: true, @@ -70,6 +82,7 @@ describe('validateWorkflow', () => { version: '1.2.2', label: '', notes: '', + dynamicInputTemplates: {}, isOpen: true, isIntermediate: true, useCache: true, @@ -140,6 +153,46 @@ describe('validateWorkflow', () => { expect(validationResult.warnings.length).toBe(1); expect(get(validationResult, 'workflow.nodes[0].data.inputs.model.value')).toBeUndefined(); }); + it('should reject workflows with duplicate workflow_return nodes at build time', async () => { + Object.assign(globalThis, { + window: { + location: { + origin: 'http://localhost', + }, + }, + }); + + const { buildWorkflowWithValidation } = await import('features/nodes/util/workflow/buildWorkflow'); + const returnNode1 = buildInvocationNode({ x: 0, y: 0 }, workflow_return); + const returnNode2 = buildInvocationNode({ x: 100, y: 0 }, workflow_return); + + const built = buildWorkflowWithValidation({ + _version: 1, + formFieldInitialValues: {}, + ...getWorkflow(), + nodes: [returnNode1, returnNode2], + edges: [], + }); + + expect(built).toBeNull(); + }); + it('should warn when loading a workflow with duplicate workflow_return nodes', async () => { + const returnNode1 = buildInvocationNode({ x: 0, y: 0 }, workflow_return); + const returnNode2 = buildInvocationNode({ x: 100, y: 0 }, workflow_return); + + await expect( + validateWorkflow({ + workflow: { + ...getWorkflow(), + nodes: [returnNode1, returnNode2], + }, + templates: { img_resize, main_model_loader, workflow_return }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }) + ).rejects.toThrow(/workflow_return/i); + }); it('should delete malformed connector edges with invalid handles', async () => { const workflow = getWorkflow(); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 448214defe3..5e69ac6da49 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -9,6 +9,7 @@ import { isModelFieldType, isModelIdentifierFieldInputInstance, } from 'features/nodes/types/field'; +import { getInvocationNodeInputTemplate } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { buildNodeFieldElement, @@ -97,7 +98,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise id === nodeId); - if (!node) { + if (!node || !isWorkflowInvocationNode(node)) { continue; } const nodeTemplate = templates[node.data.type]; if (!nodeTemplate) { continue; } - const fieldTemplate = nodeTemplate.inputs[fieldName]; + const fieldTemplate = getInvocationNodeInputTemplate(node.data, nodeTemplate, fieldName); if (!fieldTemplate) { continue; } diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx index e1c5f4ec973..8add8c00145 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx @@ -15,6 +15,7 @@ import { PiArrowCounterClockwiseBold, PiXBold } from 'react-icons/pi'; import type { S } from 'services/api/types'; import { COLUMN_WIDTHS, SYSTEM_USER_ID } from './constants'; +import { getQueueItemActionVisibility } from './getQueueItemActionVisibility'; import QueueItemDetail from './QueueItemDetail'; const selectedStyles = { bg: 'base.700' }; @@ -97,6 +98,7 @@ const QueueItemComponent = ({ index, item }: InnerItemProps) => { const isCanceled = useMemo(() => ['canceled', 'completed', 'failed'].includes(item.status), [item.status]); const isFailed = useMemo(() => ['canceled', 'failed'].includes(item.status), [item.status]); + const { canShowCancelQueueItem, canShowRetryQueueItem } = useMemo(() => getQueueItemActionVisibility(item), [item]); const originText = useOriginText(item.origin); const destinationText = useDestinationText(item.destination); @@ -183,7 +185,7 @@ const QueueItemComponent = ({ index, item }: InnerItemProps) => { - {!isFailed && ( + {canShowCancelQueueItem && !isFailed && ( { icon={} /> )} - {isFailed && ( + {canShowRetryQueueItem && isFailed && ( { ); const isFailed = useMemo(() => !!queueItem && ['canceled', 'failed'].includes(queueItem.status), [queueItem]); + const { canShowCancelQueueItem, canShowRetryQueueItem } = useMemo( + () => getQueueItemActionVisibility(queueItemDTO), + [queueItemDTO] + ); const onCancelBatch = useCallback(() => { cancelBatch.trigger(batch_id); @@ -82,7 +88,7 @@ const QueueItemComponent = ({ queueItem: queueItemDTO }: Props) => { - {!isFailed && ( + {canShowCancelQueueItem && !isFailed && ( )} - {isFailed && ( + {canShowRetryQueueItem && isFailed && (