|
| 1 | +from unittest.mock import MagicMock, patch |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +import numpy as np |
| 6 | +import pytest |
| 7 | +from flax import nnx |
| 8 | +from flax.typing import PRNGKey |
| 9 | +from jax.sharding import Mesh |
| 10 | +from tpu_inference.layers.common.attention_metadata import AttentionMetadata |
| 11 | +from tpu_inference.layers.jax.attention.gpt_oss_attention import GptOssAttention |
| 12 | +from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE |
| 13 | +from tpu_inference.models.jax.gpt_oss import GptOss |
| 14 | +from tpu_inference.runner.kv_cache import create_kv_caches |
| 15 | + |
| 16 | + |
| 17 | +class MockHfConfig: |
| 18 | + """Mocks the HuggingFace config object with small values for testing.""" |
| 19 | + |
| 20 | + def __init__(self): |
| 21 | + self.num_hidden_layers: int = 2 |
| 22 | + self.num_local_experts: int = 4 |
| 23 | + self.vocab_size: int = 1024 |
| 24 | + self.num_attention_heads: int = 64 |
| 25 | + self.num_key_value_heads: int = 8 |
| 26 | + self.head_dim: int = 64 |
| 27 | + self.hidden_size: int = self.num_attention_heads * self.head_dim |
| 28 | + self.intermediate_size: int = 256 |
| 29 | + self.num_experts_per_tok: int = 2 |
| 30 | + self.rms_norm_eps: float = 1e-5 |
| 31 | + self.swiglu_limit: float = 0.0 |
| 32 | + self.rope_theta: float = 10000.0 |
| 33 | + self.rope_scaling = { |
| 34 | + "factor": 1.0, |
| 35 | + "beta_slow": 1.0, |
| 36 | + "beta_fast": 1.0, |
| 37 | + "original_max_position_embeddings": 2048, |
| 38 | + } |
| 39 | + self.sliding_window: int | None = None |
| 40 | + |
| 41 | + |
| 42 | +class MockVllmConfig: |
| 43 | + """ |
| 44 | + Mocks the VllmConfig object, providing a mock hf_config and |
| 45 | + setting 'random_weights' to True to avoid downloading real weights. |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__(self, model: str, kv_cache_dtype: str): |
| 49 | + self.model_config = MagicMock() |
| 50 | + self.model_config.hf_config = MockHfConfig() |
| 51 | + self.model_config.model = model |
| 52 | + self.model_config.dtype = jnp.bfloat16 |
| 53 | + self.load_config = MagicMock(download_dir=None) |
| 54 | + self.cache_config = MagicMock(cache_dtype=kv_cache_dtype) |
| 55 | + self.additional_config = {"random_weights": True, "is_verbose": False} |
| 56 | + |
| 57 | + |
| 58 | +@pytest.fixture(scope="module") |
| 59 | +def mesh(): |
| 60 | + """ |
| 61 | + Creates and globally activates a mesh for the entire test module. |
| 62 | + This is necessary to satisfy the jax.jit in `create_param`. |
| 63 | + """ |
| 64 | + if not jax.devices(): |
| 65 | + pytest.skip("No JAX devices available for mesh creation.") |
| 66 | + |
| 67 | + devices = np.array(jax.local_devices()[:1]) |
| 68 | + num_devices = len(devices) |
| 69 | + assert num_devices == 1 |
| 70 | + device_mesh = devices.reshape((num_devices, 1)) |
| 71 | + |
| 72 | + # Create the mesh |
| 73 | + m = Mesh(device_mesh, axis_names=('data', 'model')) |
| 74 | + |
| 75 | + jax.set_mesh(m) |
| 76 | + |
| 77 | + yield m |
| 78 | + |
| 79 | + try: |
| 80 | + empty_devices = np.empty((0, 0), dtype=devices.dtype) |
| 81 | + jax.set_mesh(Mesh(empty_devices, axis_names=())) |
| 82 | + except Exception: |
| 83 | + pass |
| 84 | + |
| 85 | + |
| 86 | +@pytest.fixture |
| 87 | +def mock_model_inputs(): |
| 88 | + """Provides mock inputs for a forward pass.""" |
| 89 | + num_tokens = 8 |
| 90 | + num_reqs = 1 |
| 91 | + max_num_blocks_per_req = 4 |
| 92 | + input_ids = jnp.ones((num_tokens, ), dtype=jnp.int32) |
| 93 | + positions = jnp.arange(0, num_tokens, dtype=jnp.int32) |
| 94 | + block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req), |
| 95 | + dtype=jnp.int32).reshape(-1) |
| 96 | + seq_lens = jnp.array([num_tokens], dtype=jnp.int32) |
| 97 | + query_start_loc = jnp.array([0, num_tokens], dtype=jnp.int32) |
| 98 | + request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32) |
| 99 | + |
| 100 | + attention_metadata = AttentionMetadata( |
| 101 | + input_positions=positions, |
| 102 | + block_tables=block_tables, |
| 103 | + seq_lens=seq_lens, |
| 104 | + query_start_loc=query_start_loc, |
| 105 | + request_distribution=request_distribution, |
| 106 | + ) |
| 107 | + indices_do_sample = jnp.array([num_tokens - 1], dtype=jnp.int32) |
| 108 | + |
| 109 | + return (input_ids, attention_metadata, indices_do_sample) |
| 110 | + |
| 111 | + |
| 112 | +@pytest.fixture |
| 113 | +def rng() -> PRNGKey: |
| 114 | + """Provides a reusable JAX PRNGKey.""" |
| 115 | + return jax.random.PRNGKey(42) |
| 116 | + |
| 117 | + |
| 118 | +class TestGptOss: |
| 119 | + |
| 120 | + @pytest.mark.parametrize("mock_vllm_config", [ |
| 121 | + MockVllmConfig("mock/gpt-oss-small", "auto"), |
| 122 | + ]) |
| 123 | + def test_gpt_oss_init_and_forward(self, mock_vllm_config, rng, mesh, |
| 124 | + mock_model_inputs): |
| 125 | + """Tests model init, weight loading (mocked), and a forward pass.""" |
| 126 | + |
| 127 | + # Test model init |
| 128 | + hf_config = mock_vllm_config.model_config.hf_config |
| 129 | + |
| 130 | + model = GptOss(mock_vllm_config, rng, mesh) |
| 131 | + |
| 132 | + assert model.mesh.shape == {"data": 1, "model": 1} |
| 133 | + assert isinstance(model.rng, nnx.Rngs) |
| 134 | + assert len(model.layers) == hf_config.num_hidden_layers |
| 135 | + |
| 136 | + # Check key submodule shapes |
| 137 | + assert model.embedder.input_embedding_table_VD.shape == ( |
| 138 | + hf_config.vocab_size, hf_config.hidden_size) |
| 139 | + |
| 140 | + layer_0 = model.layers[0] |
| 141 | + attn = layer_0.attn |
| 142 | + assert isinstance(attn, GptOssAttention) |
| 143 | + assert attn.kernel_q_DNH.shape == (hf_config.hidden_size, |
| 144 | + hf_config.num_attention_heads, |
| 145 | + hf_config.head_dim) |
| 146 | + |
| 147 | + moe_mlp = layer_0.custom_module |
| 148 | + assert isinstance(moe_mlp, GptOssMoE) |
| 149 | + assert moe_mlp.mlp1_weight_EDF2.shape == (hf_config.num_local_experts, |
| 150 | + hf_config.hidden_size, |
| 151 | + hf_config.intermediate_size * |
| 152 | + 2) |
| 153 | + |
| 154 | + assert model.final_norm.scale.shape == (hf_config.hidden_size, ) |
| 155 | + assert model.lm_head.input_embedding_table_DV.shape == ( |
| 156 | + hf_config.hidden_size, hf_config.vocab_size) |
| 157 | + |
| 158 | + # Test model load |
| 159 | + with patch("tpu_inference.models.jax.gpt_oss.model_weights_generator", |
| 160 | + return_value=iter([])): |
| 161 | + model.load_weights(rng) |
| 162 | + |
| 163 | + # Test model forward |
| 164 | + num_key_value_heads = int(hf_config.num_key_value_heads / 2) |
| 165 | + kv_caches = create_kv_caches( |
| 166 | + num_blocks=4, |
| 167 | + block_size=32, |
| 168 | + num_kv_heads=num_key_value_heads, |
| 169 | + head_size=hf_config.head_dim, |
| 170 | + mesh=mesh, |
| 171 | + layer_names=["layer"] * hf_config.num_hidden_layers, |
| 172 | + cache_dtype=jnp.float8_e4m3fn |
| 173 | + if mock_vllm_config.cache_config.cache_dtype == "fp8" else |
| 174 | + jnp.bfloat16) |
| 175 | + |
| 176 | + input_ids, attention_metadata, indices_do_sample = mock_model_inputs |
| 177 | + |
| 178 | + kv_caches, hidden_states, aux_hidden_states = model( |
| 179 | + kv_caches, input_ids, attention_metadata) |
| 180 | + |
| 181 | + # Check output shapes |
| 182 | + assert hidden_states.shape == (8, hf_config.hidden_size) |
| 183 | + assert aux_hidden_states == [] |
| 184 | + |
| 185 | + # Test logits computation |
| 186 | + hidden_states = hidden_states[indices_do_sample, :] |
| 187 | + assert hidden_states.shape == (1, hf_config.hidden_size) |
| 188 | + |
| 189 | + logits = model.compute_logits(hidden_states) |
| 190 | + assert logits.shape == (1, hf_config.vocab_size) |
0 commit comments