From 8e34a1d8ec05f61a1287e1a6f385e7effff5f5ee Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 25 Apr 2026 15:00:59 -0400 Subject: [PATCH 01/12] fix(multiuser): redact other users' current-item identifiers from queue status events QueueItemStatusChangedEvent embeds the SessionQueueStatus, which includes the currently-running item's item_id, session_id, and batch_id. The event ships to user:{owner} and admin rooms. When user A's item changed status while user B's item was the one in progress, owner A's frontend received the event with B's identifiers exposed. In _set_queue_item_status, scrub item_id/session_id/batch_id from the embedded queue_status when the in-progress item belongs to a different user than the changed item. Aggregate counts remain global (not user-sensitive). Identified out-of-scope in the security audit of #127. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../session_queue/session_queue_sqlite.py | 12 ++ ...st_session_queue_status_event_isolation.py | 148 ++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 tests/app/services/session_queue/test_session_queue_status_event_isolation.py diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 95fb16fcbed..0a7029d2b0f 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -317,6 +317,18 @@ def _set_queue_item_status( queue_item = self.get_queue_item(item_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id) + + # get_queue_status embeds the currently-running item's identifiers (item_id, session_id, + # batch_id) into the SessionQueueStatus. The QueueItemStatusChangedEvent ships to + # user:{queue_item.user_id} and admin rooms; without this scrub, owner A would learn + # user B's identifiers whenever A's item changed status while B's item was the one in + # progress. Aggregate counts remain global (not user-sensitive). + current_item = self.get_current(queue_id=queue_item.queue_id) + if current_item is not None and current_item.user_id != queue_item.user_id: + queue_status = queue_status.model_copy( + update={"item_id": None, "session_id": None, "batch_id": None} + ) + self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status) return queue_item diff --git a/tests/app/services/session_queue/test_session_queue_status_event_isolation.py b/tests/app/services/session_queue/test_session_queue_status_event_isolation.py new file mode 100644 index 00000000000..ee8b976b46a --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_status_event_isolation.py @@ -0,0 +1,148 @@ +"""Regression tests for the cross-user identifier leak in QueueItemStatusChangedEvent. + +When user A's queue item changes status while user B's item is currently in_progress, +the embedded SessionQueueStatus inside the event must NOT expose B's item_id, +session_id, or batch_id. The full event ships to user:{A.user_id} and admin rooms, +so unredacted fields would let owner A learn user B's identifiers. +""" + +import uuid + +import pytest + +from invokeai.app.services.events.events_common import QueueItemStatusChangedEvent +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState +from tests.test_nodes import PromptTestInvocation, TestEventService + + +@pytest.fixture +def session_queue(mock_invoker: Invoker) -> SqliteSessionQueue: + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert_queue_item(session_queue: SqliteSessionQueue, user_id: str) -> int: + graph = Graph() + graph.add_node(PromptTestInvocation(id="prompt", prompt="test")) + session = GraphExecutionState(graph=graph) + session_json = session.model_dump_json(warnings=False, exclude_none=True) + batch_id = str(uuid.uuid4()) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue ( + queue_id, session, session_id, batch_id, field_values, + priority, workflow, origin, destination, retried_from_item_id, user_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ("default", session_json, session.id, batch_id, None, 0, None, None, None, None, user_id), + ) + return cursor.lastrowid # type: ignore[return-value] + + +def _last_status_event_for_item(event_bus: TestEventService, item_id: int) -> QueueItemStatusChangedEvent: + matches = [ + e for e in event_bus.events if isinstance(e, QueueItemStatusChangedEvent) and e.item_id == item_id + ] + assert matches, f"No QueueItemStatusChangedEvent found for item {item_id}" + return matches[-1] + + +def test_event_redacts_other_users_current_item_identifiers( + session_queue: SqliteSessionQueue, mock_invoker: Invoker +) -> None: + """When user A's pending item is canceled while user B's item is in_progress, the + embedded queue_status in A's status-changed event must not expose B's identifiers.""" + user_a = "user-a" + user_b = "user-b" + + a_item_id = _insert_queue_item(session_queue, user_id=user_a) + b_item_id = _insert_queue_item(session_queue, user_id=user_b) + + # Make user B's item the in-progress one. We must dequeue B first; FIFO would dequeue A + # because it was inserted first, so reverse the insertion: cancel A's, re-insert as new. + # Simpler: dequeue twice. First dequeue picks A (older); promote B by inserting in + # right order means we need B to be the in_progress item when A's event fires. + # Cancel A first to make it ineligible, then dequeue B. + # Actually we need A to be pending when its status changes — so we must dequeue B first. + # Re-do: insert B BEFORE A by temporarily inserting A second. Recreate cleanly: + session_queue.delete_queue_item(a_item_id) + session_queue.delete_queue_item(b_item_id) + b_item_id = _insert_queue_item(session_queue, user_id=user_b) + a_item_id = _insert_queue_item(session_queue, user_id=user_a) + + in_progress = session_queue.dequeue() + assert in_progress is not None and in_progress.item_id == b_item_id + assert in_progress.user_id == user_b + + event_bus: TestEventService = mock_invoker.services.events + event_bus.events.clear() + + # Now cancel user A's pending item. The emitted event for A must not leak B's + # current-item identifiers via the embedded queue_status. + canceled = session_queue.cancel_queue_item(a_item_id) + assert canceled.user_id == user_a + + a_event = _last_status_event_for_item(event_bus, a_item_id) + assert a_event.user_id == user_a + assert a_event.queue_status.item_id is None, "must not leak other user's current item_id" + assert a_event.queue_status.session_id is None, "must not leak other user's current session_id" + assert a_event.queue_status.batch_id is None, "must not leak other user's current batch_id" + # Aggregate counts in the embedded status are global and OK to share. + assert a_event.queue_status.in_progress == 1 + assert a_event.queue_status.canceled == 1 + + +def test_event_preserves_owner_current_item_identifiers( + session_queue: SqliteSessionQueue, mock_invoker: Invoker +) -> None: + """When the current in-progress item belongs to the same user as the changed item, the + embedded queue_status must continue to expose the identifiers (no over-redaction).""" + user_a = "user-a" + + a_item_id = _insert_queue_item(session_queue, user_id=user_a) + + in_progress = session_queue.dequeue() + assert in_progress is not None and in_progress.item_id == a_item_id + + event_bus: TestEventService = mock_invoker.services.events + event_bus.events.clear() + + completed = session_queue.complete_queue_item(a_item_id) + assert completed.user_id == user_a + + # The event for A's transition fires AFTER the row is marked completed, so by the time + # _set_queue_item_status reads get_current it returns None — there is no in-progress + # item to leak. queue_status fields should therefore be None. + a_event = _last_status_event_for_item(event_bus, a_item_id) + assert a_event.user_id == user_a + assert a_event.queue_status.item_id is None # no in-progress item at all + assert a_event.queue_status.completed == 1 + + +def test_event_preserves_identifiers_when_current_item_is_the_changed_item( + session_queue: SqliteSessionQueue, mock_invoker: Invoker +) -> None: + """The dequeue() transition makes the changed item itself the in-progress current item. + queue_status must expose its identifiers since they belong to the event's owner.""" + user_a = "user-a" + a_item_id = _insert_queue_item(session_queue, user_id=user_a) + + event_bus: TestEventService = mock_invoker.services.events + event_bus.events.clear() + + in_progress = session_queue.dequeue() + assert in_progress is not None and in_progress.item_id == a_item_id + + a_event = _last_status_event_for_item(event_bus, a_item_id) + assert a_event.status == "in_progress" + assert a_event.user_id == user_a + # Current item == changed item == owned by user_a → no redaction + assert a_event.queue_status.item_id == a_item_id + assert a_event.queue_status.session_id == in_progress.session_id + assert a_event.queue_status.batch_id == in_progress.batch_id From e77c0d125f672a91ca949808076862b7ae5d1e46 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Sun, 26 Apr 2026 21:29:22 +0200 Subject: [PATCH 02/12] feat: Add virtual boards that dynamically group images by date (#8971) * feat: Add virtual boards that dynamically group images by date Virtual boards are computed on-the-fly via backend queries, not stored in the database. The first virtual board type groups images by creation date into sub-boards per day. The feature is togglable via the board settings popover and the collapse state persists across sessions. * add missing Redux state migration * Chore ruff check * docs(gallery): document virtual boards by date Add a "Virtual Boards by Date" section to the gallery feature docs explaining how to enable the new By Date section, what each entry shows, and that virtual boards are a read-only view over existing images. * fix(ui): invalidate VirtualBoards tag on image generation Optimistic updates in onInvocationComplete cover ImageList and BoardImagesTotal but not VirtualBoards, so date-grouped counts and cover thumbnails would only refresh on the next mutation. Trigger an explicit invalidation when at least one non-intermediate image was added. --- docs-old/features/gallery.md | 20 +++ invokeai/app/api/routers/virtual_boards.py | 56 +++++++ invokeai/app/api_app.py | 2 + .../image_records/image_records_base.py | 24 +++ .../image_records/image_records_sqlite.py | 139 ++++++++++++++++++ .../app/services/virtual_boards/__init__.py | 0 .../virtual_boards/virtual_boards_common.py | 14 ++ .../Boards/BoardsList/BoardsList.tsx | 2 + .../Boards/BoardsList/VirtualBoardItem.tsx | 96 ++++++++++++ .../Boards/BoardsList/VirtualBoardSection.tsx | 62 ++++++++ .../Boards/BoardsSettingsPopover.tsx | 2 + .../ShowVirtualBoardsCheckbox.tsx | 29 ++++ .../components/use-gallery-image-names.ts | 60 ++++++-- .../features/gallery/store/gallerySlice.ts | 27 ++++ .../web/src/features/gallery/store/types.ts | 8 + .../services/api/endpoints/virtual_boards.ts | 56 +++++++ .../src/services/api/hooks/useBoardName.ts | 4 + .../frontend/web/src/services/api/index.ts | 1 + .../frontend/web/src/services/api/schema.ts | 137 +++++++++++++++++ .../src/services/api/util/tagInvalidation.ts | 2 +- .../services/events/onInvocationComplete.tsx | 5 + 21 files changed, 735 insertions(+), 11 deletions(-) create mode 100644 invokeai/app/api/routers/virtual_boards.py create mode 100644 invokeai/app/services/virtual_boards/__init__.py create mode 100644 invokeai/app/services/virtual_boards/virtual_boards_common.py create mode 100644 invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/VirtualBoardItem.tsx create mode 100644 invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/VirtualBoardSection.tsx create mode 100644 invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover/ShowVirtualBoardsCheckbox.tsx create mode 100644 invokeai/frontend/web/src/services/api/endpoints/virtual_boards.ts diff --git a/docs-old/features/gallery.md b/docs-old/features/gallery.md index 1c12f59c7a7..eb246b83b95 100644 --- a/docs-old/features/gallery.md +++ b/docs-old/features/gallery.md @@ -34,6 +34,26 @@ The settings button opens a list of options. Below these two buttons, you'll see the Search Boards text entry area. You use this to search for specific boards by the name of the board. Next to it is the Add Board (+) button which lets you add new boards. Boards can be renamed by clicking on the name of the board under its thumbnail and typing in the new name. +### Virtual Boards by Date + +In addition to the regular user-created boards, the Gallery can show **virtual boards** that group your images automatically by their creation date. Virtual boards are not stored in the database — they are computed on the fly from existing image metadata, so enabling or disabling them never moves or modifies your images. + +#### Enabling Virtual Boards + +Open the boards settings popover (the gear icon next to the boards search field) and toggle **Show Virtual Boards**. A new collapsible **By Date** section then appears in the boards list, with one entry per day on which images were generated (e.g. `2026-03-18`). + +Each virtual board entry shows: + +- a cover thumbnail (the most recent image of that day) +- the number of generated **images** on that date +- the number of uploaded **assets** on that date + +Selecting a virtual board filters the gallery to show only the images from that day. Search, category filters (Images / Assets), starred-first sorting and sort direction all work the same way as on regular boards. + +!!! note "Read-only" + + Virtual boards are a view over your existing images. You cannot rename, delete or auto-assign to them, and images cannot be "moved into" a virtual board — they appear there automatically based on their creation date. To organize images permanently, use regular boards. + ### Board Thumbnail Menu Each board has a context menu (ctrl+click / right-click). diff --git a/invokeai/app/api/routers/virtual_boards.py b/invokeai/app/api/routers/virtual_boards.py new file mode 100644 index 00000000000..f0c9e2edc51 --- /dev/null +++ b/invokeai/app/api/routers/virtual_boards.py @@ -0,0 +1,56 @@ +from fastapi import HTTPException, Path, Query +from fastapi.routing import APIRouter + +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageNamesResult +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.virtual_boards.virtual_boards_common import VirtualSubBoardDTO + +virtual_boards_router = APIRouter(prefix="/v1/virtual_boards", tags=["virtual_boards"]) + + +@virtual_boards_router.get( + "/by_date", + operation_id="list_virtual_boards_by_date", + response_model=list[VirtualSubBoardDTO], +) +async def list_virtual_boards_by_date( + current_user: CurrentUserOrDefault, +) -> list[VirtualSubBoardDTO]: + """Gets a list of virtual sub-boards grouped by date.""" + try: + return ApiDependencies.invoker.services.image_records.get_image_dates( + user_id=current_user.user_id, + is_admin=current_user.is_admin, + ) + except Exception: + raise HTTPException(status_code=500, detail="Failed to get virtual boards by date") + + +@virtual_boards_router.get( + "/by_date/{date}/image_names", + operation_id="list_virtual_board_image_names_by_date", + response_model=ImageNamesResult, +) +async def list_virtual_board_image_names_by_date( + current_user: CurrentUserOrDefault, + date: str = Path(description="The ISO date string, e.g. '2026-03-18'"), + starred_first: bool = Query(default=True, description="Whether to sort starred images first"), + order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The sort direction"), + categories: list[ImageCategory] | None = Query(default=None, description="The categories of images to include"), + search_term: str | None = Query(default=None, description="Search term to filter images"), +) -> ImageNamesResult: + """Gets ordered image names for a specific date.""" + try: + return ApiDependencies.invoker.services.image_records.get_image_names_by_date( + date=date, + starred_first=starred_first, + order_dir=order_dir, + categories=categories, + search_term=search_term, + user_id=current_user.user_id, + is_admin=current_user.is_admin, + ) + except Exception: + raise HTTPException(status_code=500, detail="Failed to get image names for date") diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 2ca6746b496..110fd757bd1 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -29,6 +29,7 @@ session_queue, style_presets, utilities, + virtual_boards, workflows, ) from invokeai.app.api.sockets import SocketIO @@ -177,6 +178,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api") +app.include_router(virtual_boards.virtual_boards_router, prefix="/api") app.include_router(model_relationships.model_relationships_router, prefix="/api") app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 457cf2f4686..dd1e9fd4f37 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -12,6 +12,7 @@ ) from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.virtual_boards.virtual_boards_common import VirtualSubBoardDTO class ImageRecordStorageBase(ABC): @@ -122,3 +123,26 @@ def get_image_names( ) -> ImageNamesResult: """Gets ordered list of image names with metadata for optimistic updates.""" pass + + @abstractmethod + def get_image_dates( + self, + user_id: Optional[str] = None, + is_admin: bool = False, + ) -> list[VirtualSubBoardDTO]: + """Gets a list of dates with image counts, grouped by DATE(created_at).""" + pass + + @abstractmethod + def get_image_names_by_date( + self, + date: str, + starred_first: bool = True, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + categories: Optional[list[ImageCategory]] = None, + search_term: Optional[str] = None, + user_id: Optional[str] = None, + is_admin: bool = False, + ) -> ImageNamesResult: + """Gets ordered list of image names for a specific date.""" + pass diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 07126d53a9f..e88b49c56d3 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -19,6 +19,7 @@ from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.app.services.virtual_boards.virtual_boards_common import VirtualSubBoardDTO class SqliteImageRecordStorage(ImageRecordStorageBase): @@ -503,3 +504,141 @@ def get_image_names( image_names = [row[0] for row in result] return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names)) + + def get_image_dates( + self, + user_id: Optional[str] = None, + is_admin: bool = False, + ) -> list[VirtualSubBoardDTO]: + with self._db.transaction() as cursor: + query_conditions = "" + query_params: list[Union[int, str, bool]] = [] + + # Only non-intermediate images + query_conditions += """--sql + AND images.is_intermediate = 0 + """ + + # User isolation for non-admin users + if user_id is not None and not is_admin: + query_conditions += """--sql + AND images.user_id = ? + """ + query_params.append(user_id) + + query = f"""--sql + SELECT + DATE(images.created_at) as date, + SUM(CASE WHEN images.image_category = 'general' THEN 1 ELSE 0 END) as image_count, + SUM(CASE WHEN images.image_category != 'general' THEN 1 ELSE 0 END) as asset_count, + ( + SELECT i2.image_name FROM images i2 + WHERE DATE(i2.created_at) = DATE(images.created_at) + AND i2.is_intermediate = 0 + ORDER BY i2.created_at DESC LIMIT 1 + ) as cover_image_name + FROM images + WHERE 1=1 + {query_conditions} + GROUP BY DATE(images.created_at) + ORDER BY date DESC; + """ + + cursor.execute(query, query_params) + result = cast(list[sqlite3.Row], cursor.fetchall()) + + return [ + VirtualSubBoardDTO( + virtual_board_id=f"by_date:{dict(row)['date']}", + board_name=dict(row)["date"], + date=dict(row)["date"], + image_count=dict(row)["image_count"], + asset_count=dict(row)["asset_count"], + cover_image_name=dict(row)["cover_image_name"], + ) + for row in result + ] + + def get_image_names_by_date( + self, + date: str, + starred_first: bool = True, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + categories: Optional[list[ImageCategory]] = None, + search_term: Optional[str] = None, + user_id: Optional[str] = None, + is_admin: bool = False, + ) -> ImageNamesResult: + with self._db.transaction() as cursor: + query_conditions = "" + query_params: list[Union[int, str, bool]] = [] + + # Filter by date + query_conditions += """--sql + AND DATE(images.created_at) = ? + """ + query_params.append(date) + + # Only non-intermediate images + query_conditions += """--sql + AND images.is_intermediate = 0 + """ + + if categories is not None: + category_strings = [c.value for c in set(categories)] + placeholders = ",".join("?" * len(category_strings)) + query_conditions += f"""--sql + AND images.image_category IN ( {placeholders} ) + """ + for c in category_strings: + query_params.append(c) + + # User isolation for non-admin users + if user_id is not None and not is_admin: + query_conditions += """--sql + AND images.user_id = ? + """ + query_params.append(user_id) + + if search_term: + query_conditions += """--sql + AND ( + images.metadata LIKE ? + OR images.created_at LIKE ? + ) + """ + query_params.append(f"%{search_term.lower()}%") + query_params.append(f"%{search_term.lower()}%") + + # Get starred count if starred_first is enabled + starred_count = 0 + if starred_first: + starred_count_query = f"""--sql + SELECT COUNT(*) + FROM images + WHERE images.starred = TRUE AND (1=1{query_conditions}) + """ + cursor.execute(starred_count_query, query_params) + starred_count = cast(int, cursor.fetchone()[0]) + + # Get all image names with proper ordering + if starred_first: + names_query = f"""--sql + SELECT images.image_name + FROM images + WHERE 1=1{query_conditions} + ORDER BY images.starred DESC, images.created_at {order_dir.value} + """ + else: + names_query = f"""--sql + SELECT images.image_name + FROM images + WHERE 1=1{query_conditions} + ORDER BY images.created_at {order_dir.value} + """ + + cursor.execute(names_query, query_params) + result = cast(list[sqlite3.Row], cursor.fetchall()) + image_names = [row[0] for row in result] + + return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names)) diff --git a/invokeai/app/services/virtual_boards/__init__.py b/invokeai/app/services/virtual_boards/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/app/services/virtual_boards/virtual_boards_common.py b/invokeai/app/services/virtual_boards/virtual_boards_common.py new file mode 100644 index 00000000000..e1df5a81ca5 --- /dev/null +++ b/invokeai/app/services/virtual_boards/virtual_boards_common.py @@ -0,0 +1,14 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class VirtualSubBoardDTO(BaseModel): + """A virtual sub-board computed from image metadata, not stored in the database.""" + + virtual_board_id: str = Field(description="The virtual board ID, e.g. 'by_date:2026-03-18'.") + board_name: str = Field(description="The display name of the virtual sub-board, e.g. '2026-03-18'.") + date: str = Field(description="The ISO date string, e.g. '2026-03-18'.") + image_count: int = Field(description="The number of general images for this date.") + asset_count: int = Field(description="The number of asset images for this date.") + cover_image_name: Optional[str] = Field(default=None, description="The most recent image name for this date.") diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx index 2d37a03f69f..c05a2df84fa 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx @@ -14,6 +14,7 @@ import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import AddBoardButton from './AddBoardButton'; import GalleryBoard from './GalleryBoard'; import NoBoardBoard from './NoBoardBoard'; +import { VirtualBoardSection } from './VirtualBoardSection'; export const BoardsList = memo(() => { const { t } = useTranslation(); @@ -40,6 +41,7 @@ export const BoardsList = memo(() => { if (!boardSearchText.length) { elements.push(); + elements.push(); } filteredBoards.forEach((board) => { diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/VirtualBoardItem.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/VirtualBoardItem.tsx new file mode 100644 index 00000000000..d85c90f7dc1 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/VirtualBoardItem.tsx @@ -0,0 +1,96 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; +import { Box, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectSelectedBoardId } from 'features/gallery/store/gallerySelectors'; +import { boardIdSelected } from 'features/gallery/store/gallerySlice'; +import { memo, useCallback } from 'react'; +import { PiCalendarBold, PiImageSquare } from 'react-icons/pi'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import type { VirtualSubBoard } from 'services/api/endpoints/virtual_boards'; + +const _hover: SystemStyleObject = { + bg: 'base.850', +}; + +interface VirtualBoardItemProps { + board: VirtualSubBoard; +} + +const VirtualBoardItem = ({ board }: VirtualBoardItemProps) => { + const dispatch = useAppDispatch(); + const selectedBoardId = useAppSelector(selectSelectedBoardId); + const isSelected = selectedBoardId === board.virtual_board_id; + + const onClick = useCallback(() => { + if (selectedBoardId !== board.virtual_board_id) { + dispatch(boardIdSelected({ boardId: board.virtual_board_id })); + } + }, [selectedBoardId, board.virtual_board_id, dispatch]); + + return ( + + + + + + + {board.board_name} + + + + + + {board.image_count} | {board.asset_count} + + + + + + ); +}; + +export default memo(VirtualBoardItem); + +const CoverImage = ({ coverImageName }: { coverImageName: string | null }) => { + const { currentData: coverImage } = useGetImageDTOQuery(coverImageName ?? skipToken); + + if (coverImage) { + return ( + + ); + } + + return ( + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/VirtualBoardSection.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/VirtualBoardSection.tsx new file mode 100644 index 00000000000..bdadaf77fda --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/VirtualBoardSection.tsx @@ -0,0 +1,62 @@ +import { Collapse, Flex, Icon, IconButton, Text } from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectGallerySlice, virtualBoardsSectionOpenChanged } from 'features/gallery/store/gallerySlice'; +import { memo, useCallback } from 'react'; +import { PiCalendarBold, PiCaretDownBold, PiCaretRightBold } from 'react-icons/pi'; +import { useListVirtualBoardsByDateQuery } from 'services/api/endpoints/virtual_boards'; + +import VirtualBoardItem from './VirtualBoardItem'; + +const selectShowVirtualBoards = createSelector(selectGallerySlice, (gallery) => gallery.showVirtualBoards); +const selectVirtualBoardsSectionOpen = createSelector( + selectGallerySlice, + (gallery) => gallery.virtualBoardsSectionOpen +); + +export const VirtualBoardSection = memo(() => { + const dispatch = useAppDispatch(); + const showVirtualBoards = useAppSelector(selectShowVirtualBoards); + const isOpen = useAppSelector(selectVirtualBoardsSectionOpen); + + const { data: virtualBoards } = useListVirtualBoardsByDateQuery(undefined, { + skip: !showVirtualBoards, + }); + + const toggleOpen = useCallback(() => { + dispatch(virtualBoardsSectionOpenChanged(!isOpen)); + }, [dispatch, isOpen]); + + if (!showVirtualBoards || !virtualBoards?.length) { + return null; + } + + return ( + + + + + + By Date + + + : } + onClick={toggleOpen} + /> + + + + {virtualBoards.map((board) => ( + + ))} + + + + ); +}); + +VirtualBoardSection.displayName = 'VirtualBoardSection'; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsSettingsPopover.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsSettingsPopover.tsx index 3fef611f99b..814595e7f2e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsSettingsPopover.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsSettingsPopover.tsx @@ -13,6 +13,7 @@ import { import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect'; import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox'; import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox'; +import ShowVirtualBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowVirtualBoardsCheckbox'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiGearSixFill } from 'react-icons/pi'; @@ -47,6 +48,7 @@ export const BoardsSettingsPopover = memo(() => { + diff --git a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover/ShowVirtualBoardsCheckbox.tsx b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover/ShowVirtualBoardsCheckbox.tsx new file mode 100644 index 00000000000..29e3e7ab3ce --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover/ShowVirtualBoardsCheckbox.tsx @@ -0,0 +1,29 @@ +import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectGallerySlice, showVirtualBoardsChanged } from 'features/gallery/store/gallerySlice'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback } from 'react'; + +const selectShowVirtualBoards = createSelector(selectGallerySlice, (gallery) => gallery.showVirtualBoards); + +const ShowVirtualBoardsCheckbox = () => { + const dispatch = useAppDispatch(); + const showVirtualBoards = useAppSelector(selectShowVirtualBoards); + + const onChange = useCallback( + (e: ChangeEvent) => { + dispatch(showVirtualBoardsChanged(e.target.checked)); + }, + [dispatch] + ); + + return ( + + Virtual Boards + + + ); +}; + +export default memo(ShowVirtualBoardsCheckbox); diff --git a/invokeai/frontend/web/src/features/gallery/components/use-gallery-image-names.ts b/invokeai/frontend/web/src/features/gallery/components/use-gallery-image-names.ts index c81728a1b21..487c5609062 100644 --- a/invokeai/frontend/web/src/features/gallery/components/use-gallery-image-names.ts +++ b/invokeai/frontend/web/src/features/gallery/components/use-gallery-image-names.ts @@ -1,21 +1,61 @@ +import { skipToken } from '@reduxjs/toolkit/query'; import { EMPTY_ARRAY } from 'app/store/constants'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectGetImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors'; +import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors'; +import { getDateFromVirtualBoardId, isVirtualBoardId } from 'features/gallery/store/types'; import { useGetImageNamesQuery } from 'services/api/endpoints/images'; +import { useGetVirtualBoardImageNamesByDateQuery } from 'services/api/endpoints/virtual_boards'; import { useDebounce } from 'use-debounce'; -const getImageNamesQueryOptions = { +const selectFromResult = ({ + currentData, + isLoading, + isFetching, +}: { + currentData?: { image_names: string[] }; + isLoading: boolean; + isFetching: boolean; +}) => ({ + imageNames: currentData?.image_names ?? EMPTY_ARRAY, + isLoading, + isFetching, +}); + +const queryOptions = { refetchOnReconnect: true, - selectFromResult: ({ currentData, isLoading, isFetching }) => ({ - imageNames: currentData?.image_names ?? EMPTY_ARRAY, - isLoading, - isFetching, - }), -} satisfies Parameters[1]; + selectFromResult, +}; export const useGalleryImageNames = () => { + const selectedBoardId = useAppSelector(selectSelectedBoardId); const _queryArgs = useAppSelector(selectGetImageNamesQueryArgs); const [queryArgs] = useDebounce(_queryArgs, 300); - const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(queryArgs, getImageNamesQueryOptions); - return { imageNames, isLoading, isFetching, queryArgs }; + const isVirtual = isVirtualBoardId(selectedBoardId); + + // Regular board query + const regularResult = useGetImageNamesQuery(isVirtual ? skipToken : queryArgs, queryOptions); + + // Virtual board query + const date = isVirtual ? getDateFromVirtualBoardId(selectedBoardId) : ''; + const virtualResult = useGetVirtualBoardImageNamesByDateQuery( + isVirtual + ? { + date, + categories: queryArgs.categories ?? undefined, + search_term: queryArgs.search_term || undefined, + order_dir: queryArgs.order_dir, + starred_first: queryArgs.starred_first, + } + : skipToken, + queryOptions + ); + + const result = isVirtual ? virtualResult : regularResult; + + return { + imageNames: result.imageNames, + isLoading: result.isLoading, + isFetching: result.isFetching, + queryArgs, + }; }; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 6a25caadce4..e4894b60766 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -12,6 +12,7 @@ import { type ComparisonMode, type GalleryState, type GalleryView, + isVirtualBoardId, type OrderDir, zGalleryState, } from './types'; @@ -33,6 +34,8 @@ const getInitialState = (): GalleryState => ({ comparisonMode: 'slider', comparisonFit: 'fill', shouldShowArchivedBoards: false, + showVirtualBoards: false, + virtualBoardsSectionOpen: true, boardsListOrderBy: 'created_at', boardsListOrderDir: 'DESC', }); @@ -103,6 +106,10 @@ const slice = createSlice({ state.autoAddBoardId = 'none'; return; } + // Virtual boards cannot be auto-add targets + if (isVirtualBoardId(action.payload)) { + return; + } state.autoAddBoardId = action.payload; }, galleryViewChanged: (state, action: PayloadAction) => { @@ -127,6 +134,17 @@ const slice = createSlice({ shouldShowArchivedBoardsChanged: (state, action: PayloadAction) => { state.shouldShowArchivedBoards = action.payload; }, + showVirtualBoardsChanged: (state, action: PayloadAction) => { + state.showVirtualBoards = action.payload; + // If virtual boards are hidden and a virtual board is selected, reset to 'none' + if (!action.payload && isVirtualBoardId(state.selectedBoardId)) { + state.selectedBoardId = 'none'; + state.selection = []; + } + }, + virtualBoardsSectionOpenChanged: (state, action: PayloadAction) => { + state.virtualBoardsSectionOpen = action.payload; + }, starredFirstChanged: (state, action: PayloadAction) => { state.starredFirst = action.payload; }, @@ -172,6 +190,8 @@ export const { orderDirChanged, starredFirstChanged, shouldShowArchivedBoardsChanged, + showVirtualBoardsChanged, + virtualBoardsSectionOpenChanged, searchTermChanged, boardsListOrderByChanged, boardsListOrderDirChanged, @@ -189,6 +209,13 @@ export const gallerySliceConfig: SliceConfig = { if (!('_version' in state)) { state._version = 1; } + // Add virtual boards fields if missing (added in virtual boards feature) + if (!('showVirtualBoards' in state)) { + state.showVirtualBoards = false; + } + if (!('virtualBoardsSectionOpen' in state)) { + state.virtualBoardsSectionOpen = true; + } return zGalleryState.parse(state); }, persistDenylist: ['selection', 'galleryView', 'imageToCompare'], diff --git a/invokeai/frontend/web/src/features/gallery/store/types.ts b/invokeai/frontend/web/src/features/gallery/store/types.ts index addeefe870f..c040e5834d7 100644 --- a/invokeai/frontend/web/src/features/gallery/store/types.ts +++ b/invokeai/frontend/web/src/features/gallery/store/types.ts @@ -35,8 +35,16 @@ export const zGalleryState = z.object({ comparisonMode: zComparisonMode, comparisonFit: zComparisonFit, shouldShowArchivedBoards: z.boolean(), + showVirtualBoards: z.boolean(), + virtualBoardsSectionOpen: z.boolean(), boardsListOrderBy: zBoardRecordOrderBy, boardsListOrderDir: zOrderDir, }); export type GalleryState = z.infer; + +const VIRTUAL_BOARD_ID_PREFIX = 'by_date:'; + +export const isVirtualBoardId = (id: string): boolean => id.startsWith(VIRTUAL_BOARD_ID_PREFIX); + +export const getDateFromVirtualBoardId = (id: string): string => id.replace(VIRTUAL_BOARD_ID_PREFIX, ''); diff --git a/invokeai/frontend/web/src/services/api/endpoints/virtual_boards.ts b/invokeai/frontend/web/src/services/api/endpoints/virtual_boards.ts new file mode 100644 index 00000000000..b450bf84436 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/endpoints/virtual_boards.ts @@ -0,0 +1,56 @@ +import queryString from 'query-string'; +import type { ImageCategory } from 'services/api/types'; + +import type { ApiTagDescription } from '..'; +import { api, buildV1Url } from '..'; + +export type VirtualSubBoard = { + virtual_board_id: string; + board_name: string; + date: string; + image_count: number; + asset_count: number; + cover_image_name: string | null; +}; + +type ImageNamesResult = { + image_names: string[]; + starred_count: number; + total_count: number; +}; + +const buildVirtualBoardsUrl = (path: string = '') => buildV1Url(`virtual_boards/${path}`); + +const virtualBoardsApi = api.injectEndpoints({ + endpoints: (build) => ({ + listVirtualBoardsByDate: build.query({ + query: () => ({ + url: buildVirtualBoardsUrl('by_date'), + }), + providesTags: (): ApiTagDescription[] => ['VirtualBoards', 'FetchOnReconnect'], + }), + + getVirtualBoardImageNamesByDate: build.query< + ImageNamesResult, + { + date: string; + starred_first?: boolean; + order_dir?: 'ASC' | 'DESC'; + categories?: ImageCategory[]; + search_term?: string; + } + >({ + query: ({ date, ...params }) => ({ + url: buildVirtualBoardsUrl( + `by_date/${date}/image_names?${queryString.stringify(params, { arrayFormat: 'none', skipNull: true, skipEmptyString: true })}` + ), + }), + providesTags: (_result, _error, arg): ApiTagDescription[] => [ + { type: 'ImageNameList', id: `virtual_${arg.date}` }, + 'FetchOnReconnect', + ], + }), + }), +}); + +export const { useListVirtualBoardsByDateQuery, useGetVirtualBoardImageNamesByDateQuery } = virtualBoardsApi; diff --git a/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts b/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts index 5b741907662..eb847b6c93d 100644 --- a/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts +++ b/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts @@ -1,4 +1,5 @@ import type { BoardId } from 'features/gallery/store/types'; +import { getDateFromVirtualBoardId, isVirtualBoardId } from 'features/gallery/store/types'; import { t } from 'i18next'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; @@ -7,6 +8,9 @@ export const useBoardName = (board_id: BoardId) => { { include_archived: true }, { selectFromResult: ({ data }) => { + if (isVirtualBoardId(board_id)) { + return { boardName: getDateFromVirtualBoardId(board_id) }; + } const selectedBoard = data?.find((b) => b.board_id === board_id); const boardName = selectedBoard?.board_name || t('boards.uncategorized'); diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 85a5d320a1a..56fa307dd25 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -60,6 +60,7 @@ const tagTypes = [ 'FetchOnReconnect', 'ClientState', 'UserList', + 'VirtualBoards', ] as const; export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>; export const LIST_TAG = 'LIST'; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 4b8e4da95a5..cb872bddfdd 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1457,6 +1457,46 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/virtual_boards/by_date": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * List Virtual Boards By Date + * @description Gets a list of virtual sub-boards grouped by date. + */ + get: operations["list_virtual_boards_by_date"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/virtual_boards/by_date/{date}/image_names": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * List Virtual Board Image Names By Date + * @description Gets ordered image names for a specific date. + */ + get: operations["list_virtual_board_image_names_by_date"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/model_relationships/i/{model_key}": { parameters: { query?: never; @@ -30025,6 +30065,42 @@ export type components = { /** Error Type */ type: string; }; + /** + * VirtualSubBoardDTO + * @description A virtual sub-board computed from image metadata, not stored in the database. + */ + VirtualSubBoardDTO: { + /** + * Virtual Board Id + * @description The virtual board ID, e.g. 'by_date:2026-03-18'. + */ + virtual_board_id: string; + /** + * Board Name + * @description The display name of the virtual sub-board, e.g. '2026-03-18'. + */ + board_name: string; + /** + * Date + * @description The ISO date string, e.g. '2026-03-18'. + */ + date: string; + /** + * Image Count + * @description The number of general images for this date. + */ + image_count: number; + /** + * Asset Count + * @description The number of asset images for this date. + */ + asset_count: number; + /** + * Cover Image Name + * @description The most recent image name for this date. + */ + cover_image_name?: string | null; + }; /** Workflow */ Workflow: { /** @@ -34125,6 +34201,67 @@ export interface operations { }; }; }; + list_virtual_boards_by_date: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["VirtualSubBoardDTO"][]; + }; + }; + }; + }; + list_virtual_board_image_names_by_date: { + parameters: { + query?: { + /** @description Whether to sort starred images first */ + starred_first?: boolean; + /** @description The sort direction */ + order_dir?: components["schemas"]["SQLiteDirection"]; + /** @description The categories of images to include */ + categories?: components["schemas"]["ImageCategory"][] | null; + /** @description Search term to filter images */ + search_term?: string | null; + }; + header?: never; + path: { + /** @description The ISO date string, e.g. '2026-03-18' */ + date: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["ImageNamesResult"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; get_related_models: { parameters: { query?: never; diff --git a/invokeai/frontend/web/src/services/api/util/tagInvalidation.ts b/invokeai/frontend/web/src/services/api/util/tagInvalidation.ts index 477a5a03f87..bac3130d312 100644 --- a/invokeai/frontend/web/src/services/api/util/tagInvalidation.ts +++ b/invokeai/frontend/web/src/services/api/util/tagInvalidation.ts @@ -4,7 +4,7 @@ import { getListImagesUrl } from 'services/api/util'; import type { ApiTagDescription } from '..'; export const getTagsToInvalidateForBoardAffectingMutation = (affected_boards: string[]): ApiTagDescription[] => { - const tags: ApiTagDescription[] = ['ImageNameList']; + const tags: ApiTagDescription[] = ['ImageNameList', 'VirtualBoards']; for (const board_id of affected_boards) { tags.push({ diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx index 9a6fb3aab7f..ea6a237d4b2 100644 --- a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx +++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx @@ -154,6 +154,11 @@ export const buildOnInvocationComplete = ( // No need to invalidate tags since we're doing optimistic updates // Board totals are already updated above via upsertQueryEntries + // Exception: virtual board groupings aren't covered by the optimistic updates above, so + // their counts/cover thumbnails would otherwise lag behind until the next mutation. + if (Object.keys(boardTotalAdditions).length > 0) { + dispatch(imagesApi.util.invalidateTags(['VirtualBoards'])); + } const autoSwitch = selectAutoSwitch(getState()); From f9f2a32e966193a8868840ebe412360ad5dcc380 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Mon, 27 Apr 2026 04:03:16 +0200 Subject: [PATCH 03/12] Feat(UI): Add LLM-powered prompt expansion and image-to-prompt features (#8899) * Add LLM-powered prompt expansion and image-to-prompt features Adds two new buttons to the positive prompt area: - "Expand Prompt" uses a local TextLLM model (AutoModelForCausalLM) to expand brief prompts into detailed image generation prompts - "Image to Prompt" uses an existing LLaVA OneVision model to generate descriptive prompts from uploaded images Backend: new TextLLM model type with config, loader, pipeline wrapper, workflow node, and two new API endpoints (expand-prompt, image-to-prompt). Also fixes HuggingFace metadata fetch assertion error when file size is None. Frontend: ExpandPromptButton and ImageToPromptButton components with model picker popovers, RTK Query mutations, and model type hooks. Buttons only appear when compatible models are installed. * chore fix windows paths * Fix device mismatch for LLM inference and add CPU-only toggle for Text LLM models Derive the execution device from the loaded model parameters instead of the global TorchDevice chooser so that cpu_only models no longer receive GPU-bound inputs. Also expose the existing cpu_only setting in the frontend Model Manager for Text LLM models. * Harden LLM endpoints and add tests - Bound max_tokens to 1-2048 on ExpandPromptRequest to prevent OOM - Replace asserts with explicit type checks and proper HTTP status codes (404 for unknown models, 422 for wrong model type, 500 for unexpected) - Use float32 dtype for cpu_only TextLLM models instead of global fp16 - Add 16 tests for TextLLMPipeline and API request validation * Add Ctrl+Z undo for LLM prompt changes Saves the previous prompt before LLM overwrites it (Expand Prompt and Image to Prompt). Pressing Ctrl+Z in the prompt textarea restores the original prompt. Undo state auto-expires after 30 seconds and is cleared when the user types manually. * Add documentation and What's New entry for LLM prompt tools - Add docs/features/prompt-tools.md covering Expand Prompt, Image to Prompt, compatible models, Ctrl+Z undo, and the workflow node - Register new doc page in mkdocs.yml under Features - Add What's New item in en.json for the LLM Prompt Tools feature * fix: resolve merge conflict in mkdocs.yml nav * feat(ui): allow dragging gallery images onto prompt box for Image to Prompt Add drop target on the positive prompt textarea so users can drag images from the gallery directly into the prompt area. When dropped, the Image to Prompt popover opens automatically with the image pre-loaded, ready for description generation. * chore typegen * Fix typo in Z-Image Turbo diversity description * Fix three bugs in LLM/VLM utility endpoints Move torch.no_grad() from async endpoint into worker functions where inference actually runs, since the context manager does not carry across the thread boundary used by asyncio.to_thread(). Add threading.Lock around load_model() calls to serialize access to the thread-unsafe model loader, preventing race conditions from concurrent HTTP requests. Catch ImageFileNotFoundException in image_to_prompt and return 404 instead of letting it fall through to the blanket 500 handler. * Fix tokenizer validation, drag-drop dead end, and i18n for LLM prompt tools Validate tokenizer files at model probe time instead of deferring to runtime. Guard image drag-drop on the prompt textarea behind LLaVA model availability. Add missing modelManager.textLLM i18n key and replace all hardcoded strings in ImageToPromptButton and ExpandPromptButton with translation calls. * Add unit tests for promptUndo module * Fix typo in Z-Image Turbo diversity description * Chore fix typegen --------- Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com> --- docs/features/prompt-tools.md | 50 +++ invokeai/app/api/routers/utilities.py | 180 +++++++++- invokeai/app/invocations/fields.py | 1 + invokeai/app/invocations/text_llm.py | 65 ++++ .../backend/model_manager/configs/factory.py | 2 + .../backend/model_manager/configs/text_llm.py | 52 +++ .../load/model_loaders/text_llm.py | 32 ++ invokeai/backend/model_manager/taxonomy.py | 1 + invokeai/backend/text_llm_pipeline.py | 56 ++++ invokeai/frontend/web/public/locales/en.json | 16 +- .../hooks/useEncoderModelSettings.ts | 4 +- .../web/src/features/modelManagerV2/models.ts | 7 + .../EncoderModelSettings.tsx | 4 +- .../subpanels/ModelPanel/ModelView.tsx | 7 +- .../web/src/features/nodes/types/common.ts | 1 + .../components/Core/ParamPositivePrompt.tsx | 83 ++++- .../features/prompt/ExpandPromptButton.tsx | 124 +++++++ .../features/prompt/ImageToPromptButton.tsx | 171 ++++++++++ .../src/features/prompt/promptUndo.test.ts | 87 +++++ .../web/src/features/prompt/promptUndo.ts | 56 ++++ .../src/services/api/endpoints/utilities.ts | 41 +++ .../src/services/api/hooks/modelsByType.ts | 4 + .../frontend/web/src/services/api/schema.ts | 314 +++++++++++++++++- .../frontend/web/src/services/api/types.ts | 5 + mkdocs.yml | 1 + tests/backend/text_llm/__init__.py | 0 .../text_llm/test_text_llm_api_models.py | 54 +++ .../text_llm/test_text_llm_pipeline.py | 130 ++++++++ 28 files changed, 1520 insertions(+), 28 deletions(-) create mode 100644 docs/features/prompt-tools.md create mode 100644 invokeai/app/invocations/text_llm.py create mode 100644 invokeai/backend/model_manager/configs/text_llm.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/text_llm.py create mode 100644 invokeai/backend/text_llm_pipeline.py create mode 100644 invokeai/frontend/web/src/features/prompt/ExpandPromptButton.tsx create mode 100644 invokeai/frontend/web/src/features/prompt/ImageToPromptButton.tsx create mode 100644 invokeai/frontend/web/src/features/prompt/promptUndo.test.ts create mode 100644 invokeai/frontend/web/src/features/prompt/promptUndo.ts create mode 100644 tests/backend/text_llm/__init__.py create mode 100644 tests/backend/text_llm/test_text_llm_api_models.py create mode 100644 tests/backend/text_llm/test_text_llm_pipeline.py diff --git a/docs/features/prompt-tools.md b/docs/features/prompt-tools.md new file mode 100644 index 00000000000..5b00bfa4956 --- /dev/null +++ b/docs/features/prompt-tools.md @@ -0,0 +1,50 @@ +# LLM Prompt Tools + +InvokeAI includes two built-in tools that use local language models to help you write better prompts. Both tools appear as small buttons in the top-right corner of the positive prompt area and are only visible when you have a compatible model installed. + +## Expand Prompt + +Takes your short prompt and expands it into a detailed, vivid description suitable for image generation. + +**How to use:** + +1. Type a brief prompt (e.g. "a cat in a garden") +2. Click the sparkle button in the prompt area +3. Select a Text LLM model from the dropdown +4. Click **Expand** +5. Your prompt is replaced with the expanded version + +**Compatible models:** Any HuggingFace model with a `ForCausalLM` architecture. Recommended options: + +| Model | Size | HuggingFace ID | +|-------|------|----------------| +| Qwen2.5 1.5B Instruct | ~3 GB | `Qwen/Qwen2.5-1.5B-Instruct` | +| Phi-3 Mini Instruct | ~7.5 GB | `microsoft/Phi-3-mini-4k-instruct` | +| TinyLlama Chat | ~2 GB | `TinyLlama/TinyLlama-1.1B-Chat-v1.0` | + +Install by pasting the HuggingFace ID into the Model Manager. The model is automatically detected as a **Text LLM** type. + +## Image to Prompt + +Upload an image and generate a descriptive prompt from it using a vision-language model. + +**How to use:** + +1. Click the image button in the prompt area +2. Select a LLaVA OneVision model from the dropdown +3. Click **Upload Image** and select an image +4. Click **Generate Prompt** +5. The generated description is set as your prompt + +**Compatible models:** LLaVA OneVision models (already supported by InvokeAI). + +## Undo + +Both tools overwrite your current prompt. You can undo this change: + +- Press **Ctrl+Z** (or **Cmd+Z** on macOS) in the prompt textarea within 30 seconds +- The undo state is cleared when you start typing manually + +## Workflow Node + +A **Text LLM** node is also available in the workflow editor for use in automated pipelines. It accepts a prompt string and model selection as inputs and outputs the expanded text as a string. diff --git a/invokeai/app/api/routers/utilities.py b/invokeai/app/api/routers/utilities.py index 921645b1d86..f77f77a8534 100644 --- a/invokeai/app/api/routers/utilities.py +++ b/invokeai/app/api/routers/utilities.py @@ -1,13 +1,32 @@ +import asyncio +import logging +import threading +from pathlib import Path from typing import Optional, Union +import torch from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator -from fastapi import Body +from fastapi import Body, HTTPException from fastapi.routing import APIRouter -from pydantic import BaseModel +from pydantic import BaseModel, Field from pyparsing import ParseException +from transformers import AutoProcessor, AutoTokenizer, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.services.image_files.image_files_common import ImageFileNotFoundException +from invokeai.app.services.model_records.model_records_base import UnknownModelException +from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline +from invokeai.backend.model_manager.taxonomy import ModelType +from invokeai.backend.text_llm_pipeline import DEFAULT_SYSTEM_PROMPT, TextLLMPipeline +from invokeai.backend.util.devices import TorchDevice + +logger = logging.getLogger(__name__) utilities_router = APIRouter(prefix="/v1/utilities", tags=["utilities"]) +# The underlying model loader is not thread-safe, so we serialize load_model calls. +_model_load_lock = threading.Lock() + class DynamicPromptsResponse(BaseModel): prompts: list[str] @@ -42,3 +61,160 @@ async def parse_dynamicprompts( prompts = [prompt] error = str(e) return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error) + + +# --- Expand Prompt --- + + +class ExpandPromptRequest(BaseModel): + prompt: str + model_key: str + max_tokens: int = Field(default=300, ge=1, le=2048) + system_prompt: str | None = None + + +class ExpandPromptResponse(BaseModel): + expanded_prompt: str + error: str | None = None + + +def _resolve_model_path(model_config_path: str) -> Path: + """Resolve a model config path to an absolute path.""" + model_path = Path(model_config_path) + if model_path.is_absolute(): + return model_path.resolve() + base_models_path = ApiDependencies.invoker.services.configuration.models_path + return (base_models_path / model_path).resolve() + + +def _run_expand_prompt(prompt: str, model_key: str, max_tokens: int, system_prompt: str | None) -> str: + """Run text LLM inference synchronously (called from thread).""" + model_manager = ApiDependencies.invoker.services.model_manager + model_config = model_manager.store.get_model(model_key) + + if model_config.type != ModelType.TextLLM: + raise ValueError(f"Model '{model_key}' is not a TextLLM model (got {model_config.type})") + + with _model_load_lock: + loaded_model = model_manager.load.load_model(model_config) + + with torch.no_grad(), loaded_model.model_on_device() as (_, model): + model_abs_path = _resolve_model_path(model_config.path) + tokenizer = AutoTokenizer.from_pretrained(model_abs_path, local_files_only=True) + + pipeline = TextLLMPipeline(model, tokenizer) + model_device = next(model.parameters()).device + output = pipeline.run( + prompt=prompt, + system_prompt=system_prompt or DEFAULT_SYSTEM_PROMPT, + max_new_tokens=max_tokens, + device=model_device, + dtype=TorchDevice.choose_torch_dtype(), + ) + + return output + + +@utilities_router.post( + "/expand-prompt", + operation_id="expand_prompt", + responses={ + 200: {"model": ExpandPromptResponse}, + }, +) +async def expand_prompt(body: ExpandPromptRequest) -> ExpandPromptResponse: + """Expand a brief prompt into a detailed image generation prompt using a text LLM.""" + try: + expanded = await asyncio.to_thread( + _run_expand_prompt, + body.prompt, + body.model_key, + body.max_tokens, + body.system_prompt, + ) + return ExpandPromptResponse(expanded_prompt=expanded) + except UnknownModelException: + raise HTTPException(status_code=404, detail=f"Model '{body.model_key}' not found") + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except Exception as e: + logger.error(f"Error expanding prompt: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# --- Image to Prompt --- + + +class ImageToPromptRequest(BaseModel): + image_name: str + model_key: str + instruction: str = "Describe this image in detail for use as an AI image generation prompt." + + +class ImageToPromptResponse(BaseModel): + prompt: str + error: str | None = None + + +def _run_image_to_prompt(image_name: str, model_key: str, instruction: str) -> str: + """Run LLaVA OneVision inference synchronously (called from thread).""" + model_manager = ApiDependencies.invoker.services.model_manager + model_config = model_manager.store.get_model(model_key) + + if model_config.type != ModelType.LlavaOnevision: + raise ValueError(f"Model '{model_key}' is not a LLaVA OneVision model (got {model_config.type})") + + with _model_load_lock: + loaded_model = model_manager.load.load_model(model_config) + + # Load the image from InvokeAI's image store + image = ApiDependencies.invoker.services.images.get_pil_image(image_name) + image = image.convert("RGB") + + with torch.no_grad(), loaded_model.model_on_device() as (_, model): + if not isinstance(model, LlavaOnevisionForConditionalGeneration): + raise TypeError(f"Expected LlavaOnevisionForConditionalGeneration, got {type(model).__name__}") + + model_abs_path = _resolve_model_path(model_config.path) + processor = AutoProcessor.from_pretrained(model_abs_path, local_files_only=True) + if not isinstance(processor, LlavaOnevisionProcessor): + raise TypeError(f"Expected LlavaOnevisionProcessor, got {type(processor).__name__}") + + pipeline = LlavaOnevisionPipeline(model, processor) + model_device = next(model.parameters()).device + output = pipeline.run( + prompt=instruction, + images=[image], + device=model_device, + dtype=TorchDevice.choose_torch_dtype(), + ) + + return output + + +@utilities_router.post( + "/image-to-prompt", + operation_id="image_to_prompt", + responses={ + 200: {"model": ImageToPromptResponse}, + }, +) +async def image_to_prompt(body: ImageToPromptRequest) -> ImageToPromptResponse: + """Generate a descriptive prompt from an image using a vision-language model.""" + try: + prompt = await asyncio.to_thread( + _run_image_to_prompt, + body.image_name, + body.model_key, + body.instruction, + ) + return ImageToPromptResponse(prompt=prompt) + except UnknownModelException: + raise HTTPException(status_code=404, detail=f"Model '{body.model_key}' not found") + except ImageFileNotFoundException: + raise HTTPException(status_code=404, detail=f"Image '{body.image_name}' not found") + except (ValueError, TypeError) as e: + raise HTTPException(status_code=422, detail=str(e)) + except Exception as e: + logger.error(f"Error generating prompt from image: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 2fc5fd5a3c0..e53aeb417b2 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -229,6 +229,7 @@ class FieldDescriptions: instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'." flux_redux_conditioning = "FLUX Redux conditioning tensor" vllm_model = "The VLLM model to use" + text_llm_model = "The text language model to use for text generation" flux_fill_conditioning = "FLUX Fill conditioning tensor" flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)" diff --git a/invokeai/app/invocations/text_llm.py b/invokeai/app/invocations/text_llm.py new file mode 100644 index 00000000000..789e65be018 --- /dev/null +++ b/invokeai/app/invocations/text_llm.py @@ -0,0 +1,65 @@ +import torch +from transformers import AutoTokenizer + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import FieldDescriptions, InputField, UIComponent +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.invocations.primitives import StringOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.taxonomy import ModelType +from invokeai.backend.text_llm_pipeline import DEFAULT_SYSTEM_PROMPT, TextLLMPipeline +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "text_llm", + title="Text LLM", + tags=["llm", "text", "prompt"], + category="llm", + version="1.0.0", + classification=Classification.Beta, +) +class TextLLMInvocation(BaseInvocation): + """Run a text language model to generate or expand text (e.g. for prompt expansion).""" + + prompt: str = InputField( + default="", + description="Input text prompt.", + ui_component=UIComponent.Textarea, + ) + system_prompt: str = InputField( + default=DEFAULT_SYSTEM_PROMPT, + description="System prompt that guides the model's behavior.", + ui_component=UIComponent.Textarea, + ) + text_llm_model: ModelIdentifierField = InputField( + title="Text LLM Model", + description=FieldDescriptions.text_llm_model, + ui_model_type=ModelType.TextLLM, + ) + max_tokens: int = InputField( + default=300, + ge=1, + le=2048, + description="Maximum number of tokens to generate.", + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> StringOutput: + model_config = context.models.get_config(self.text_llm_model) + + with context.models.load(self.text_llm_model).model_on_device() as (_, model): + model_abs_path = context.models.get_absolute_path(model_config) + tokenizer = AutoTokenizer.from_pretrained(model_abs_path, local_files_only=True) + + pipeline = TextLLMPipeline(model, tokenizer) + model_device = next(model.parameters()).device + output = pipeline.run( + prompt=self.prompt, + system_prompt=self.system_prompt, + max_new_tokens=self.max_tokens, + device=model_device, + dtype=TorchDevice.choose_torch_dtype(), + ) + + return StringOutput(value=output) diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py index 4d26b4c3347..9059aecebd9 100644 --- a/invokeai/backend/model_manager/configs/factory.py +++ b/invokeai/backend/model_manager/configs/factory.py @@ -97,6 +97,7 @@ T2IAdapter_Diffusers_SDXL_Config, ) from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config +from invokeai.backend.model_manager.configs.text_llm import TextLLM_Diffusers_Config from invokeai.backend.model_manager.configs.textual_inversion import ( TI_File_SD1_Config, TI_File_SD2_Config, @@ -269,6 +270,7 @@ Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()], Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()], Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()], + Annotated[TextLLM_Diffusers_Config, TextLLM_Diffusers_Config.get_tag()], Annotated[ExternalApiModelConfig, ExternalApiModelConfig.get_tag()], # Unknown model (fallback) Annotated[Unknown_Config, Unknown_Config.get_tag()], diff --git a/invokeai/backend/model_manager/configs/text_llm.py b/invokeai/backend/model_manager/configs/text_llm.py new file mode 100644 index 00000000000..a0fb3e009f9 --- /dev/null +++ b/invokeai/backend/model_manager/configs/text_llm.py @@ -0,0 +1,52 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_class_name_from_config_dict_or_raise, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelType, +) + + +class TextLLM_Diffusers_Config(Diffusers_Config_Base, Config_Base): + """Model config for text-only causal language models (e.g. Llama, Phi, Qwen, Mistral).""" + + type: Literal[ModelType.TextLLM] = Field(default=ModelType.TextLLM) + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only") + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + # Check that the model's architecture is a causal language model. + # This covers LlamaForCausalLM, PhiForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM, + # MistralForCausalLM, GemmaForCausalLM, GPTNeoXForCausalLM, etc. + class_name = get_class_name_from_config_dict_or_raise(common_config_paths(mod.path)) + if not class_name.endswith("ForCausalLM"): + raise NotAMatchError(f"model architecture '{class_name}' is not a causal language model") + + # Verify tokenizer files exist to avoid runtime failures + tokenizer_files = {"tokenizer.json", "tokenizer.model", "tokenizer_config.json"} + if not any((mod.path / f).exists() for f in tokenizer_files): + raise NotAMatchError( + f"no tokenizer files found in '{mod.path}' " + f"(expected at least one of: {', '.join(sorted(tokenizer_files))})" + ) + + return cls(**override_fields) diff --git a/invokeai/backend/model_manager/load/model_loaders/text_llm.py b/invokeai/backend/model_manager/load/model_loaders/text_llm.py new file mode 100644 index 00000000000..0ebfe3cc453 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/text_llm.py @@ -0,0 +1,32 @@ +from pathlib import Path +from typing import Optional + +import torch +from transformers import AutoModelForCausalLM + +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextLLM, format=ModelFormat.Diffusers) +class TextLLMModelLoader(ModelLoader): + """Class for loading text causal language models (Llama, Phi, Qwen, Mistral, etc.).""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("Unexpected submodel requested for TextLLM model.") + + # Use float32 for CPU-only models since CPU fp16 is emulated and slow. + dtype = self._torch_dtype + if getattr(config, "cpu_only", False) is True: + dtype = torch.float32 + + model_path = Path(config.path) + model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, torch_dtype=dtype) + return model diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index b2b55ebd3fc..a141d43cf42 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -82,6 +82,7 @@ class ModelType(str, Enum): SigLIP = "siglip" FluxRedux = "flux_redux" LlavaOnevision = "llava_onevision" + TextLLM = "text_llm" ExternalImageGenerator = "external_image_generator" Unknown = "unknown" diff --git a/invokeai/backend/text_llm_pipeline.py b/invokeai/backend/text_llm_pipeline.py new file mode 100644 index 00000000000..69815c1a7f7 --- /dev/null +++ b/invokeai/backend/text_llm_pipeline.py @@ -0,0 +1,56 @@ +import torch +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +DEFAULT_SYSTEM_PROMPT = ( + "You are an expert prompt writer for AI image generation. " + "Given a brief description, expand it into a detailed, vivid prompt suitable for generating high-quality images. " + "Only output the expanded prompt, nothing else." +) + + +class TextLLMPipeline: + """A wrapper for a causal language model + tokenizer for text generation.""" + + def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase): + self._model = model + self._tokenizer = tokenizer + + def run( + self, + prompt: str, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + max_new_tokens: int = 300, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float16, + ) -> str: + # Build messages for chat template if supported, otherwise use raw prompt. + if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template is not None: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + formatted_prompt: str = self._tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + # Fallback for models without chat template + if system_prompt: + formatted_prompt = f"{system_prompt}\n\nUser: {prompt}\nAssistant:" + else: + formatted_prompt = prompt + + inputs = self._tokenizer(formatted_prompt, return_tensors="pt").to(device=device) + output = self._model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=0.7, + top_p=0.9, + ) + + # Decode only the newly generated tokens (exclude the input prompt tokens). + input_length = inputs["input_ids"].shape[1] + generated_tokens = output[0][input_length:] + response = self._tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() + + return response diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 75c5ad6671f..65378016d40 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -378,7 +378,16 @@ "promptHistory": "Prompt History", "clearHistory": "Clear History", "usePrompt": "Use prompt", - "searchPrompts": "Search..." + "searchPrompts": "Search...", + "imageToPrompt": "Image to Prompt", + "selectVisionModel": "Select Vision Model...", + "changeImage": "Change Image", + "uploadImage": "Upload Image", + "generatePrompt": "Generate Prompt", + "expandPromptWithLLM": "Expand Prompt with LLM", + "expandPrompt": "Expand Prompt", + "selectTextLLM": "Select Text LLM...", + "expand": "Expand" }, "queue": { "queue": "Queue", @@ -1284,6 +1293,7 @@ }, "controlLora": "Control LoRA", "llavaOnevision": "LLaVA OneVision", + "textLLM": "Text LLM", "syncModels": "Sync Models", "syncModelsTooltip": "Identify and remove unused model files in the InvokeAI root directory.", "syncModelsDirectory": "Synchronize Models Directory", @@ -3324,6 +3334,10 @@ "whatsNew": { "whatsNewInInvoke": "What's New in Invoke", "items": [ + "LLM Prompt Tools: Use local language models to expand prompts or generate prompts from images. Install a Text LLM model (e.g. Qwen2.5-1.5B-Instruct) to get started.", + "FLUX.2 Klein Support: InvokeAI now supports the new FLUX.2 Klein models (4B and 9B variants) with GGUF, FP8, and Diffusers formats. Features include txt2img, img2img, inpainting, and outpainting. See 'Starter Models' to get started.", + "DyPE support for FLUX models improves high-resolution (>1536 px up to 4K) images. Go to the 'Advanced Options' section to activate.", + "Z-Image Turbo diversity: Active 'Seed Variance Enhancer' under 'Advanced Options' to add diversity to your ZiT gens.", "Multi-user mode supports multiple isolated users on the same server.", "Enhanced support for Z-Image and FLUX.2 Models.", "Multiple user interface enhancements and new canvas features." diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useEncoderModelSettings.ts b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useEncoderModelSettings.ts index b1521f55fce..6b3e9d71010 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useEncoderModelSettings.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useEncoderModelSettings.ts @@ -7,6 +7,7 @@ import type { Qwen3EncoderModelConfig, SigLIPModelConfig, T5EncoderModelConfig, + TextLLMModelConfig, } from 'services/api/types'; type EncoderModelConfig = @@ -15,7 +16,8 @@ type EncoderModelConfig = | Qwen3EncoderModelConfig | CLIPVisionModelConfig | SigLIPModelConfig - | LlavaOnevisionModelConfig; + | LlavaOnevisionModelConfig + | TextLLMModelConfig; export const useEncoderModelSettings = (modelConfig: EncoderModelConfig) => { const encoderModelSettingsDefaults = useMemo(() => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 7cdba474bbf..1c0d2b20c7c 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -17,6 +17,7 @@ import { isSpandrelImageToImageModelConfig, isT2IAdapterModelConfig, isT5EncoderModelConfig, + isTextLLMModelConfig, isTIModelConfig, isUnknownModelConfig, isVAEModelConfig, @@ -122,6 +123,11 @@ const MODEL_CATEGORIES: Record = { i18nKey: 'modelManager.llavaOnevision', filter: isLLaVAModelConfig, }, + text_llm: { + category: 'text_llm', + i18nKey: 'modelManager.textLLM', + filter: isTextLLMModelConfig, + }, external_image_generator: { category: 'external_image_generator', i18nKey: 'modelManager.externalImageGenerator', @@ -176,6 +182,7 @@ export const MODEL_TYPE_TO_LONG_NAME: Record = { clip_embed: 'CLIP Embed', siglip: 'SigLIP', flux_redux: 'FLUX Redux', + text_llm: 'Text LLM', external_image_generator: 'External Image Generator', unknown: 'Unknown', }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/EncoderModelSettings/EncoderModelSettings.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/EncoderModelSettings/EncoderModelSettings.tsx index e10766214f4..9bfe3974ddf 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/EncoderModelSettings/EncoderModelSettings.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/EncoderModelSettings/EncoderModelSettings.tsx @@ -19,6 +19,7 @@ import type { Qwen3EncoderModelConfig, SigLIPModelConfig, T5EncoderModelConfig, + TextLLMModelConfig, } from 'services/api/types'; export type EncoderModelSettingsFormData = { @@ -31,7 +32,8 @@ type EncoderModelConfig = | Qwen3EncoderModelConfig | CLIPVisionModelConfig | SigLIPModelConfig - | LlavaOnevisionModelConfig; + | LlavaOnevisionModelConfig + | TextLLMModelConfig; type Props = { modelConfig: EncoderModelConfig; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx index d29d330facd..365f7cff4b8 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx @@ -21,6 +21,7 @@ import { type Qwen3EncoderModelConfig, type SigLIPModelConfig, type T5EncoderModelConfig, + type TextLLMModelConfig, } from 'services/api/types'; import { isExternalModel } from './isExternalModel'; @@ -37,7 +38,8 @@ type EncoderModelConfig = | Qwen3EncoderModelConfig | CLIPVisionModelConfig | SigLIPModelConfig - | LlavaOnevisionModelConfig; + | LlavaOnevisionModelConfig + | TextLLMModelConfig; const isEncoderModel = (modelConfig: AnyModelConfigWithExternal): modelConfig is EncoderModelConfig => { return ( @@ -46,7 +48,8 @@ const isEncoderModel = (modelConfig: AnyModelConfigWithExternal): modelConfig is modelConfig.type === 'qwen3_encoder' || modelConfig.type === 'clip_vision' || modelConfig.type === 'siglip' || - modelConfig.type === 'llava_onevision' + modelConfig.type === 'llava_onevision' || + modelConfig.type === 'text_llm' ); }; diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 75c3415cefb..d1aa0523a43 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -121,6 +121,7 @@ export const zModelType = z.enum([ 'vae', 'lora', 'llava_onevision', + 'text_llm', 'control_lora', 'controlnet', 't2i_adapter', diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index 89169b5ea54..5167dd1527b 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -1,3 +1,5 @@ +import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine'; +import { dropTargetForElements, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter'; import { Box, Flex, Textarea } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize'; @@ -7,6 +9,9 @@ import { selectPositivePrompt, selectPositivePromptHistory, } from 'features/controlLayers/store/paramsSlice'; +import { singleImageDndSource } from 'features/dnd/dnd'; +import { DndDropOverlay } from 'features/dnd/DndDropOverlay'; +import type { DndTargetState } from 'features/dnd/types'; import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton'; import { NegativePromptToggleButton } from 'features/parameters/components/Core/NegativePromptToggleButton'; import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel'; @@ -14,7 +19,10 @@ import { PromptOverlayButtonWrapper } from 'features/parameters/components/Promp import { PromptResizeHandle } from 'features/parameters/components/Prompts/PromptResizeHandle'; import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt'; import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; +import { ExpandPromptButton } from 'features/prompt/ExpandPromptButton'; +import { ImageToPromptButton } from 'features/prompt/ImageToPromptButton'; import { PromptPopover } from 'features/prompt/PromptPopover'; +import { clearPromptUndo, consumePromptUndo } from 'features/prompt/promptUndo'; import { usePrompt } from 'features/prompt/usePrompt'; import { usePromptAttentionHotkeys } from 'features/prompt/usePromptAttentionHotkeys'; import { @@ -22,11 +30,13 @@ import { selectStylePresetViewMode, } from 'features/stylePresets/store/stylePresetSlice'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; -import React, { memo, useCallback, useRef } from 'react'; +import React, { memo, useCallback, useEffect, useRef, useState } from 'react'; import type { HotkeyCallback } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { useClickAway } from 'react-use'; import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets'; +import { useLlavaModels } from 'services/api/hooks/modelsByType'; +import type { ImageDTO } from 'services/api/types'; import { PositivePromptHistoryIconButton } from './PositivePromptHistory'; @@ -116,6 +126,8 @@ export const ParamPositivePrompt = memo(() => { const viewMode = useAppSelector(selectStylePresetViewMode); const activeStylePresetId = useAppSelector(selectStylePresetActivePresetId); const modelSupportsNegativePrompt = useAppSelector(selectModelSupportsNegativePrompt); + const [llavaModels] = useLlavaModels(); + const hasLlavaModels = llavaModels.length > 0; const promptHistoryApi = usePromptHistory(); @@ -139,15 +151,41 @@ export const ParamPositivePrompt = memo(() => { // When the user changes the prompt, reset the prompt history state. This event is not fired when the prompt is // changed via the prompt history navigation. promptHistoryApi.reset(); + // Clear LLM undo state when the user types manually + clearPromptUndo(); }, [dispatch, promptHistoryApi] ); - const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({ + const { + onChange, + isOpen, + onClose, + onOpen, + onSelect, + onKeyDown: onKeyDownPrompt, + onFocus, + } = usePrompt({ prompt, textareaRef: textareaRef, onChange: handleChange, }); + const onKeyDown = useCallback( + (e: React.KeyboardEvent) => { + // Intercept Ctrl+Z to undo LLM prompt changes + if (e.key === 'z' && (e.ctrlKey || e.metaKey) && !e.shiftKey) { + const previousPrompt = consumePromptUndo(); + if (previousPrompt !== null) { + e.preventDefault(); + dispatch(positivePromptChanged(previousPrompt)); + return; + } + } + onKeyDownPrompt(e); + }, + [dispatch, onKeyDownPrompt] + ); + // When the user clicks away from the textarea, reset the prompt history state. useClickAway(textareaRef, promptHistoryApi.reset); @@ -201,8 +239,44 @@ export const ParamPositivePrompt = memo(() => { onPromptChange: (prompt) => dispatch(positivePromptChanged(prompt)), }); + // Drop target for gallery images -> Image to Prompt + const dropTargetRef = useRef(null); + const [droppedImage, setDroppedImage] = useState(undefined); + const [dndState, setDndState] = useState('idle'); + + const clearDroppedImage = useCallback(() => { + setDroppedImage(undefined); + }, []); + + useEffect(() => { + const element = dropTargetRef.current; + if (!element || !hasLlavaModels) { + return; + } + + return combine( + dropTargetForElements({ + element, + canDrop: ({ source }) => singleImageDndSource.typeGuard(source.data), + onDragEnter: () => setDndState('over'), + onDragLeave: () => setDndState('potential'), + onDrop: ({ source }) => { + setDndState('idle'); + if (singleImageDndSource.typeGuard(source.data)) { + setDroppedImage(source.data.payload.imageDTO); + } + }, + }), + monitorForElements({ + canMonitor: ({ source }) => singleImageDndSource.typeGuard(source.data), + onDragStart: () => setDndState('potential'), + onDrop: () => setDndState('idle'), + }) + ); + }, [hasLlavaModels]); + return ( - +