Skip to content

Commit ad9f204

Browse files
committed
Add MLU Support.
1 parent 04f9d2b commit ad9f204

File tree

4 files changed

+9
-0
lines changed

4 files changed

+9
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
is_accelerate_version,
6363
is_hpu_available,
6464
is_torch_npu_available,
65+
is_mlu_available,
6566
is_torch_version,
6667
is_transformers_version,
6768
logging,

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
is_google_colab,
8282
is_hf_hub_version,
8383
is_hpu_available,
84+
is_mlu_available,
8485
is_inflect_available,
8586
is_invisible_watermark_available,
8687
is_k_diffusion_available,

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
192192

193193
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
194194
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
195+
_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
195196
_transformers_available, _transformers_version = _is_package_available("transformers")
196197
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
197198
_kernels_available, _kernels_version = _is_package_available("kernels")
@@ -243,6 +244,10 @@ def is_torch_npu_available():
243244
return _torch_npu_available
244245

245246

247+
def is_mlu_available():
248+
return _torch_mlu_available
249+
250+
246251
def is_flax_available():
247252
return _flax_available
248253

src/diffusers/utils/torch_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ def get_device():
286286
return "xpu"
287287
elif torch.backends.mps.is_available():
288288
return "mps"
289+
elif torch.mlu.is_available():
290+
return "mlu"
289291
else:
290292
return "cpu"
291293

0 commit comments

Comments
 (0)