Commit 65c4e40
Cross self attention switch (#251)
* skip flash block sizes setting for cross attention.
* change sharding based on cross/self attention.
* update sharding rules for attn.
* lint.
* ring attention rules are added at front if not present to shard sequence on fsdp axis
* test fix
* Add dense padded attention kernel and use unsafe rng key for generation
* Update
* Ignore history
* remove file
* Flag for using segment ids and masking padding tokens in attention
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Tokamax splash attn
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Flag for using same sequence sharding for self and cross
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* update requirements.txt
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Delete splash_attn_benchmark.py
* Delete padded_flash_attn.py
* Merge main
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Ruff format
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Ruff format
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Ruff format
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Address comments
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Address comments
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Address comments
Signed-off-by: Kunjan Patel <kunjanp@google.com>
* Fix pprint error, add description of attention configuration params
* Fix pprint error, add description of attention configuration params
* Fix pprint error, add description of attention configuration params
---------
Signed-off-by: Kunjan Patel <kunjanp@google.com>
Co-authored-by: Kunjan Patel <kunjan@ucla.edu>
Co-authored-by: Kunjan Patel <kunjanp@google.com>1 parent d843dc0 commit 65c4e40
File tree
25 files changed
+497
-212
lines changed- .github/workflows
- src/maxdiffusion
- configs
- models
- wan
- transformers
- pipelines/wan
- tests
- trainers
- tests/schedulers
25 files changed
+497
-212
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
58 | 58 | | |
59 | 59 | | |
60 | 60 | | |
61 | | - | |
| 61 | + | |
62 | 62 | | |
63 | 63 | | |
64 | 64 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | | - | |
8 | 7 | | |
9 | 8 | | |
10 | 9 | | |
| |||
98 | 97 | | |
99 | 98 | | |
100 | 99 | | |
| 100 | + | |
101 | 101 | | |
102 | 102 | | |
103 | 103 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
36 | | - | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
37 | 41 | | |
38 | 42 | | |
39 | 43 | | |
| |||
44 | 48 | | |
45 | 49 | | |
46 | 50 | | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
47 | 60 | | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
50 | 50 | | |
51 | 51 | | |
52 | 52 | | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
53 | 62 | | |
54 | 63 | | |
55 | 64 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
52 | 62 | | |
53 | 63 | | |
54 | 64 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
50 | 50 | | |
51 | 51 | | |
52 | 52 | | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
53 | 63 | | |
54 | 64 | | |
55 | 65 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
66 | 75 | | |
67 | 76 | | |
68 | 77 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
66 | 75 | | |
67 | 76 | | |
68 | 77 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
62 | 62 | | |
63 | 63 | | |
64 | 64 | | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
65 | 74 | | |
66 | 75 | | |
67 | 76 | | |
| |||
0 commit comments