diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cf77aaee8205..6884d3be9292 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -108,6 +108,7 @@ is_tensorboard_available, is_timm_available, is_torch_available, + is_torch_mlu_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index adf8ed8b0694..57b0a337922a 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") +_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _kernels_available, _kernels_version = _is_package_available("kernels") @@ -243,6 +244,10 @@ def is_torch_npu_available(): return _torch_npu_available +def is_torch_mlu_available(): + return _torch_mlu_available + + def is_flax_available(): return _flax_available diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index a1ab8cda431f..f760a1bf7261 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union from . import logging -from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version +from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version if is_torch_available(): @@ -286,6 +286,8 @@ def get_device(): return "xpu" elif torch.backends.mps.is_available(): return "mps" + elif is_torch_mlu_available(): + return "mlu" else: return "cpu"