diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index 32140da667a..f0cc5c8961e 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -574,6 +574,20 @@ "type": "", "validation": {} }, + { + "category": "GENERATION", + "default": "round_robin", + "description": "Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.", + "env_var": "INVOKEAI_SESSION_QUEUE_MODE", + "literal_values": [ + "FIFO", + "round_robin" + ], + "name": "session_queue_mode", + "required": false, + "type": "typing.Literal['FIFO', 'round_robin']", + "validation": {} + }, { "category": "GENERATION", "default": false, diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 41a5a411c7a..d62cac5095f 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -141,12 +141,11 @@ async def get_queue_item_ids( queue_id: str = Path(description="The queue id to perform this operation on"), order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters. Non-admin users only see their own items.""" + """Gets all queue item ids that match the given parameters. The IDs themselves are not sensitive; + per-item field redaction is performed when the items are fetched via list_all_queue_items or + get_queue_items_by_item_ids.""" try: - user_id = None if current_user.is_admin else current_user.user_id - return ApiDependencies.invoker.services.session_queue.get_queue_item_ids( - queue_id=queue_id, order_dir=order_dir, user_id=user_id - ) + return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}") @@ -436,10 +435,15 @@ async def get_queue_status( current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> SessionQueueAndProcessorStatus: - """Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it.""" + """Gets the status of the session queue. Returns global counts plus the calling user's own + pending/in_progress counts (so the UI can show an X/Y badge). Non-admin users cannot see the + current item's identifiers unless they own it.""" try: - user_id = None if current_user.is_admin else current_user.user_id - queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id) + queue = ApiDependencies.invoker.services.session_queue.get_queue_status( + queue_id, + user_id=current_user.user_id, + is_admin=current_user.is_admin, + ) processor = ApiDependencies.invoker.services.session_processor.get_status() return SessionQueueAndProcessorStatus(queue=queue, processor=processor) except Exception as e: diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 5783b804c0b..b02b5bbb067 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -260,20 +260,37 @@ async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) + def _owner_and_admin_sids(self, owner_user_id: str) -> list[str]: + """Sids belonging to the event's owner or to any admin. + + Used as `skip_sid` when broadcasting a sanitized companion event to the queue room, + so the owner and admins (who already received the full event) don't get a second + copy that would clobber their cache with redacted values. + """ + return [ + sid + for sid, info in self._socket_users.items() + if info.get("user_id") == owner_user_id or info.get("is_admin") + ] + async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): """Handle queue events with user isolation. - All queue item events (invocation events AND QueueItemStatusChangedEvent) are - private to the owning user and admins. They carry unsanitized user_id, batch_id, - session_id, origin, destination and error metadata, and must never be broadcast - to the whole queue room — otherwise any other authenticated subscriber could - observe cross-user queue activity. + Queue events split into two routing paths: - RecallParametersUpdatedEvent is also private to the owner + admins. + 1. The owner and admins receive the full unsanitized event in their `user:{id}` / + `admin` rooms. The full payload may include batch_id, session_id, origin, + destination, error metadata, etc. - BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and - is also routed privately. QueueClearedEvent is the only queue event that - is still broadcast to the whole queue room. + 2. For events that other authenticated users need to know about so their queue list + and badge counts stay in sync (QueueItemStatusChangedEvent and BatchEnqueuedEvent), + a sanitized companion event is also emitted to the full queue room with the + owner's and admins' sids in `skip_sid`. The companion uses `user_id="redacted"` + as a sentinel so the frontend handler knows to do tag invalidation only and skip + per-session side effects. + + InvocationEventBase events stay private (owner + admins only). RecallParametersUpdatedEvent + is also private. QueueClearedEvent has no user identity and is broadcast to the queue room. IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase inherits from QueueItemEventBase. The order of isinstance checks matters! @@ -302,10 +319,51 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room") - # Other queue item events (QueueItemStatusChangedEvent) carry unsanitized - # user_id, batch_id, session_id, origin, destination and error metadata. - # They are private to the owning user + admins — never broadcast to the - # full queue room. + # QueueItemStatusChangedEvent: full to owner+admin, sanitized to everyone else in + # the queue room so their queue list, badge, and item caches refresh. + elif isinstance(event_data, QueueItemStatusChangedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + + sanitized = event_data.model_copy( + update={ + "user_id": "redacted", + "batch_id": "redacted", + "session_id": "redacted", + "origin": None, + "destination": None, + "error_type": None, + "error_message": None, + "error_traceback": None, + } + ) + # Strip identifying fields out of the embedded batch_status / queue_status too. + sanitized.batch_status = sanitized.batch_status.model_copy( + update={"batch_id": "redacted", "origin": None, "destination": None} + ) + sanitized.queue_status = sanitized.queue_status.model_copy( + update={ + "item_id": None, + "session_id": None, + "batch_id": None, + "user_pending": None, + "user_in_progress": None, + } + ) + await self._sio.emit( + event=event_name, + data=sanitized.model_dump(mode="json"), + room=event_data.queue_id, + skip_sid=self._owner_and_admin_sids(event_data.user_id), + ) + + logger.debug( + f"Emitted queue_item_status_changed: full to {user_room}+admin, sanitized to queue {event_data.queue_id}" + ) + + # Other queue item events (currently none beyond QueueItemStatusChangedEvent that + # carry user_id) stay private to owner + admins. elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"): user_room = f"user:{event_data.user_id}" await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) @@ -320,14 +378,25 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room") - # BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and - # enqueued counts. Route it privately to the owner + admins so other - # users do not observe cross-user batch activity. + # BatchEnqueuedEvent: full to owner+admin, sanitized to everyone else in the queue + # room so their badge total and queue list pick up the new items. elif isinstance(event_data, BatchEnqueuedEvent): user_room = f"user:{event_data.user_id}" await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") - logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room") + + sanitized = event_data.model_copy( + update={"user_id": "redacted", "batch_id": "redacted", "origin": None} + ) + await self._sio.emit( + event=event_name, + data=sanitized.model_dump(mode="json"), + room=event_data.queue_id, + skip_sid=self._owner_and_admin_sids(event_data.user_id), + ) + logger.debug( + f"Emitted batch_enqueued: full to {user_room}+admin, sanitized to queue {event_data.queue_id}" + ) else: # For remaining queue events (e.g. QueueClearedEvent) that do not diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 729eb1332c0..c99461b3fab 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -30,6 +30,7 @@ ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] +SESSION_QUEUE_MODE = Literal["FIFO", "round_robin"] CONFIG_SCHEMA_VERSION = "4.0.2" EXTERNAL_PROVIDER_CONFIG_FIELDS = ( "external_gemini_api_key", @@ -108,6 +109,7 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. + session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
Valid values: `FIFO`, `round_robin` clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. allow_nodes: List of nodes to allow. Omit to allow all. @@ -203,6 +205,7 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") + session_queue_mode: SESSION_QUEUE_MODE = Field(default="round_robin", description="Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.") clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.") max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.") diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 14b93d97fc7..04bd81e3174 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -73,8 +73,19 @@ def is_full(self, queue_id: str) -> IsFullResult: pass @abstractmethod - def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: - """Gets the status of the queue. If user_id is provided, also includes user-specific counts.""" + def get_queue_status( + self, + queue_id: str, + user_id: Optional[str] = None, + is_admin: bool = False, + ) -> SessionQueueStatus: + """Gets the status of the queue. + + Always returns global pending/in_progress/etc. counts. When user_id is provided, also + populates user_pending and user_in_progress with that user's own counts (so the UI can + render an X/Y badge). When is_admin is False, the current item's identifiers are hidden + unless the calling user owns the in-progress item. + """ pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index d87221fbbae..7472ea07f63 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -309,6 +309,12 @@ class SessionQueueStatus(BaseModel): failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") total: int = Field(..., description="Total number of queue items") + user_pending: Optional[int] = Field( + default=None, description="Number of pending queue items for the calling user (multiuser only)" + ) + user_in_progress: Optional[int] = Field( + default=None, description="Number of in-progress queue items for the calling user (multiuser only)" + ) class SessionQueueCountsByDestination(BaseModel): diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 95fb16fcbed..326baed1b31 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -210,9 +210,45 @@ async def enqueue_batch( return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: - with self._db.transaction() as cursor: - cursor.execute( - """--sql + config = self.__invoker.services.configuration + use_round_robin = config.multiuser and config.session_queue_mode == "round_robin" + + if use_round_robin: + query = """--sql + WITH user_last_served AS ( + -- Track when each user last had an item started, to determine whose turn it is. + SELECT user_id, MAX(started_at) AS last_served_at + FROM session_queue + WHERE started_at IS NOT NULL + GROUP BY user_id + ), + user_next_item AS ( + -- For each user, select their single best pending item (highest priority, then oldest). + SELECT + user_id, + item_id, + ROW_NUMBER() OVER ( + PARTITION BY user_id + ORDER BY priority DESC, item_id ASC + ) AS rn + FROM session_queue + WHERE status = 'pending' + ) + SELECT + sq.*, + u.display_name AS user_display_name, + u.email AS user_email + FROM session_queue sq + LEFT JOIN users u ON sq.user_id = u.user_id + JOIN user_next_item uni ON sq.item_id = uni.item_id AND uni.rn = 1 + LEFT JOIN user_last_served uls ON sq.user_id = uls.user_id + ORDER BY + COALESCE(uls.last_served_at, '1970-01-01') ASC, + sq.item_id ASC + LIMIT 1 + """ + else: + query = """--sql SELECT sq.*, u.display_name as user_display_name, @@ -225,7 +261,9 @@ def dequeue(self) -> Optional[SessionQueueItem]: sq.item_id ASC LIMIT 1 """ - ) + + with self._db.transaction() as cursor: + cursor.execute(query) result = cast(Union[sqlite3.Row, None], cursor.fetchone()) if result is None: return None @@ -846,9 +884,25 @@ def get_queue_item_ids( return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids)) - def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: + def get_queue_status( + self, + queue_id: str, + user_id: Optional[str] = None, + is_admin: bool = False, + ) -> SessionQueueStatus: with self._db.transaction() as cursor: - # When user_id is provided (non-admin), only count that user's items + cursor.execute( + """--sql + SELECT status, count(*) + FROM session_queue + WHERE queue_id = ? + GROUP BY status + """, + (queue_id,), + ) + counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + + user_counts_result: list[sqlite3.Row] = [] if user_id is not None: cursor.execute( """--sql @@ -859,24 +913,23 @@ def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> Sess """, (queue_id, user_id), ) - else: - cursor.execute( - """--sql - SELECT status, count(*) - FROM session_queue - WHERE queue_id = ? - GROUP BY status - """, - (queue_id,), - ) - counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + user_counts_result = cast(list[sqlite3.Row], cursor.fetchall()) current_item = self.get_current(queue_id=queue_id) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} - # For non-admin users, hide current item details if they don't own it - show_current_item = current_item is not None and (user_id is None or current_item.user_id == user_id) + user_pending: Optional[int] = None + user_in_progress: Optional[int] = None + if user_id is not None: + user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result} + user_pending = user_counts.get("pending", 0) + user_in_progress = user_counts.get("in_progress", 0) + + # Non-admins cannot see the current item's identifiers unless they own it. + show_current_item = current_item is not None and ( + is_admin or user_id is None or current_item.user_id == user_id + ) return SessionQueueStatus( queue_id=queue_id, @@ -889,6 +942,8 @@ def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> Sess failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, + user_pending=user_pending, + user_in_progress=user_in_progress, ) def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: diff --git a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx index e8636466066..1ba2ffd572d 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx @@ -1,4 +1,6 @@ import { Badge, Portal } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectIsAuthenticated } from 'features/auth/store/authSlice'; import type { RefObject } from 'react'; import { memo, useEffect, useMemo, useState } from 'react'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; @@ -10,14 +12,24 @@ type Props = { type SessionQueueStatus = components['schemas']['SessionQueueStatus']; +const hasUserCounts = (queueData: SessionQueueStatus): boolean => { + return ( + queueData.user_pending !== undefined && + queueData.user_pending !== null && + queueData.user_in_progress !== undefined && + queueData.user_in_progress !== null + ); +}; + /** - * Calculates the appropriate badge text based on queue status. + * Calculates the appropriate badge text based on queue status and authentication state. * Returns null if badge should be hidden. * - * In multiuser mode, the backend already scopes counts to the current user for non-admins, - * so pending + in_progress reflects the user's own queue items. + * In multiuser mode, the badge is "X/Y" where X is the calling user's pending+in_progress count + * and Y is the total across all users. In single-user mode (or when user counts are unavailable) + * the badge shows the total only. */ -const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null => { +const getBadgeText = (queueData: SessionQueueStatus | undefined, isAuthenticated: boolean): string | null => { if (!queueData) { return null; } @@ -28,18 +40,24 @@ const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null return null; } + if (isAuthenticated && hasUserCounts(queueData)) { + const userPending = queueData.user_pending! + queueData.user_in_progress!; + return `${userPending}/${totalPending}`; + } + return totalPending.toString(); }; export const QueueCountBadge = memo(({ targetRef }: Props) => { const [badgePos, setBadgePos] = useState<{ x: string; y: string } | null>(null); + const isAuthenticated = useAppSelector(selectIsAuthenticated); const { queueData } = useGetQueueStatusQuery(undefined, { selectFromResult: (res) => ({ queueData: res.data?.queue, }), }); - const badgeText = useMemo(() => getBadgeText(queueData), [queueData]); + const badgeText = useMemo(() => getBadgeText(queueData, isAuthenticated), [queueData, isAuthenticated]); useEffect(() => { if (!targetRef.current) { diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 4b8e4da95a5..f12ec2e538e 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1795,7 +1795,9 @@ export type paths = { }; /** * Get Queue Item Ids - * @description Gets all queue item ids that match the given parameters. Non-admin users only see their own items. + * @description Gets all queue item ids that match the given parameters. The IDs themselves are not sensitive; + * per-item field redaction is performed when the items are fetched via list_all_queue_items or + * get_queue_items_by_item_ids. */ get: operations["get_queue_item_ids"]; put?: never; @@ -2055,7 +2057,9 @@ export type paths = { }; /** * Get Queue Status - * @description Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it. + * @description Gets the status of the session queue. Returns global counts plus the calling user's own + * pending/in_progress counts (so the UI can show an X/Y badge). Non-admin users cannot see the + * current item's identifiers unless they own it. */ get: operations["get_queue_status"]; put?: never; @@ -15641,6 +15645,7 @@ export type components = { * force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). * pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. * max_queue_size: Maximum number of items in the session queue. + * session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
Valid values: `FIFO`, `round_robin` * clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. * max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. * allow_nodes: List of nodes to allow. Omit to allow all. @@ -15972,6 +15977,13 @@ export type components = { * @default 10000 */ max_queue_size?: number; + /** + * Session Queue Mode + * @description Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting. + * @default round_robin + * @enum {string} + */ + session_queue_mode?: "FIFO" | "round_robin"; /** * Clear Queue On Startup * @description Empties session queue on startup. If true, disables `max_queue_history`. @@ -26807,6 +26819,16 @@ export type components = { * @description Total number of queue items */ total: number; + /** + * User Pending + * @description Number of pending queue items for the calling user (multiuser only) + */ + user_pending?: number | null; + /** + * User In Progress + * @description Number of in-progress queue items for the calling user (multiuser only) + */ + user_in_progress?: number | null; }; /** * SetupRequest diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 6771e9e7e00..d742ad09bf5 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -388,6 +388,24 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis }); socket.on('queue_item_status_changed', (data) => { + // Sanitized companion event sent to non-owner queue subscribers in multiuser mode. The + // backend sets user_id="redacted" and clears identifiers/error fields. We must not run + // payload-driven cache mutations or per-session side effects (node state reset, progress + // clear, completion bookkeeping) — those belong to the owner. Just invalidate queue tags + // so the non-owner's queue list and badge counts refetch with sanitized data. + if (data.user_id === 'redacted') { + log.trace({ data }, `Sanitized queue_item_status_changed for item ${data.item_id}`); + const tags: ApiTagDescription[] = [ + 'SessionQueueStatus', + 'SessionQueueItemIdList', + { type: 'SessionQueueItem', id: data.item_id }, + { type: 'SessionQueueItem', id: LIST_TAG }, + { type: 'SessionQueueItem', id: LIST_ALL_TAG }, + ]; + dispatch(queueApi.util.invalidateTags(tags)); + return; + } + if (finishedQueueItemIds.has(data.item_id)) { log.trace({ data }, `Received event for already-finished queue item ${data.item_id}`); return; diff --git a/tests/app/routers/test_multiuser_authorization.py b/tests/app/routers/test_multiuser_authorization.py index 85354c6a577..813b5170a09 100644 --- a/tests/app/routers/test_multiuser_authorization.py +++ b/tests/app/routers/test_multiuser_authorization.py @@ -1332,14 +1332,31 @@ def test_get_queue_status_hides_current_item_for_non_owner(self): assert status_obj.session_id is None assert status_obj.batch_id is None - def test_session_queue_status_no_user_fields(self): - """SessionQueueStatus should not have user_pending/user_in_progress fields anymore. - Non-admin users now get their own counts in the main pending/in_progress fields.""" + def test_session_queue_status_has_user_fields(self): + """SessionQueueStatus exposes user_pending/user_in_progress so the queue badge + can render an X/Y count (X = caller's jobs, Y = global total).""" from invokeai.app.services.session_queue.session_queue_common import SessionQueueStatus fields = set(SessionQueueStatus.model_fields.keys()) - assert "user_pending" not in fields - assert "user_in_progress" not in fields + assert "user_pending" in fields + assert "user_in_progress" in fields + + status_obj = SessionQueueStatus( + queue_id="default", + item_id=None, + session_id=None, + batch_id=None, + pending=5, + in_progress=1, + completed=0, + failed=0, + canceled=0, + total=6, + user_pending=2, + user_in_progress=1, + ) + assert status_obj.user_pending == 2 + assert status_obj.user_in_progress == 1 # =========================================================================== @@ -1707,8 +1724,11 @@ def test_batch_enqueued_event_carries_user_id(self) -> None: assert event.queue_id == "default" def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None: - """Verify that _handle_queue_event emits QueueItemStatusChangedEvent ONLY to - user:{user_id} and admin rooms, never to the queue_id room.""" + """_handle_queue_event must emit the FULL QueueItemStatusChangedEvent only to the + owner's user room and the admin room. A sanitized companion (user_id="redacted", + identifiers stripped) is also emitted to the queue_id room so other users' UIs can + refresh, with the owner's and admins' sids in skip_sid so they don't get a duplicate + that would clobber their cache.""" import asyncio from unittest.mock import AsyncMock @@ -1757,20 +1777,60 @@ def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None ), ) + # Track owner sid so we can verify skip_sid is honored + socketio._socket_users["sid-owner"] = {"user_id": "owner-xyz", "is_admin": False} + socketio._socket_users["sid-admin"] = {"user_id": "admin-1", "is_admin": True} + socketio._socket_users["sid-other"] = {"user_id": "other-user", "is_admin": False} + mock_emit = AsyncMock() socketio._sio.emit = mock_emit asyncio.run(socketio._handle_queue_event(("queue_item_status_changed", event))) - rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list] - assert "user:owner-xyz" in rooms_emitted_to - assert "admin" in rooms_emitted_to - # CRITICAL: must NOT emit to the queue_id room — that would leak to other users - assert "default" not in rooms_emitted_to + # Collect (room, payload, skip_sid) for each emit call + emits = [ + (c.kwargs.get("room"), c.kwargs.get("data"), c.kwargs.get("skip_sid")) for c in mock_emit.call_args_list + ] + + # Full event must go to owner room and admin room with original sensitive fields + owner_emits = [(p, s) for r, p, s in emits if r == "user:owner-xyz"] + admin_emits = [(p, s) for r, p, s in emits if r == "admin"] + assert len(owner_emits) == 1 and len(admin_emits) == 1 + for payload, _ in owner_emits + admin_emits: + assert payload["user_id"] == "owner-xyz" + assert payload["batch_id"] == "batch-private" + assert payload["session_id"] == "sess-private" + assert payload["destination"] == "canvas" + + # A sanitized companion event must go to the queue_id room with sensitive fields cleared + queue_emits = [(p, s) for r, p, s in emits if r == "default"] + assert len(queue_emits) == 1, "expected exactly one sanitized emit to queue room" + sanitized_payload, skip_sid = queue_emits[0] + assert sanitized_payload["user_id"] == "redacted" + assert sanitized_payload["batch_id"] == "redacted" + assert sanitized_payload["session_id"] == "redacted" + assert sanitized_payload["origin"] is None + assert sanitized_payload["destination"] is None + assert sanitized_payload["error_type"] is None + assert sanitized_payload["batch_status"]["batch_id"] == "redacted" + assert sanitized_payload["batch_status"]["destination"] is None + assert sanitized_payload["queue_status"]["item_id"] is None + assert sanitized_payload["queue_status"]["batch_id"] is None + assert sanitized_payload["queue_status"]["user_pending"] is None + # Owner and admin sids must be skipped so they don't receive the duplicate + assert "sid-owner" in skip_sid + assert "sid-admin" in skip_sid + # Third-party user must NOT be skipped — they need the sanitized event + assert "sid-other" not in skip_sid + # Status (non-sensitive) is preserved so the non-owner UI knows what changed + assert sanitized_payload["status"] == "in_progress" + assert sanitized_payload["item_id"] == 1 def test_batch_enqueued_routed_privately(self, socketio: Any) -> None: - """Verify that _handle_queue_event emits BatchEnqueuedEvent ONLY to - user:{user_id} and admin rooms, never to the queue_id room.""" + """_handle_queue_event must emit the FULL BatchEnqueuedEvent only to the owner's + user room and the admin room. A sanitized companion (user_id="redacted", batch_id + and origin stripped) is also emitted to the queue_id room so other users' badge + totals refresh, with owner/admin sids in skip_sid.""" import asyncio from unittest.mock import AsyncMock @@ -1791,15 +1851,39 @@ def test_batch_enqueued_routed_privately(self, socketio: Any) -> None: ) event = BatchEnqueuedEvent.build(enqueue_result, user_id="owner-zzz") + socketio._socket_users["sid-owner"] = {"user_id": "owner-zzz", "is_admin": False} + socketio._socket_users["sid-admin"] = {"user_id": "admin-1", "is_admin": True} + socketio._socket_users["sid-other"] = {"user_id": "other-user", "is_admin": False} + mock_emit = AsyncMock() socketio._sio.emit = mock_emit asyncio.run(socketio._handle_queue_event(("batch_enqueued", event))) - rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list] - assert "user:owner-zzz" in rooms_emitted_to - assert "admin" in rooms_emitted_to - assert "default" not in rooms_emitted_to + emits = [ + (c.kwargs.get("room"), c.kwargs.get("data"), c.kwargs.get("skip_sid")) for c in mock_emit.call_args_list + ] + + # Full event to owner + admin contains the real batch_id and origin + owner_emits = [(p, s) for r, p, s in emits if r == "user:owner-zzz"] + admin_emits = [(p, s) for r, p, s in emits if r == "admin"] + assert len(owner_emits) == 1 and len(admin_emits) == 1 + for payload, _ in owner_emits + admin_emits: + assert payload["user_id"] == "owner-zzz" + assert payload["batch_id"] == "batch-pvt" + assert payload["origin"] == "workflows" + + # Sanitized event to queue room: user/batch/origin redacted, owner+admin skipped + queue_emits = [(p, s) for r, p, s in emits if r == "default"] + assert len(queue_emits) == 1 + sanitized_payload, skip_sid = queue_emits[0] + assert sanitized_payload["user_id"] == "redacted" + assert sanitized_payload["batch_id"] == "redacted" + assert sanitized_payload["origin"] is None + assert sanitized_payload["enqueued"] == 5 # count is non-sensitive + assert "sid-owner" in skip_sid + assert "sid-admin" in skip_sid + assert "sid-other" not in skip_sid def test_queue_cleared_still_broadcast(self, socketio: Any) -> None: """QueueClearedEvent does not carry user identity and should still be broadcast diff --git a/tests/app/services/session_queue/test_session_queue_dequeue.py b/tests/app/services/session_queue/test_session_queue_dequeue.py new file mode 100644 index 00000000000..0f82f2babaa --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_dequeue.py @@ -0,0 +1,214 @@ +"""Tests for session queue dequeue() ordering: FIFO and round-robin modes.""" + +import json +import uuid +from typing import Optional + +import pytest +from pydantic_core import to_jsonable_python + +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState + +_EMPTY_SESSION_JSON = json.dumps(to_jsonable_python(GraphExecutionState(graph=Graph()).model_dump())) + + +@pytest.fixture +def session_queue_fifo(mock_invoker: Invoker) -> SqliteSessionQueue: + """Queue backed by a single-user (FIFO) invoker.""" + # Default config has multiuser=False, so FIFO is always used. + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +@pytest.fixture +def session_queue_round_robin(mock_invoker: Invoker) -> SqliteSessionQueue: + """Queue backed by a multiuser invoker with round_robin mode.""" + mock_invoker.services.configuration = InvokeAIAppConfig( + use_memory_db=True, + node_cache_size=0, + multiuser=True, + session_queue_mode="round_robin", + ) + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert_queue_item( + session_queue: SqliteSessionQueue, + queue_id: str, + user_id: str, + priority: int = 0, +) -> int: + """Directly insert a minimal queue item and return its item_id.""" + session_id = str(uuid.uuid4()) + batch_id = str(uuid.uuid4()) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (queue_id, _EMPTY_SESSION_JSON, session_id, batch_id, None, priority, None, None, None, None, user_id), + ) + return cursor.lastrowid # type: ignore[return-value] + + +def _dequeue_user_ids(session_queue: SqliteSessionQueue, count: int) -> list[Optional[str]]: + """Dequeue `count` items and return the list of user_ids in dequeue order.""" + result = [] + for _ in range(count): + item = session_queue.dequeue() + result.append(item.user_id if item is not None else None) + return result + + +# --------------------------------------------------------------------------- +# FIFO tests +# --------------------------------------------------------------------------- + + +def test_fifo_single_user_order(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: items from a single user are dequeued in insertion order.""" + queue_id = "default" + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_fifo, 3) + assert user_ids == ["user_a", "user_a", "user_a"] + + +def test_fifo_multi_user_preserves_insertion_order(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: jobs from multiple users are dequeued in strict insertion order, not interleaved.""" + queue_id = "default" + # Insert A1, A2, B1, C1, C2, A3 – FIFO should preserve this exact order. + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_b") + _insert_queue_item(session_queue_fifo, queue_id, "user_c") + _insert_queue_item(session_queue_fifo, queue_id, "user_c") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_fifo, 6) + assert user_ids == ["user_a", "user_a", "user_b", "user_c", "user_c", "user_a"] + + +def test_fifo_priority_respected(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: higher-priority items are dequeued before lower-priority ones.""" + queue_id = "default" + _insert_queue_item(session_queue_fifo, queue_id, "user_a", priority=0) + _insert_queue_item(session_queue_fifo, queue_id, "user_a", priority=10) + + user_ids = _dequeue_user_ids(session_queue_fifo, 2) + # Both are user_a; second inserted item has higher priority and should come first. + assert user_ids == ["user_a", "user_a"] + + +def test_fifo_returns_none_when_empty(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: dequeue returns None when the queue is empty.""" + assert session_queue_fifo.dequeue() is None + + +# --------------------------------------------------------------------------- +# Round-robin tests +# --------------------------------------------------------------------------- + + +def test_round_robin_interleaves_users(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: jobs from multiple users are interleaved one per user per round. + + Queue insertion order (matching the issue example): + A job 1, A job 2, B job 1, C job 1, C job 2, A job 3 + + Expected dequeue order: + A job 1, B job 1, C job 1, A job 2, C job 2, A job 3 + """ + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_b") + _insert_queue_item(session_queue_round_robin, queue_id, "user_c") + _insert_queue_item(session_queue_round_robin, queue_id, "user_c") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 6) + assert user_ids == ["user_a", "user_b", "user_c", "user_a", "user_c", "user_a"] + + +def test_round_robin_single_user_behaves_like_fifo(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin with only one user produces the same order as FIFO.""" + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 3) + assert user_ids == ["user_a", "user_a", "user_a"] + + +def test_round_robin_handles_user_joining_mid_queue(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: a user who joins later is correctly interleaved.""" + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_b") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 3) + # Round 1: A (oldest rank-1 item), B (rank-1 item) + # Round 2: A (rank-2 item) + assert user_ids == ["user_a", "user_b", "user_a"] + + +def test_round_robin_returns_none_when_empty(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: dequeue returns None when the queue is empty.""" + assert session_queue_round_robin.dequeue() is None + + +def test_round_robin_priority_within_user_respected(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: within a single user's items, higher priority is dequeued first.""" + queue_id = "default" + # Insert low-priority item first, then high-priority for same user. + _insert_queue_item(session_queue_round_robin, queue_id, "user_a", priority=0) + _insert_queue_item(session_queue_round_robin, queue_id, "user_a", priority=10) + _insert_queue_item(session_queue_round_robin, queue_id, "user_b", priority=0) + + # Round 1: user_a's best item (priority 10), user_b's only item. + # Round 2: user_a's remaining item (priority 0). + items = [] + for _ in range(3): + item = session_queue_round_robin.dequeue() + assert item is not None + items.append((item.user_id, item.priority)) + + assert items[0] == ("user_a", 10) + assert items[1] == ("user_b", 0) + assert items[2] == ("user_a", 0) + + +def test_round_robin_ignored_in_single_user_mode(mock_invoker: Invoker) -> None: + """When multiuser=False, round_robin config is ignored and FIFO is used.""" + mock_invoker.services.configuration = InvokeAIAppConfig( + use_memory_db=True, + node_cache_size=0, + multiuser=False, + session_queue_mode="round_robin", + ) + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + + queue_id = "default" + _insert_queue_item(queue, queue_id, "user_a") + _insert_queue_item(queue, queue_id, "user_a") + _insert_queue_item(queue, queue_id, "user_b") + + # FIFO order: user_a, user_a, user_b + user_ids = _dequeue_user_ids(queue, 3) + assert user_ids == ["user_a", "user_a", "user_b"]