Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/tir/transform/common_subexpr_elim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,18 @@ class CSEPlanner : public StmtExprVisitor {
current_scope_ = saved;
}

/*! \brief AttrStmt: value in parent scope, body in child scope. */
/*!
* \brief AttrStmt: value in parent scope, body in child scope.
*
* For thread_extent attrs, the value is NOT visited. This prevents CSE
* from inserting Bind statements between nested thread_extent AttrStmts,
* which breaks downstream passes (LowerDeviceKernelLaunch drops the
* Binds while the thread_extent value still references CSE variables).
*/
void VisitStmt_(const AttrStmtNode* op) override {
VisitExpr(op->value);
if (op->attr_key != attr::thread_extent) {
VisitExpr(op->value);
}
int saved = current_scope_;
current_scope_ = AllocScope(saved, ffi::GetRef<Stmt>(op));
VisitStmt(op->body);
Expand Down Expand Up @@ -747,6 +756,25 @@ class CSERewriter : public StmtExprMutator {
return visited;
}

/*!
* \brief For thread_extent AttrStmts, skip rewriting the value expression.
*
* The planner does not record expressions in thread_extent values, so
* no CSE variable is defined for them. Rewriting the value could replace
* sub-expressions with CSE variables whose Bind is inside the body scope
* (unreachable from the value position), breaking downstream passes.
*/
Stmt VisitStmt_(const AttrStmtNode* op) override {
if (op->attr_key == attr::thread_extent) {
Stmt body = VisitStmt(op->body);
if (body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
}
return AttrStmt(op->node, op->attr_key, op->value, body);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this is functionally correct, it's more idiomatic in TVM's mutators to use the CopyOnWrite pattern. This also ensures that you can benefit from in-place updates when possible.

Consider changing this to:

auto n = CopyOnWrite(op);
n->body = std::move(body);
return Stmt(n);

}
return StmtExprMutator::VisitStmt_(op);
}

private:
/*! \brief Plan: stmts to insert before each target. */
InsertBeforeTable insert_before_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,38 @@ def test_let_floordiv_pattern():
assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}"


# =====================================================================
# T22: Thread extent value -- no CSE between launch_thread calls
# Expressions in thread_extent values must not be extracted to Bind
# statements between nested thread_extent AttrStmts. Doing so breaks
# LowerDeviceKernelLaunch which drops the Bind while the thread_extent
# value still references the CSE variable. The value stays inline;
# body occurrences are still CSE'd after all launch_thread calls.
# =====================================================================
def test_thread_extent_no_cse_between_launch():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((50,), "int32"), y: T.int32, z: T.int32):
blockIdx_y = T.launch_thread("blockIdx.y", y) # noqa: F841
blockIdx_x = T.launch_thread("blockIdx.x", y + z) # noqa: F841
A[0] = y + z
A[1] = y + z

@tvm.script.ir_module
class Expected:
@T.prim_func
def main(A: T.Buffer((50,), "int32"), y: T.int32, z: T.int32):
blockIdx_y = T.launch_thread("blockIdx.y", y) # noqa: F841
blockIdx_x = T.launch_thread("blockIdx.x", y + z) # noqa: F841
cse_v1: T.int32 = y + z
A[0] = cse_v1
A[1] = cse_v1

after = tvm.tir.transform.CommonSubexprElim()(Before)
tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
test_basic()
test_if_single_branch()
Expand All @@ -735,3 +767,4 @@ def test_let_floordiv_pattern():
test_let_value_cse()
test_nested_let_no_extraction()
test_let_floordiv_pattern()
test_thread_extent_no_cse_between_launch()
Loading