Skip to content

Commit 359da68

Browse files
committed
[Misc] Fix dtype not being configured correctly
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent eb90df0 commit 359da68

File tree

5 files changed

+52
-67
lines changed

5 files changed

+52
-67
lines changed

tests/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,5 @@ def test_get_jax_dtype_from_str_dtype():
231231
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
232232
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
233233
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
234-
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
234+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
235235
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
236-
assert get_jax_dtype_from_str_dtype("auto") is None

tpu_inference/models/jax/utils/quantization/quantization_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
154154
logger.info(f"Memory usage before applying quantization of params: "
155155
f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156156

157-
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
158-
kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
159-
160-
# Handle the case where kv_cache_dtype is "auto"
161-
if kv_cache_jnp_dtype is None:
162-
assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
157+
if kv_cache_dtype != "auto":
158+
kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
159+
else:
163160
kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
164161

165162
kv_caches = create_kv_caches(

tpu_inference/platforms/tpu_platform.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import jax.numpy as jnp
66
import torch
77
import vllm.envs as vllm_envs
8-
from torchax.ops.mappings import j2t_dtype
98
from tpu_info import device
109
from vllm.inputs import ProcessorInputs, PromptType
1110
from vllm.platforms.interface import Platform, PlatformEnum
@@ -14,6 +13,7 @@
1413
from tpu_inference import envs
1514
from tpu_inference.layers.common.sharding import ShardingConfigManager
1615
from tpu_inference.logger import init_logger
16+
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
1717

1818
if TYPE_CHECKING:
1919
from vllm.attention.backends.registry import _Backend
@@ -28,12 +28,6 @@
2828

2929
logger = init_logger(__name__)
3030

31-
_DTYPE: dict[str, jnp.dtype] = {
32-
"bfloat16": jnp.bfloat16,
33-
"float": jnp.float32,
34-
"float32": jnp.float32,
35-
}
36-
3731

3832
class TpuPlatform(Platform):
3933
_enum = PlatformEnum.TPU
@@ -158,20 +152,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
158152
# NOTE(xiang): convert dtype to jnp.dtype
159153
# NOTE(wenlong): skip this logic for mm model preprocessing
160154
# For mm model preprocessors, it may need the output dtype to be torch.
161-
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
155+
# In order to avoid a PR to vLLM, we postpone the dtype checking during
156+
# tpu_worker initialization
162157
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
163-
if not isinstance(vllm_config.model_config.dtype, str):
164-
logger.warning(
165-
"The model dtype is not properly set for JAX backend. "
166-
"Overwriting it to jnp.bfloat16")
167-
vllm_config.model_config.dtype = jnp.bfloat16
168-
else:
169-
vllm_config.model_config.dtype = _DTYPE.get(
170-
vllm_config.model_config.dtype, jnp.bfloat16)
171-
172-
if impl == "vllm":
173-
vllm_config.model_config.dtype = j2t_dtype(
174-
vllm_config.model_config.dtype.dtype)
158+
try:
159+
dtype = to_jax_dtype(vllm_config.model_config.dtype)
160+
except ValueError:
161+
logger.warning("The model dtype is not set properly."
162+
"Falling back to jnp.bfloat16")
163+
dtype = jnp.bfloat16
164+
if impl == "vllm":
165+
dtype = to_torch_dtype(dtype)
166+
vllm_config.model_config.dtype = dtype
175167

176168
# TODO(cuiq): remove this dependency.
177169
from vllm.v1.attention.backends.pallas import PallasAttentionBackend

tpu_inference/runner/tpu_runner.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
import jax.numpy as jnp
1010
import jaxtyping
1111
import numpy as np
12-
import torch
1312
import vllm.envs as vllm_envs
1413
from flax import nnx
1514
from jax.experimental import mesh_utils
1615
from jax.sharding import NamedSharding, PartitionSpec
17-
from torchax.ops.mappings import j2t_dtype
1816
from vllm.config import VllmConfig
1917
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
2018
has_kv_transfer_group)
@@ -64,7 +62,7 @@
6462
StructuredDecodingManager
6563
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
6664
from tpu_inference.utils import (device_array, make_optimized_mesh,
67-
time_function)
65+
time_function, to_torch_dtype)
6866

