diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index 032af60c4164..2b19d6d72c0f 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -400,6 +400,12 @@ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant) # Build kwargs for dynamic loading load_kwargs = {"torch_dtype": torchType} + # For large models (e.g., >80GB), enable low_cpu_mem_usage and device_map + # to avoid OOM during loading by distributing across multiple GPUs + if request.LowVRAM: + load_kwargs["low_cpu_mem_usage"] = True + load_kwargs["device_map"] = "balanced" + # Add variant if not loading from single file if not fromSingleFile and variant: load_kwargs["variant"] = variant @@ -428,7 +434,8 @@ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant) ) from e # Apply LowVRAM optimization if supported and requested - if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload'): + # Skip if device_map was used (they conflict with each other) + if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload') and "device_map" not in load_kwargs: pipe.enable_model_cpu_offload() return pipe @@ -436,6 +443,47 @@ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant) def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + def Shutdown(self, request, context): + """ + Shutdown and release GPU memory for the loaded model. + This allows dynamic model reloading with different configurations (e.g., different LoRA adapters). + """ + try: + print("Shutting down diffusers backend...", file=sys.stderr) + + # Release pipeline + if hasattr(self, 'pipe') and self.pipe is not None: + del self.pipe + self.pipe = None + + # Release controlnet + if hasattr(self, 'controlnet') and self.controlnet is not None: + del self.controlnet + self.controlnet = None + + # Release compel + if hasattr(self, 'compel') and self.compel is not None: + del self.compel + self.compel = None + + # Clear CUDA cache to release GPU memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("CUDA cache cleared", file=sys.stderr) + + # Reset state flags + self.img2vid = False + self.txt2vid = False + self.ltx2_pipeline = False + self.options = {} + + print("Diffusers backend shutdown complete", file=sys.stderr) + return backend_pb2.Result(message="Model unloaded successfully", success=True) + except Exception as err: + print(f"Error during shutdown: {err}", file=sys.stderr) + return backend_pb2.Result(success=False, message=f"Shutdown error: {err}") + def LoadModel(self, request, context): try: print(f"Loading model {request.Model}...", file=sys.stderr) @@ -582,9 +630,11 @@ def LoadModel(self, request, context): self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights) if device != "cpu": - self.pipe.to(device) - if self.controlnet: - self.controlnet.to(device) + # Skip .to(device) if device_map was used (they conflict with each other) + if not hasattr(self.pipe, "hf_device_map") or self.pipe.hf_device_map is None: + self.pipe.to(device) + if self.controlnet: + self.controlnet.to(device) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")