From 74cd12655027a2d4f59de65d13a70de851e6a123 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 18 Mar 2026 13:42:45 -0400 Subject: [PATCH] [TIR][CSE] Fix crash with thread_extent values CSE previously inserted Bind statements between nested thread_extent AttrStmts when the extent value (e.g. 511 // B + 1) also appeared in the body. LowerDeviceKernelLaunch later dropped these Binds while the thread_extent value still referenced CSE variables, causing LLVM host codegen to fail with "cannot find variable cse_v3". Fix: skip recording and rewriting expressions in thread_extent values. Body expressions are still CSE'd normally. --- src/tir/transform/common_subexpr_elim.cc | 32 ++++++++++++++++-- .../test_tir_transform_common_subexpr_elim.py | 33 +++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/src/tir/transform/common_subexpr_elim.cc b/src/tir/transform/common_subexpr_elim.cc index fb4e01a1371c..a1106f7cf922 100644 --- a/src/tir/transform/common_subexpr_elim.cc +++ b/src/tir/transform/common_subexpr_elim.cc @@ -554,9 +554,18 @@ class CSEPlanner : public StmtExprVisitor { current_scope_ = saved; } - /*! \brief AttrStmt: value in parent scope, body in child scope. */ + /*! + * \brief AttrStmt: value in parent scope, body in child scope. + * + * For thread_extent attrs, the value is NOT visited. This prevents CSE + * from inserting Bind statements between nested thread_extent AttrStmts, + * which breaks downstream passes (LowerDeviceKernelLaunch drops the + * Binds while the thread_extent value still references CSE variables). + */ void VisitStmt_(const AttrStmtNode* op) override { - VisitExpr(op->value); + if (op->attr_key != attr::thread_extent) { + VisitExpr(op->value); + } int saved = current_scope_; current_scope_ = AllocScope(saved, ffi::GetRef(op)); VisitStmt(op->body); @@ -747,6 +756,25 @@ class CSERewriter : public StmtExprMutator { return visited; } + /*! + * \brief For thread_extent AttrStmts, skip rewriting the value expression. + * + * The planner does not record expressions in thread_extent values, so + * no CSE variable is defined for them. Rewriting the value could replace + * sub-expressions with CSE variables whose Bind is inside the body scope + * (unreachable from the value position), breaking downstream passes. + */ + Stmt VisitStmt_(const AttrStmtNode* op) override { + if (op->attr_key == attr::thread_extent) { + Stmt body = VisitStmt(op->body); + if (body.same_as(op->body)) { + return ffi::GetRef(op); + } + return AttrStmt(op->node, op->attr_key, op->value, body); + } + return StmtExprMutator::VisitStmt_(op); + } + private: /*! \brief Plan: stmts to insert before each target. */ InsertBeforeTable insert_before_; diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index 237ee8ba2492..5b91e5960534 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -713,6 +713,38 @@ def test_let_floordiv_pattern(): assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}" +# ===================================================================== +# T22: Thread extent value -- no CSE between launch_thread calls +# Expressions in thread_extent values must not be extracted to Bind +# statements between nested thread_extent AttrStmts. Doing so breaks +# LowerDeviceKernelLaunch which drops the Bind while the thread_extent +# value still references the CSE variable. The value stays inline; +# body occurrences are still CSE'd after all launch_thread calls. +# ===================================================================== +def test_thread_extent_no_cse_between_launch(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((50,), "int32"), y: T.int32, z: T.int32): + blockIdx_y = T.launch_thread("blockIdx.y", y) # noqa: F841 + blockIdx_x = T.launch_thread("blockIdx.x", y + z) # noqa: F841 + A[0] = y + z + A[1] = y + z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer((50,), "int32"), y: T.int32, z: T.int32): + blockIdx_y = T.launch_thread("blockIdx.y", y) # noqa: F841 + blockIdx_x = T.launch_thread("blockIdx.x", y + z) # noqa: F841 + cse_v1: T.int32 = y + z + A[0] = cse_v1 + A[1] = cse_v1 + + after = tvm.tir.transform.CommonSubexprElim()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": test_basic() test_if_single_branch() @@ -735,3 +767,4 @@ def test_let_floordiv_pattern(): test_let_value_cse() test_nested_let_no_extraction() test_let_floordiv_pattern() + test_thread_extent_no_cse_between_launch()