diff --git a/diffsynth/__init__.py b/diffsynth/__init__.py index bb67a43f..6a404feb 100644 --- a/diffsynth/__init__.py +++ b/diffsynth/__init__.py @@ -1 +1,7 @@ from .core import * + +# Distributed/Multi-GPU support (optional import) +try: + from . import distributed +except ImportError: + pass diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py index 751792d0..b0179b91 100644 --- a/diffsynth/core/vram/layers.py +++ b/diffsynth/core/vram/layers.py @@ -64,7 +64,11 @@ def cast_to(self, weight, dtype, device): def check_free_vram(self): device = self.computation_device if self.computation_device != "npu" else "npu:0" - gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device) + device_module = getattr(torch, self.computation_device_type, None) + # Only CUDA and NPU have mem_get_info, for MPS/CPU assume enough memory + if device_module is None or not hasattr(device_module, "mem_get_info"): + return True + gpu_mem_state = device_module.mem_get_info(device) used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3) return used_memory < self.vram_limit diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index fa355a16..4d73acdb 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -2,13 +2,29 @@ import torch import numpy as np from einops import repeat, reduce -from typing import Union +from typing import Union, Optional, Dict, List from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type from ..utils.lora import GeneralLoRALoader from ..models.model_loader import ModelPool from ..utils.controlnet import ControlNetInput +def get_available_gpus() -> List[int]: + """Get list of available GPU device IDs.""" + if torch.cuda.is_available(): + return list(range(torch.cuda.device_count())) + return [] + + +def get_gpu_memory_map() -> Dict[int, float]: + """Get available memory (in GB) for each GPU.""" + memory_map = {} + for gpu_id in get_available_gpus(): + free, total = torch.cuda.mem_get_info(gpu_id) + memory_map[gpu_id] = free / (1024 ** 3) + return memory_map + + class PipelineUnit: def __init__( self, @@ -155,7 +171,10 @@ def load_models_to_device(self, model_names): for module in model.modules(): if hasattr(module, "offload"): module.offload() - getattr(torch, self.device_type).empty_cache() + # Clear cache if available (only CUDA has empty_cache) + device_module = getattr(torch, self.device_type, None) + if device_module is not None and hasattr(device_module, "empty_cache"): + device_module.empty_cache() # onload models for name, model in self.named_children(): if name in model_names: @@ -301,6 +320,127 @@ def check_vram_management_state(self): if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"): vram_management_enabled = True return vram_management_enabled + + + def enable_multi_gpu( + self, + mode: str = "auto", + device_map: Optional[Dict[str, str]] = None, + tensor_parallel_layers: Optional[List[str]] = None, + ): + """ + Enable multi-GPU support for this pipeline. + + Args: + mode: Parallelism mode: + - "auto": Automatically select best strategy + - "model": Distribute different models to different GPUs + - "tensor": Split large layers across GPUs (requires torchrun) + - "data": Same model on all GPUs, different data (for batch processing) + device_map: Manual device mapping, e.g., {"dit": "cuda:0", "text_encoder": "cuda:1"} + tensor_parallel_layers: Layer names to apply tensor parallelism to + + Returns: + self for method chaining + """ + try: + from ..distributed import MultiGPUPipeline, enable_multi_gpu + except ImportError: + print("Warning: Distributed module not available. Multi-GPU support disabled.") + return self + + num_gpus = len(get_available_gpus()) + if num_gpus <= 1: + print("Only one GPU available. Multi-GPU support not needed.") + return self + + print(f"Enabling multi-GPU support with {num_gpus} GPUs, mode={mode}") + + if mode == "model" and device_map is None: + # Auto-create device map for model parallel + device_map = self._auto_create_device_map() + + if device_map is not None: + self._apply_device_map(device_map) + print(f"Applied device map: {device_map}") + + return self + + + def _auto_create_device_map(self) -> Dict[str, str]: + """ + Automatically create a device map for model parallelism. + + Distributes model components across available GPUs based on their size. + """ + num_gpus = len(get_available_gpus()) + if num_gpus <= 1: + return {} + + # Get model components and their sizes + components = {} + for name, module in self.named_children(): + if module is not None and isinstance(module, torch.nn.Module): + num_params = sum(p.numel() for p in module.parameters()) + if num_params > 0: + components[name] = num_params + + if not components: + return {} + + # Sort by size (largest first) and assign to GPUs + sorted_components = sorted(components.items(), key=lambda x: -x[1]) + device_map = {} + gpu_loads = [0] * num_gpus + + for name, size in sorted_components: + # Assign to GPU with least load + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + device_map[name] = f"cuda:{min_gpu}" + gpu_loads[min_gpu] += size + + return device_map + + + def _apply_device_map(self, device_map: Dict[str, str]): + """ + Apply a device map to distribute models across GPUs. + + Args: + device_map: Mapping of model names to devices + """ + for name, device in device_map.items(): + if hasattr(self, name): + module = getattr(self, name) + if module is not None: + module.to(device) + print(f" Moved {name} to {device}") + + + def get_model_distribution(self) -> Dict[str, str]: + """ + Get the current distribution of models across devices. + + Returns: + Dictionary mapping model names to their current devices + """ + distribution = {} + for name, module in self.named_children(): + if module is not None and isinstance(module, torch.nn.Module): + try: + device = next(module.parameters()).device + distribution[name] = str(device) + except StopIteration: + distribution[name] = "no parameters" + return distribution + + + def print_model_distribution(self): + """Print the current distribution of models across devices.""" + distribution = self.get_model_distribution() + print("Model distribution:") + for name, device in distribution.items(): + print(f" {name}: {device}") def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): diff --git a/diffsynth/distributed/__init__.py b/diffsynth/distributed/__init__.py new file mode 100644 index 00000000..3b581c0d --- /dev/null +++ b/diffsynth/distributed/__init__.py @@ -0,0 +1,70 @@ +""" +Distributed computing support for DiffSynth-Studio. + +Provides multi-GPU support through: +- Data Parallel (DP): Batch-level parallelism +- Tensor Parallel (TP): Layer-level parallelism +- Pipeline utilities for distributed inference +""" + +from .parallel import ( + init_distributed, + cleanup_distributed, + get_rank, + get_world_size, + is_distributed, + is_main_process, + barrier, + broadcast, + all_reduce, + all_gather, +) + +from .tensor_parallel import ( + TensorParallelLinear, + ColumnParallelLinear, + RowParallelLinear, + split_tensor_along_dim, + gather_tensor_along_dim, +) + +from .data_parallel import ( + DataParallelPipeline, + scatter_batch, + gather_outputs, +) + +from .multi_gpu import ( + MultiGPUPipeline, + auto_device_map, + get_optimal_device_map, +) + +__all__ = [ + # Initialization + "init_distributed", + "cleanup_distributed", + "get_rank", + "get_world_size", + "is_distributed", + "is_main_process", + # Communication + "barrier", + "broadcast", + "all_reduce", + "all_gather", + # Tensor Parallel + "TensorParallelLinear", + "ColumnParallelLinear", + "RowParallelLinear", + "split_tensor_along_dim", + "gather_tensor_along_dim", + # Data Parallel + "DataParallelPipeline", + "scatter_batch", + "gather_outputs", + # Multi-GPU + "MultiGPUPipeline", + "auto_device_map", + "get_optimal_device_map", +] diff --git a/diffsynth/distributed/data_parallel.py b/diffsynth/distributed/data_parallel.py new file mode 100644 index 00000000..07ddb740 --- /dev/null +++ b/diffsynth/distributed/data_parallel.py @@ -0,0 +1,305 @@ +""" +Data Parallelism utilities for batch-level parallel processing. + +Enables processing multiple inputs across GPUs simultaneously. +""" + +import torch +import torch.nn as nn +import torch.distributed as dist +from typing import List, Optional, Dict, Any, Union, Tuple +from PIL import Image +import numpy as np + +from .parallel import ( + get_rank, + get_world_size, + is_distributed, + is_main_process, + barrier, + all_gather, +) + + +def scatter_batch( + batch: Union[torch.Tensor, List, Dict], + dim: int = 0, +) -> Union[torch.Tensor, List, Dict]: + """ + Scatter a batch across all ranks. + + Args: + batch: Input batch (tensor, list, or dict) + dim: Dimension to scatter along for tensors + + Returns: + Local portion of the batch for this rank + """ + if not is_distributed(): + return batch + + rank = get_rank() + world_size = get_world_size() + + if isinstance(batch, torch.Tensor): + # Split tensor + chunks = batch.chunk(world_size, dim=dim) + if rank < len(chunks): + return chunks[rank] + else: + # Edge case: more ranks than batch items + return chunks[-1][:0] # Empty tensor + + elif isinstance(batch, (list, tuple)): + # Split list + batch_size = len(batch) + per_rank = batch_size // world_size + remainder = batch_size % world_size + + start = rank * per_rank + min(rank, remainder) + end = start + per_rank + (1 if rank < remainder else 0) + + result = batch[start:end] + return type(batch)(result) if isinstance(batch, tuple) else result + + elif isinstance(batch, dict): + # Scatter each value + return {k: scatter_batch(v, dim) for k, v in batch.items()} + + else: + # Cannot scatter, return as-is + return batch + + +def gather_outputs( + output: Union[torch.Tensor, List, Dict], + dim: int = 0, + dst: int = 0, +) -> Union[torch.Tensor, List, Dict]: + """ + Gather outputs from all ranks. + + Args: + output: Local output from this rank + dim: Dimension to gather along for tensors + dst: Destination rank (only this rank gets the full output) + + Returns: + Gathered output on dst rank, original on others + """ + if not is_distributed(): + return output + + rank = get_rank() + world_size = get_world_size() + + if isinstance(output, torch.Tensor): + # Gather tensor + gathered = all_gather(output) + if rank == dst: + return torch.cat(gathered, dim=dim) + return output + + elif isinstance(output, list): + # Gather lists + gathered = all_gather_object(output) + if rank == dst: + result = [] + for sublist in gathered: + result.extend(sublist) + return result + return output + + elif isinstance(output, dict): + # Gather each value + if rank == dst: + return {k: gather_outputs(v, dim, dst) for k, v in output.items()} + return output + + elif isinstance(output, Image.Image): + # Convert PIL to tensor, gather, convert back + output_tensor = torch.from_numpy(np.array(output)) + gathered = all_gather(output_tensor) + if rank == dst: + return [Image.fromarray(t.numpy()) for t in gathered] + return output + + else: + # Cannot gather, return as-is + return output + + +def all_gather_object(obj: Any) -> List[Any]: + """Gather arbitrary Python objects from all ranks.""" + if not is_distributed(): + return [obj] + + world_size = get_world_size() + output = [None] * world_size + dist.all_gather_object(output, obj) + return output + + +class DataParallelPipeline: + """ + Wrapper to enable data parallelism for diffusion pipelines. + + Distributes batch processing across multiple GPUs. + """ + + def __init__( + self, + pipeline: nn.Module, + gather_on_main: bool = True, + ): + """ + Args: + pipeline: The diffusion pipeline to parallelize + gather_on_main: Whether to gather results on main process + """ + self.pipeline = pipeline + self.gather_on_main = gather_on_main + self.rank = get_rank() + self.world_size = get_world_size() + + def __call__( + self, + batch_size: int = 1, + **kwargs, + ) -> Union[List[Image.Image], torch.Tensor]: + """ + Run pipeline with data parallel batch processing. + + Args: + batch_size: Total batch size across all ranks + **kwargs: Pipeline arguments (prompt, negative_prompt, etc.) + + Returns: + Generated images (gathered on main process if gather_on_main=True) + """ + # Calculate local batch size + local_batch_size = batch_size // self.world_size + remainder = batch_size % self.world_size + if self.rank < remainder: + local_batch_size += 1 + + if local_batch_size == 0: + # This rank has no work + barrier() + if self.gather_on_main and not is_main_process(): + return [] + return [] + + # Scatter prompt if it's a list + prompt = kwargs.get("prompt", "") + if isinstance(prompt, list): + kwargs["prompt"] = scatter_batch(prompt) + + negative_prompt = kwargs.get("negative_prompt", "") + if isinstance(negative_prompt, list): + kwargs["negative_prompt"] = scatter_batch(negative_prompt) + + # Scatter seed if provided as list + seed = kwargs.get("seed", None) + if isinstance(seed, list): + kwargs["seed"] = scatter_batch(seed) + + # Run local inference + local_output = self.pipeline(**kwargs) + + # Gather results + if self.gather_on_main: + output = gather_outputs(local_output, dst=0) + barrier() + return output if is_main_process() else local_output + else: + barrier() + return local_output + + def __getattr__(self, name: str): + """Delegate attribute access to wrapped pipeline.""" + if name in ("pipeline", "gather_on_main", "rank", "world_size"): + return object.__getattribute__(self, name) + return getattr(self.pipeline, name) + + def to(self, device): + """Move pipeline to device.""" + self.pipeline.to(device) + return self + + +class BatchDistributor: + """ + Utility class for distributing work across GPUs. + + Handles batch creation and result collection for multi-GPU inference. + """ + + def __init__(self, world_size: Optional[int] = None): + self.world_size = world_size or get_world_size() + self.rank = get_rank() + + def distribute_prompts( + self, + prompts: List[str], + ) -> Tuple[List[str], List[int]]: + """ + Distribute prompts across ranks. + + Args: + prompts: List of prompts + + Returns: + Tuple of (local_prompts, original_indices) + """ + indices = list(range(len(prompts))) + local_prompts = scatter_batch(prompts) + local_indices = scatter_batch(indices) + return local_prompts, local_indices + + def collect_results( + self, + local_results: List[Any], + local_indices: List[int], + ) -> List[Any]: + """ + Collect results from all ranks and reorder. + + Args: + local_results: Results from this rank + local_indices: Original indices of local results + + Returns: + All results in original order + """ + # Gather results and indices + all_results = all_gather_object(local_results) + all_indices = all_gather_object(local_indices) + + if not is_main_process(): + return local_results + + # Flatten and reorder + flat_results = [] + flat_indices = [] + for results, indices in zip(all_results, all_indices): + flat_results.extend(results) + flat_indices.extend(indices) + + # Sort by original index + sorted_pairs = sorted(zip(flat_indices, flat_results), key=lambda x: x[0]) + return [result for _, result in sorted_pairs] + + def get_local_batch_size(self, total_batch_size: int) -> int: + """ + Calculate local batch size for this rank. + + Args: + total_batch_size: Total batch size across all ranks + + Returns: + Batch size for this rank + """ + base = total_batch_size // self.world_size + remainder = total_batch_size % self.world_size + return base + (1 if self.rank < remainder else 0) diff --git a/diffsynth/distributed/multi_gpu.py b/diffsynth/distributed/multi_gpu.py new file mode 100644 index 00000000..7bac82ba --- /dev/null +++ b/diffsynth/distributed/multi_gpu.py @@ -0,0 +1,436 @@ +""" +Multi-GPU Pipeline Support. + +Provides utilities for distributing models across multiple GPUs +and running inference efficiently. +""" + +import torch +import torch.nn as nn +from typing import Optional, Dict, List, Any, Union, Tuple +from dataclasses import dataclass +import math + +from .parallel import ( + init_distributed, + get_rank, + get_world_size, + is_distributed, + is_main_process, + barrier, + get_device, +) +from .tensor_parallel import apply_tensor_parallelism + + +@dataclass +class DeviceMapEntry: + """Entry in a device map.""" + device: str + dtype: torch.dtype = torch.bfloat16 + offload: bool = False + + +def get_gpu_memory_info() -> List[Tuple[int, int, int]]: + """ + Get memory info for all available GPUs. + + Returns: + List of (device_id, free_memory_bytes, total_memory_bytes) + """ + if not torch.cuda.is_available(): + return [] + + info = [] + for i in range(torch.cuda.device_count()): + free, total = torch.cuda.mem_get_info(i) + info.append((i, free, total)) + return info + + +def estimate_model_size(model: nn.Module, dtype: torch.dtype = torch.bfloat16) -> int: + """ + Estimate the memory size of a model in bytes. + + Args: + model: PyTorch model + dtype: Data type for estimation + + Returns: + Estimated size in bytes + """ + dtype_size = { + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, + torch.int8: 1, + }.get(dtype, 4) + + total_params = sum(p.numel() for p in model.parameters()) + return total_params * dtype_size + + +def auto_device_map( + models: Dict[str, nn.Module], + max_memory_per_gpu: Optional[Dict[int, int]] = None, + dtype: torch.dtype = torch.bfloat16, +) -> Dict[str, str]: + """ + Automatically create a device map for models. + + Distributes models across available GPUs based on their size + and available memory. + + Args: + models: Dictionary of model name -> model + max_memory_per_gpu: Maximum memory to use per GPU (bytes) + dtype: Data type for memory estimation + + Returns: + Dictionary of model name -> device string + """ + if not torch.cuda.is_available(): + return {name: "cpu" for name in models} + + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + return {name: "cpu" for name in models} + + if num_gpus == 1: + return {name: "cuda:0" for name in models} + + # Get available memory per GPU + gpu_memory = get_gpu_memory_info() + + if max_memory_per_gpu is None: + max_memory_per_gpu = {i: int(free * 0.9) for i, free, _ in gpu_memory} + + # Estimate model sizes + model_sizes = {name: estimate_model_size(model, dtype) for name, model in models.items()} + + # Sort models by size (largest first) + sorted_models = sorted(model_sizes.items(), key=lambda x: -x[1]) + + # Greedy assignment + device_map = {} + gpu_usage = {i: 0 for i in range(num_gpus)} + + for name, size in sorted_models: + # Find GPU with most available space + best_gpu = min( + range(num_gpus), + key=lambda i: gpu_usage[i] if gpu_usage[i] + size <= max_memory_per_gpu.get(i, float('inf')) else float('inf') + ) + + if gpu_usage[best_gpu] + size <= max_memory_per_gpu.get(best_gpu, float('inf')): + device_map[name] = f"cuda:{best_gpu}" + gpu_usage[best_gpu] += size + else: + # No GPU has enough space, use CPU offload + device_map[name] = "cpu" + + return device_map + + +def get_optimal_device_map( + model_configs: List[Dict[str, Any]], + strategy: str = "balanced", +) -> Dict[str, str]: + """ + Get optimal device mapping for model configurations. + + Args: + model_configs: List of model configurations with estimated sizes + strategy: Distribution strategy ("balanced", "sequential", "largest_first") + + Returns: + Device map dictionary + """ + if not torch.cuda.is_available(): + return {cfg.get("name", f"model_{i}"): "cpu" for i, cfg in enumerate(model_configs)} + + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + return {cfg.get("name", f"model_{i}"): "cpu" for i, cfg in enumerate(model_configs)} + + device_map = {} + + if strategy == "balanced": + # Distribute models evenly across GPUs + for i, cfg in enumerate(model_configs): + name = cfg.get("name", f"model_{i}") + device_map[name] = f"cuda:{i % num_gpus}" + + elif strategy == "sequential": + # Fill GPUs sequentially + gpu_idx = 0 + gpu_memory = get_gpu_memory_info() + current_usage = 0 + + for i, cfg in enumerate(model_configs): + name = cfg.get("name", f"model_{i}") + size = cfg.get("size_bytes", 0) + + _, free, _ = gpu_memory[gpu_idx] + if current_usage + size > free * 0.9: + gpu_idx = min(gpu_idx + 1, num_gpus - 1) + current_usage = 0 + + device_map[name] = f"cuda:{gpu_idx}" + current_usage += size + + elif strategy == "largest_first": + # Sort by size and assign largest to first GPU + sorted_configs = sorted( + enumerate(model_configs), + key=lambda x: x[1].get("size_bytes", 0), + reverse=True + ) + + gpu_usage = [0] * num_gpus + + for i, cfg in sorted_configs: + name = cfg.get("name", f"model_{i}") + size = cfg.get("size_bytes", 0) + + # Assign to GPU with least usage + best_gpu = min(range(num_gpus), key=lambda g: gpu_usage[g]) + device_map[name] = f"cuda:{best_gpu}" + gpu_usage[best_gpu] += size + + return device_map + + +class MultiGPUPipeline: + """ + Pipeline wrapper that distributes model components across multiple GPUs. + + Supports: + - Model parallel: Different models on different GPUs + - Tensor parallel: Single model split across GPUs + - Data parallel: Same model on all GPUs processing different data + """ + + def __init__( + self, + pipeline: nn.Module, + parallel_mode: str = "model", + tensor_parallel_layers: Optional[List[str]] = None, + device_map: Optional[Dict[str, str]] = None, + ): + """ + Args: + pipeline: The diffusion pipeline + parallel_mode: Parallelism mode ("model", "tensor", "data", "hybrid") + tensor_parallel_layers: Layers to apply tensor parallelism to + device_map: Manual device mapping for model parallel + """ + self.pipeline = pipeline + self.parallel_mode = parallel_mode + self.tensor_parallel_layers = tensor_parallel_layers + self.device_map = device_map + + self._setup_parallelism() + + def _setup_parallelism(self): + """Set up the parallelism strategy.""" + num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + + if num_gpus <= 1: + # Single GPU or CPU - no parallelism needed + self.parallel_mode = "none" + return + + if self.parallel_mode == "model": + self._setup_model_parallel() + elif self.parallel_mode == "tensor": + self._setup_tensor_parallel() + elif self.parallel_mode == "data": + self._setup_data_parallel() + elif self.parallel_mode == "hybrid": + self._setup_hybrid_parallel() + + def _setup_model_parallel(self): + """Set up model parallelism - different models on different GPUs.""" + # Get all model components + models = {} + for name, module in self.pipeline.named_children(): + if isinstance(module, nn.Module) and sum(p.numel() for p in module.parameters()) > 0: + models[name] = module + + # Create device map if not provided + if self.device_map is None: + self.device_map = auto_device_map(models) + + # Move models to assigned devices + for name, device in self.device_map.items(): + if hasattr(self.pipeline, name): + module = getattr(self.pipeline, name) + if module is not None: + module.to(device) + + def _setup_tensor_parallel(self): + """Set up tensor parallelism - layers split across GPUs.""" + # Initialize distributed + init_distributed() + + # Apply tensor parallelism to specified layers + for name, module in self.pipeline.named_children(): + if isinstance(module, nn.Module): + apply_tensor_parallelism( + module, + tp_layers=self.tensor_parallel_layers, + ) + + def _setup_data_parallel(self): + """Set up data parallelism - same model on all GPUs.""" + from .data_parallel import DataParallelPipeline + + # Wrap pipeline in data parallel wrapper + self._data_parallel = DataParallelPipeline(self.pipeline) + + def _setup_hybrid_parallel(self): + """Set up hybrid parallelism - combination of strategies.""" + # Use model parallel for large components + self._setup_model_parallel() + + # Use tensor parallel for transformer blocks + if self.tensor_parallel_layers: + self._setup_tensor_parallel() + + def __call__(self, **kwargs): + """Run the pipeline.""" + if self.parallel_mode == "data": + return self._data_parallel(**kwargs) + else: + return self.pipeline(**kwargs) + + def __getattr__(self, name: str): + """Delegate attribute access to wrapped pipeline.""" + if name in ("pipeline", "parallel_mode", "tensor_parallel_layers", "device_map", "_data_parallel"): + return object.__getattribute__(self, name) + return getattr(self.pipeline, name) + + def to(self, device): + """Move pipeline to device (respects device map if set).""" + if self.device_map: + # Already distributed, ignore + return self + self.pipeline.to(device) + return self + + @property + def device(self): + """Get the primary device of the pipeline.""" + if self.device_map: + # Return the device of the main model (usually the transformer) + for key in ["dit", "transformer", "unet"]: + if key in self.device_map: + return torch.device(self.device_map[key]) + # Return first device + return torch.device(list(self.device_map.values())[0]) + return get_device() + + +def enable_multi_gpu( + pipeline: nn.Module, + mode: str = "auto", + tensor_parallel_layers: Optional[List[str]] = None, +) -> MultiGPUPipeline: + """ + Enable multi-GPU support for a pipeline. + + Args: + pipeline: The diffusion pipeline + mode: Parallelism mode ("auto", "model", "tensor", "data", "hybrid") + tensor_parallel_layers: Layers to apply tensor parallelism to + + Returns: + MultiGPUPipeline wrapper + """ + num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + + if num_gpus <= 1: + # No multi-GPU needed + return MultiGPUPipeline(pipeline, parallel_mode="none") + + if mode == "auto": + # Auto-select based on model size and GPU count + total_params = sum(p.numel() for p in pipeline.parameters()) + + if total_params > 10e9: # > 10B parameters + # Large model - use tensor or model parallel + mode = "tensor" if num_gpus >= 4 else "model" + else: + # Smaller model - use data parallel for throughput + mode = "data" + + return MultiGPUPipeline( + pipeline, + parallel_mode=mode, + tensor_parallel_layers=tensor_parallel_layers, + ) + + +def setup_multi_gpu_training( + pipeline: nn.Module, + use_ddp: bool = True, + use_fsdp: bool = False, + mixed_precision: bool = True, +) -> nn.Module: + """ + Set up multi-GPU training for a pipeline. + + Args: + pipeline: The model/pipeline to train + use_ddp: Whether to use DistributedDataParallel + use_fsdp: Whether to use FullyShardedDataParallel + mixed_precision: Whether to use mixed precision + + Returns: + Wrapped model ready for distributed training + """ + import torch.distributed as dist + from torch.nn.parallel import DistributedDataParallel as DDP + + if not dist.is_initialized(): + init_distributed() + + rank = get_rank() + device = torch.device(f"cuda:{rank}") + + # Move to device + pipeline = pipeline.to(device) + + if use_fsdp: + try: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import MixedPrecision + + mp_policy = None + if mixed_precision: + mp_policy = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ) + + pipeline = FSDP( + pipeline, + mixed_precision=mp_policy, + device_id=rank, + ) + except ImportError: + print("FSDP not available, falling back to DDP") + use_ddp = True + use_fsdp = False + + if use_ddp and not use_fsdp: + pipeline = DDP( + pipeline, + device_ids=[rank], + output_device=rank, + find_unused_parameters=True, + ) + + return pipeline diff --git a/diffsynth/distributed/parallel.py b/diffsynth/distributed/parallel.py new file mode 100644 index 00000000..b80d5101 --- /dev/null +++ b/diffsynth/distributed/parallel.py @@ -0,0 +1,218 @@ +""" +Basic distributed utilities for multi-GPU support. + +Provides initialization and communication primitives. +""" + +import os +import torch +import torch.distributed as dist +from typing import Optional, List, Any + + +# Global state +_DISTRIBUTED_INITIALIZED = False +_RANK = 0 +_WORLD_SIZE = 1 +_LOCAL_RANK = 0 +_DEVICE = None + + +def get_backend_for_device(device_type: str) -> str: + """Get the appropriate distributed backend for the device type.""" + if device_type == "cuda": + return "nccl" + elif device_type == "npu": + return "hccl" + else: + return "gloo" + + +def init_distributed( + backend: Optional[str] = None, + init_method: str = "env://", + world_size: Optional[int] = None, + rank: Optional[int] = None, + local_rank: Optional[int] = None, + device_type: str = "cuda", +) -> bool: + """ + Initialize distributed process group. + + Args: + backend: Distributed backend ("nccl", "gloo", "hccl"). Auto-detected if None. + init_method: URL for process group initialization. + world_size: Total number of processes. Read from env if None. + rank: Global rank of this process. Read from env if None. + local_rank: Local rank on this node. Read from env if None. + device_type: Device type for computation ("cuda", "npu", "cpu"). + + Returns: + True if distributed is successfully initialized, False otherwise. + """ + global _DISTRIBUTED_INITIALIZED, _RANK, _WORLD_SIZE, _LOCAL_RANK, _DEVICE + + if _DISTRIBUTED_INITIALIZED: + return True + + # Check if we're in a distributed environment + if world_size is None: + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if rank is None: + rank = int(os.environ.get("RANK", 0)) + if local_rank is None: + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + + # Single GPU - no distributed needed + if world_size == 1: + _RANK = 0 + _WORLD_SIZE = 1 + _LOCAL_RANK = 0 + if device_type == "cuda" and torch.cuda.is_available(): + _DEVICE = torch.device("cuda:0") + elif device_type == "mps" and torch.backends.mps.is_available(): + _DEVICE = torch.device("mps") + else: + _DEVICE = torch.device("cpu") + return False + + # Multi-GPU - initialize distributed + if backend is None: + backend = get_backend_for_device(device_type) + + if not dist.is_initialized(): + dist.init_process_group( + backend=backend, + init_method=init_method, + world_size=world_size, + rank=rank, + ) + + _RANK = rank + _WORLD_SIZE = world_size + _LOCAL_RANK = local_rank + _DISTRIBUTED_INITIALIZED = True + + # Set device for this process + if device_type == "cuda": + torch.cuda.set_device(local_rank) + _DEVICE = torch.device(f"cuda:{local_rank}") + elif device_type == "npu": + import torch_npu + torch.npu.set_device(local_rank) + _DEVICE = torch.device(f"npu:{local_rank}") + else: + _DEVICE = torch.device("cpu") + + return True + + +def cleanup_distributed(): + """Clean up distributed process group.""" + global _DISTRIBUTED_INITIALIZED + if dist.is_initialized(): + dist.destroy_process_group() + _DISTRIBUTED_INITIALIZED = False + + +def get_rank() -> int: + """Get the global rank of this process.""" + if dist.is_initialized(): + return dist.get_rank() + return _RANK + + +def get_world_size() -> int: + """Get the total number of processes.""" + if dist.is_initialized(): + return dist.get_world_size() + return _WORLD_SIZE + + +def get_local_rank() -> int: + """Get the local rank on this node.""" + return _LOCAL_RANK + + +def get_device() -> torch.device: + """Get the device for this process.""" + return _DEVICE + + +def is_distributed() -> bool: + """Check if running in distributed mode.""" + return _DISTRIBUTED_INITIALIZED and _WORLD_SIZE > 1 + + +def is_main_process() -> bool: + """Check if this is the main process (rank 0).""" + return get_rank() == 0 + + +def barrier(): + """Synchronize all processes.""" + if is_distributed(): + dist.barrier() + + +def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast tensor from source rank to all ranks.""" + if is_distributed(): + dist.broadcast(tensor, src=src) + return tensor + + +def all_reduce( + tensor: torch.Tensor, + op: dist.ReduceOp = dist.ReduceOp.SUM, +) -> torch.Tensor: + """Reduce tensor across all ranks.""" + if is_distributed(): + dist.all_reduce(tensor, op=op) + return tensor + + +def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]: + """Gather tensors from all ranks.""" + if not is_distributed(): + return [tensor] + + world_size = get_world_size() + gathered = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gathered, tensor) + return gathered + + +def all_gather_into_tensor(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: + """Gather tensors from all ranks and concatenate along dim.""" + if not is_distributed(): + return tensor + + gathered = all_gather(tensor) + return torch.cat(gathered, dim=dim) + + +def reduce_scatter(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: + """Reduce and scatter tensor across ranks.""" + if not is_distributed(): + return tensor + + world_size = get_world_size() + rank = get_rank() + + # Split input tensor + chunks = tensor.chunk(world_size, dim=dim) + + # Create output tensor + output = torch.zeros_like(chunks[rank]) + + # Reduce-scatter + dist.reduce_scatter(output, list(chunks)) + + return output + + +def print_rank0(*args, **kwargs): + """Print only on rank 0.""" + if is_main_process(): + print(*args, **kwargs) diff --git a/diffsynth/distributed/tensor_parallel.py b/diffsynth/distributed/tensor_parallel.py new file mode 100644 index 00000000..8ab73735 --- /dev/null +++ b/diffsynth/distributed/tensor_parallel.py @@ -0,0 +1,516 @@ +""" +Tensor Parallelism for distributing large layers across GPUs. + +Implements column-parallel and row-parallel linear layers following +Megatron-LM style tensor parallelism. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from typing import Optional, Tuple + +from .parallel import get_rank, get_world_size, is_distributed + + +def split_tensor_along_dim( + tensor: torch.Tensor, + dim: int, + world_size: Optional[int] = None, + rank: Optional[int] = None, +) -> torch.Tensor: + """ + Split a tensor along a dimension and return the chunk for this rank. + + Args: + tensor: Input tensor to split + dim: Dimension to split along + world_size: Total number of ranks (auto-detected if None) + rank: This process rank (auto-detected if None) + + Returns: + The chunk of tensor for this rank + """ + if world_size is None: + world_size = get_world_size() + if rank is None: + rank = get_rank() + + if world_size == 1: + return tensor + + # Ensure tensor is divisible + size = tensor.size(dim) + assert size % world_size == 0, f"Tensor size {size} not divisible by world_size {world_size}" + + return tensor.chunk(world_size, dim=dim)[rank].contiguous() + + +def gather_tensor_along_dim( + tensor: torch.Tensor, + dim: int, + world_size: Optional[int] = None, +) -> torch.Tensor: + """ + Gather tensor chunks from all ranks along a dimension. + + Args: + tensor: Local tensor chunk + dim: Dimension to gather along + world_size: Total number of ranks (auto-detected if None) + + Returns: + Full gathered tensor + """ + if world_size is None: + world_size = get_world_size() + + if world_size == 1: + return tensor + + # Prepare gather list + gathered = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gathered, tensor) + + return torch.cat(gathered, dim=dim) + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Copy input to model parallel region (identity forward, all-reduce backward).""" + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + if is_distributed(): + dist.all_reduce(grad_output) + return grad_output + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """Reduce output from model parallel region (all-reduce forward, identity backward).""" + + @staticmethod + def forward(ctx, input_): + if is_distributed(): + dist.all_reduce(input_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather output from model parallel region.""" + + @staticmethod + def forward(ctx, input_, dim): + ctx.dim = dim + if not is_distributed(): + return input_ + return gather_tensor_along_dim(input_, dim) + + @staticmethod + def backward(ctx, grad_output): + if not is_distributed(): + return grad_output, None + return split_tensor_along_dim(grad_output, ctx.dim), None + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Scatter input to model parallel region.""" + + @staticmethod + def forward(ctx, input_, dim): + ctx.dim = dim + if not is_distributed(): + return input_ + return split_tensor_along_dim(input_, dim) + + @staticmethod + def backward(ctx, grad_output): + if not is_distributed(): + return grad_output, None + return gather_tensor_along_dim(grad_output, ctx.dim), None + + +def copy_to_tensor_parallel_region(input_: torch.Tensor) -> torch.Tensor: + """Copy input to tensor parallel region.""" + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_tensor_parallel_region(input_: torch.Tensor) -> torch.Tensor: + """All-reduce input from tensor parallel region.""" + return _ReduceFromModelParallelRegion.apply(input_) + + +def gather_from_tensor_parallel_region(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Gather input from tensor parallel region.""" + return _GatherFromModelParallelRegion.apply(input_, dim) + + +def scatter_to_tensor_parallel_region(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Scatter input to tensor parallel region.""" + return _ScatterToModelParallelRegion.apply(input_, dim) + + +class ColumnParallelLinear(nn.Module): + """ + Linear layer with column parallelism. + + The weight matrix is split along the output dimension (columns). + Y = XA where A is split column-wise across GPUs. + + Each GPU computes a portion of the output features. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + gather_output: bool = True, + init_method: Optional[callable] = None, + ): + """ + Args: + in_features: Input feature dimension + out_features: Total output feature dimension + bias: Whether to use bias + gather_output: Whether to gather output from all ranks + init_method: Weight initialization function + """ + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + + world_size = get_world_size() + assert out_features % world_size == 0, \ + f"out_features ({out_features}) must be divisible by world_size ({world_size})" + + self.out_features_per_partition = out_features // world_size + + # Create local weight + self.weight = nn.Parameter( + torch.empty(self.out_features_per_partition, in_features) + ) + + if bias: + self.bias = nn.Parameter(torch.empty(self.out_features_per_partition)) + else: + self.register_parameter("bias", None) + + # Initialize + self.reset_parameters(init_method) + + def reset_parameters(self, init_method: Optional[callable] = None): + """Initialize weights.""" + if init_method is not None: + init_method(self.weight) + else: + nn.init.kaiming_uniform_(self.weight, a=5**0.5) + + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + + Args: + input_: Input tensor of shape [..., in_features] + + Returns: + Output tensor of shape [..., out_features] if gather_output, + else [..., out_features_per_partition] + """ + # Copy input to parallel region + input_parallel = copy_to_tensor_parallel_region(input_) + + # Local linear + output_parallel = F.linear(input_parallel, self.weight, self.bias) + + # Gather if needed + if self.gather_output: + output = gather_from_tensor_parallel_region(output_parallel, dim=-1) + else: + output = output_parallel + + return output + + @classmethod + def from_linear( + cls, + linear: nn.Linear, + gather_output: bool = True, + ) -> "ColumnParallelLinear": + """ + Create ColumnParallelLinear from existing nn.Linear. + + Splits the weight matrix along output dimension. + """ + layer = cls( + in_features=linear.in_features, + out_features=linear.out_features, + bias=linear.bias is not None, + gather_output=gather_output, + ) + + # Split weights + rank = get_rank() + world_size = get_world_size() + weight_chunks = linear.weight.data.chunk(world_size, dim=0) + layer.weight.data.copy_(weight_chunks[rank]) + + if linear.bias is not None: + bias_chunks = linear.bias.data.chunk(world_size, dim=0) + layer.bias.data.copy_(bias_chunks[rank]) + + return layer + + +class RowParallelLinear(nn.Module): + """ + Linear layer with row parallelism. + + The weight matrix is split along the input dimension (rows). + Y = XA where A is split row-wise across GPUs. + + Each GPU receives a portion of the input features. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + input_is_parallel: bool = False, + init_method: Optional[callable] = None, + ): + """ + Args: + in_features: Total input feature dimension + out_features: Output feature dimension + bias: Whether to use bias + input_is_parallel: Whether input is already split across ranks + init_method: Weight initialization function + """ + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.input_is_parallel = input_is_parallel + + world_size = get_world_size() + assert in_features % world_size == 0, \ + f"in_features ({in_features}) must be divisible by world_size ({world_size})" + + self.in_features_per_partition = in_features // world_size + + # Create local weight + self.weight = nn.Parameter( + torch.empty(out_features, self.in_features_per_partition) + ) + + # Bias is NOT split - only rank 0 has full bias + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + + # Initialize + self.reset_parameters(init_method) + + def reset_parameters(self, init_method: Optional[callable] = None): + """Initialize weights.""" + if init_method is not None: + init_method(self.weight) + else: + nn.init.kaiming_uniform_(self.weight, a=5**0.5) + + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + + Args: + input_: Input tensor of shape [..., in_features] or + [..., in_features_per_partition] if input_is_parallel + + Returns: + Output tensor of shape [..., out_features] + """ + # Split input if not already parallel + if self.input_is_parallel: + input_parallel = input_ + else: + input_parallel = scatter_to_tensor_parallel_region(input_, dim=-1) + + # Local linear (no bias yet) + output_parallel = F.linear(input_parallel, self.weight) + + # All-reduce across ranks + output = reduce_from_tensor_parallel_region(output_parallel) + + # Add bias (only on full output after reduce) + if self.bias is not None: + output = output + self.bias + + return output + + @classmethod + def from_linear( + cls, + linear: nn.Linear, + input_is_parallel: bool = False, + ) -> "RowParallelLinear": + """ + Create RowParallelLinear from existing nn.Linear. + + Splits the weight matrix along input dimension. + """ + layer = cls( + in_features=linear.in_features, + out_features=linear.out_features, + bias=linear.bias is not None, + input_is_parallel=input_is_parallel, + ) + + # Split weights + rank = get_rank() + world_size = get_world_size() + weight_chunks = linear.weight.data.chunk(world_size, dim=1) + layer.weight.data.copy_(weight_chunks[rank]) + + # Bias is not split + if linear.bias is not None: + layer.bias.data.copy_(linear.bias.data) + + return layer + + +class TensorParallelLinear(nn.Module): + """ + Wrapper that automatically chooses Column or Row parallel based on dimensions. + + For expanding layers (out > in): Use ColumnParallel + For contracting layers (in > out): Use RowParallel + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + parallel_mode: Optional[str] = None, + ): + """ + Args: + in_features: Input feature dimension + out_features: Output feature dimension + bias: Whether to use bias + parallel_mode: "column", "row", or None for auto + """ + super().__init__() + + if parallel_mode is None: + # Auto-select based on dimensions + parallel_mode = "column" if out_features >= in_features else "row" + + self.parallel_mode = parallel_mode + + if parallel_mode == "column": + self.linear = ColumnParallelLinear( + in_features, out_features, bias=bias, gather_output=True + ) + else: + self.linear = RowParallelLinear( + in_features, out_features, bias=bias, input_is_parallel=False + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return self.linear(input_) + + @classmethod + def from_linear( + cls, + linear: nn.Linear, + parallel_mode: Optional[str] = None, + ) -> "TensorParallelLinear": + """Create TensorParallelLinear from existing nn.Linear.""" + layer = cls.__new__(cls) + nn.Module.__init__(layer) + + if parallel_mode is None: + parallel_mode = "column" if linear.out_features >= linear.in_features else "row" + + layer.parallel_mode = parallel_mode + + if parallel_mode == "column": + layer.linear = ColumnParallelLinear.from_linear(linear, gather_output=True) + else: + layer.linear = RowParallelLinear.from_linear(linear, input_is_parallel=False) + + return layer + + +def apply_tensor_parallelism( + module: nn.Module, + tp_layers: Optional[list] = None, + min_features: int = 1024, +) -> nn.Module: + """ + Apply tensor parallelism to a module by replacing Linear layers. + + Args: + module: Module to parallelize + tp_layers: List of layer names to parallelize. If None, auto-detect large layers. + min_features: Minimum feature size to consider for parallelism + + Returns: + Module with tensor parallel layers + """ + world_size = get_world_size() + if world_size == 1: + return module + + def should_parallelize(name: str, layer: nn.Module) -> bool: + if not isinstance(layer, nn.Linear): + return False + if tp_layers is not None: + return any(tp_name in name for tp_name in tp_layers) + # Auto-detect: parallelize large layers + return ( + layer.in_features >= min_features or + layer.out_features >= min_features + ) and ( + layer.in_features % world_size == 0 and + layer.out_features % world_size == 0 + ) + + def replace_linear(parent: nn.Module, name: str, layer: nn.Linear): + tp_layer = TensorParallelLinear.from_linear(layer) + setattr(parent, name, tp_layer) + + # Traverse and replace + for name, child in list(module.named_modules()): + if should_parallelize(name, child): + # Find parent + parts = name.rsplit(".", 1) + if len(parts) == 1: + parent = module + child_name = name + else: + parent = module.get_submodule(parts[0]) + child_name = parts[1] + replace_linear(parent, child_name, child) + + return module diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py index be2ee587..70eec5b2 100644 --- a/diffsynth/models/dinov3_image_encoder.py +++ b/diffsynth/models/dinov3_image_encoder.py @@ -70,7 +70,10 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=None): + # Use model's device if not specified + if device is None: + device = next(self.parameters()).device inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device) bool_masked_pos = None diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 10184f85..dbc3a0e1 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -47,7 +47,10 @@ def __init__(self): } ) - def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"): + def forward(self, image, torch_dtype=torch.bfloat16, device=None): + # Use model's device if not specified + if device is None: + device = next(self.parameters()).device pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = pixel_values.to(device=device, dtype=torch_dtype) output_attentions = False diff --git a/examples/multi_gpu_inference.py b/examples/multi_gpu_inference.py new file mode 100644 index 00000000..03a1ee54 --- /dev/null +++ b/examples/multi_gpu_inference.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python +""" +Multi-GPU Inference Example for DiffSynth-Studio + +This script demonstrates how to use multiple GPUs for inference. + +Usage: + # Model Parallel (distribute models across GPUs): + python multi_gpu_inference.py --mode model --prompt "a beautiful sunset" + + # Data Parallel (same model on all GPUs, batch processing): + torchrun --nproc_per_node=2 multi_gpu_inference.py --mode data --batch_size 4 + + # Tensor Parallel (split layers across GPUs): + torchrun --nproc_per_node=2 multi_gpu_inference.py --mode tensor +""" + +import argparse +import torch +import os +from pathlib import Path + +# Add parent directory to path +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig + + +def check_multi_gpu(): + """Check available GPUs and print info.""" + if not torch.cuda.is_available(): + print("CUDA is not available. Multi-GPU requires CUDA.") + return 0 + + num_gpus = torch.cuda.device_count() + print(f"\n{'='*60}") + print(f"Multi-GPU Configuration") + print(f"{'='*60}") + print(f"Available GPUs: {num_gpus}") + + for i in range(num_gpus): + props = torch.cuda.get_device_properties(i) + free, total = torch.cuda.mem_get_info(i) + print(f" GPU {i}: {props.name}") + print(f" Memory: {free/1024**3:.1f} GB free / {total/1024**3:.1f} GB total") + + print(f"{'='*60}\n") + return num_gpus + + +def run_model_parallel(args): + """ + Model Parallel: Distribute different model components to different GPUs. + + Best for: Large models that don't fit on a single GPU. + """ + print("Running Model Parallel inference...") + + # Load pipeline + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda:0", # Initial device + model_configs=[ + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", + ), + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="text_encoder/model*.safetensors", + ), + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="vae/diffusion_pytorch_model.safetensors", + ), + ], + tokenizer_config=ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="tokenizer/", + ), + processor_config=ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="processor/", + ), + ) + + # Enable multi-GPU with model parallelism + # This will distribute dit, text_encoder, vae across available GPUs + pipe.enable_multi_gpu(mode="model") + + # Print model distribution + pipe.print_model_distribution() + + # Generate image + image = pipe( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + width=args.width, + height=args.height, + num_inference_steps=args.steps, + ) + + # Save output + output_path = Path(args.output_dir) / "model_parallel_output.png" + output_path.parent.mkdir(parents=True, exist_ok=True) + image.save(str(output_path)) + print(f"Saved to: {output_path}") + + +def run_data_parallel(args): + """ + Data Parallel: Same model on all GPUs, process different data in parallel. + + Best for: Batch inference to maximize throughput. + Must be launched with torchrun. + """ + from diffsynth.distributed import ( + init_distributed, + get_rank, + get_world_size, + is_main_process, + DataParallelPipeline, + scatter_batch, + gather_outputs, + ) + + # Initialize distributed + init_distributed() + rank = get_rank() + world_size = get_world_size() + local_device = f"cuda:{rank}" + + print(f"[Rank {rank}/{world_size}] Running Data Parallel inference on {local_device}...") + + # Load pipeline on local device + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=local_device, + model_configs=[ + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", + ), + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="text_encoder/model*.safetensors", + ), + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="vae/diffusion_pytorch_model.safetensors", + ), + ], + tokenizer_config=ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="tokenizer/", + ), + processor_config=ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="processor/", + ), + ) + + # Create batch of prompts + prompts = [f"{args.prompt} - variation {i+1}" for i in range(args.batch_size)] + + # Scatter prompts across ranks + local_prompts = scatter_batch(prompts) + print(f"[Rank {rank}] Processing {len(local_prompts)} prompts: {local_prompts}") + + # Generate images locally + local_images = [] + for prompt in local_prompts: + image = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + width=args.width, + height=args.height, + num_inference_steps=args.steps, + ) + local_images.append(image) + + # Gather results on main process + all_images = gather_outputs(local_images) + + # Save on main process + if is_main_process(): + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, img in enumerate(all_images): + output_path = output_dir / f"data_parallel_output_{i}.png" + img.save(str(output_path)) + print(f"Saved to: {output_path}") + + +def run_tensor_parallel(args): + """ + Tensor Parallel: Split large layers across GPUs. + + Best for: Very large models where even single components don't fit on one GPU. + Must be launched with torchrun. + """ + from diffsynth.distributed import ( + init_distributed, + get_rank, + get_world_size, + is_main_process, + apply_tensor_parallelism, + ) + + # Initialize distributed + init_distributed() + rank = get_rank() + world_size = get_world_size() + local_device = f"cuda:{rank}" + + print(f"[Rank {rank}/{world_size}] Running Tensor Parallel inference...") + + # Load pipeline + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=local_device, + model_configs=[ + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", + ), + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="text_encoder/model*.safetensors", + ), + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="vae/diffusion_pytorch_model.safetensors", + ), + ], + tokenizer_config=ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="tokenizer/", + ), + processor_config=ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="processor/", + ), + ) + + # Apply tensor parallelism to large linear layers in the transformer + if hasattr(pipe, 'dit') and pipe.dit is not None: + apply_tensor_parallelism( + pipe.dit, + tp_layers=["linear", "proj", "mlp"], # Layer name patterns to parallelize + min_features=4096, # Only parallelize layers with >= 4096 features + ) + print(f"[Rank {rank}] Applied tensor parallelism to DiT") + + # Generate image + image = pipe( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + width=args.width, + height=args.height, + num_inference_steps=args.steps, + ) + + # Save on main process + if is_main_process(): + output_path = Path(args.output_dir) / "tensor_parallel_output.png" + output_path.parent.mkdir(parents=True, exist_ok=True) + image.save(str(output_path)) + print(f"Saved to: {output_path}") + + +def run_custom_device_map(args): + """ + Custom Device Map: Manually assign models to specific GPUs. + + Useful for fine-grained control over memory distribution. + """ + print("Running with custom device map...") + + # Define custom device map + device_map = { + "dit": "cuda:0", # DiT (largest) on GPU 0 + "text_encoder": "cuda:1", # Text encoder on GPU 1 + "vae": "cuda:1", # VAE on GPU 1 (smaller) + } + + # Check if we have enough GPUs + num_gpus = torch.cuda.device_count() + if num_gpus < 2: + print(f"Warning: Custom device map requires 2 GPUs, but only {num_gpus} available.") + print("Falling back to single GPU...") + device_map = {k: "cuda:0" for k in device_map} + + # Load pipeline + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda:0", # Initial device + model_configs=[ + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", + ), + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="text_encoder/model*.safetensors", + ), + ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="vae/diffusion_pytorch_model.safetensors", + ), + ], + tokenizer_config=ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="tokenizer/", + ), + processor_config=ModelConfig( + download_source="huggingface", + model_id="Qwen/Qwen-Image", + origin_file_pattern="processor/", + ), + ) + + # Apply custom device map + pipe.enable_multi_gpu(mode="model", device_map=device_map) + + # Print distribution + pipe.print_model_distribution() + + # Generate image + image = pipe( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + width=args.width, + height=args.height, + num_inference_steps=args.steps, + ) + + # Save output + output_path = Path(args.output_dir) / "custom_device_map_output.png" + output_path.parent.mkdir(parents=True, exist_ok=True) + image.save(str(output_path)) + print(f"Saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Multi-GPU inference example") + + parser.add_argument( + "--mode", + type=str, + default="model", + choices=["model", "data", "tensor", "custom"], + help="Parallelism mode: model, data, tensor, or custom", + ) + parser.add_argument( + "--prompt", + type=str, + default="a beautiful sunset over the ocean", + help="Text prompt for image generation", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="blurry, ugly, bad quality", + help="Negative prompt", + ) + parser.add_argument( + "--width", + type=int, + default=1024, + help="Image width", + ) + parser.add_argument( + "--height", + type=int, + default=1024, + help="Image height", + ) + parser.add_argument( + "--steps", + type=int, + default=25, + help="Number of inference steps", + ) + parser.add_argument( + "--batch_size", + type=int, + default=4, + help="Batch size for data parallel mode", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Output directory", + ) + + args = parser.parse_args() + + # Check GPUs + num_gpus = check_multi_gpu() + + if num_gpus == 0: + print("No GPUs available. Exiting.") + return + + # Run selected mode + if args.mode == "model": + run_model_parallel(args) + elif args.mode == "data": + run_data_parallel(args) + elif args.mode == "tensor": + run_tensor_parallel(args) + elif args.mode == "custom": + run_custom_device_map(args) + + +if __name__ == "__main__": + main()