Skip to content

Commit 1c5ccc0

Browse files
author
bzgoogle
committed
[GPT-OSS] add unit test for GPT-OSS
1 parent 78131cf commit 1c5ccc0

File tree

1 file changed

+190
-0
lines changed

1 file changed

+190
-0
lines changed

tests/models/jax/test_gpt_oss.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import numpy as np
6+
import pytest
7+
from flax import nnx
8+
from flax.typing import PRNGKey
9+
from jax.sharding import Mesh
10+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
11+
from tpu_inference.layers.jax.attention.gpt_oss_attention import GptOssAttention
12+
from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE
13+
from tpu_inference.models.jax.gpt_oss import GptOss
14+
from tpu_inference.runner.kv_cache import create_kv_caches
15+
16+
17+
class MockHfConfig:
18+
"""Mocks the HuggingFace config object with small values for testing."""
19+
20+
def __init__(self):
21+
self.num_hidden_layers: int = 2
22+
self.num_local_experts: int = 4
23+
self.vocab_size: int = 1024
24+
self.num_attention_heads: int = 64
25+
self.num_key_value_heads: int = 8
26+
self.head_dim: int = 64
27+
self.hidden_size: int = self.num_attention_heads * self.head_dim
28+
self.intermediate_size: int = 256
29+
self.num_experts_per_tok: int = 2
30+
self.rms_norm_eps: float = 1e-5
31+
self.swiglu_limit: float = 0.0
32+
self.rope_theta: float = 10000.0
33+
self.rope_scaling = {
34+
"factor": 1.0,
35+
"beta_slow": 1.0,
36+
"beta_fast": 1.0,
37+
"original_max_position_embeddings": 2048,
38+
}
39+
self.sliding_window: int | None = None
40+
41+
42+
class MockVllmConfig:
43+
"""
44+
Mocks the VllmConfig object, providing a mock hf_config and
45+
setting 'random_weights' to True to avoid downloading real weights.
46+
"""
47+
48+
def __init__(self, model: str, kv_cache_dtype: str):
49+
self.model_config = MagicMock()
50+
self.model_config.hf_config = MockHfConfig()
51+
self.model_config.model = model
52+
self.model_config.dtype = jnp.bfloat16
53+
self.load_config = MagicMock(download_dir=None)
54+
self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
55+
self.additional_config = {"random_weights": True, "is_verbose": False}
56+
57+
58+
@pytest.fixture(scope="module")
59+
def mesh():
60+
"""
61+
Creates and globally activates a mesh for the entire test module.
62+
This is necessary to satisfy the jax.jit in `create_param`.
63+
"""
64+
if not jax.devices():
65+
pytest.skip("No JAX devices available for mesh creation.")
66+
67+
devices = np.array(jax.local_devices()[:1])
68+
num_devices = len(devices)
69+
assert num_devices == 1
70+
device_mesh = devices.reshape((num_devices, 1))
71+
72+
# Create the mesh
73+
m = Mesh(device_mesh, axis_names=('data', 'model'))
74+
75+
jax.set_mesh(m)
76+
77+
yield m
78+
79+
try:
80+
empty_devices = np.empty((0, 0), dtype=devices.dtype)
81+
jax.set_mesh(Mesh(empty_devices, axis_names=()))
82+
except Exception:
83+
pass
84+
85+
86+
@pytest.fixture
87+
def mock_model_inputs():
88+
"""Provides mock inputs for a forward pass."""
89+
num_tokens = 8
90+
num_reqs = 1
91+
max_num_blocks_per_req = 4
92+
input_ids = jnp.ones((num_tokens, ), dtype=jnp.int32)
93+
positions = jnp.arange(0, num_tokens, dtype=jnp.int32)
94+
block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req),
95+
dtype=jnp.int32).reshape(-1)
96+
seq_lens = jnp.array([num_tokens], dtype=jnp.int32)
97+
query_start_loc = jnp.array([0, num_tokens], dtype=jnp.int32)
98+
request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32)
99+
100+
attention_metadata = AttentionMetadata(
101+
input_positions=positions,
102+
block_tables=block_tables,
103+
seq_lens=seq_lens,
104+
query_start_loc=query_start_loc,
105+
request_distribution=request_distribution,
106+
)
107+
indices_do_sample = jnp.array([num_tokens - 1], dtype=jnp.int32)
108+
109+
return (input_ids, attention_metadata, indices_do_sample)
110+
111+
112+
@pytest.fixture
113+
def rng() -> PRNGKey:
114+
"""Provides a reusable JAX PRNGKey."""
115+
return jax.random.PRNGKey(42)
116+
117+
118+
class TestGptOss:
119+
120+
@pytest.mark.parametrize("mock_vllm_config", [
121+
MockVllmConfig("mock/gpt-oss-small", "auto"),
122+
])
123+
def test_gpt_oss_init_and_forward(self, mock_vllm_config, rng, mesh,
124+
mock_model_inputs):
125+
"""Tests model init, weight loading (mocked), and a forward pass."""
126+
127+
# Test model init
128+
hf_config = mock_vllm_config.model_config.hf_config
129+
130+
model = GptOss(mock_vllm_config, rng, mesh)
131+
132+
assert model.mesh.shape == {"data": 1, "model": 1}
133+
assert isinstance(model.rng, nnx.Rngs)
134+
assert len(model.layers) == hf_config.num_hidden_layers
135+
136+
# Check key submodule shapes
137+
assert model.embedder.input_embedding_table_VD.shape == (
138+
hf_config.vocab_size, hf_config.hidden_size)
139+
140+
layer_0 = model.layers[0]
141+
attn = layer_0.attn
142+
assert isinstance(attn, GptOssAttention)
143+
assert attn.kernel_q_DNH.shape == (hf_config.hidden_size,
144+
hf_config.num_attention_heads,
145+
hf_config.head_dim)
146+
147+
moe_mlp = layer_0.custom_module
148+
assert isinstance(moe_mlp, GptOssMoE)
149+
assert moe_mlp.mlp1_weight_EDF2.shape == (hf_config.num_local_experts,
150+
hf_config.hidden_size,
151+
hf_config.intermediate_size *
152+
2)
153+
154+
assert model.final_norm.scale.shape == (hf_config.hidden_size, )
155+
assert model.lm_head.input_embedding_table_DV.shape == (
156+
hf_config.hidden_size, hf_config.vocab_size)
157+
158+
# Test model load
159+
with patch("tpu_inference.models.jax.gpt_oss.model_weights_generator",
160+
return_value=iter([])):
161+
model.load_weights(rng)
162+
163+
# Test model forward
164+
num_key_value_heads = int(hf_config.num_key_value_heads / 2)
165+
kv_caches = create_kv_caches(
166+
num_blocks=4,
167+
block_size=32,
168+
num_kv_heads=num_key_value_heads,
169+
head_size=hf_config.head_dim,
170+
mesh=mesh,
171+
layer_names=["layer"] * hf_config.num_hidden_layers,
172+
cache_dtype=jnp.float8_e4m3fn
173+
if mock_vllm_config.cache_config.cache_dtype == "fp8" else
174+
jnp.bfloat16)
175+
176+
input_ids, attention_metadata, indices_do_sample = mock_model_inputs
177+
178+
kv_caches, hidden_states, aux_hidden_states = model(
179+
kv_caches, input_ids, attention_metadata)
180+
181+
# Check output shapes
182+
assert hidden_states.shape == (8, hf_config.hidden_size)
183+
assert aux_hidden_states == []
184+
185+
# Test logits computation
186+
hidden_states = hidden_states[indices_do_sample, :]
187+
assert hidden_states.shape == (1, hf_config.hidden_size)
188+
189+
logits = model.compute_logits(hidden_states)
190+
assert logits.shape == (1, hf_config.vocab_size)

0 commit comments

Comments
 (0)