Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions diffsynth/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from .core import *

# Distributed/Multi-GPU support (optional import)
try:
from . import distributed
except ImportError:
pass
6 changes: 5 additions & 1 deletion diffsynth/core/vram/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def cast_to(self, weight, dtype, device):

def check_free_vram(self):
device = self.computation_device if self.computation_device != "npu" else "npu:0"
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
device_module = getattr(torch, self.computation_device_type, None)
# Only CUDA and NPU have mem_get_info, for MPS/CPU assume enough memory
if device_module is None or not hasattr(device_module, "mem_get_info"):
return True
gpu_mem_state = device_module.mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit

Expand Down
144 changes: 142 additions & 2 deletions diffsynth/diffusion/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,29 @@
import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from typing import Union, Optional, Dict, List
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput


def get_available_gpus() -> List[int]:
"""Get list of available GPU device IDs."""
if torch.cuda.is_available():
return list(range(torch.cuda.device_count()))
return []


def get_gpu_memory_map() -> Dict[int, float]:
"""Get available memory (in GB) for each GPU."""
memory_map = {}
for gpu_id in get_available_gpus():
free, total = torch.cuda.mem_get_info(gpu_id)
memory_map[gpu_id] = free / (1024 ** 3)
return memory_map


class PipelineUnit:
def __init__(
self,
Expand Down Expand Up @@ -155,7 +171,10 @@ def load_models_to_device(self, model_names):
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
getattr(torch, self.device_type).empty_cache()
# Clear cache if available (only CUDA has empty_cache)
device_module = getattr(torch, self.device_type, None)
if device_module is not None and hasattr(device_module, "empty_cache"):
device_module.empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
Expand Down Expand Up @@ -301,6 +320,127 @@ def check_vram_management_state(self):
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
vram_management_enabled = True
return vram_management_enabled


def enable_multi_gpu(
self,
mode: str = "auto",
device_map: Optional[Dict[str, str]] = None,
tensor_parallel_layers: Optional[List[str]] = None,
):
Comment on lines +325 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The enable_multi_gpu function has some inconsistencies that could confuse users:

  1. The imports MultiGPUPipeline and enable_multi_gpu from ..distributed on line 347 are unused within this function.
  2. The function's docstring states it supports auto, model, tensor, and data modes. However, the implementation only contains logic for mode="model". Other modes will silently do nothing.

This discrepancy between documentation and behavior can lead to unexpected results. Please consider either:

  • Removing the unused imports and updating the docstring to clarify that this method only handles model parallelism.
  • Refactoring this method to correctly handle all documented modes, possibly by delegating to the more comprehensive logic in diffsynth/distributed/multi_gpu.py.

"""
Enable multi-GPU support for this pipeline.

Args:
mode: Parallelism mode:
- "auto": Automatically select best strategy
- "model": Distribute different models to different GPUs
- "tensor": Split large layers across GPUs (requires torchrun)
- "data": Same model on all GPUs, different data (for batch processing)
device_map: Manual device mapping, e.g., {"dit": "cuda:0", "text_encoder": "cuda:1"}
tensor_parallel_layers: Layer names to apply tensor parallelism to

Returns:
self for method chaining
"""
try:
from ..distributed import MultiGPUPipeline, enable_multi_gpu
except ImportError:
print("Warning: Distributed module not available. Multi-GPU support disabled.")
return self

num_gpus = len(get_available_gpus())
if num_gpus <= 1:
print("Only one GPU available. Multi-GPU support not needed.")
return self

print(f"Enabling multi-GPU support with {num_gpus} GPUs, mode={mode}")

if mode == "model" and device_map is None:
# Auto-create device map for model parallel
device_map = self._auto_create_device_map()

if device_map is not None:
self._apply_device_map(device_map)
print(f"Applied device map: {device_map}")

return self


def _auto_create_device_map(self) -> Dict[str, str]:
"""
Automatically create a device map for model parallelism.

Distributes model components across available GPUs based on their size.
"""
num_gpus = len(get_available_gpus())
if num_gpus <= 1:
return {}

# Get model components and their sizes
components = {}
for name, module in self.named_children():
if module is not None and isinstance(module, torch.nn.Module):
num_params = sum(p.numel() for p in module.parameters())
if num_params > 0:
components[name] = num_params

if not components:
return {}

# Sort by size (largest first) and assign to GPUs
sorted_components = sorted(components.items(), key=lambda x: -x[1])
device_map = {}
gpu_loads = [0] * num_gpus

for name, size in sorted_components:
# Assign to GPU with least load
min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
device_map[name] = f"cuda:{min_gpu}"
gpu_loads[min_gpu] += size

return device_map


def _apply_device_map(self, device_map: Dict[str, str]):
"""
Apply a device map to distribute models across GPUs.

Args:
device_map: Mapping of model names to devices
"""
for name, device in device_map.items():
if hasattr(self, name):
module = getattr(self, name)
if module is not None:
module.to(device)
print(f" Moved {name} to {device}")


def get_model_distribution(self) -> Dict[str, str]:
"""
Get the current distribution of models across devices.

Returns:
Dictionary mapping model names to their current devices
"""
distribution = {}
for name, module in self.named_children():
if module is not None and isinstance(module, torch.nn.Module):
try:
device = next(module.parameters()).device
distribution[name] = str(device)
except StopIteration:
distribution[name] = "no parameters"
return distribution


def print_model_distribution(self):
"""Print the current distribution of models across devices."""
distribution = self.get_model_distribution()
print("Model distribution:")
for name, device in distribution.items():
print(f" {name}: {device}")


def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
Expand Down
70 changes: 70 additions & 0 deletions diffsynth/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Distributed computing support for DiffSynth-Studio.

Provides multi-GPU support through:
- Data Parallel (DP): Batch-level parallelism
- Tensor Parallel (TP): Layer-level parallelism
- Pipeline utilities for distributed inference
"""

from .parallel import (
init_distributed,
cleanup_distributed,
get_rank,
get_world_size,
is_distributed,
is_main_process,
barrier,
broadcast,
all_reduce,
all_gather,
)

from .tensor_parallel import (
TensorParallelLinear,
ColumnParallelLinear,
RowParallelLinear,
split_tensor_along_dim,
gather_tensor_along_dim,
)

from .data_parallel import (
DataParallelPipeline,
scatter_batch,
gather_outputs,
)

from .multi_gpu import (
MultiGPUPipeline,
auto_device_map,
get_optimal_device_map,
)

__all__ = [
# Initialization
"init_distributed",
"cleanup_distributed",
"get_rank",
"get_world_size",
"is_distributed",
"is_main_process",
# Communication
"barrier",
"broadcast",
"all_reduce",
"all_gather",
# Tensor Parallel
"TensorParallelLinear",
"ColumnParallelLinear",
"RowParallelLinear",
"split_tensor_along_dim",
"gather_tensor_along_dim",
# Data Parallel
"DataParallelPipeline",
"scatter_batch",
"gather_outputs",
# Multi-GPU
"MultiGPUPipeline",
"auto_device_map",
"get_optimal_device_map",
]
Loading