diff --git a/docs/src/content/docs/configuration/invokeai-yaml.mdx b/docs/src/content/docs/configuration/invokeai-yaml.mdx index 987c8eb98a2..6ac56053928 100644 --- a/docs/src/content/docs/configuration/invokeai-yaml.mdx +++ b/docs/src/content/docs/configuration/invokeai-yaml.mdx @@ -114,6 +114,39 @@ Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These These options set the paths of various directories and files used by InvokeAI. Any user-defined paths should be absolute paths. +#### Multi-GPU Generation + +On a machine with more than one GPU, InvokeAI can run several generation sessions at the same time — one per GPU — instead of processing the queue one job at a time. Jobs are distributed fairly across users, so a single user's large batch cannot monopolize every GPU while others wait. + +This is controlled by the `generation_devices` setting: + +```yaml +generation_devices: auto # default value +``` + +| Value | Behavior | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------- | +| `auto` | Use every available CUDA GPU, running one generation session per GPU concurrently. This is the default. | +| `[cuda:0,cuda:1]` | Use the specific devices listed, one session per device. Useful for reserving a GPU for other work. | +| `[cuda:0]` | Use a single specific device. Generation runs serially, as it did before multi-GPU support. | +| `[]` | Use the first detected device. Generation runs serially, as it did before multi-GPU support. | + +Each entry in the list must be one of `cpu`, `cuda`, `mps`, or `cuda:N`, where `N` is a zero-based device number (`cuda:0` is the first GPU, `cuda:1` the second, and so on). + +```yaml +# Use the first and third GPUs, leaving the second free for other tasks +generation_devices: [cuda:0, cuda:2] +``` + +Notes: + +- On a system without a CUDA GPU, `auto` resolves to the single best available device (`mps` on Apple Silicon, otherwise `cpu`), so generation runs serially. +- Each active GPU gets its own model cache, and model weights are duplicated in system RAM for every device. Running many GPUs in parallel therefore increases RAM usage — ensure you have ample system memory before enabling a large device list. +- Duplicate entries are ignored; `[cuda:0, cuda:0]` is treated as `[cuda:0]`. +- You can restrict which physical GPUs InvokeAI sees with the `CUDA_VISIBLE_DEVICES` environment variable. When set, `auto` only enumerates the visible subset, and `cuda:N` indices refer to positions within that subset. + +During parallel generation, the progress display shows one progress bar per active session, stacked vertically, each disappearing as its session completes. + #### Image Subfolder Strategy By default, generated images are stored in a single flat directory under `outputs/images/`. The `image_subfolder_strategy` setting lets you organize newly-created images into subfolders automatically. You can edit this setting in `invokeai.yaml` or, as an admin user, in the Settings panel. diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index fcb47dbfb23..1987a90abce 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -490,6 +490,17 @@ "type": "", "validation": {} }, + { + "category": "DEVICE", + "default": "auto", + "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", + "env_var": "INVOKEAI_GENERATION_DEVICES", + "literal_values": [], + "name": "generation_devices", + "required": false, + "type": "typing.Union[typing.Literal['auto'], list[str]]", + "validation": {} + }, { "category": "DEVICE", "default": "auto", diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index 832e58f5e24..a8e0c68d781 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -1,15 +1,16 @@ import locale +import re from enum import Enum from importlib.metadata import distributions from pathlib import Path as FilePath from threading import Lock -from typing import Any +from typing import Any, Literal, Union import torch import yaml from fastapi import Body, HTTPException, Path from fastapi.routing import APIRouter -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.api.auth_dependencies import AdminUserOrDefault from invokeai.app.api.dependencies import ApiDependencies @@ -118,6 +119,16 @@ def _remove_nullable_default_from_schema(schema: dict[str, Any]) -> None: schema.update(non_null_schemas[0]) +_GENERATION_DEVICE_PATTERN = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") + + +class GenerationDeviceOption(BaseModel): + """A device that may be selected for generation.""" + + device: str = Field(description="The device identifier, e.g. 'cuda:0', 'mps', or 'cpu'") + name: str = Field(description="Human-readable device name") + + class UpdateAppGenerationSettingsRequest(BaseModel): """Writable generation-related app settings.""" @@ -131,14 +142,59 @@ class UpdateAppGenerationSettingsRequest(BaseModel): ge=0, description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.", ) + generation_devices: Union[Literal["auto"], list[str]] | None = Field( + default=None, + description="Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI.", + json_schema_extra=_remove_nullable_default_from_schema, + ) + + @field_validator("generation_devices") + @classmethod + def validate_generation_devices( + cls, v: Union[Literal["auto"], list[str], None] + ) -> Union[Literal["auto"], list[str], None]: + if v is None or v == "auto": + return v + for device in v: + if not _GENERATION_DEVICE_PATTERN.match(device): + raise ValueError( + f"Invalid generation device '{device}'. Valid values are 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) + return v @model_validator(mode="after") def validate_explicit_nulls(self) -> "UpdateAppGenerationSettingsRequest": if "image_subfolder_strategy" in self.model_fields_set and self.image_subfolder_strategy is None: raise ValueError("image_subfolder_strategy may not be null") + if "generation_devices" in self.model_fields_set and self.generation_devices is None: + raise ValueError("generation_devices may not be null") return self +@app_router.get( + "/generation_device_options", + operation_id="get_generation_device_options", + status_code=200, + response_model=list[GenerationDeviceOption], +) +async def get_generation_device_options() -> list[GenerationDeviceOption]: + """List the devices available for generation, for use with the `generation_devices` setting.""" + options: list[GenerationDeviceOption] = [] + if torch.cuda.is_available(): + for index in range(torch.cuda.device_count()): + device = f"cuda:{index}" + try: + name = torch.cuda.get_device_name(index) + except Exception: + name = device + options.append(GenerationDeviceOption(device=device, name=name)) + elif torch.backends.mps.is_available(): + options.append(GenerationDeviceOption(device="mps", name="Apple MPS")) + else: + options.append(GenerationDeviceOption(device="cpu", name="CPU")) + return options + + @app_router.get( "/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields ) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index bdd2e406444..53c4c68981f 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -443,7 +443,11 @@ async def update_model_record( # nn.Module at load time, so toggling them on a cached model is otherwise silently a no-op until # the entry is evicted. Drop any unlocked cached entries for this model so the next load rebuilds. if _load_settings_changed(previous_config, config): - dropped = ApiDependencies.invoker.services.model_manager.load.ram_cache.drop_model(key) + # Drop the model from every per-device cache so the next load on any GPU rebuilds it. + dropped = sum( + cache.drop_model(key) + for cache in ApiDependencies.invoker.services.model_manager.load.ram_caches.values() + ) if dropped: logger.info( f"Dropped {dropped} cached entr{'y' if dropped == 1 else 'ies'} for model {key} after settings change." @@ -1304,9 +1308,10 @@ async def get_stats() -> Optional[CacheStats]: ) async def empty_model_cache(current_admin: AdminUserOrDefault) -> None: """Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped.""" - # Request 1000GB of room in order to force the cache to drop all models. + # Request 1000GB of room in order to force each per-device cache to drop all models. ApiDependencies.invoker.services.logger.info("Emptying model cache.") - ApiDependencies.invoker.services.model_manager.load.ram_cache.make_room(1000 * 2**30) + for cache in ApiDependencies.invoker.services.model_manager.load.ram_caches.values(): + cache.make_room(1000 * 2**30) class HFTokenStatus(str, Enum): diff --git a/invokeai/app/invocations/anima_denoise.py b/invokeai/app/invocations/anima_denoise.py index 9fa4b3fb07a..b301e817f9c 100644 --- a/invokeai/app/invocations/anima_denoise.py +++ b/invokeai/app/invocations/anima_denoise.py @@ -608,7 +608,7 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor if driver is not None: user_step = 0 - pbar = tqdm(total=total_steps, desc="Denoising (Anima)") + pbar = tqdm(total=total_steps, desc=f"Denoising (Anima){TorchDevice.get_session_device_label()}") for it in driver.iterations(): timestep = torch.tensor( [it.sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype @@ -655,7 +655,9 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor pbar.close() else: # Built-in Euler implementation (default for Anima) - for step_idx in tqdm(range(total_steps), desc="Denoising (Anima)"): + for step_idx in tqdm( + range(total_steps), desc=f"Denoising (Anima){TorchDevice.get_session_device_label()}" + ): sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py index c04210401be..cb06d2b3ff6 100644 --- a/invokeai/app/invocations/cogview4_denoise.py +++ b/invokeai/app/invocations/cogview4_denoise.py @@ -294,7 +294,7 @@ def _run_diffusion( assert isinstance(transformer, CogView4Transformer2DModel) # Denoising loop - for step_idx in tqdm(range(total_steps)): + for step_idx in tqdm(range(total_steps), desc=f"Denoising{TorchDevice.get_session_device_label()}"): t_curr = timesteps[step_idx] sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py index f6c90b9690c..10c9080ac5e 100644 --- a/invokeai/app/invocations/sd3_denoise.py +++ b/invokeai/app/invocations/sd3_denoise.py @@ -284,7 +284,10 @@ def _run_diffusion( assert isinstance(transformer, SD3Transformer2DModel) # 6. Denoising loop - for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_idx, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): # Expand the latents if we are doing CFG. latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # Expand the timestep to match the latent model input. diff --git a/invokeai/app/invocations/z_image_denoise.py b/invokeai/app/invocations/z_image_denoise.py index c1e864ea179..c6887840df8 100644 --- a/invokeai/app/invocations/z_image_denoise.py +++ b/invokeai/app/invocations/z_image_denoise.py @@ -569,7 +569,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: # Use diffusers scheduler for stepping # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps) # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): sched_timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized sigma (0-1) @@ -686,7 +686,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: pbar.close() else: # Original Euler implementation (default, optimized for Z-Image) - for step_idx in tqdm(range(total_steps)): + for step_idx in tqdm(range(total_steps), desc=f"Denoising{TorchDevice.get_session_device_label()}"): sigma_curr = sigmas[step_idx] sigma_prev = sigmas[step_idx + 1] diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index e6cc7c2798c..15d447dd182 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -11,7 +11,7 @@ import shutil from functools import lru_cache from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union import yaml from pydantic import BaseModel, Field, PrivateAttr, field_validator @@ -205,6 +205,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$") + generation_devices: Union[Literal["auto"], list[str]] = Field(default="auto", description="Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION @@ -257,6 +258,19 @@ class InvokeAIAppConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True) + @field_validator("generation_devices") + @classmethod + def validate_generation_devices(cls, v: Union[str, list[str]]) -> Union[str, list[str]]: + if v == "auto": + return v + pattern = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") + for device in v: + if not pattern.match(device): + raise ValueError( + f"Invalid generation device '{device}'. Valid values are 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) + return v + def update_config(self, config: dict[str, Any] | InvokeAIAppConfig, clobber: bool = True) -> None: """Updates the config, overwriting existing values. diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 0c530f9a2f7..c30fa31b75c 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -138,6 +138,10 @@ class InvocationProgressEvent(InvocationEventBase): image: ProgressImage | None = Field( default=None, description="An image representing the current state of the progress" ) + device: str | None = Field( + default=None, + description="The device processing this session, e.g. 'cuda:1' (set only when running on a CUDA GPU)", + ) @classmethod def build( @@ -148,6 +152,13 @@ def build( percentage: float | None = None, image: ProgressImage | None = None, ) -> "InvocationProgressEvent": + # This is emitted from the session-processor worker thread, which pins its CUDA device via + # TorchDevice.set_session_device(). Resolve that here so the UI can label progress by GPU. + from invokeai.backend.util.devices import TorchDevice + + session_device = TorchDevice.get_session_device() + device = str(session_device) if session_device is not None and session_device.type == "cuda" else None + return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, @@ -161,6 +172,7 @@ def build( percentage=percentage, image=image, message=message, + device=device, ) diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 12b737a7cf1..ec84439547a 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -1,4 +1,5 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team +import threading from pathlib import Path from queue import Queue from typing import Optional, Union @@ -23,6 +24,9 @@ def __init__(self, output_folder: Union[str, Path]): self.__cache: dict[Path, PILImageType] = {} self.__cache_ids = Queue[Path]() self.__max_cache_size = 10 # TODO: get this from config + # Guards the cache structures (__cache / __cache_ids), which are read and mutated from + # multiple session-processor worker threads in multi-GPU parallel mode. + self.__cache_lock = threading.Lock() self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__thumbnails_folder = self.__output_folder / "thumbnails" @@ -41,6 +45,13 @@ def get(self, image_name: str, image_subfolder: str = "") -> PILImageType: return cache_item image = Image.open(image_path) + # Image.open() is lazy: it reads the header but defers pixel decoding (and holds the + # file handle open) until the first .load()/.copy()/.convert(). The opened object is + # cached and the SAME object is handed to every caller, so in multi-GPU parallel mode + # two worker threads can call .copy() on it concurrently and race on the shared file + # handle and decoder state, producing "broken data stream" / "self.png is not None" + # errors. Forcing the decode here makes the cached object safe for concurrent reads. + image.load() self.__set_cache(image_path, image) return image except FileNotFoundError as e: @@ -105,16 +116,18 @@ def delete(self, image_name: str, image_subfolder: str = "") -> None: if image_path.exists(): image_path.unlink() - if image_path in self.__cache: - del self.__cache[image_path] thumbnail_name = get_thumbnail_name(image_name) thumbnail_path = self.get_path(thumbnail_name, True, image_subfolder=image_subfolder) if thumbnail_path.exists(): thumbnail_path.unlink() - if thumbnail_path in self.__cache: - del self.__cache[thumbnail_path] + + with self.__cache_lock: + if image_path in self.__cache: + del self.__cache[image_path] + if thumbnail_path in self.__cache: + del self.__cache[thumbnail_path] except Exception as e: raise ImageFileDeleteException from e @@ -185,13 +198,15 @@ def __validate_storage_folders(self) -> None: folder.mkdir(parents=True, exist_ok=True) def __get_cache(self, image_name: Path) -> Optional[PILImageType]: - return None if image_name not in self.__cache else self.__cache[image_name] + with self.__cache_lock: + return None if image_name not in self.__cache else self.__cache[image_name] def __set_cache(self, image_name: Path, image: PILImageType): - if image_name not in self.__cache: - self.__cache[image_name] = image - self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache - if len(self.__cache) > self.__max_cache_size: - cache_id = self.__cache_ids.get() - if cache_id in self.__cache: - del self.__cache[cache_id] + with self.__cache_lock: + if image_name not in self.__cache: + self.__cache[image_name] = image + self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache + if len(self.__cache) > self.__max_cache_size: + cache_id = self.__cache_ids.get() + if cache_id in self.__cache: + del self.__cache[cache_id] diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 3baf11029ff..7c9fdeee11b 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -116,6 +116,11 @@ def __init__( self._restore_completed_event.set() self._download_queue = download_queue self._download_cache: Dict[int, ModelInstallJob] = {} + # Per-source locks serializing download_and_cache_model() so parallel (multi-GPU) sessions + # that need the same remote model (e.g. the LaMa infill model) don't race to download into + # the same cache directory. _download_cache_locks_guard protects the dict itself. + self._download_cache_locks: Dict[str, threading.Lock] = {} + self._download_cache_locks_guard = threading.Lock() self._running = False self._session = session self._install_thread: Optional[threading.Thread] = None @@ -724,27 +729,47 @@ def download_and_cache_model( if len(contents) > 0: return contents[0] - model_path.mkdir(parents=True, exist_ok=True) - model_source = self._guess_source(str(source)) - remote_files, _ = self._remote_files_from_source(model_source) - # Handle multiple subfolders for HFModelSource - subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else [] - job = self._multifile_download( - dest=model_path, - remote_files=remote_files, - subfolder=model_source.subfolder - if isinstance(model_source, HFModelSource) and len(subfolders) <= 1 - else None, - subfolders=subfolders if len(subfolders) > 1 else None, - ) - files_string = "file" if len(remote_files) == 1 else "files" - self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") - self._download_queue.wait_for_job(job) - if job.complete: - assert job.download_path is not None - return job.download_path - else: - raise Exception(job.error) + # Serialize concurrent downloads of the same source. Parallel multi-GPU sessions can each + # request the same remote model (e.g. the LaMa infill model) at once; without this lock they + # both download into the same cache directory and collide on the final rename, which fails on + # Windows with "WinError 32: the file is being used by another process". The other waiters + # find the completed download on the post-lock re-check below and skip downloading. + with self._download_cache_lock(str(source)): + if model_path.exists(): + contents = list(model_path.iterdir()) + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + model_source = self._guess_source(str(source)) + remote_files, _ = self._remote_files_from_source(model_source) + # Handle multiple subfolders for HFModelSource + subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else [] + job = self._multifile_download( + dest=model_path, + remote_files=remote_files, + subfolder=model_source.subfolder + if isinstance(model_source, HFModelSource) and len(subfolders) <= 1 + else None, + subfolders=subfolders if len(subfolders) > 1 else None, + ) + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") + self._download_queue.wait_for_job(job) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + + def _download_cache_lock(self, source: str) -> threading.Lock: + """Return the lock that serializes downloads for a given source, creating it on first use.""" + with self._download_cache_locks_guard: + lock = self._download_cache_locks.get(source) + if lock is None: + lock = threading.Lock() + self._download_cache_locks[source] = lock + return lock def _remote_files_from_source( self, source: ModelSource diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 87a405b4ea4..8fc9823328d 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -26,7 +26,21 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo @property @abstractmethod def ram_cache(self) -> ModelCache: - """Return the RAM cache used by this loader.""" + """Return the RAM cache for the current thread's execution device. + + In multi-GPU mode, each session-processor worker is pinned to a device and gets its own + cache; this resolves to the calling thread's cache. Outside a worker (e.g. API threads), + it resolves to the default device's cache. + """ + + @property + @abstractmethod + def ram_caches(self) -> dict[str, ModelCache]: + """Return all per-device RAM caches, keyed by normalized device string. + + Use this for maintenance operations that must apply to every device (clear cache, drop a + model from all devices, shutdown). + """ @abstractmethod def load_model_from_path( diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 2e2d2ae219d..33c7ef6108c 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -18,7 +18,7 @@ ModelLoaderRegistry, ModelLoaderRegistryBase, ) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import MODEL_LOAD_LOCK, ModelCache from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType from invokeai.backend.util.devices import TorchDevice @@ -33,13 +33,25 @@ def __init__( app_config: InvokeAIAppConfig, ram_cache: ModelCache, registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, + ram_caches: Optional[dict[str, ModelCache]] = None, ): - """Initialize the model load service.""" + """Initialize the model load service. + + Args: + ram_cache: The default RAM cache, used when no per-device cache matches the calling + thread (e.g. single-device installs, or API threads). + ram_caches: Optional map of normalized device string -> ModelCache for multi-GPU mode. + One cache per generation device. The default `ram_cache` is always included. + """ logger = InvokeAILogger.get_logger(self.__class__.__name__) logger.setLevel(app_config.log_level.upper()) self._logger = logger self._app_config = app_config - self._ram_cache = ram_cache + self._default_ram_cache = ram_cache + # Map normalized device string -> cache. Always includes the default cache so that callers + # without a pinned device (API threads) resolve to a valid cache. + self._ram_caches: dict[str, ModelCache] = dict(ram_caches) if ram_caches else {} + self._ram_caches.setdefault(str(TorchDevice.normalize(ram_cache.execution_device)), ram_cache) self._registry = registry def start(self, invoker: Invoker) -> None: @@ -47,8 +59,18 @@ def start(self, invoker: Invoker) -> None: @property def ram_cache(self) -> ModelCache: - """Return the RAM cache used by this loader.""" - return self._ram_cache + """Return the RAM cache for the calling thread's execution device. + + `choose_torch_device()` is thread-local-aware: a session-processor worker pinned to a GPU + gets that GPU's cache; everything else falls back to the default cache. + """ + key = str(TorchDevice.choose_torch_device()) + return self._ram_caches.get(key, self._default_ram_cache) + + @property + def ram_caches(self) -> dict[str, ModelCache]: + """Return all per-device RAM caches, keyed by normalized device string.""" + return dict(self._ram_caches) def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -67,7 +89,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo loaded_model: LoadedModel = implementation( app_config=self._app_config, logger=self._logger, - ram_cache=self._ram_cache, + ram_cache=self.ram_cache, ).load_model(model_config, submodel_type) if hasattr(self, "_invoker"): @@ -78,9 +100,11 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None ) -> LoadedModelWithoutConfig: + # Resolve the calling thread's cache once so the whole load uses a single device's cache. + ram_cache = self.ram_cache cache_key = str(model_path) try: - return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) except IndexError: pass @@ -110,7 +134,7 @@ def diffusers_load_directory(directory: Path) -> AnyModel: load_class = GenericDiffusersLoader( app_config=self._app_config, logger=self._logger, - ram_cache=self._ram_cache, + ram_cache=ram_cache, convert_cache=self.convert_cache, ).get_hf_load_class(directory) return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) @@ -123,6 +147,15 @@ def diffusers_load_directory(directory: Path) -> AnyModel: else lambda path: safetensors_load_file(path, device="cpu") ) assert loader is not None - raw_model = loader(model_path) - self._ram_cache.put(key=cache_key, model=raw_model) - return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) + # Serialize construction (see MODEL_LOAD_LOCK): the diffusers loader path uses the same + # process-global, non-thread-safe monkey-patches as the main loader, so it takes the write + # lock to exclude concurrent VRAM moves. Re-check the cache after acquiring the lock in case + # a worker sharing this cache built it while we waited. + with MODEL_LOAD_LOCK.write_lock(): + try: + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) + except IndexError: + pass + raw_model = loader(model_path) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 6141a635f4d..b7680524a34 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -60,9 +60,10 @@ def start(self, invoker: Invoker) -> None: service.start(invoker) def stop(self, invoker: Invoker) -> None: - # Shutdown the model cache to cancel any pending timers - if hasattr(self._load, "ram_cache"): - self._load.ram_cache.shutdown() + # Shutdown every per-device model cache to cancel any pending keep-alive timers. + if hasattr(self._load, "ram_caches"): + for cache in self._load.ram_caches.values(): + cache.shutdown() for service in [self._store, self._install, self._load]: if hasattr(service, "stop"): @@ -85,22 +86,38 @@ def build_model_manager( logger = InvokeAILogger.get_logger(cls.__name__) logger.setLevel(app_config.log_level.upper()) - ram_cache = ModelCache( - execution_device_working_mem_gb=app_config.device_working_mem_gb, - enable_partial_loading=app_config.enable_partial_loading, - keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, - max_ram_cache_size_gb=app_config.max_cache_ram_gb, - max_vram_cache_size_gb=app_config.max_cache_vram_gb, - execution_device=execution_device or TorchDevice.choose_torch_device(), - storage_device="cpu", - log_memory_usage=app_config.log_memory_usage, - logger=logger, - keep_alive_minutes=app_config.model_cache_keep_alive_min, - ) + def build_cache(device: torch.device) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=app_config.device_working_mem_gb, + enable_partial_loading=app_config.enable_partial_loading, + keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, + max_ram_cache_size_gb=app_config.max_cache_ram_gb, + max_vram_cache_size_gb=app_config.max_cache_vram_gb, + execution_device=device, + storage_device="cpu", + log_memory_usage=app_config.log_memory_usage, + logger=logger, + keep_alive_minutes=app_config.model_cache_keep_alive_min, + ) + + # The default cache for callers without a pinned device (API threads, single-device installs). + default_device = execution_device or TorchDevice.choose_torch_device() + ram_cache = build_cache(default_device) + + # In multi-GPU mode, build one independent cache per generation device. Each session-processor + # worker is pinned to a device (see TorchDevice.set_session_device) and resolves to its own + # cache. The default cache is always included by ModelLoadService. + ram_caches: dict[str, ModelCache] = {str(TorchDevice.normalize(default_device)): ram_cache} + for device in TorchDevice.get_generation_devices(app_config.generation_devices): + key = str(device) + if key not in ram_caches: + ram_caches[key] = build_cache(device) + loader = ModelLoadService( app_config=app_config, ram_cache=ram_cache, registry=ModelLoaderRegistry, + ram_caches=ram_caches, ) installer = ModelInstallService( app_config=app_config, diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index b361259a4b1..ae00173e422 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,4 +1,5 @@ from queue import Queue +from threading import Lock from typing import TYPE_CHECKING, Optional, TypeVar from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase @@ -21,6 +22,9 @@ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: self._cache: dict[str, T] = {} self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size + # Guards the in-memory cache so concurrent session-processor workers (multi-GPU) can't race + # the check-then-evict in `_set_cache` (which could otherwise raise KeyError on eviction). + self._cache_lock = Lock() def start(self, invoker: "Invoker") -> None: self._invoker = invoker @@ -50,16 +54,19 @@ def save(self, obj: T) -> str: def delete(self, name: str) -> None: self._underlying_storage.delete(name) - if name in self._cache: - del self._cache[name] + with self._cache_lock: + if name in self._cache: + del self._cache[name] self._on_deleted(name) def _get_cache(self, name: str) -> Optional[T]: - return None if name not in self._cache else self._cache[name] + with self._cache_lock: + return None if name not in self._cache else self._cache[name] def _set_cache(self, name: str, data: T): - if name not in self._cache: - self._cache[name] = data - self._cache_ids.put(name) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) + with self._cache_lock: + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7159c19e746..27c1f2a8632 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -5,6 +5,8 @@ from threading import Event as ThreadEvent from typing import Optional +import torch + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, @@ -31,6 +33,7 @@ from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler +from invokeai.backend.util.devices import TorchDevice class DefaultSessionRunner(SessionRunnerBase): @@ -305,6 +308,26 @@ def _on_node_error( ) +class _SessionWorker: + """A single generation worker: one thread, optionally pinned to one device. + + In single-device (legacy) mode there is exactly one worker with `device=None`. In multi-GPU + mode there is one worker per configured device, each with its own session runner and cancel + event so concurrent sessions can be canceled independently. + """ + + def __init__(self, device: Optional[torch.device], runner: SessionRunnerBase) -> None: + self.device = device + self.runner = runner + self.cancel_event = ThreadEvent() + self.queue_item: Optional[SessionQueueItem] = None + self.thread: Optional[Thread] = None + + @property + def label(self) -> str: + return str(self.device) if self.device is not None else "default device" + + class DefaultSessionProcessor(SessionProcessorBase): def __init__( self, @@ -319,57 +342,113 @@ def __init__( self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] self._thread_limit = thread_limit self._polling_interval = polling_interval + self._workers: list[_SessionWorker] = [] + + def _resolve_devices(self) -> list[Optional[torch.device]]: + """Determine the per-worker devices from config. + + Resolves `generation_devices` (which defaults to `"auto"` — every available GPU) into one + normalized device per worker. Returns a single `None` (legacy single-worker, device chosen by + the global config) only if the resolution is empty (e.g. `generation_devices` set to an empty + list). + """ + generation_devices = self._invoker.services.configuration.generation_devices + devices = TorchDevice.get_generation_devices(generation_devices) + if not devices: + return [None] + return list(devices) + + def _clone_session_runner(self, template: SessionRunnerBase) -> SessionRunnerBase: + """Create an independent runner for an additional worker. + + Each worker needs its own runner because the runner stores its session's cancel event. + We carry over the template's callbacks so all workers behave identically. + """ + if isinstance(template, DefaultSessionRunner): + return DefaultSessionRunner( + on_before_run_session_callbacks=list(template._on_before_run_session_callbacks), + on_before_run_node_callbacks=list(template._on_before_run_node_callbacks), + on_after_run_node_callbacks=list(template._on_after_run_node_callbacks), + on_node_error_callbacks=list(template._on_node_error_callbacks), + on_after_run_session_callbacks=list(template._on_after_run_session_callbacks), + ) + # Unknown runner implementation — only safe to reuse in single-worker mode. + return template def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker - self._queue_item: Optional[SessionQueueItem] = None - self._invocation: Optional[BaseInvocation] = None self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() self._poll_now_event = ThreadEvent() - self._cancel_event = ThreadEvent() register_events(QueueClearedEvent, self._on_queue_cleared) register_events(BatchEnqueuedEvent, self._on_batch_enqueued) register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed) - self._thread_semaphore = BoundedSemaphore(self._thread_limit) + devices = self._resolve_devices() # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, - # the profiler will create a new profile for each session. + # the profiler will create a new profile for each session. Profiling uses a process-global cProfile, which + # cannot cleanly attribute work when multiple sessions run concurrently, so it is disabled in multi-GPU mode. + profiler_enabled = self._invoker.services.configuration.profile_graphs + if profiler_enabled and len(devices) > 1: + self._invoker.services.logger.warning( + "Graph profiling is disabled because multiple generation devices are configured." + ) + profiler_enabled = False self._profiler = ( Profiler( logger=self._invoker.services.logger, output_dir=self._invoker.services.configuration.profiles_path, prefix=self._invoker.services.configuration.profile_prefix, ) - if self._invoker.services.configuration.profile_graphs + if profiler_enabled else None ) - self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) - self._thread = Thread( - name="session_processor", - target=self._process, - daemon=True, - kwargs={ - "stop_event": self._stop_event, - "poll_now_event": self._poll_now_event, - "resume_event": self._resume_event, - "cancel_event": self._cancel_event, - }, - ) - self._thread.start() + self._thread_semaphore = BoundedSemaphore(len(devices)) + + # Start in the running (resumed) state. + self._stop_event.clear() + self._resume_event.set() + + self._workers = [] + for index, device in enumerate(devices): + runner = self.session_runner if index == 0 else self._clone_session_runner(self.session_runner) + worker = _SessionWorker(device=device, runner=runner) + runner.start(services=invoker.services, cancel_event=worker.cancel_event, profiler=self._profiler) + self._workers.append(worker) + + if len(self._workers) > 1: + self._invoker.services.logger.info( + f"Starting session processor with {len(self._workers)} parallel workers on devices: " + f"{', '.join(w.label for w in self._workers)}" + ) + + for index, worker in enumerate(self._workers): + worker.thread = Thread( + name=f"session_processor_{index}", + target=self._process, + daemon=True, + kwargs={ + "worker": worker, + "stop_event": self._stop_event, + "poll_now_event": self._poll_now_event, + "resume_event": self._resume_event, + }, + ) + worker.thread.start() def stop(self, *args, **kwargs) -> None: self._stop_event.set() # Cancel any in-progress generation so that long-running nodes (e.g. denoising) stop at - # the next step boundary instead of running to completion. Without this, the generation + # the next step boundary instead of running to completion. Without this, a generation # thread may still be executing CUDA operations when Python teardown begins, which can # cause a C++ std::terminate() crash ("terminate called without an active exception"). - self._cancel_event.set() - # Wake the thread if it is sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused). + for worker in self._workers: + worker.cancel_event.set() + # Wake any worker sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused). self._poll_now_event.set() self._resume_event.set() @@ -377,28 +456,31 @@ def _poll_now(self) -> None: self._poll_now_event.set() async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: - if self._queue_item and self._queue_item.queue_id == event[1].queue_id: - self._cancel_event.set() + # Cancel every worker currently running an item from the cleared queue. + canceled = False + for worker in self._workers: + if worker.queue_item and worker.queue_item.queue_id == event[1].queue_id: + worker.cancel_event.set() + canceled = True + if canceled: self._poll_now() async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None: self._poll_now() async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: - # Make sure the cancel event is for the currently processing queue item - if self._queue_item and self._queue_item.item_id != event[1].item_id: - return - if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: - # When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is - # emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel - # event, which the session runner checks between invocations. If set, the session runner loop is broken. - # - # Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such - # node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item - # is canceled, and if it is, raises a `CanceledException` to stop execution immediately. - if event[1].status == "canceled": - self._cancel_event.set() - self._poll_now() + # Find the worker (if any) currently running the item whose status changed. + for worker in self._workers: + if worker.queue_item and worker.queue_item.item_id == event[1].item_id: + if event[1].status in ["completed", "failed", "canceled"]: + # When the queue item is canceled via HTTP, the status is set to "canceled" and this event is + # emitted. We respond by setting that worker's cancel event, which its session runner checks + # between invocations (and which denoise_latents' step callback checks mid-node, raising + # CanceledException to stop immediately). + if event[1].status == "canceled": + worker.cancel_event.set() + self._poll_now() + return def resume(self) -> SessionProcessorStatus: if not self._resume_event.is_set(): @@ -413,22 +495,28 @@ def pause(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus: return SessionProcessorStatus( is_started=self._resume_event.is_set(), - is_processing=self._queue_item is not None, + is_processing=any(worker.queue_item is not None for worker in self._workers), ) def _process( self, + worker: _SessionWorker, stop_event: ThreadEvent, poll_now_event: ThreadEvent, resume_event: ThreadEvent, - cancel_event: ThreadEvent, ): try: - # Any unhandled exception in this block is a fatal processor error and will stop the processor. + # Any unhandled exception in this block is a fatal processor error and will stop this worker. self._thread_semaphore.acquire() - stop_event.clear() - resume_event.set() - cancel_event.clear() + + # Pin this worker thread to its device so all device-selecting code (TorchDevice.choose_torch_device, + # which nodes and the model loader consult) resolves to this GPU. CUDA's current device is per-thread. + if worker.device is not None: + TorchDevice.set_session_device(worker.device) + if worker.device.type == "cuda": + torch.cuda.set_device(worker.device) + + worker.cancel_event.clear() while not stop_event.is_set(): poll_now_event.clear() @@ -437,10 +525,17 @@ def _process( # If we are paused, wait for resume event resume_event.wait() - # Get the next session to process - self._queue_item = self._invoker.services.session_queue.dequeue() + if stop_event.is_set(): + break + + # Get the next session to process. dequeue() atomically claims the item, so concurrent + # workers never receive the same item. Pass this worker's device so the item is + # tagged with the GPU that ran it (None in single-device/legacy mode). + worker.queue_item = self._invoker.services.session_queue.dequeue( + device=str(worker.device) if worker.device is not None else None + ) - if self._queue_item is None: + if worker.queue_item is None: # The queue was empty, wait for next polling interval or event to try again self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(self._polling_interval) @@ -453,19 +548,20 @@ def _process( gc.collect() self._invoker.services.logger.info( - f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}" + f"Executing queue item {worker.queue_item.item_id}, session {worker.queue_item.session_id} " + f"on {worker.label}" ) - cancel_event.clear() + worker.cancel_event.clear() # Run the graph - self.session_runner.run(queue_item=self._queue_item) + worker.runner.run(queue_item=worker.queue_item) except Exception as e: error_type = e.__class__.__name__ error_message = str(e) error_traceback = traceback.format_exc() self._on_non_fatal_processor_error( - queue_item=self._queue_item, + queue_item=worker.queue_item, error_type=error_type, error_message=error_message, error_traceback=error_traceback, @@ -474,7 +570,7 @@ def _process( poll_now_event.wait(self._polling_interval) continue except Exception as e: - # Fatal error in processor, log and pass - we're done here + # Fatal error in this worker, log and pass - we're done here error_type = e.__class__.__name__ error_message = str(e) error_traceback = traceback.format_exc() @@ -482,9 +578,9 @@ def _process( self._invoker.services.logger.error(error_traceback) pass finally: - stop_event.clear() - poll_now_event.clear() - self._queue_item = None + worker.queue_item = None + if worker.device is not None: + TorchDevice.clear_session_device() self._thread_semaphore.release() def _on_non_fatal_processor_error( diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 73acf9c31aa..07f4be1fded 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -31,8 +31,8 @@ class SessionQueueBase(ABC): """Base class for session queue""" @abstractmethod - def dequeue(self) -> Optional[SessionQueueItem]: - """Dequeues the next session queue item.""" + def dequeue(self, device: Optional[str] = None) -> Optional[SessionQueueItem]: + """Dequeues the next session queue item, recording the processing device (e.g. 'cuda:1') if given.""" pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index d87221fbbae..8e149af3afe 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -262,6 +262,10 @@ class SessionQueueItem(BaseModel): retried_from_item_id: Optional[int] = Field( default=None, description="The item_id of the queue item that this item was retried from" ) + device: Optional[str] = Field( + default=None, + description="The device that processed this queue item, e.g. 'cuda:1' (set only when running on a CUDA GPU)", + ) session: GraphExecutionState = Field(description="The fully-populated session to be executed") workflow: Optional[WorkflowWithoutID] = Field( default=None, description="The workflow associated with this queue item" diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index a05ed468857..aa2ccc371a6 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -1,6 +1,7 @@ import asyncio import json import sqlite3 +import threading from typing import Optional, Union, cast from pydantic_core import to_jsonable_python @@ -42,6 +43,12 @@ class SqliteSessionQueue(SessionQueueBase): __invoker: Invoker + # Serializes the select-candidate-then-claim sequence in `dequeue()`. The DB connection's + # RLock serializes individual statements, but the gap between selecting the next pending item + # and marking it 'in_progress' is a race: with multiple session-processor workers (multi-GPU), + # two workers could select the same item. Holding this lock across the whole claim prevents it. + _dequeue_lock = threading.Lock() + def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() @@ -209,28 +216,34 @@ async def enqueue_batch( self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id) return enqueue_result - def dequeue(self) -> Optional[SessionQueueItem]: - with self._db.transaction() as cursor: - cursor.execute( - """--sql - SELECT - sq.*, - u.display_name as user_display_name, - u.email as user_email - FROM session_queue sq - LEFT JOIN users u ON sq.user_id = u.user_id - WHERE sq.status = 'pending' - ORDER BY - sq.priority DESC, - sq.item_id ASC - LIMIT 1 - """ - ) - result = cast(Union[sqlite3.Row, None], cursor.fetchone()) - if result is None: - return None - queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress") + def dequeue(self, device: Optional[str] = None) -> Optional[SessionQueueItem]: + # Hold the dequeue lock across the select-then-claim so concurrent workers (multi-GPU) + # cannot select and claim the same pending item. `_set_queue_item_status` already no-ops + # if the item was concurrently moved to a terminal state (e.g. canceled), so we only need + # to guard against two dequeues racing for the same pending row. + with self._dequeue_lock: + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT + sq.*, + u.display_name as user_display_name, + u.email as user_email + FROM session_queue sq + LEFT JOIN users u ON sq.user_id = u.user_id + WHERE sq.status = 'pending' + ORDER BY + sq.priority DESC, + sq.item_id ASC + LIMIT 1 + """ + ) + result = cast(Union[sqlite3.Row, None], cursor.fetchone()) + if result is None: + return None + queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) + # Record the claiming worker's device so the UI can label the item by GPU. + queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress", device=device) return queue_item def get_next(self, queue_id: str) -> Optional[SessionQueueItem]: @@ -287,6 +300,7 @@ def _set_queue_item_status( error_type: Optional[str] = None, error_message: Optional[str] = None, error_traceback: Optional[str] = None, + device: Optional[str] = None, ) -> SessionQueueItem: with self._db.transaction() as cursor: cursor.execute( @@ -308,10 +322,10 @@ def _set_queue_item_status( cursor.execute( """--sql UPDATE session_queue - SET status = ?, status_sequence = COALESCE(status_sequence, 0) + 1, error_type = ?, error_message = ?, error_traceback = ? + SET status = ?, status_sequence = COALESCE(status_sequence, 0) + 1, error_type = ?, error_message = ?, error_traceback = ?, device = COALESCE(?, device) WHERE item_id = ? """, - (status, error_type, error_message, error_traceback, item_id), + (status, error_type, error_message, error_traceback, device, item_id), ) queue_item = self.get_queue_item(item_id) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 12642610c8c..3e1d5c53f3e 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -34,6 +34,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_30 import build_migration_30 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_31 import build_migration_31 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_32 import build_migration_32 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -85,6 +86,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_29()) migrator.register_migration(build_migration_30()) migrator.register_migration(build_migration_31()) + migrator.register_migration(build_migration_32()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_32.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_32.py new file mode 100644 index 00000000000..fe60433463d --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_32.py @@ -0,0 +1,36 @@ +"""Migration 32: Add device column to session_queue table. + +This records which device (e.g. 'cuda:1') processed a queue item, so the UI can show a per-item +GPU number in the Session Queue. Existing rows get NULL (unknown device). +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration32Callback: + """Migration to add a device column to the session_queue table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='session_queue';") + if cursor.fetchone() is None: + return + + cursor.execute("PRAGMA table_info(session_queue);") + columns = [row[1] for row in cursor.fetchall()] + + if "device" not in columns: + cursor.execute("ALTER TABLE session_queue ADD COLUMN device TEXT;") + + +def build_migration_32() -> Migration: + """Builds the migration object for migrating from version 31 to version 32. + + This migration adds a device column to the session_queue table. + """ + return Migration( + from_version=31, + to_version=32, + callback=Migration32Callback(), + ) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 0f4cf07ee5b..7b29a58d44f 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -15,6 +15,7 @@ from invokeai.backend.flux.model import Flux from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.util.devices import TorchDevice def denoise( @@ -95,7 +96,7 @@ def denoise( # Use diffusers scheduler for stepping # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps) # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized (0-1) for the model @@ -266,7 +267,10 @@ def denoise( return img # Original Euler implementation (when scheduler is None) - for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_index, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): # DyPE: Update step state for timestep-dependent scaling if dype_extension is not None and dype_embedder is not None: dype_extension.update_step_state( diff --git a/invokeai/backend/flux2/denoise.py b/invokeai/backend/flux2/denoise.py index 2ff66236ce8..cd84b14b99d 100644 --- a/invokeai/backend/flux2/denoise.py +++ b/invokeai/backend/flux2/denoise.py @@ -14,6 +14,7 @@ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.util.devices import TorchDevice def denoise( @@ -118,7 +119,7 @@ def denoise( is_heun = hasattr(scheduler, "state_in_first_order") user_step = 0 - pbar = tqdm(total=total_steps, desc="Denoising") + pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}") for step_index in range(num_scheduler_steps): timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized (0-1) for the model @@ -226,7 +227,10 @@ def denoise( pbar.close() else: # Manual Euler stepping (original behavior) - for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + for step_index, (t_curr, t_prev) in tqdm( + list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))), + desc=f"Denoising{TorchDevice.get_session_device_label()}", + ): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) # Run the transformer model (matching diffusers: guidance=guidance, return_dict=False) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 4609a2e92ab..984362f185d 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -17,7 +17,7 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import MODEL_LOAD_LOCK, ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType @@ -57,7 +57,12 @@ def __init__(self, cache_record: CacheRecord, cache: ModelCache): self._cache = cache def __enter__(self) -> AnyModel: - self._cache.lock(self._cache_record, None) + # Hold the MODEL_LOAD_LOCK read lock across the VRAM load (lock() runs + # load_state_dict(assign=True), which calls register_parameter) so it can't overlap a + # concurrent model construction that has the global register_parameter -> meta patch active. + # Acquired before the cache's own lock to keep a consistent lock order (see MODEL_LOAD_LOCK). + with MODEL_LOAD_LOCK.read_lock(): + self._cache.lock(self._cache_record, None) try: self.repair_required_tensors_on_device() return self.model @@ -77,7 +82,9 @@ def model_on_device( :param working_mem_bytes: The amount of working memory to keep available on the compute device when loading the model. """ - self._cache.lock(self._cache_record, working_mem_bytes) + # See __enter__ for why the VRAM load is wrapped in the read lock. + with MODEL_LOAD_LOCK.read_lock(): + self._cache.lock(self._cache_record, working_mem_bytes) try: self.repair_required_tensors_on_device() yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model) @@ -94,7 +101,12 @@ def repair_required_tensors_on_device(self) -> int: cached_model = self._cache_record.cached_model if not isinstance(cached_model, CachedModelWithPartialLoad): return 0 - return cached_model.repair_required_tensors_on_compute_device() + # Repair runs load_state_dict(assign=True) -> register_parameter, so it must hold the read + # lock to avoid being hijacked onto the `meta` device by a concurrent construction. This is + # also called directly (outside __enter__/model_on_device) by some text-encoder invocations, + # so the guard lives here rather than only at the call sites. + with MODEL_LOAD_LOCK.read_lock(): + return cached_model.repair_required_tensors_on_compute_device() class LoadedModel(LoadedModelWithoutConfig): diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 040b55cb6ec..02929ff6132 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -13,7 +13,11 @@ from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key +from invokeai.backend.model_manager.load.model_cache.model_cache import ( + MODEL_LOAD_LOCK, + ModelCache, + get_model_cache_key, +) from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.model_manager.taxonomy import ( @@ -52,7 +56,9 @@ ) -# TO DO: The loader is not thread safe! +# The construction path is not thread-safe on its own; it monkey-patches process-global torch state +# (see MODEL_LOAD_LOCK). Concurrent callers must hold the MODEL_LOAD_LOCK write lock (see +# _load_and_cache). class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" @@ -85,8 +91,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo if not model_path.exists(): raise FileNotFoundError(f"Files for model '{model_config.name}' not found at {model_path}") - with skip_torch_weight_init(): - cache_record = self._load_and_cache(model_config, submodel_type) + cache_record = self._load_and_cache(model_config, submodel_type) return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache) @property @@ -124,25 +129,46 @@ def _get_execution_device( def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord: stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")]) + cache_key = get_model_cache_key(config.key, submodel_type) try: - return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) + return self._ram_cache.get(key=cache_key, stats_name=stats_name) except IndexError: pass - config.path = str(self._get_model_path(config)) - self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) - loaded_model = self._load_model(config, submodel_type) - - # Determine execution device from model config, considering submodel type - execution_device = self._get_execution_device(config, submodel_type) - - self._ram_cache.put( - get_model_cache_key(config.key, submodel_type), - model=loaded_model, - execution_device=execution_device, - ) - - return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) + # Cache miss: construct the model from disk. This path holds the MODEL_LOAD_LOCK *write* + # lock because it relies on process-global, non-thread-safe monkey-patches + # (skip_torch_weight_init and, inside the loaders, accelerate.init_empty_weights / diffusers + # low_cpu_mem_usage). The write lock excludes both other constructions AND concurrent VRAM + # load/unload on other workers (which take the read lock); without that, a concurrent move's + # load_state_dict(assign=True) -> register_parameter gets hijacked onto the `meta` device. + # See MODEL_LOAD_LOCK for the full explanation. + # + # Lock-ordering: the write lock is acquired before any ModelCache._lock taken below + # (get/make_room/put), matching the readers' order, so there is no AB-BA deadlock. + with MODEL_LOAD_LOCK.write_lock(): + # Double-checked locking: another worker sharing this cache may have loaded the same + # entry while we waited for the mutex. (Workers on other devices use a different cache, + # so they will still miss here and construct their own copy — which is intended.) + try: + return self._ram_cache.get(key=cache_key, stats_name=stats_name) + except IndexError: + pass + + config.path = str(self._get_model_path(config)) + self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) + with skip_torch_weight_init(): + loaded_model = self._load_model(config, submodel_type) + + # Determine execution device from model config, considering submodel type + execution_device = self._get_execution_device(config, submodel_type) + + self._ram_cache.put( + cache_key, + model=loaded_model, + execution_device=execution_device, + ) + + return self._ram_cache.get(key=cache_key, stats_name=stats_name) def get_size_fs( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index e3a0928e52b..762bbe167cb 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -2,10 +2,11 @@ import logging import threading import time +from contextlib import contextmanager from dataclasses import dataclass from functools import wraps from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Protocol +from typing import Any, Callable, Dict, Generator, List, Optional, Protocol import psutil import torch @@ -35,6 +36,73 @@ MB = 2**20 +class _ModelLoadReadWriteLock: + """A write-preferring readers-writer lock that serializes model construction against VRAM moves. + + The model load machinery depends on PROCESS-GLOBAL monkey-patches that are not thread-safe: + model CONSTRUCTION (diffusers `from_pretrained` / `accelerate.init_empty_weights`) temporarily + replaces `torch.nn.Module.register_parameter` so that every newly-registered parameter is routed + to the `meta` device. While that patch is installed, ANY `register_parameter` call in ANY thread + is hijacked onto `meta`. VRAM load/unload uses `nn.Module.load_state_dict(assign=True)`, which + assigns `Parameter`s via `__setattr__` -> `register_parameter` — so if it runs concurrently with + a construction on another worker thread, its real weights get stranded on `meta`. That surfaces + later as "Cannot copy out of meta tensor; no data!" or "unrecognized device meta". + + - Construction takes the WRITE lock (exclusive — no reader and no other writer may run). + - VRAM load/unload takes the READ lock (shared, so concurrent moves on different GPUs still + overlap each other; they only block while a construction holds the write lock). + + Write-preferring: once a construction is waiting, new readers queue behind it, so a steady stream + of VRAM moves from busy workers can't starve a pending load. + + Lock-ordering contract: callers MUST acquire this lock *before* any `ModelCache._lock`, never + after. Readers do so by taking the read lock around the outer `ModelCache.lock()` call (see + `LoadedModelWithoutConfig`), and writers around the whole construction (see + `ModelLoader._load_and_cache`). Acquiring it in the other order — cache lock first, then this + lock — would risk an AB-BA deadlock with a writer that takes a cache lock during `put()`. + """ + + def __init__(self) -> None: + self._cond = threading.Condition(threading.Lock()) + self._readers = 0 + self._writers_waiting = 0 + self._writer_active = False + + @contextmanager + def read_lock(self) -> Generator[None, None, None]: + with self._cond: + # Defer to any active or waiting writer (write-preferring). + while self._writer_active or self._writers_waiting > 0: + self._cond.wait() + self._readers += 1 + try: + yield + finally: + with self._cond: + self._readers -= 1 + if self._readers == 0: + self._cond.notify_all() + + @contextmanager + def write_lock(self) -> Generator[None, None, None]: + with self._cond: + self._writers_waiting += 1 + while self._writer_active or self._readers > 0: + self._cond.wait() + self._writers_waiting -= 1 + self._writer_active = True + try: + yield + finally: + with self._cond: + self._writer_active = False + self._cond.notify_all() + + +# Process-global lock guarding the non-thread-safe model load machinery. See _ModelLoadReadWriteLock. +MODEL_LOAD_LOCK = _ModelLoadReadWriteLock() + + # TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels. def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str: """Get the cache key for a model based on the optional submodel type.""" @@ -229,6 +297,11 @@ def unsubscribe() -> None: return unsubscribe + @property + def execution_device(self) -> torch.device: + """Return the default execution device this cache loads models onto.""" + return self._execution_device + @property @synchronized def stats(self) -> Optional[CacheStats]: @@ -546,9 +619,13 @@ def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Option loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0 # Use the model's actual compute_device for logging, not the cache's default model_device = cache_entry.cached_model.compute_device + if model_device.type == "cuda": + device_label = f"cuda device #{model_device.index}" if model_device.index is not None else "cuda device" + else: + device_label = f"{model_device.type} device" self._logger.info( f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto " - f"{model_device.type} device in {(time.time() - start_time):.2f}s. " + f"{device_label} in {(time.time() - start_time):.2f}s. " f"Total model size: {model_total_bytes / MB:.2f}MB, " f"VRAM: {model_cur_vram_bytes / MB:.2f}MB ({loaded_percent:.1%})" ) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 4191db734f9..be3800411ad 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -10,6 +10,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager +from invokeai.backend.util.devices import TorchDevice class StableDiffusionBackend: @@ -44,7 +45,9 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa # ext: preview[pre_denoise_loop, priority=low] ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx) - for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020 + for ctx.step_index, ctx.timestep in enumerate( # noqa: B020 + tqdm(ctx.inputs.timesteps, desc=f"Denoising{TorchDevice.get_session_device_label()}") + ): # ext: inpaint (apply mask to latents on non-inpaint models) ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 359ce45dc4f..7f5d9a96feb 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,3 +1,4 @@ +import threading from typing import Dict, Literal, Optional, Union import torch @@ -46,9 +47,52 @@ class TorchDevice: CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") + # Per-thread execution device. When set (by a session-processor worker thread bound to a + # specific GPU), `choose_torch_device()` returns it instead of consulting the global config. + # This is the lynchpin that makes the ~79 `choose_torch_device()` call sites (nodes, model + # patcher, etc.) resolve to the calling worker's GPU without per-call-site changes. + _session_device = threading.local() + + @classmethod + def set_session_device(cls, device: Union[str, torch.device]) -> None: + """Pin the calling thread's execution device. Used by multi-GPU session workers.""" + cls._session_device.device = cls.normalize(device) + + @classmethod + def get_session_device(cls) -> Optional[torch.device]: + """Return the calling thread's pinned execution device, or None if unset.""" + return getattr(cls._session_device, "device", None) + + @classmethod + def clear_session_device(cls) -> None: + """Remove the calling thread's pinned execution device, reverting to global config.""" + if hasattr(cls._session_device, "device"): + del cls._session_device.device + + @classmethod + def get_session_device_index(cls) -> Optional[int]: + """Return the CUDA index of the calling thread's effective device, or None if not on CUDA. + + Resolves the thread-local session device when a worker has pinned one (multi-GPU), otherwise + falls back to the globally-configured device. Used to annotate logs/progress with the GPU + number so concurrent sessions can be told apart. + """ + device = cls.get_session_device() or cls.choose_torch_device() + return device.index if device.type == "cuda" else None + + @classmethod + def get_session_device_label(cls) -> str: + """Return a ``" (#N)"`` suffix for the calling thread's CUDA device, or ``""`` when not on CUDA.""" + index = cls.get_session_device_index() + return f" (#{index})" if index is not None else "" + @classmethod def choose_torch_device(cls) -> torch.device: """Return the torch.device to use for accelerated inference.""" + # A worker thread pinned to a specific GPU takes precedence over the global config. + session_device = cls.get_session_device() + if session_device is not None: + return session_device app_config = get_config() if app_config.device != "auto": device = torch.device(app_config.device) @@ -93,6 +137,34 @@ def get_torch_device_name(cls) -> str: device = cls.choose_torch_device() return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + @classmethod + def get_generation_devices(cls, generation_devices: Union[str, list[str], None]) -> list[torch.device]: + """Resolve the configured `generation_devices` into a concrete, deduplicated device list. + + - ``"auto"`` (the default) expands to every visible CUDA device, or the single best available + device (mps/cpu) when CUDA is unavailable. + - An explicit list is normalized and deduplicated, with order preserved. + - ``None`` or an empty list yields an empty list; the caller decides the single-device fallback. + """ + if generation_devices == "auto": + if torch.cuda.is_available(): + device_strs: list[str] = [f"cuda:{index}" for index in range(torch.cuda.device_count())] + else: + device_strs = [str(cls.choose_torch_device())] + elif not generation_devices: + return [] + else: + device_strs = list(generation_devices) + + devices: list[torch.device] = [] + seen: set[str] = set() + for device_str in device_strs: + device = cls.normalize(device_str) + if str(device) not in seen: + seen.add(str(device)) + devices.append(device) + return devices + @classmethod def normalize(cls, device: Union[str, torch.device]) -> torch.device: """Add the device index to CUDA devices.""" diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 2c9526c59a9..ce5405d269a 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -6431,6 +6431,30 @@ } } }, + "/api/v1/app/generation_device_options": { + "get": { + "tags": ["app"], + "summary": "Get Generation Device Options", + "description": "List the devices available for generation, for use with the `generation_devices` setting.", + "operationId": "get_generation_device_options", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/GenerationDeviceOption" + }, + "type": "array", + "title": "Response Get Generation Device Options" + } + } + } + } + } + } + }, "/api/v1/app/runtime_config": { "get": { "tags": ["app"], @@ -28429,6 +28453,24 @@ "title": "GeneratePasswordResponse", "description": "Response containing a generated password." }, + "GenerationDeviceOption": { + "properties": { + "device": { + "type": "string", + "title": "Device", + "description": "The device identifier, e.g. 'cuda:0', 'mps', or 'cpu'" + }, + "name": { + "type": "string", + "title": "Name", + "description": "Human-readable device name" + } + }, + "type": "object", + "required": ["device", "name"], + "title": "GenerationDeviceOption", + "description": "A device that may be selected for generation." + }, "GetMaskBoundingBoxInvocation": { "category": "mask", "class": "invocation", @@ -39892,6 +39934,19 @@ ], "default": null, "description": "An image representing the current state of the progress" + }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "The device processing this session, e.g. 'cuda:1' (set only when running on a CUDA GPU)", + "title": "Device" } }, "required": [ @@ -39907,7 +39962,8 @@ "invocation_source_id", "message", "percentage", - "image" + "image", + "device" ], "title": "InvocationProgressEvent", "type": "object" @@ -41119,6 +41175,23 @@ "description": "Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", "default": "auto" }, + "generation_devices": { + "anyOf": [ + { + "type": "string", + "const": "auto" + }, + { + "items": { + "type": "string" + }, + "type": "array" + } + ], + "title": "Generation Devices", + "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", + "default": "auto" + }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], @@ -65129,6 +65202,18 @@ "title": "Retried From Item Id", "description": "The item_id of the queue item that this item was retried from" }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Device", + "description": "The device that processed this queue item, e.g. 'cuda:1' (set only when running on a CUDA GPU)" + }, "session": { "$ref": "#/components/schemas/GraphExecutionState", "description": "The fully-populated session to be executed" @@ -70106,6 +70191,10 @@ ], "title": "Max Queue History", "description": "Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items." + }, + "generation_devices": { + "title": "Generation Devices", + "description": "Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI." } }, "type": "object", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 75367a502db..bfdb0853ec0 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -444,6 +444,7 @@ "next": "Next", "status": "Status", "total": "Total", + "gpu": "GPU #", "time": "Time", "credits": "Credits", "pending": "Pending", @@ -1841,6 +1842,11 @@ "enableNSFWChecker": "Enable NSFW Checker", "general": "General", "generation": "Generation", + "generationDevices": "Generation Devices", + "generationDevicesAuto": "Auto (all GPUs)", + "generationDevicesHelp": "Select which devices to use for parallel generation, one session per device. \"Auto\" uses every available GPU.", + "generationDevicesRestart": "Changes take effect after restarting InvokeAI.", + "generationDevicesSaveFailed": "Failed to save Generation Devices", "imageSubfolderStrategy": "Image Subfolder Strategy", "imageSubfolderStrategyDate": "Date", "imageSubfolderStrategyFlat": "Flat", diff --git a/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts new file mode 100644 index 00000000000..3348ae14a2f --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, it } from 'vitest'; + +import { getCudaDeviceIndex } from './getCudaDeviceIndex'; + +describe('getCudaDeviceIndex', () => { + it('parses the index from a cuda device string', () => { + expect(getCudaDeviceIndex('cuda:0')).toBe(0); + expect(getCudaDeviceIndex('cuda:1')).toBe(1); + expect(getCudaDeviceIndex('cuda:11')).toBe(11); + }); + + it('returns null for non-cuda devices', () => { + expect(getCudaDeviceIndex('cpu')).toBeNull(); + expect(getCudaDeviceIndex('mps')).toBeNull(); + }); + + it('returns null for null/undefined/empty', () => { + expect(getCudaDeviceIndex(null)).toBeNull(); + expect(getCudaDeviceIndex(undefined)).toBeNull(); + expect(getCudaDeviceIndex('')).toBeNull(); + }); + + it('returns null for malformed cuda strings', () => { + expect(getCudaDeviceIndex('cuda')).toBeNull(); + expect(getCudaDeviceIndex('cuda:')).toBeNull(); + expect(getCudaDeviceIndex('cuda:x')).toBeNull(); + expect(getCudaDeviceIndex('cuda:0:0')).toBeNull(); + }); +}); diff --git a/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts new file mode 100644 index 00000000000..d4a394b48fc --- /dev/null +++ b/invokeai/frontend/web/src/common/util/getCudaDeviceIndex.ts @@ -0,0 +1,13 @@ +/** + * Parse the CUDA device index from a device string (e.g. `"cuda:1"` → `1`). + * + * Returns `null` when the device is null/undefined or is not a CUDA device (e.g. `"cpu"`, `"mps"`). + * Used to label progress previews and queue items with their GPU number in multi-GPU setups. + */ +export const getCudaDeviceIndex = (device: string | null | undefined): number | null => { + if (!device) { + return null; + } + const match = /^cuda:(\d+)$/.exec(device); + return match ? Number(match[1]) : null; +}; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index ceaf6c5f435..1978a7fc1ab 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -22,6 +22,7 @@ import type { ImageDTO } from 'services/api/types'; import { useImageViewerContext } from './context'; import { NoContentForViewer } from './NoContentForViewer'; import { ProgressImage } from './ProgressImage2'; +import { ProgressImageTiles } from './ProgressImageTiles'; import { ProgressIndicator } from './ProgressIndicator2'; export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => { @@ -30,10 +31,17 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu const shouldShowItemDetails = useAppSelector(selectShouldShowItemDetails); const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer); const { goToPreviousImage, goToNextImage, isFetching } = useNextPrevItemNavigation(); - const { onLoadImage, $progressEvent, $progressImage, $isProgressImageResolving, $isTemporarilyShowingSelectedImage } = - useImageViewerContext(); + const { + onLoadImage, + $progressEvent, + $progressImage, + $activeProgressData, + $isProgressImageResolving, + $isTemporarilyShowingSelectedImage, + } = useImageViewerContext(); const progressEvent = useStore($progressEvent); const progressImage = useStore($progressImage); + const activeProgressData = useStore($activeProgressData); const isProgressImageResolving = useStore($isProgressImageResolving); const isTemporarilyShowingSelectedImage = useStore($isTemporarilyShowingSelectedImage); const [imageToRender, setImageToRender] = useState(null); @@ -186,6 +194,9 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu }); const withProgress = shouldShowProgressInViewer && hasProgressImage && !isTemporarilyShowingSelectedImage; + // When more than one session is generating concurrently (multi-GPU), tile their previews instead of + // showing only the most recent one. + const withTiledProgress = withProgress && activeProgressData.length > 1; return ( } {withProgress && ( - - {progressEvent && ( - + {withTiledProgress ? ( + + ) : ( + <> + + {progressEvent && ( + + )} + )} )} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx new file mode 100644 index 00000000000..6f66c02e929 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx @@ -0,0 +1,39 @@ +import { Flex, Grid, GridItem } from '@invoke-ai/ui-library'; +import { memo, useMemo } from 'react'; + +import type { ViewerProgressDatum } from './context'; +import { ProgressImage } from './ProgressImage2'; +import { ProgressIndicator } from './ProgressIndicator2'; + +/** + * Renders one tile per concurrently-running session (multi-GPU). Each tile shows that session's live + * preview image plus a small progress indicator. Used by the viewer when more than one session is + * active; a single active session uses the full-size preview instead. + */ +export const ProgressImageTiles = memo(({ data }: { data: ViewerProgressDatum[] }) => { + // Lay the tiles out in a roughly-square grid that grows with the number of active sessions. + const columns = useMemo(() => Math.ceil(Math.sqrt(data.length)), [data.length]); + + return ( + + {data.map((datum) => ( + + + + + + + ))} + + ); +}); +ProgressImageTiles.displayName = 'ProgressImageTiles'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx index b635c37d804..f5ca94f732d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressIndicator2.tsx @@ -1,18 +1,37 @@ import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library'; -import { CircularProgress, Tooltip } from '@invoke-ai/ui-library'; +import { CircularProgress, Text, Tooltip } from '@invoke-ai/ui-library'; +import { getCudaDeviceIndex } from 'common/util/getCudaDeviceIndex'; import { memo } from 'react'; import type { S } from 'services/api/types'; import { formatProgressMessage } from 'services/events/stores'; const circleStyles: SystemStyleObject = { + // The callers position this circle with `position="absolute"`, which makes it the containing + // block for the absolutely-centered GPU label below. Do NOT set `position` here — an `sx` value + // would override the caller's prop and break the circle's corner anchoring. circle: { transitionProperty: 'none', transitionDuration: '0s', }, }; +// Centered GPU-number label drawn inside the ring (CircularProgressLabel isn't exported by the ui-library). +const labelStyles: SystemStyleObject = { + position: 'absolute', + top: '50%', + left: '50%', + transform: 'translate(-50%, -50%)', + fontSize: '0.6rem', + lineHeight: 1, + fontWeight: 'bold', + color: 'invokeBlue.300', + textShadow: '0 0 3px var(--invoke-colors-base-900)', + pointerEvents: 'none', +}; + export const ProgressIndicator = memo( ({ progressEvent, ...rest }: { progressEvent: S['InvocationProgressEvent'] } & CircularProgressProps) => { + const gpuIndex = getCudaDeviceIndex(progressEvent?.device); return ( + > + {gpuIndex !== null && {gpuIndex}} + ); } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx index d502deb4498..6f6a95d4f29 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { selectAutoSwitch } from 'features/gallery/store/gallerySelectors'; import type { ProgressImage as ProgressImageType } from 'features/nodes/types/common'; import { LRUCache } from 'lru-cache'; -import { type Atom, atom, computed, type WritableAtom } from 'nanostores'; +import { type Atom, atom, computed, map, type MapStore, type WritableAtom } from 'nanostores'; import type { PropsWithChildren } from 'react'; import { createContext, memo, useCallback, useContext, useEffect, useMemo, useRef, useState } from 'react'; import type { S } from 'services/api/types'; @@ -12,10 +12,24 @@ import { $socket } from 'services/events/stores'; import { assert } from 'tsafe'; import type { JsonObject } from 'type-fest'; +/** Live progress for a single in-flight session (queue item). Used to tile the viewer when several + * sessions run concurrently (multi-GPU). Only items that have produced a preview image are tracked. */ +export type ViewerProgressDatum = { + itemId: number; + progressEvent: S['InvocationProgressEvent']; + progressImage: ProgressImageType; +}; + +type ViewerProgressDataMap = Record; + type ImageViewerContextValue = { $progressEvent: Atom; $progressImage: Atom; $hasProgressImage: Atom; + /** Per-session progress, keyed by queue item id. Drives the tiled multi-session preview. */ + $progressData: MapStore; + /** Active sessions (those with a preview image), sorted by item id for a stable tile order. */ + $activeProgressData: Atom; $isProgressImageResolving: Atom; $isTemporarilyShowingSelectedImage: WritableAtom; onLoadImage: () => void; @@ -31,6 +45,15 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { const $progressEvent = useState(() => atom(null))[0]; const $progressImage = useState(() => atom(null))[0]; const $hasProgressImage = useState(() => computed($progressImage, (progressImage) => progressImage !== null))[0]; + // Per-session progress, keyed by queue item id, for the tiled multi-session preview (multi-GPU). + const $progressData = useState(() => map({}))[0]; + const $activeProgressData = useState(() => + computed($progressData, (progressData) => + Object.values(progressData) + .filter((datum): datum is ViewerProgressDatum => datum !== undefined) + .sort((a, b) => a.itemId - b.itemId) + ) + )[0]; const $isProgressImageResolving = useState(() => atom(false))[0]; const $isTemporarilyShowingSelectedImage = useState(() => atom(false))[0]; const shouldClearProgressImageOnLoadRef = useRef(false); @@ -56,6 +79,12 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { $progressEvent.set(data); if (data.image) { $progressImage.set(data.image); + // Track per-session so the viewer can tile concurrent sessions (multi-GPU). + $progressData.setKey(data.item_id, { + itemId: data.item_id, + progressEvent: data, + progressImage: data.image, + }); } }; @@ -64,7 +93,7 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { return () => { socket.off('invocation_progress', onInvocationProgress); }; - }, [$isProgressImageResolving, $progressEvent, $progressImage, finishedQueueItemIds, socket]); + }, [$isProgressImageResolving, $progressData, $progressEvent, $progressImage, finishedQueueItemIds, socket]); useEffect(() => { if (!socket) { @@ -81,6 +110,9 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { } if (data.status === 'completed' || data.status === 'canceled' || data.status === 'failed') { finishedQueueItemIds.set(data.item_id, true); + // Remove this session's tile from the multi-session preview as soon as it reaches a terminal + // state. The single-image "resolve" illusion below is handled separately via onLoadImage. + $progressData.setKey(data.item_id, undefined); // Completed queue items have the progress event cleared by the onLoadImage callback. This allows the viewer to // create the illusion of the progress image "resolving" into the final image. If we cleared the progress image // now, there would be a flicker where the progress image disappears before the final image appears, and the @@ -115,7 +147,15 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { return () => { socket.off('queue_item_status_changed', onQueueItemStatusChanged); }; - }, [$isProgressImageResolving, $progressEvent, $progressImage, autoSwitch, finishedQueueItemIds, socket]); + }, [ + $isProgressImageResolving, + $progressData, + $progressEvent, + $progressImage, + autoSwitch, + finishedQueueItemIds, + socket, + ]); const onLoadImage = useCallback(() => { if (!shouldClearProgressImageOnLoadRef.current) { @@ -133,12 +173,16 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { $progressEvent, $progressImage, $hasProgressImage, + $progressData, + $activeProgressData, $isProgressImageResolving, $isTemporarilyShowingSelectedImage, onLoadImage, }), [ $hasProgressImage, + $progressData, + $activeProgressData, $isProgressImageResolving, $isTemporarilyShowingSelectedImage, $progressEvent, diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx index e1c5f4ec973..6d3f773a2e9 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemComponent.tsx @@ -1,6 +1,7 @@ import type { ChakraProps, CollapseProps, FlexProps } from '@invoke-ai/ui-library'; import { ButtonGroup, Collapse, Flex, IconButton, Text } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; +import { getCudaDeviceIndex } from 'common/util/getCudaDeviceIndex'; import { selectCurrentUser } from 'features/auth/store/authSlice'; import QueueStatusBadge from 'features/queue/components/common/QueueStatusBadge'; import { useDestinationText } from 'features/queue/components/QueueList/useDestinationText'; @@ -95,6 +96,8 @@ const QueueItemComponent = ({ index, item }: InnerItemProps) => { return `${seconds}s`; }, [item]); + const gpuIndex = useMemo(() => getCudaDeviceIndex(item.device), [item.device]); + const isCanceled = useMemo(() => ['canceled', 'completed', 'failed'].includes(item.status), [item.status]); const isFailed = useMemo(() => ['canceled', 'failed'].includes(item.status), [item.status]); const originText = useOriginText(item.origin); @@ -140,6 +143,9 @@ const QueueItemComponent = ({ index, item }: InnerItemProps) => { + + {gpuIndex !== null ? gpuIndex : '-'} + {executionTime || '-'} diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx index 4cd3397d217..9f6e2fa5458 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueListHeader.tsx @@ -33,6 +33,7 @@ const QueueListHeader = () => { w={COLUMN_WIDTHS.statusBadge} alignItems="center" /> + { +// In "fit" mode (e.g. the strip below a dockview tab label) the stack is constrained to a fixed height. +// Bars stay at FIT_BAR_HEIGHT_PX while they fit, then shrink to share the available space so they never +// overlap the label, no matter how many sessions are running. +const FIT_BAR_HEIGHT_PX = 4; +const FIT_BAR_GAP_PX = 1; + +type ProgressBarProps = ProgressProps & { + /** Applied to the Flex that stacks the per-session bars. Use for positioning (e.g. absolute). */ + containerProps?: FlexProps; + /** + * When set, the stacked bars are constrained to this total height (in px) and shrink to share it, so + * they never grow past the available space (e.g. the strip below a dockview tab label). + */ + fitHeightPx?: number; +}; + +type BarDescriptor = { + key: number | string; + value: number; + isIndeterminate: boolean; +}; + +const ProgressBar = ({ containerProps, fitHeightPx, ...props }: ProgressBarProps) => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); const isConnected = useStore($isConnected); - const lastProgressEvent = useStore($lastProgressEvent); + const activeProgressEvents = useStore($activeProgressEvents); const loadingModelsCount = useStore($loadingModelsCount); - const value = useMemo(() => { - if (!lastProgressEvent) { - return 0; - } - return (lastProgressEvent.percentage ?? 0) * 100; - }, [lastProgressEvent]); - - const isIndeterminate = useMemo(() => { - if (!isConnected) { - return false; - } - - if (loadingModelsCount > 0) { - return true; - } - - if (!queueStatus?.queue.in_progress) { - return false; - } - if (!lastProgressEvent) { - return true; + const bars = useMemo(() => { + // One bar per in-flight session (multi-GPU). Each session's progress is tracked independently, so + // the bars no longer jump back and forth when several sessions render simultaneously. + if (activeProgressEvents.length > 0) { + return activeProgressEvents.map((event) => ({ + key: event.item_id, + value: (event.percentage ?? 0) * 100, + isIndeterminate: isConnected && (loadingModelsCount > 0 || event.percentage === null || event.percentage === 0), + })); } - if (lastProgressEvent.percentage === null) { - return true; + // Fallback single bar: idle, or generation has started but no progress event has arrived yet (e.g. + // while models are loading). Mirrors the previous single-bar indeterminate behavior. + let isIndeterminate = false; + if (isConnected && (loadingModelsCount > 0 || Boolean(queueStatus?.queue.in_progress))) { + isIndeterminate = true; } + return [{ key: 'idle', value: 0, isIndeterminate }]; + }, [activeProgressEvents, isConnected, loadingModelsCount, queueStatus?.queue.in_progress]); - if (lastProgressEvent.percentage === 0) { - return true; + // In fit mode, cap the whole stack to the available strip and let the bars flex to share it. When the + // bars fit at their natural height the stack is shorter than the cap; once they don't, they shrink. + const isFit = fitHeightPx !== undefined; + const fitContainerProps = useMemo(() => { + if (!isFit) { + return undefined; } + const naturalHeight = bars.length * FIT_BAR_HEIGHT_PX + Math.max(0, bars.length - 1) * FIT_BAR_GAP_PX; + return { h: `${Math.min(naturalHeight, fitHeightPx)}px`, gap: `${FIT_BAR_GAP_PX}px` }; + }, [bars.length, fitHeightPx, isFit]); - return false; - }, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress, loadingModelsCount]); + const fitBarProps: ProgressProps | undefined = isFit ? { flex: '1 1 0', minH: 0, h: 'auto' } : undefined; return ( - + + {bars.map((bar) => ( + + ))} + ); }; diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx new file mode 100644 index 00000000000..2980fb85c73 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsGenerationDevices.tsx @@ -0,0 +1,249 @@ +import { + Flex, + FormControl, + FormHelperText, + FormLabel, + Tag, + TagCloseButton, + Text, + Tooltip, +} from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; +import { toast } from 'features/toast/toast'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + useGetGenerationDeviceOptionsQuery, + useGetRuntimeConfigQuery, + useUpdateRuntimeConfigMutation, +} from 'services/api/endpoints/appInfo'; + +const AUTO = 'auto'; + +type GenerationDevicesValue = 'auto' | string[]; + +/** Drop the verbose vendor prefix so e.g. "NVIDIA GeForce RTX 3090" reads as "RTX 3090". */ +const shortenDeviceName = (name: string): string => name.replace(/^NVIDIA GeForce /, '').replace(/^NVIDIA /, ''); + +type DeviceBadge = { + /** The device identifier, or 'auto' for the special "use all GPUs" badge. */ + device: string; + /** The label shown on the badge. */ + label: string; + /** A human-readable description shown on hover (e.g. the GPU model name). */ + tooltip?: string; +}; + +export const SettingsGenerationDevices = memo(() => { + const { t } = useTranslation(); + const currentUser = useAppSelector(selectCurrentUser); + const { data: runtimeConfig } = useGetRuntimeConfigQuery(); + const { data: deviceOptions } = useGetGenerationDeviceOptionsQuery(); + const [updateRuntimeConfig, { isLoading }] = useUpdateRuntimeConfigMutation(); + + const generationDevices: GenerationDevicesValue = runtimeConfig?.config.generation_devices ?? AUTO; + const isAuto = generationDevices === AUTO; + const selectedDevices = useMemo(() => (isAuto ? [] : [...generationDevices]), [generationDevices, isAuto]); + + const canEditRuntimeConfig = runtimeConfig ? !runtimeConfig.config.multiuser || currentUser?.is_admin : false; + const isDisabled = !runtimeConfig || !canEditRuntimeConfig || isLoading; + + const save = useCallback( + async (value: GenerationDevicesValue) => { + try { + await updateRuntimeConfig({ generation_devices: value }).unwrap(); + } catch { + toast({ + id: 'SETTINGS_GENERATION_DEVICES_SAVE_FAILED', + title: t('settings.generationDevicesSaveFailed'), + status: 'error', + }); + } + }, + [t, updateRuntimeConfig] + ); + + const autoBadge = useMemo(() => ({ device: AUTO, label: t('settings.generationDevicesAuto') }), [t]); + + // Build a per-device badge (label + tooltip) keyed by device id, e.g. "cuda:0 (RTX 3090 #1)". + // Cards sharing a name get a 1-based "#N" suffix so identical GPUs can be told apart. + const deviceBadges = useMemo>(() => { + const options = deviceOptions ?? []; + const nameCounts = new Map(); + for (const option of options) { + const name = shortenDeviceName(option.name); + nameCounts.set(name, (nameCounts.get(name) ?? 0) + 1); + } + const ordinals = new Map(); + const badges: Record = {}; + for (const option of options) { + const name = shortenDeviceName(option.name); + const ordinal = (ordinals.get(name) ?? 0) + 1; + ordinals.set(name, ordinal); + const namePart = (nameCounts.get(name) ?? 0) > 1 ? `${name} #${ordinal}` : name; + badges[option.device] = { device: option.device, label: `${option.device} (${namePart})`, tooltip: option.name }; + } + return badges; + }, [deviceOptions]); + + // Fall back to a bare device id when a configured device isn't in the current options (e.g. a + // GPU that's no longer present). + const getDeviceBadge = useCallback( + (device: string): DeviceBadge => deviceBadges[device] ?? { device, label: device }, + [deviceBadges] + ); + + // The active badges: the `auto` pseudo-device, or the explicitly-selected devices in config order. + const activeBadges = useMemo(() => { + if (isAuto) { + return [autoBadge]; + } + return selectedDevices.map(getDeviceBadge); + }, [autoBadge, getDeviceBadge, isAuto, selectedDevices]); + + // The inactive badges: `auto` (when an explicit list is active) plus any unselected devices. + const inactiveBadges = useMemo(() => { + const badges: DeviceBadge[] = []; + if (!isAuto) { + badges.push(autoBadge); + } + for (const option of deviceOptions ?? []) { + if (!selectedDevices.includes(option.device)) { + badges.push(getDeviceBadge(option.device)); + } + } + return badges; + }, [autoBadge, deviceOptions, getDeviceBadge, isAuto, selectedDevices]); + + const onActivate = useCallback( + (device: string) => { + if (isDisabled) { + return; + } + if (device === AUTO) { + save(AUTO); + return; + } + // Switching from `auto` starts a fresh explicit list; otherwise append to the current selection. + const next = isAuto ? [device] : Array.from(new Set([...selectedDevices, device])); + save(next); + }, + [isAuto, isDisabled, save, selectedDevices] + ); + + const onDeactivate = useCallback( + (device: string) => { + if (isDisabled) { + return; + } + const next = selectedDevices.filter((d) => d !== device); + // Never leave an empty selection — fall back to `auto`, which is always meaningful. + save(next.length > 0 ? next : AUTO); + }, + [isDisabled, save, selectedDevices] + ); + + return ( + + {t('settings.generationDevices')} + + {activeBadges.map((badge) => ( + + ))} + + {inactiveBadges.length > 0 && ( + + {inactiveBadges.map((badge) => ( + + ))} + + )} + + {t('settings.generationDevicesHelp')}{' '} + + {t('settings.generationDevicesRestart')} + + + + ); +}); + +SettingsGenerationDevices.displayName = 'SettingsGenerationDevices'; + +type DeviceTagProps = { + badge: DeviceBadge; + isActive: boolean; + isClosable: boolean; + isDisabled: boolean; + onActivate: (device: string) => void; + onDeactivate: (device: string) => void; +}; + +const DeviceTag = memo(({ badge, isActive, isClosable, isDisabled, onActivate, onDeactivate }: DeviceTagProps) => { + const onClick = useCallback(() => { + if (isDisabled) { + return; + } + if (isActive) { + // An active, non-closable badge (the exclusive `auto`) is a no-op when clicked. + if (isClosable) { + onDeactivate(badge.device); + } + } else { + onActivate(badge.device); + } + }, [badge.device, isActive, isClosable, isDisabled, onActivate, onDeactivate]); + + const isInteractive = !isDisabled && (!isActive || isClosable); + + const tag = ( + + + {badge.label} + + {isActive && isClosable && } + + ); + + if (!badge.tooltip) { + return tag; + } + + return ( + + {tag} + + ); +}); + +DeviceTag.displayName = 'DeviceTag'; diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx index 64478953a37..62604ba0eab 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx @@ -28,6 +28,7 @@ import { useRefreshAfterResetModal } from 'features/system/components/SettingsMo import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled'; import { SettingsDeveloperLogLevel } from 'features/system/components/SettingsModal/SettingsDeveloperLogLevel'; import { SettingsDeveloperLogNamespaces } from 'features/system/components/SettingsModal/SettingsDeveloperLogNamespaces'; +import { SettingsGenerationDevices } from 'features/system/components/SettingsModal/SettingsGenerationDevices'; import { SettingsImageSubfolderStrategySelect } from 'features/system/components/SettingsModal/SettingsImageSubfolderStrategySelect'; import { useClearIntermediates } from 'features/system/components/SettingsModal/useClearIntermediates'; import { StickyScrollable } from 'features/system/components/StickyScrollable'; @@ -321,6 +322,7 @@ const SettingsModal = (props: { children: ReactElement }) => { + diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx index 80f851ab7af..62246faa0f8 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx @@ -34,7 +34,11 @@ export const DockviewTabCanvasViewer = memo((props: IDockviewPanelHeaderProps {currentQueueItemDestination === 'canvas' && isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx index 440847d7451..285afa3a1b6 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx @@ -37,7 +37,11 @@ export const DockviewTabCanvasWorkspace = memo((props: IDockviewPanelHeaderProps {t(props.params.i18nKey)} {currentQueueItemDestination === canvasSessionId && isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx index 1d997caaf78..c89f682e66a 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx @@ -32,7 +32,11 @@ export const DockviewTabProgress = memo((props: IDockviewPanelHeaderProps {isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts index 653f458dde8..d8801fe9845 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts @@ -58,6 +58,16 @@ export const appInfoApi = api.injectEndpoints({ }), providesTags: ['AppConfig'], }), + getGenerationDeviceOptions: build.query< + paths['/api/v1/app/generation_device_options']['get']['responses']['200']['content']['application/json'], + void + >({ + query: () => ({ + url: buildAppInfoUrl('generation_device_options'), + method: 'GET', + }), + providesTags: ['FetchOnReconnect'], + }), updateRuntimeConfig: build.mutation< paths['/api/v1/app/runtime_config']['patch']['responses']['200']['content']['application/json'], paths['/api/v1/app/runtime_config']['patch']['requestBody']['content']['application/json'] @@ -149,6 +159,7 @@ export const { useGetAppDepsQuery, useGetPatchmatchStatusQuery, useGetRuntimeConfigQuery, + useGetGenerationDeviceOptionsQuery, useGetExternalProviderStatusesQuery, useGetExternalProviderConfigsQuery, useSetExternalProviderConfigMutation, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 5726458dc3a..8ac110d4e39 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1652,6 +1652,26 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/app/generation_device_options": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Generation Device Options + * @description List the devices available for generation, for use with the `generation_devices` setting. + */ + get: operations["get_generation_device_options"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/app/runtime_config": { parameters: { query?: never; @@ -12176,6 +12196,22 @@ export type components = { */ password: string; }; + /** + * GenerationDeviceOption + * @description A device that may be selected for generation. + */ + GenerationDeviceOption: { + /** + * Device + * @description The device identifier, e.g. 'cuda:0', 'mps', or 'cpu' + */ + device: string; + /** + * Name + * @description Human-readable device name + */ + name: string; + }; /** * Get Image Mask Bounding Box * @description Gets the bounding box of the given mask image. @@ -16079,6 +16115,12 @@ export type components = { * @default null */ image: components["schemas"]["ProgressImage"] | null; + /** + * Device + * @description The device processing this session, e.g. 'cuda:1' (set only when running on a CUDA GPU) + * @default null + */ + device: string | null; }; /** * InvocationStartedEvent @@ -16492,6 +16534,12 @@ export type components = { * @default auto */ device?: string; + /** + * Generation Devices + * @description Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number) + * @default auto + */ + generation_devices?: "auto" | string[]; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. @@ -28137,6 +28185,11 @@ export type components = { * @description The item_id of the queue item that this item was retried from */ retried_from_item_id?: number | null; + /** + * Device + * @description The device that processed this queue item, e.g. 'cuda:1' (set only when running on a CUDA GPU) + */ + device?: string | null; /** @description The fully-populated session to be executed */ session: components["schemas"]["GraphExecutionState"]; /** @description The workflow associated with this queue item */ @@ -30851,6 +30904,11 @@ export type components = { * @description Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items. */ max_queue_history?: number | null; + /** + * Generation Devices + * @description Devices to use for parallel generation. `auto` uses every available GPU; provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices. Takes effect after restarting InvokeAI. + */ + generation_devices?: unknown; }; /** * UserDTO @@ -36301,6 +36359,26 @@ export interface operations { }; }; }; + get_generation_device_options: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["GenerationDeviceOption"][]; + }; + }; + }; + }; get_runtime_config: { parameters: { query?: never; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 4d5b5901321..f5e1d20676d 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -43,7 +43,13 @@ import { createWorkflowExecutionCoordinator } from 'services/events/workflowExec import type { Socket } from 'socket.io-client'; import type { JsonObject } from 'type-fest'; -import { $lastProgressEvent, $loadingModelsCount } from './stores'; +import { + $lastProgressEvent, + $loadingModelsCount, + clearAllProgressEvents, + clearProgressEvent, + setProgressEvent, +} from './stores'; const log = logger('events'); @@ -86,6 +92,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.emit('subscribe_queue', { queue_id: 'default' }); socket.emit('subscribe_bulk_download', { bulk_download_id: 'default' }); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); }); @@ -93,6 +100,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.debug('Connect error'); setIsConnected(false); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); if (error && error.message) { const data: string | undefined = (error as unknown as { data: string | undefined }).data; @@ -111,6 +119,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.debug('Disconnected'); workflowExecutionCoordinator.cancelPendingWorkflowReconciliations(); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); setIsConnected(false); }); @@ -140,6 +149,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.trace({ data } as JsonObject, _message); $lastProgressEvent.set(data); + setProgressEvent(data); }); socket.on('invocation_error', (data) => { @@ -448,11 +458,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis } // If the queue item is completed, failed, or cancelled, we want to clear the last progress event $lastProgressEvent.set(null); + // Also remove this session's per-item progress so its stacked progress bar disappears. + clearProgressEvent(item_id); } }); socket.on('queue_cleared', (data) => { log.debug({ data }, 'Queue cleared'); + clearAllProgressEvents(); dispatch( queueApi.util.invalidateTags([ 'SessionQueueStatus', diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index 180f4a3a636..7c7630e2019 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -1,5 +1,5 @@ import { round } from 'es-toolkit/compat'; -import { atom, computed } from 'nanostores'; +import { atom, computed, map } from 'nanostores'; import type { S } from 'services/api/types'; import type { AppSocket } from 'services/events/types'; @@ -8,6 +8,33 @@ export const $isConnected = atom(false); export const $lastProgressEvent = atom(null); export const $loadingModelsCount = atom(0); +/** + * Live progress events keyed by queue item id. Unlike `$lastProgressEvent` (a single global value that + * is overwritten by whichever session reported last), this tracks each in-flight session separately so + * the UI can render one progress bar per concurrent session (multi-GPU). Entries are added as progress + * events arrive and removed when the session reaches a terminal state. + */ +const $progressEvents = map>({}); + +/** In-flight sessions sorted by queue item id, for a stable top-to-bottom bar order. */ +export const $activeProgressEvents = computed($progressEvents, (events) => + Object.values(events) + .filter((event): event is S['InvocationProgressEvent'] => event !== undefined) + .sort((a, b) => a.item_id - b.item_id) +); + +export const setProgressEvent = (event: S['InvocationProgressEvent']) => { + $progressEvents.setKey(event.item_id, event); +}; + +export const clearProgressEvent = (itemId: number) => { + $progressEvents.setKey(itemId, undefined); +}; + +export const clearAllProgressEvents = () => { + $progressEvents.set({}); +}; + export const $lastProgressMessage = computed($lastProgressEvent, (val) => { if (!val) { return null; diff --git a/tests/app/routers/test_app_info.py b/tests/app/routers/test_app_info.py index da493cee457..96eb23f1342 100644 --- a/tests/app/routers/test_app_info.py +++ b/tests/app/routers/test_app_info.py @@ -225,6 +225,64 @@ def test_update_runtime_config_image_subfolder_strategy_schema() -> None: } +def test_update_runtime_config_persists_generation_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": ["cuda:0", "cuda:1"]}) + + assert response.status_code == 200 + assert response.json()["config"]["generation_devices"] == ["cuda:0", "cuda:1"] + + config_path = get_config().config_file_path + file_config = load_and_migrate_config(config_path) + assert file_config.generation_devices == ["cuda:0", "cuda:1"] + assert get_config().generation_devices == ["cuda:0", "cuda:1"] + + # "auto" round-trips back to the default. + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": "auto"}) + assert response.status_code == 200 + assert response.json()["config"]["generation_devices"] == "auto" + assert get_config().generation_devices == "auto" + + +def test_update_runtime_config_rejects_invalid_generation_device( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": ["gpu0"]}) + + assert response.status_code == 422 + + +def test_update_runtime_config_rejects_null_generation_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.patch("/api/v1/app/runtime_config", json={"generation_devices": None}) + + assert response.status_code == 422 + + +def test_get_generation_device_options_lists_devices( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: + monkeypatch.setattr(app_info.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(app_info.torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr(app_info.torch.cuda, "get_device_name", lambda index: f"GPU {index}") + + response = client.get("/api/v1/app/generation_device_options") + + assert response.status_code == 200 + assert response.json() == [ + {"device": "cuda:0", "name": "GPU 0"}, + {"device": "cuda:1", "name": "GPU 1"}, + ] + + def test_update_runtime_config_reads_and_writes_yaml_under_config_lock( monkeypatch: Any, mock_invoker: Invoker, client: TestClient ) -> None: diff --git a/tests/app/services/model_load/test_model_load_device_routing.py b/tests/app/services/model_load/test_model_load_device_routing.py new file mode 100644 index 00000000000..85b3868b92f --- /dev/null +++ b/tests/app/services/model_load/test_model_load_device_routing.py @@ -0,0 +1,96 @@ +"""Tests that ModelLoadService routes to the per-device cache for the calling thread (multi-GPU).""" + +import threading +from collections.abc import Iterator + +import pytest +import torch + +from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config +from invokeai.app.services.model_load.model_load_default import ModelLoadService +from invokeai.backend.util.devices import TorchDevice + + +@pytest.fixture(autouse=True) +def restore_global_device() -> Iterator[None]: + """`get_config()` is a process-wide singleton; restore `device` so we don't leak a CUDA device + into later CPU-only tests (e.g. the model-loading suite on the CUDA-less CI runner).""" + config = get_config() + original_device = config.device + try: + yield + finally: + config.device = original_device + TorchDevice.clear_session_device() + + +class _FakeCache: + """Stand-in for ModelCache; ModelLoadService only needs `.execution_device` for keying.""" + + def __init__(self, device: str): + self.execution_device = torch.device(device) + + +def _build_service() -> tuple[ModelLoadService, _FakeCache, _FakeCache]: + cache0 = _FakeCache("cuda:0") + cache1 = _FakeCache("cuda:1") + service = ModelLoadService( + app_config=InvokeAIAppConfig(), + ram_cache=cache0, # type: ignore[arg-type] + ram_caches={"cuda:0": cache0, "cuda:1": cache1}, # type: ignore[arg-type] + ) + return service, cache0, cache1 + + +def test_ram_cache_routes_to_pinned_device(): + """A thread pinned to cuda:1 resolves to that device's cache; the default thread to cuda:0.""" + service, cache0, cache1 = _build_service() + + # The default thread has no session device; point config.device at cuda:0 so it resolves there. + get_config().device = "cuda:0" + assert service.ram_cache is cache0 + + results: dict[str, object] = {} + + def worker(): + TorchDevice.set_session_device("cuda:1") + try: + results["cache"] = service.ram_cache + finally: + TorchDevice.clear_session_device() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert results["cache"] is cache1 + # Main thread is unaffected by the worker's pinning. + assert service.ram_cache is cache0 + + +def test_ram_caches_exposes_all_devices(): + service, cache0, cache1 = _build_service() + caches = service.ram_caches + assert set(caches.keys()) == {"cuda:0", "cuda:1"} + assert caches["cuda:0"] is cache0 + assert caches["cuda:1"] is cache1 + + +def test_unknown_device_falls_back_to_default(): + """A thread pinned to a device with no cache falls back to the default cache.""" + service, cache0, _ = _build_service() + + results: dict[str, object] = {} + + def worker(): + TorchDevice.set_session_device("cuda:7") + try: + results["cache"] = service.ram_cache + finally: + TorchDevice.clear_session_device() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert results["cache"] is cache0 diff --git a/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py b/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py new file mode 100644 index 00000000000..2d61bf0f37a --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py @@ -0,0 +1,90 @@ +"""Tests that concurrent dequeue() calls (multi-GPU session workers) never claim the same item twice.""" + +import threading +import uuid + +import pytest + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState +from tests.test_nodes import PromptTestInvocation + + +@pytest.fixture +def session_queue(mock_invoker: Invoker) -> SqliteSessionQueue: + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert_queue_item(session_queue: SqliteSessionQueue, user_id: str = "system") -> int: + graph = Graph() + graph.add_node(PromptTestInvocation(id="prompt", prompt="test")) + session = GraphExecutionState(graph=graph) + session_json = session.model_dump_json(warnings=False, exclude_none=True) + batch_id = str(uuid.uuid4()) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue ( + queue_id, session, session_id, batch_id, field_values, priority, + workflow, origin, destination, retried_from_item_id, user_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ("default", session_json, session.id, batch_id, None, 0, None, None, None, None, user_id), + ) + return cursor.lastrowid + + +def test_concurrent_dequeue_never_claims_same_item_twice(session_queue: SqliteSessionQueue) -> None: + item_count = 50 + worker_count = 8 + for _ in range(item_count): + _insert_queue_item(session_queue) + + claimed_ids: list[int] = [] + claimed_lock = threading.Lock() + start_barrier = threading.Barrier(worker_count) + + def worker() -> None: + # Release all workers at once to maximize contention on the dequeue path. + start_barrier.wait() + while True: + item = session_queue.dequeue() + if item is None: + break + with claimed_lock: + claimed_ids.append(item.item_id) + + threads = [threading.Thread(target=worker) for _ in range(worker_count)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every item is claimed exactly once: no duplicates, none lost. + assert len(claimed_ids) == item_count + assert len(set(claimed_ids)) == item_count + + +def test_dequeue_records_processing_device(session_queue: SqliteSessionQueue) -> None: + _insert_queue_item(session_queue) + + item = session_queue.dequeue(device="cuda:1") + assert item is not None + assert item.device == "cuda:1" + + # The device persists across later status transitions (which pass device=None). + completed = session_queue._set_queue_item_status(item.item_id, "completed") + assert completed.device == "cuda:1" + + +def test_dequeue_without_device_leaves_device_unset(session_queue: SqliteSessionQueue) -> None: + _insert_queue_item(session_queue) + + item = session_queue.dequeue() + assert item is not None + assert item.device is None diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 3f134e3c3da..aa8433c632e 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -2,6 +2,7 @@ Test abstract device class. """ +import threading from unittest.mock import patch import pytest @@ -24,6 +25,91 @@ def test_device_choice(device_name): assert torch_device == torch.device(device_name) +# ===== per-thread session device (multi-GPU worker pinning) ================ + + +def test_session_device_overrides_config(): + """A per-thread session device takes precedence over the global config.device.""" + config = get_config() + config.device = "cpu" + try: + TorchDevice.set_session_device("cuda:1") + assert TorchDevice.choose_torch_device() == torch.device("cuda:1") + finally: + TorchDevice.clear_session_device() + # Once cleared, we fall back to the global config. + assert TorchDevice.choose_torch_device() == torch.device("cpu") + + +def test_session_device_is_thread_local(): + """Each thread sees only its own pinned device; the main thread is unaffected.""" + config = get_config() + config.device = "cpu" + results: dict[str, torch.device] = {} + barrier = threading.Barrier(2) + + def worker(name: str, device: str): + TorchDevice.set_session_device(device) + # Wait so both threads have set their device before either reads it, proving isolation. + barrier.wait() + results[name] = TorchDevice.choose_torch_device() + TorchDevice.clear_session_device() + + t0 = threading.Thread(target=worker, args=("a", "cuda:0")) + t1 = threading.Thread(target=worker, args=("b", "cuda:1")) + t0.start() + t1.start() + t0.join() + t1.join() + + assert results["a"] == torch.device("cuda:0") + assert results["b"] == torch.device("cuda:1") + # The main thread never set a session device, so it still uses the global config. + assert TorchDevice.get_session_device() is None + assert TorchDevice.choose_torch_device() == torch.device("cpu") + + +# ===== generation_devices resolution (config -> concrete device list) ======= + + +def test_get_generation_devices_auto_expands_to_all_cuda(): + """`auto` enumerates every visible CUDA device.""" + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=True), + patch("invokeai.backend.util.devices.torch.cuda.device_count", return_value=3), + ): + assert TorchDevice.get_generation_devices("auto") == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + ] + + +def test_get_generation_devices_auto_without_cuda(): + """`auto` falls back to the single best device when CUDA is unavailable.""" + config = get_config() + config.device = "cpu" + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=False), + patch("invokeai.backend.util.devices.torch.backends.mps.is_available", return_value=False), + ): + assert TorchDevice.get_generation_devices("auto") == [torch.device("cpu")] + + +def test_get_generation_devices_explicit_list_is_deduplicated(): + """An explicit list is normalized and deduplicated, preserving order.""" + assert TorchDevice.get_generation_devices(["cuda:0", "cuda:0", "cuda:1"]) == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + ] + + +@pytest.mark.parametrize("value", [None, []]) +def test_get_generation_devices_empty(value): + """`None` or an empty list resolves to an empty list (caller handles the single-device fallback).""" + assert TorchDevice.get_generation_devices(value) == [] + + @pytest.mark.parametrize("device_dtype_pair", device_types_cpu) def test_device_dtype_cpu(device_dtype_pair): with (