|
31 | 31 | from helion._testing import import_path |
32 | 32 | from helion._testing import skipIfCpu |
33 | 33 | from helion._testing import skipIfRocm |
| 34 | +from helion.autotuner import DESurrogateHybrid |
34 | 35 | from helion.autotuner import DifferentialEvolutionSearch |
35 | 36 | from helion.autotuner import PatternSearch |
36 | 37 | from helion.autotuner.base_search import BaseSearch |
@@ -381,6 +382,21 @@ def test_differential_evolution_search(self): |
381 | 382 | fn = bound_kernel.compile_config(best) |
382 | 383 | torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1) |
383 | 384 |
|
| 385 | + @skipIfRocm("too slow on rocm") |
| 386 | + @skip("too slow") |
| 387 | + def test_de_surrogate_hybrid(self): |
| 388 | + args = ( |
| 389 | + torch.randn([512, 512], device=DEVICE), |
| 390 | + torch.randn([512, 512], device=DEVICE), |
| 391 | + ) |
| 392 | + bound_kernel = examples_matmul.bind(args) |
| 393 | + random.seed(123) |
| 394 | + best = DESurrogateHybrid( |
| 395 | + bound_kernel, args, population_size=5, max_generations=3 |
| 396 | + ).autotune() |
| 397 | + fn = bound_kernel.compile_config(best) |
| 398 | + torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1) |
| 399 | + |
384 | 400 | @skip("too slow") |
385 | 401 | def test_pattern_search(self): |
386 | 402 | args = ( |
|
0 commit comments