Skip to content

Commit 7f7be82

Browse files
authored
Consolidate quant method names into a single file (#1101)
1 parent d1422d2 commit 7f7be82

File tree

6 files changed

+38
-10
lines changed

6 files changed

+38
-10
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
UNQUANTIZED = "unquantized"
2+
MXFP4 = "mxfp4"
3+
AWQ = "awq"
4+
COMPRESSED_TENSORS = "compressed-tensors"
5+
6+
7+
def get_tpu_quant_method(quant_method: str) -> str:
8+
return "tpu-" + quant_method

tpu_inference/layers/vllm/quantization/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from vllm.model_executor.layers.quantization.base_config import \
66
QuantizationConfig
77

8+
from tpu_inference.layers.common import quant_methods
89
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
910
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
1011
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
@@ -20,9 +21,9 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
2021
# TODO(kyuyeunk): Add support for "tpu_int8".
2122
method_to_config: dict[str, str] = {
2223
None: VllmUnquantizedConfig,
23-
"compressed-tensors": VllmCompressedTensorsConfig,
24-
"awq": VllmAWQConfig,
25-
"mxfp4": VllmMxfp4Config,
24+
quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
25+
quant_methods.AWQ: VllmAWQConfig,
26+
quant_methods.MXFP4: VllmMxfp4Config,
2627
}
2728
if model_config.quantization not in method_to_config:
2829
raise NotImplementedError(
@@ -32,7 +33,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
3233
assert issubclass(quant_config, JaxCommonConfig)
3334
quant_config.set_configs(vllm_config, mesh)
3435

35-
# TODO(kyuyeunk): Create more programmatic way to handle this.
36-
model_config.quantization = "tpu-" + quant_config.get_name()
36+
model_config.quantization = quant_methods.get_tpu_quant_method(
37+
quant_config.get_name())
3738
return VllmConfig.get_quantization_config(model_config,
3839
vllm_config.load_config)

tpu_inference/layers/vllm/quantization/awq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
is_layer_skipped, unpack_quantized_values_into_int32)
1919
from vllm.scalar_type import scalar_types
2020

21+
from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
2122
from tpu_inference.layers.vllm.linear_common import (
2223
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
2324
from tpu_inference.layers.vllm.quantization.common import (
@@ -29,9 +30,13 @@
2930
logger = init_logger(__name__)
3031

3132

32-
@register_quantization_config("tpu-awq")
33+
@register_quantization_config(get_tpu_quant_method(AWQ))
3334
class VllmAWQConfig(AWQConfig, JaxCommonConfig):
3435

36+
@classmethod
37+
def get_name(cls):
38+
return AWQ
39+
3540
def get_supported_act_dtypes(self) -> list[torch.dtype]:
3641
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
3742
# bfloat16 is signifcantly preferred over foat16. This might lead to

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1717
find_matched_target, should_ignore_layer)
1818

19+
from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20+
get_tpu_quant_method)
1921
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
2022
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
2123
VllmCompressedTensorsW8A8Fp8MoEMethod
@@ -30,9 +32,13 @@
3032
logger = init_logger(__name__)
3133

3234

33-
@register_quantization_config("tpu-compressed-tensors")
35+
@register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
3436
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
3537

38+
@classmethod
39+
def get_name(cls) -> str:
40+
return COMPRESSED_TENSORS
41+
3642
def get_scheme(self,
3743
layer: torch.nn.Module,
3844
layer_name: Optional[str] = None

tpu_inference/layers/vllm/quantization/mxfp4.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from vllm.model_executor.layers.quantization.utils.quant_utils import \
2525
is_layer_skipped
2626

27+
from tpu_inference.layers.common.quant_methods import (MXFP4,
28+
get_tpu_quant_method)
2729
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
2830
from tpu_inference.layers.vllm.linear_common import \
2931
reorder_concatenated_tensor_for_sharding
@@ -64,9 +66,13 @@ def dequantize_block_weight(weight: jax.Array,
6466
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
6567

6668

67-
@register_quantization_config("tpu-mxfp4")
69+
@register_quantization_config(get_tpu_quant_method(MXFP4))
6870
class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
6971

72+
@classmethod
73+
def get_name(cls):
74+
return MXFP4
75+
7076
def get_quant_method(self, layer: torch.nn.Module,
7177
prefix: str) -> Optional["QuantizeMethodBase"]:
7278
from vllm.attention.layer import Attention # Avoid circular import

tpu_inference/layers/vllm/quantization/unquantized.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from tpu_inference import envs
2525
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26+
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27+
get_tpu_quant_method)
2628
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
2729
from tpu_inference.layers.vllm.linear_common import (
2830
reorder_concatenated_tensor_for_sharding,
@@ -34,12 +36,12 @@
3436
logger = init_logger(__name__)
3537

3638

37-
@register_quantization_config("tpu-unquantized")
39+
@register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
3840
class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
3941

4042
@classmethod
4143
def get_name(cls) -> str:
42-
return "unquantized"
44+
return UNQUANTIZED
4345

4446
@classmethod
4547
def get_supported_act_dtypes(cls) -> list[torch.dtype]:

0 commit comments

Comments
 (0)