Skip to content

Commit d1422d2

Browse files
authored
[CI] Fix head dim check (#1091)
1 parent c1b4188 commit d1422d2

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

tests/models/jax/test_llama3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,14 @@ def test_llama32_1b(self, mock_vllm_config, rng, mesh, mock_model_inputs):
107107
num_heads = hf_config.num_attention_heads
108108
num_kv_heads = hf_config.num_key_value_heads
109109
rope_theta = hf_config.rope_theta
110-
original_head_dim = hf_config.head_dim
111-
head_dim = 128
110+
head_dim = hf_config.head_dim
112111
intermediate_size = hf_config.intermediate_size
113112

114113
assert attn.hidden_size == hidden_size
115114
assert attn.num_heads == num_heads
116115
assert attn.num_kv_heads == num_kv_heads
117116
assert attn.rope_theta == rope_theta
118-
assert attn.head_dim_original == original_head_dim
117+
assert attn.head_dim_original == head_dim
119118
assert attn.head_dim == head_dim
120119
assert attn.q_proj.kernel.shape == (hidden_size, num_heads, head_dim)
121120
assert attn.k_proj.kernel.shape == (hidden_size, num_kv_heads,

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
168168
"head_dim, expected_padded_head_dim",
169169
[
170170
(1, 128),
171-
(64, 128),
171+
(64, 64),
172172
(127, 128),
173173
(128, 128),
174174
(129, 256),

0 commit comments

Comments
 (0)