|
4 | 4 | import pytest |
5 | 5 | import torch |
6 | 6 |
|
| 7 | +from vllm._aiter_ops import rocm_aiter_ops |
7 | 8 | from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config |
8 | 9 | from vllm.model_executor.custom_op import CustomOp |
9 | 10 | from vllm.model_executor.layers.activation import ( |
|
15 | 16 | dispatch_topk_func, |
16 | 17 | vllm_topk_softmax, |
17 | 18 | ) |
18 | | -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( |
19 | | - is_rocm_aiter_moe_enabled, |
20 | | -) |
21 | 19 | from vllm.model_executor.layers.layernorm import ( |
22 | 20 | RMSNorm, |
23 | 21 | dispatch_rocm_rmsnorm_func, |
@@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str): |
126 | 124 | RMSNorm(1024).enabled() |
127 | 125 |
|
128 | 126 |
|
129 | | -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) |
130 | | -def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): |
131 | | - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) |
132 | | - topk_func = dispatch_topk_func() |
133 | | - is_rocm_aiter_moe_enabled.cache_clear() |
134 | | - if current_platform.is_rocm() and int(use_rocm_aiter): |
135 | | - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( |
136 | | - rocm_aiter_topk_softmax, |
137 | | - ) |
| 127 | +@pytest.mark.parametrize( |
| 128 | + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] |
| 129 | +) |
| 130 | +def test_topk_dispatch(use_rocm_aiter: bool): |
| 131 | + topk_func = dispatch_topk_func(use_rocm_aiter) |
138 | 132 |
|
139 | | - assert topk_func == rocm_aiter_topk_softmax |
| 133 | + if current_platform.is_rocm() and use_rocm_aiter: |
| 134 | + assert topk_func == rocm_aiter_ops.topk_softmax |
140 | 135 | else: |
141 | 136 | assert topk_func == vllm_topk_softmax |
142 | 137 |
|
143 | 138 |
|
144 | 139 | @pytest.mark.parametrize("add_residual", [True, False]) |
145 | 140 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
146 | | -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) |
147 | | -@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) |
| 141 | +@pytest.mark.parametrize("use_rocm_aiter", [True, False]) |
148 | 142 | @pytest.mark.skipif( |
149 | 143 | not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" |
150 | 144 | ) |
151 | 145 | def test_rms_norm_dispatch( |
152 | | - add_residual: bool, |
153 | | - dtype: torch.dtype, |
154 | | - use_rocm_aiter: str, |
155 | | - use_rocm_aiter_norm: str, |
156 | | - monkeypatch, |
| 146 | + add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool |
157 | 147 | ): |
158 | | - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) |
159 | | - monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) |
160 | | - rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) |
| 148 | + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter) |
161 | 149 |
|
162 | 150 | should_use_rocm_aiter = ( |
163 | 151 | current_platform.is_rocm() |
164 | | - and int(use_rocm_aiter) |
165 | | - and int(use_rocm_aiter_norm) |
| 152 | + and use_rocm_aiter |
166 | 153 | and dtype in RMS_NORM_SUPPORTED_DTYPES |
167 | 154 | ) |
168 | 155 |
|
169 | 156 | if add_residual and should_use_rocm_aiter: |
170 | | - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add |
| 157 | + assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add |
171 | 158 | elif should_use_rocm_aiter: |
172 | | - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm |
| 159 | + assert rms_norm_func == rocm_aiter_ops.rms_norm |
173 | 160 | elif add_residual: |
174 | 161 | assert rms_norm_func == fused_add_rms_norm |
175 | 162 | else: |
|
0 commit comments