55from vllm .model_executor .layers .quantization .base_config import \
66 QuantizationConfig
77
8+ from tpu_inference .layers .common import quant_methods
89from tpu_inference .layers .vllm .quantization .awq import VllmAWQConfig
910from tpu_inference .layers .vllm .quantization .common import JaxCommonConfig
1011from tpu_inference .layers .vllm .quantization .compressed_tensors .compressed_tensors import \
@@ -20,9 +21,9 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
2021 # TODO(kyuyeunk): Add support for "tpu_int8".
2122 method_to_config : dict [str , str ] = {
2223 None : VllmUnquantizedConfig ,
23- "compressed-tensors" : VllmCompressedTensorsConfig ,
24- "awq" : VllmAWQConfig ,
25- "mxfp4" : VllmMxfp4Config ,
24+ quant_methods . COMPRESSED_TENSORS : VllmCompressedTensorsConfig ,
25+ quant_methods . AWQ : VllmAWQConfig ,
26+ quant_methods . MXFP4 : VllmMxfp4Config ,
2627 }
2728 if model_config .quantization not in method_to_config :
2829 raise NotImplementedError (
@@ -32,7 +33,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
3233 assert issubclass (quant_config , JaxCommonConfig )
3334 quant_config .set_configs (vllm_config , mesh )
3435
35- # TODO(kyuyeunk): Create more programmatic way to handle this.
36- model_config . quantization = "tpu-" + quant_config .get_name ()
36+ model_config . quantization = quant_methods . get_tpu_quant_method (
37+ quant_config .get_name () )
3738 return VllmConfig .get_quantization_config (model_config ,
3839 vllm_config .load_config )
0 commit comments