Skip to content

Commit 5dcba70

Browse files
authored
Add failing bitwise determinism test for aot_eager (#170)
1 parent 2b9ef0a commit 5dcba70

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

tests/test_aot_eager.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pytest
2+
import torch
3+
from torch.utils._debug_mode import DebugMode
4+
5+
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
6+
7+
# TODO: make device generic
8+
9+
10+
@pytest.fixture(scope="module")
11+
def llama3_debug_model():
12+
torch.manual_seed(1999)
13+
model_args = TransformerModelArgs(
14+
dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000
15+
)
16+
return Transformer(model_args).cuda()
17+
18+
19+
def test_deterministic(llama3_debug_model):
20+
batch_size = 8
21+
seqlen = 2048
22+
vocab_size = llama3_debug_model.model_args.vocab_size
23+
torch.manual_seed(2999)
24+
x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda")
25+
torch.manual_seed(3999)
26+
r1 = llama3_debug_model(x)
27+
torch.manual_seed(3999)
28+
r2 = llama3_debug_model(x)
29+
assert torch.equal(r1, r2) # bitwise equal
30+
31+
32+
def test_debug_mode_bitwise_equivalent(llama3_debug_model):
33+
batch_size = 8
34+
seqlen = 2048
35+
vocab_size = llama3_debug_model.model_args.vocab_size
36+
torch.manual_seed(2999)
37+
x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda")
38+
torch.manual_seed(3999)
39+
r1 = llama3_debug_model(x)
40+
torch.manual_seed(3999)
41+
with DebugMode() as debug_mode:
42+
r2 = llama3_debug_model(x)
43+
print(debug_mode.debug_string())
44+
assert torch.equal(r1, r2) # bitwise equal
45+
46+
47+
@pytest.mark.xfail
48+
def test_aot_eager_bitwise_equivalent(llama3_debug_model):
49+
batch_size = 8
50+
seqlen = 2048
51+
vocab_size = llama3_debug_model.model_args.vocab_size
52+
torch.manual_seed(2999)
53+
x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda")
54+
torch.manual_seed(3999)
55+
r1 = llama3_debug_model(x)
56+
torch.manual_seed(3999)
57+
r2 = torch.compile(backend="aot_eager")(llama3_debug_model)(x)
58+
assert torch.equal(r1, r2) # bitwise equal

0 commit comments

Comments
 (0)