Commit 745644f
FIX SAM for bfloat16 (#1764)
Summary:
Ok this was kinda annoying
Basically the SAM codebase had a few places where it hardcodes `torch.float32` such that even if you convert the model to `torch.bfloat16` a few parts of the model won't be and will have type mismatch errors - this fixes the problem cpuhrsch desertfire - idk enough about floats and why there isn't some type promotion rule for bfloat16
I wonder whether we should add tests for multiple dtypes in torchbench to make checking for this kind of issue more robust especially now that bfloat16 seems to be the default for dynamo xuzhao9
## Logs
```
FAILED (errors=1)
(sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_sam_eval_cuda"
E
======================================================================
ERROR: test_sam_eval_cuda (__main__.TestBenchmark)
----------------------------------------------------------------------
components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last):
File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 482, in _run_block
exec( # noqa: P204
File "<subprocess-worker>", line 2, in <module>
File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 280, in invoke
out = self.eval()
File "/home/ubuntu/benchmark/torchbenchmark/models/sam/__init__.py", line 65, in eval
masks, scores, logits = predictor.predict(
File "/home/ubuntu/benchmark/torchbenchmark/models/sam/predictor.py", line 164, in predict
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
TypeError: Got unsupported ScalarType BFloat16
working_dir: /tmp/tmpg5de41du
stdout:
[2023-07-13] 01:57:38.499061: TIMER_SUBPROCESS_BEGIN_EXEC
[2023-07-13] 01:57:39.002078: TIMER_SUBPROCESS_FAILED
[2023-07-13] 01:57:39.002141: TIMER_SUBPROCESS_FINISHED
[2023-07-13] 01:57:39.002153: TIMER_SUBPROCESS_BEGIN_READ
stderr:
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ubuntu/benchmark/test.py", line 104, in eval_fn
task.invoke()
File "/home/ubuntu/benchmark/torchbenchmark/__init__.py", line 402, in invoke
self.worker.run("""
File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_worker.py", line 155, in run
self._run(snippet)
File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_worker.py", line 320, in _run
subprocess_rpc.SerializedException.raise_from(
File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 458, in raise_from
raise e from ChildTraceException(traceback_str)
TypeError: Got unsupported ScalarType BFloat16
----------------------------------------------------------------------
Ran 1 test in 7.814s
FAILED (errors=1)
(sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_sam_eval_cuda"
.
----------------------------------------------------------------------
Ran 1 test in 8.315s
OK
```
Pull Request resolved: #1764
Reviewed By: drisspg, cpuhrsch
Differential Revision: D47441873
Pulled By: msaroufim
fbshipit-source-id: a60880fd7c0826cfd469ace39d76894469ca0e5e1 parent 2ea018e commit 745644f
File tree
4 files changed
+8
-3
lines changed- torchbenchmark/models/sam
4 files changed
+8
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
46 | | - | |
47 | 46 | | |
48 | 47 | | |
49 | 48 | | |
| |||
57 | 56 | | |
58 | 57 | | |
59 | 58 | | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
60 | 62 | | |
61 | 63 | | |
62 | 64 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
129 | 129 | | |
130 | 130 | | |
131 | 131 | | |
| 132 | + | |
132 | 133 | | |
133 | 134 | | |
134 | 135 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
160 | 160 | | |
161 | 161 | | |
162 | 162 | | |
163 | | - | |
164 | | - | |
| 163 | + | |
| 164 | + | |
165 | 165 | | |
166 | 166 | | |
167 | 167 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
186 | 186 | | |
187 | 187 | | |
188 | 188 | | |
| 189 | + | |
| 190 | + | |
189 | 191 | | |
190 | 192 | | |
191 | 193 | | |
| |||
0 commit comments