From 3f48145f5a4ab67d7ce2a1cfbc6138f8d54d2043 Mon Sep 17 00:00:00 2001 From: jairoguo Date: Thu, 5 Feb 2026 13:14:01 +0800 Subject: [PATCH 1/2] fix(diffusers): support large models with device_map for multi-GPU distribution When loading very large models (e.g., Qwen-Image ~95GB) on GPUs with limited headroom, the model loads successfully but leaves no memory for inference. This PR adds support for multi-GPU distribution via device_map when LowVRAM is enabled: 1. Add low_cpu_mem_usage=True and device_map='balanced' during model loading to distribute large models across multiple GPUs 2. Skip enable_model_cpu_offload() when device_map is used, as they conflict with each other (ValueError: device mapping strategy doesn't allow enable_model_cpu_offload) 3. Skip .to(device) when device_map is used, as they also conflict (ValueError: device mapping strategy doesn't allow explicit device placement using to()) This enables running models like Qwen-Image on multi-GPU setups where a single GPU doesn't have enough memory for both model weights and inference. Tested with: - Qwen-Image (~95GB) on 3x NVIDIA H20 (96GB each) - Configuration: low_vram: true, pipeline_type: QwenImagePipeline --- backend/python/diffusers/backend.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index 032af60c4164..a43a9a008c4a 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -400,6 +400,12 @@ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant) # Build kwargs for dynamic loading load_kwargs = {"torch_dtype": torchType} + # For large models (e.g., >80GB), enable low_cpu_mem_usage and device_map + # to avoid OOM during loading by distributing across multiple GPUs + if request.LowVRAM: + load_kwargs["low_cpu_mem_usage"] = True + load_kwargs["device_map"] = "balanced" + # Add variant if not loading from single file if not fromSingleFile and variant: load_kwargs["variant"] = variant @@ -428,7 +434,8 @@ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant) ) from e # Apply LowVRAM optimization if supported and requested - if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload'): + # Skip if device_map was used (they conflict with each other) + if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload') and "device_map" not in load_kwargs: pipe.enable_model_cpu_offload() return pipe @@ -582,9 +589,11 @@ def LoadModel(self, request, context): self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights) if device != "cpu": - self.pipe.to(device) - if self.controlnet: - self.controlnet.to(device) + # Skip .to(device) if device_map was used (they conflict with each other) + if not hasattr(self.pipe, "hf_device_map") or self.pipe.hf_device_map is None: + self.pipe.to(device) + if self.controlnet: + self.controlnet.to(device) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") From e3a64e0784703bce15bcdba5f3650fce8425bdaf Mon Sep 17 00:00:00 2001 From: jairoguo Date: Thu, 5 Feb 2026 14:17:34 +0800 Subject: [PATCH 2/2] feat(diffusers): add Shutdown method to release GPU memory Add Shutdown method to the diffusers backend that properly releases GPU memory when a model is unloaded. This enables dynamic model reloading with different configurations (e.g., switching LoRA adapters) without restarting the service. The Shutdown method: - Releases the pipeline, controlnet, and compel objects - Clears CUDA cache with torch.cuda.empty_cache() - Resets state flags (img2vid, txt2vid, ltx2_pipeline) This works with LocalAI's existing /backend/shutdown API endpoint, which terminates the gRPC process. The explicit cleanup ensures GPU memory is properly released before process termination. Tested with Qwen-Image (~95GB) on NVIDIA H20 GPUs. --- backend/python/diffusers/backend.py | 41 +++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index a43a9a008c4a..2b19d6d72c0f 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -443,6 +443,47 @@ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant) def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + def Shutdown(self, request, context): + """ + Shutdown and release GPU memory for the loaded model. + This allows dynamic model reloading with different configurations (e.g., different LoRA adapters). + """ + try: + print("Shutting down diffusers backend...", file=sys.stderr) + + # Release pipeline + if hasattr(self, 'pipe') and self.pipe is not None: + del self.pipe + self.pipe = None + + # Release controlnet + if hasattr(self, 'controlnet') and self.controlnet is not None: + del self.controlnet + self.controlnet = None + + # Release compel + if hasattr(self, 'compel') and self.compel is not None: + del self.compel + self.compel = None + + # Clear CUDA cache to release GPU memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("CUDA cache cleared", file=sys.stderr) + + # Reset state flags + self.img2vid = False + self.txt2vid = False + self.ltx2_pipeline = False + self.options = {} + + print("Diffusers backend shutdown complete", file=sys.stderr) + return backend_pb2.Result(message="Model unloaded successfully", success=True) + except Exception as err: + print(f"Error during shutdown: {err}", file=sys.stderr) + return backend_pb2.Result(success=False, message=f"Shutdown error: {err}") + def LoadModel(self, request, context): try: print(f"Loading model {request.Model}...", file=sys.stderr)