From caedbfec12735b274cd3b5754d8b57fbf10add7e Mon Sep 17 00:00:00 2001 From: longkeyy Date: Mon, 22 Dec 2025 01:48:55 +0800 Subject: [PATCH] fix: Add MPS and CPU device support This PR fixes several issues that prevent DiffSynth-Studio from running on non-CUDA devices (Apple Silicon MPS and CPU): 1. base_pipeline.py: Check if empty_cache exists before calling it - Only CUDA has torch.cuda.empty_cache() - MPS and CPU don't have this method 2. siglip2_image_encoder.py: Remove hardcoded device="cuda" default - Now auto-detects device from model parameters - Falls back to specified device if provided 3. dinov3_image_encoder.py: Remove hardcoded device="cuda" default - Same fix as siglip2_image_encoder.py 4. vram/layers.py: Check if mem_get_info exists before calling it - Only CUDA and NPU have mem_get_info() - For MPS/CPU, assume enough memory is available These changes enable running Qwen-Image pipelines on Apple Silicon Macs and CPU-only machines without requiring any monkey-patching workarounds. --- diffsynth/core/vram/layers.py | 6 +++++- diffsynth/diffusion/base_pipeline.py | 5 ++++- diffsynth/models/dinov3_image_encoder.py | 5 ++++- diffsynth/models/siglip2_image_encoder.py | 5 ++++- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py index 751792d0..b0179b91 100644 --- a/diffsynth/core/vram/layers.py +++ b/diffsynth/core/vram/layers.py @@ -64,7 +64,11 @@ def cast_to(self, weight, dtype, device): def check_free_vram(self): device = self.computation_device if self.computation_device != "npu" else "npu:0" - gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device) + device_module = getattr(torch, self.computation_device_type, None) + # Only CUDA and NPU have mem_get_info, for MPS/CPU assume enough memory + if device_module is None or not hasattr(device_module, "mem_get_info"): + return True + gpu_mem_state = device_module.mem_get_info(device) used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) return used_memory < self.vram_limit diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index fa355a16..4d90ff29 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -155,7 +155,10 @@ def load_models_to_device(self, model_names): for module in model.modules(): if hasattr(module, "offload"): module.offload() - getattr(torch, self.device_type).empty_cache() + # Clear cache if available (only CUDA has empty_cache) + device_module = getattr(torch, self.device_type, None) + if device_module is not None and hasattr(device_module, "empty_cache"): + device_module.empty_cache() # onload models for name, model in self.named_children(): if name in model_names: diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py index be2ee587..70eec5b2 100644 --- a/diffsynth/models/dinov3_image_encoder.py +++ b/diffsynth/models/dinov3_image_encoder.py @@ -70,7 +70,10 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=None): + # Use model's device if not specified + if device is None: + device = next(self.parameters()).device inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) bool_masked_pos = None diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 10184f85..dbc3a0e1 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -47,7 +47,10 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=None): + # Use model's device if not specified + if device is None: + device = next(self.parameters()).device pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = pixel_values.to(device=device, dtype=torch_dtype) output_attentions = False