Commit 411e388
Add HF Auth mixin to Stable Diffusion (#1763)
Summary:
Right now stale diffusion and lit-llama are not actually running in CI because they get rate limited by huggingface. since we've now added an auth token as a github secret we can move stable diffusion out of canary and do things like include it in blueberries dashboard
We also added some nice errors so people running in torchbench locally know they will need to have a token to run these models
Anyways auth is a mixin which seems like the right abstraction
# Some relevant details about the model
Torchbench has a function `get_module()` that has the intent of testing a `nn.Module` on an actual `torch.Tensor`
Unfortunately a `StableDiffusionPipeline` is not an `nn.Module` it's a composition of a tokenizer and 3 seperate `nn.Modules` an encoder, vae and unet.
## text_encoder
```python
def get_module(self):
batch_size = 1
sequence_length = 10
vocab_size = 32000
# Generate random indices within the valid range
input_tensor = torch.randint(low=0, high=vocab_size, size=(batch_size, sequence_length))
# Make sure the tensor has the correct data type
input_tensor = input_tensor.long()
print(self.pipe.text_encoder(input_tensor))
return self.pipe.text_encoder, input_tensor
```
Text encoder outputs a `BaseModelOutputWithPooling` which has multiple nn modules https://gist.github.com/msaroufim/51f0038863c5cce4cc3045e4d9f9c399
```
======================================================================
FAIL: test_stable_diffusion_example_cuda (__main__.TestBenchmark)
----------------------------------------------------------------------
components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last):
File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 482, in _run_block
exec( # noqa: P204
File "<subprocess-worker>", line 35, in <module>
File "<subprocess-worker>", line 12, in _run_in_worker_f
File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 26, in __call__
obj.__post__init__()
File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 126, in __post__init__
self.accuracy = check_accuracy(self)
File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 469, in check_accuracy
model, example_inputs = maybe_cast(tbmodel, model, example_inputs)
File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 424, in maybe_cast
example_inputs = clone_inputs(example_inputs)
File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 297, in clone_inputs
assert isinstance(value, torch.Tensor)
AssertionError
```
## vae
```python
def get_module(self):
print(self.pipe.vae(torch.randn(9,3,9,9)))
```
Same problem for vae
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L27
## unet
```python
def get_module(self):
# This will only benchmark the unet since that's the biggest layer
# Stable diffusion is a composition of a text encoder, unet and vae
encoder_hidden_states = torch.randn(320, 1024)
sample = torch.randn(4, 4, 4, 32)
timestep = 5
inputs_to_pipe = {'timestep': timestep, 'encoder_hidden_states': encoder_hidden_states, 'sample': sample}
result = self.pipe.unet(**inputs_to_pipe)
return self.pipe, inputs_to_pipe
```
Unet unfortunately does not have a tensor input
For VAE and encoder the test failure is particularly helpful
```
(sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_stable_diffusion_example_cuda"
F
======================================================================
FAIL: test_stable_diffusion_example_cuda (__main__.TestBenchmark)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/ubuntu/benchmark/test.py", line 75, in example_fn
assert accuracy == "pass" or accuracy == "eager_1st_run_OOM", f"Expected accuracy pass, get {accuracy}"
AssertionError: Expected accuracy pass, get eager_1st_run_fail
----------------------------------------------------------------------
Ran 1 test in 7.402s
FAILED (failures=1)
```
Pull Request resolved: #1763
Reviewed By: xuzhao9
Differential Revision: D47565523
Pulled By: msaroufim
fbshipit-source-id: c949ce8a31c0a4706658937fc6603a22a4bc3ec61 parent 09de70c commit 411e388
File tree
8 files changed
+55
-20
lines changed- .github/workflows
- torchbenchmark
- canary_models/stable_diffusion
- models/stable_diffusion
- util
- framework/huggingface
8 files changed
+55
-20
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
| 14 | + | |
13 | 15 | | |
14 | 16 | | |
15 | 17 | | |
| |||
36 | 38 | | |
37 | 39 | | |
38 | 40 | | |
39 | | - | |
40 | | - | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
41 | 44 | | |
42 | 45 | | |
43 | 46 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
| |||
This file was deleted.
Lines changed: 16 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
| 8 | + | |
8 | 9 | | |
9 | 10 | | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
13 | | - | |
| 14 | + | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
| |||
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
22 | | - | |
23 | 23 | | |
| 24 | + | |
24 | 25 | | |
25 | 26 | | |
26 | | - | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
| 29 | + | |
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
| 36 | + | |
36 | 37 | | |
37 | | - | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
38 | 49 | | |
39 | 50 | | |
40 | 51 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
Lines changed: 5 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
159 | 159 | | |
160 | 160 | | |
161 | 161 | | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
162 | 167 | | |
163 | 168 | | |
164 | 169 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
23 | | - | |
| 23 | + | |
0 commit comments