Conversation
|
I think this works well |
|
perhaps something like this to preserve float precision on cuda? diff --git a/point_e/util/precision_compatibility.py b/point_e/util/precision_compatibility.py
new file mode 100644
--- /dev/null
+++ b/point_e/util/precision_compatibility.py
@@ -0,0 +1,5 @@
+import torch
+import numpy as np
+
+NP_FLOAT32_64 = np.float32 if torch.backends.mps.is_available() else np.float64
+TH_FLOAT32_64 = torch.float32 if torch.backends.mps.is_available() else torch.float64
\ No newline at end of filediff --git a/point_e/diffusion/gaussian_diffusion.py b/point_e/diffusion/gaussian_diffusion.py
--- point_e/diffusion/gaussian_diffusion.py
+++ point_e/diffusion/gaussian_diffusion.py
@@ -6,8 +6,9 @@
from typing import Any, Dict, Iterable, Optional, Sequence, Union
import numpy as np
import torch as th
+from point_e.util.precision_compatibility import NP_FLOAT32_64, TH_FLOAT32_64
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
@@ -15,9 +16,9 @@
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "linear":
- betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float32)
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=NP_FLOAT32_64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
@@ -159,9 +160,9 @@
self.channel_scales = channel_scales
self.channel_biases = channel_biases
# originally uses float64 for accuracy, moving to float32 for mps compatibility
- betas = np.array(betas, dtype=np.float32)
+ betas = np.array(betas, dtype=NP_FLOAT32_64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
@@ -1012,9 +1013,9 @@
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
- res = th.from_numpy(arr).to(dtype=th.float32, device=timesteps.device)[timesteps].to(th.float32)
+ res = th.from_numpy(arr).to(dtype=TH_FLOAT32_64, device=timesteps.device)[timesteps].to(TH_FLOAT32_64)
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
|
I love it! |
Co-Authored-By: henrycunh <henrycunh@gmail.com>
|
@henrycunh Added! |
|
Tried now on a macbook air M2. It worked very well, for reference:
Only problem is the actual implementation of pytorch for MPS, that get this: |
|
I apologize for my question, but how noticeable is the change to float32? |
I'm pretty confident that using higher precision, like |
|
Could we set that as a parameter that defaults to 64 but write another paramter that is 32? |
|
^ agree |
This PR introduces Metal GPU support, at the cost of slightly lowering accuracy on the gaussian_diffusion step (changing
float64tofloat32, only when running on mps).