Skip to content

Commit df4f5a4

Browse files
committed
[Misc] Fix model dtype not being configured correctly
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 7a9a4ae commit df4f5a4

File tree

3 files changed

+33
-51
lines changed

3 files changed

+33
-51
lines changed

tpu_inference/platforms/tpu_platform.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import jax.numpy as jnp
66
import vllm.envs as vllm_envs
7-
from torchax.ops.mappings import j2t_dtype
87
from tpu_info import device
98
from vllm.inputs import ProcessorInputs, PromptType
109
from vllm.platforms.interface import Platform, PlatformEnum
@@ -13,6 +12,7 @@
1312
from tpu_inference import envs
1413
from tpu_inference.layers.common.sharding import ShardingConfigManager
1514
from tpu_inference.logger import init_logger
15+
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
1616

1717
if TYPE_CHECKING:
1818
from vllm.attention.backends.registry import _Backend
@@ -150,18 +150,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
150150
# For mm model preprocessors, it may need the output dtype to be torch.
151151
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
152152
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
153-
if not isinstance(vllm_config.model_config.dtype, str):
154-
logger.warning(
155-
"The model dtype is not properly set for JAX backend. "
156-
"Overwriting it to jnp.bfloat16")
157-
vllm_config.model_config.dtype = jnp.bfloat16
153+
dtype = vllm_config.model_config.dtype
154+
if impl == "vllm":
155+
vllm_config.model_config.dtype = to_torch_dtype(dtype)
158156
else:
159-
vllm_config.model_config.dtype = _DTYPE.get(
160-
vllm_config.model_config.dtype, jnp.bfloat16)
161-
162-
if impl == "vllm":
163-
vllm_config.model_config.dtype = j2t_dtype(
164-
vllm_config.model_config.dtype.dtype)
157+
vllm_config.model_config.dtype = to_jax_dtype(dtype)
165158

166159
# TODO(cuiq): remove this dependency.
167160
from vllm.v1.attention.backends.pallas import PallasAttentionBackend

tpu_inference/runner/tpu_runner.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
import jax.numpy as jnp
1111
import jaxtyping
1212
import numpy as np
13-
import torch
1413
import vllm.envs as envs
1514
from flax import nnx
1615
from jax.experimental import mesh_utils
1716
from jax.sharding import NamedSharding, PartitionSpec
18-
from torchax.ops.mappings import j2t, j2t_dtype
17+
from torchax.ops.mappings import j2t
1918
from vllm.config import VllmConfig
2019
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
2120
has_kv_transfer_group)
@@ -64,7 +63,7 @@
6463
StructuredDecodingManager
6564
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
6665
from tpu_inference.utils import (device_array, make_optimized_mesh,
67-
time_function)
66+
time_function, to_torch_dtype)
6867

6968
logger = init_logger(__name__)
7069

@@ -78,17 +77,6 @@
7877
request_distribution=[0, 0, 0],
7978
)
8079

81-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82-
"half": torch.half,
83-
"bfloat16": torch.bfloat16,
84-
"float": torch.float,
85-
"fp8": torch.float8_e4m3fn,
86-
"fp8_e4m3": torch.float8_e4m3fn,
87-
"fp8_e5m2": torch.float8_e5m2,
88-
"int8": torch.int8,
89-
"uint8": torch.uint8,
90-
}
91-
9280

9381
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
9482
"""Holds asynchronous model output specifically from a TPU runner.
@@ -250,22 +238,10 @@ def __init__(
250238
self.uses_mrope, self.model_config)
251239
self.lora_utils = LoraUtils(self)
252240

253-
cache_config = self.cache_config
254-
if cache_config.cache_dtype == "auto":
255-
model_dtype = self.dtype
256-
if isinstance(model_dtype, str):
257-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
258-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
259-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
260-
elif isinstance(model_dtype, torch.dtype):
261-
self.kv_cache_dtype = model_dtype
262-
else:
263-
raise ValueError(
264-
"KV cache is unsupported for model_dtype of %s",
265-
model_dtype)
266-
else:
267-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
268-
cache_config.cache_dtype]
241+
cache_dtype = self.cache_config.cache_dtype
242+
if cache_dtype == "auto":
243+
cache_dtype = self.dtype
244+
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
269245

270246
self._pre_async_results: AsyncPreResults | None = None
271247
self._substitute_placeholder_token_fn = _substitute_placeholder_token

tpu_inference/utils.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
import jax
99
import jax.numpy as jnp
1010
import numpy as np
11+
import torch
1112
from jax._src import dtypes
1213
from jax._src import mesh as mesh_lib
1314
from jax._src import xla_bridge as xb
1415
from jax._src.lib import xla_client as xc
16+
from jax._src.numpy.scalar_types import _ScalarMeta
1517
from jax.sharding import Mesh, NamedSharding, PartitionSpec
18+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
1619
from vllm import envs as vllm_envs
1720
from vllm import utils
1821

@@ -26,13 +29,23 @@
2629
# This is used to translate from a string name for a dtype
2730
# to formal jax.numpy DType. One use case for this is
2831
# converting the `--kv_cache_dtype` flag to a dtype.
29-
TPU_STR_DTYPE_TO_JAX_DTYPE = {
30-
"bfloat16": jnp.bfloat16,
31-
"fp8": jnp.float8_e4m3fn,
32-
"fp8_e4m3": jnp.float8_e4m3,
33-
"fp8_e5m2": jnp.float8_e5m2,
34-
"int8": jnp.int8,
35-
}
32+
33+
34+
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype):
35+
if isinstance(dtype, str):
36+
return jnp.dtype(dtype)
37+
elif isinstance(dtype, torch.dtype):
38+
return t2j_dtype(dtype)
39+
elif isinstance(dtype, jnp.dtype):
40+
return dtype
41+
elif isinstance(dtype, _ScalarMeta):
42+
return dtype.dtype
43+
44+
45+
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype):
46+
dtype = to_jax_dtype()
47+
return j2t_dtype(dtype)
48+
3649

3750
_megacore = False
3851
logger = init_logger(__name__)
@@ -295,8 +308,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
295308
Returns:
296309
jnp.dtype: The JAX dtype.
297310
"""
298-
str_dtype = str_dtype.lower().strip()
299-
return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
311+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
312+
return to_jax_dtype(str_dtype)
300313

301314

302315
def time_function(func):

0 commit comments

Comments
 (0)