Commit 760cc7d
Support of device ordering (#95)
* Trying out new version of redistribute_local_tensor
Taken from pytorch/pytorch#160266, but I'm hitting an assertion for now
* update ordered sharding
* relocate redistribute tensor function (jax way map tensor dim to mesh dim)
* fix loss curve mismatch
* imporve ordering logic and bring back _optimize_same_nd_sharding_as_1d
* lint
* fix small bug
* adress review feedback
* fix CI
---------
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>1 parent 939635a commit 760cc7d
File tree
5 files changed
+942
-65
lines changed- .github/workflows
- autoparallel
- dtensor_util
5 files changed
+942
-65
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
50 | | - | |
| 50 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
51 | 51 | | |
52 | 52 | | |
53 | 53 | | |
54 | | - | |
| 54 | + | |
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
| 59 | + | |
| 60 | + | |
59 | 61 | | |
60 | 62 | | |
61 | 63 | | |
62 | 64 | | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
69 | 70 | | |
70 | | - | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
71 | 76 | | |
72 | 77 | | |
73 | 78 | | |
| |||
109 | 114 | | |
110 | 115 | | |
111 | 116 | | |
112 | | - | |
| 117 | + | |
113 | 118 | | |
114 | 119 | | |
115 | 120 | | |
| |||
0 commit comments