Skip to content

Commit aecf0c5

Browse files
Add MLU Support. (#12629)
* Add MLU Support. * fix comment. * rename is_mlu_available to is_torch_mlu_available * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 0c75892 commit aecf0c5

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
is_tensorboard_available,
109109
is_timm_available,
110110
is_torch_available,
111+
is_torch_mlu_available,
111112
is_torch_npu_available,
112113
is_torch_version,
113114
is_torch_xla_available,

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
192192

193193
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
194194
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
195+
_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
195196
_transformers_available, _transformers_version = _is_package_available("transformers")
196197
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
197198
_kernels_available, _kernels_version = _is_package_available("kernels")
@@ -243,6 +244,10 @@ def is_torch_npu_available():
243244
return _torch_npu_available
244245

245246

247+
def is_torch_mlu_available():
248+
return _torch_mlu_available
249+
250+
246251
def is_flax_available():
247252
return _flax_available
248253

src/diffusers/utils/torch_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Callable, Dict, List, Optional, Tuple, Union
2121

2222
from . import logging
23-
from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version
23+
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
2424

2525

2626
if is_torch_available():
@@ -286,6 +286,8 @@ def get_device():
286286
return "xpu"
287287
elif torch.backends.mps.is_available():
288288
return "mps"
289+
elif is_torch_mlu_available():
290+
return "mlu"
289291
else:
290292
return "cpu"
291293

0 commit comments

Comments
 (0)