6967
logger = init_logger(__name__)
7068

@@ -78,17 +76,6 @@
7876
request_distribution=[0, 0, 0],
7977
)
8078

81-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82-
"half": torch.half,
83-
"bfloat16": torch.bfloat16,
84-
"float": torch.float,
85-
"fp8": torch.float8_e4m3fn,
86-
"fp8_e4m3": torch.float8_e4m3fn,
87-
"fp8_e5m2": torch.float8_e5m2,
88-
"int8": torch.int8,
89-
"uint8": torch.uint8,
90-
}
91-
9279

9380
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
9481
"""Holds asynchronous model output specifically from a TPU runner.
@@ -262,22 +249,10 @@ def __init__(
262249
self.uses_mrope, self.model_config)
263250
self.lora_utils = LoraUtils(self)
264251

265-
cache_config = self.cache_config
266-
if cache_config.cache_dtype == "auto":
267-
model_dtype = self.dtype
268-
if isinstance(model_dtype, str):
269-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
270-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
271-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
272-
elif isinstance(model_dtype, torch.dtype):
273-
self.kv_cache_dtype = model_dtype
274-
else:
275-
raise ValueError(
276-
"KV cache is unsupported for model_dtype of %s",
277-
model_dtype)
278-
else:
279-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
280-
cache_config.cache_dtype]
252+
cache_dtype = self.cache_config.cache_dtype
253+
if cache_dtype == "auto":
254+
cache_dtype = self.dtype
255+
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
281256

282257
self._pre_async_results: AsyncPreResults | None = None
283258
self._substitute_placeholder_token_fn = _substitute_placeholder_token

tpu_inference/utils.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
import jax
99
import jax.numpy as jnp
1010
import numpy as np
11+
import torch
1112
from jax._src import dtypes
1213
from jax._src import mesh as mesh_lib
1314
from jax._src import xla_bridge as xb
1415
from jax._src.lib import xla_client as xc
16+
from jax._src.numpy.scalar_types import _ScalarMeta
1517
from jax.sharding import Mesh, NamedSharding, PartitionSpec
18+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
1619
from vllm import envs as vllm_envs
1720
from vllm import utils
1821

@@ -23,17 +26,36 @@
2326
TPU_HEAD_SIZE_ALIGNMENT = 128
2427
TPU_SECOND_LAST_MINOR = 8
2528

26-
# This is used to translate from a string name for a dtype
27-
# to formal jax.numpy DType. One use case for this is
28-
# converting the `--kv_cache_dtype` flag to a dtype.
29-
TPU_STR_DTYPE_TO_JAX_DTYPE = {
30-
"bfloat16": jnp.bfloat16,
29+
# Map vllm dtype string that doesn't exactly match jax dtype string name.
30+
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
3131
"fp8": jnp.float8_e4m3fn,
32-
"fp8_e4m3": jnp.float8_e4m3,
32+
"fp8_e4m3": jnp.float8_e4m3fn,
3333
"fp8_e5m2": jnp.float8_e5m2,
34-
"int8": jnp.int8,
3534
}
3635

36+
37+
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38+
if isinstance(dtype, str):
39+
if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40+
return dict_dtype
41+
return jnp.dtype(dtype)
42+
elif isinstance(dtype, torch.dtype):
43+
return t2j_dtype(dtype)
44+
elif isinstance(dtype, jnp.dtype):
45+
return dtype
46+
elif isinstance(dtype, _ScalarMeta):
47+
return dtype.dtype
48+
else:
49+
raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50+
51+
52+
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53+
# Use jax dtype as an intermediate dtype which we'll be used to convert it
54+
# into torch dtype.
55+
dtype = to_jax_dtype(dtype)
56+
return j2t_dtype(dtype)
57+
58+
3759
_megacore = False
3860
logger = init_logger(__name__)
3961

@@ -295,8 +317,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
295317
Returns:
296318
jnp.dtype: The JAX dtype.
297319
"""
298-
str_dtype = str_dtype.lower().strip()
299-
return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
320+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
321+
return to_jax_dtype(str_dtype)
300322

301323

302324
def time_function(func):

0 commit comments

Comments
 (0)