Skip to content

Commit 2f304a3

Browse files
committed
Fix loop restructure
1 parent be095b6 commit 2f304a3

File tree

8 files changed

+345
-86
lines changed

8 files changed

+345
-86
lines changed

lib/polygeist/Passes/CanonicalizeFor.cpp

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,13 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
592592
if (!step)
593593
return failure();
594594

595+
// Cannot transform for if step is not loop-invariant
596+
if (auto op = step.getDefiningOp()) {
597+
if (loop->isAncestor(op)) {
598+
return failure();
599+
}
600+
}
601+
595602
bool negativeStep = false;
596603
if (auto cop = step.getDefiningOp<ConstantIntOp>()) {
597604
if (cop.value() < 0) {
@@ -905,6 +912,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
905912
return failure();
906913

907914
SmallVector<std::pair<BlockArgument, Value>, 2> m;
915+
908916
// The return results of the while which are used
909917
SmallVector<Value, 2> prevResults;
910918
// The corresponding value in the before which
@@ -958,7 +966,17 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
958966
}
959967
return failure();
960968
}
961-
m.emplace_back(std::get<2>(pair), thenYielded);
969+
// If the value yielded from then then is defined in the while before
970+
// but not being moved down with the if, don't change anything.
971+
if (!ifOp.getThenRegion().isAncestor(thenYielded.getParentRegion()) &&
972+
op.getBefore().isAncestor(thenYielded.getParentRegion())) {
973+
prevResults.push_back(std::get<0>(pair));
974+
condArgs.push_back(thenYielded);
975+
} else {
976+
// Otherwise, mark the corresponding after argument to be replaced
977+
// with the value yielded in the if statement.
978+
m.emplace_back(std::get<2>(pair), thenYielded);
979+
}
962980
} else {
963981
assert(prevResults.size() == condArgs.size());
964982
prevResults.push_back(std::get<0>(pair));
@@ -974,29 +992,35 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
974992
rewriter.updateRootInPlace(afterYield, [&] {
975993
afterYield.getResultsMutable().assign(yieldArgs);
976994
});
977-
978-
llvm::SetVector<Value> sv;
979-
findValuesUsedBelow(ifOp, sv);
980-
981995
Block *afterB = &op.getAfter().front();
982996

983-
for (auto v : sv) {
984-
condArgs.push_back(v);
985-
auto arg = afterB->addArgument(v.getType());
986-
for (OpOperand &use : llvm::make_early_inc_range(v.getUses())) {
987-
if (ifOp->isAncestor(use.getOwner()) || use.getOwner() == afterYield)
988-
rewriter.updateRootInPlace(use.getOwner(), [&]() { use.set(arg); });
997+
{
998+
llvm::SetVector<Value> sv;
999+
findValuesUsedBelow(ifOp, sv);
1000+
1001+
for (auto v : sv) {
1002+
condArgs.push_back(v);
1003+
auto arg = afterB->addArgument(v.getType());
1004+
for (OpOperand &use : llvm::make_early_inc_range(v.getUses())) {
1005+
if (ifOp->isAncestor(use.getOwner()) ||
1006+
use.getOwner() == afterYield)
1007+
rewriter.updateRootInPlace(use.getOwner(),
1008+
[&]() { use.set(arg); });
1009+
}
9891010
}
9901011
}
9911012

9921013
rewriter.setInsertionPoint(term);
9931014
rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
9941015
condArgs);
9951016

1017+
SmallVector<unsigned> indices;
9961018
for (int i = m.size() - 1; i >= 0; i--) {
1019+
assert(m[i].first.getType() == m[i].second.getType());
9971020
m[i].first.replaceAllUsesWith(m[i].second);
998-
afterB->eraseArgument(m[i].first.getArgNumber());
1021+
indices.push_back(m[i].first.getArgNumber());
9991022
}
1023+
afterB->eraseArguments(indices);
10001024

10011025
rewriter.eraseOp(ifOp.thenYield());
10021026
Block *thenB = ifOp.thenBlock();
@@ -1266,7 +1290,10 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
12661290
struct WhileLICM : public OpRewritePattern<WhileOp> {
12671291
using OpRewritePattern<WhileOp>::OpRewritePattern;
12681292
static bool canBeHoisted(Operation *op,
1269-
function_ref<bool(Value)> definedOutside) {
1293+
function_ref<bool(Value)> definedOutside,
1294+
bool isSpeculatable) {
1295+
// TODO consider requirement of isSpeculatable
1296+
12701297
// Check that dependencies are defined outside of loop.
12711298
if (!llvm::all_of(op->getOperands(), definedOutside))
12721299
return false;
@@ -1293,7 +1320,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
12931320
for (auto &region : op->getRegions()) {
12941321
for (auto &block : region) {
12951322
for (auto &innerOp : block.without_terminator())
1296-
if (!canBeHoisted(&innerOp, definedOutside))
1323+
if (!canBeHoisted(&innerOp, definedOutside, isSpeculatable))
12971324
return false;
12981325
}
12991326
}
@@ -1311,10 +1338,14 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
13111338
// properties.
13121339
auto isDefinedOutsideOfBody = [&](Value value) {
13131340
auto definingOp = value.getDefiningOp();
1314-
bool definedOutside =
1315-
(definingOp && !!willBeMovedSet.count(definingOp)) ||
1316-
!op.getBefore().isAncestor(value.getParentRegion());
1317-
return definedOutside;
1341+
if (!definingOp) {
1342+
if (auto ba = value.dyn_cast<BlockArgument>())
1343+
definingOp = ba.getOwner()->getParentOp();
1344+
assert(definingOp);
1345+
}
1346+
if (willBeMovedSet.count(definingOp))
1347+
return true;
1348+
return op != definingOp && !op->isAncestor(definingOp);
13181349
};
13191350

13201351
// Do not use walk here, as we do not want to go into nested regions and
@@ -1323,7 +1354,17 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
13231354
// processed.
13241355
for (auto &block : op.getBefore()) {
13251356
for (auto &op : block.without_terminator()) {
1326-
bool legal = canBeHoisted(&op, isDefinedOutsideOfBody);
1357+
bool legal = canBeHoisted(&op, isDefinedOutsideOfBody, false);
1358+
if (legal) {
1359+
opsToMove.push_back(&op);
1360+
willBeMovedSet.insert(&op);
1361+
}
1362+
}
1363+
}
1364+
1365+
for (auto &block : op.getAfter()) {
1366+
for (auto &op : block.without_terminator()) {
1367+
bool legal = canBeHoisted(&op, isDefinedOutsideOfBody, true);
13271368
if (legal) {
13281369
opsToMove.push_back(&op);
13291370
willBeMovedSet.insert(&op);

lib/polygeist/Passes/LoopRestructure.cpp

Lines changed: 67 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1515
#include "mlir/IR/Builders.h"
1616
#include "mlir/IR/Dominance.h"
17+
#include "mlir/IR/PatternMatch.h"
1718
#include "mlir/IR/RegionGraphTraits.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/Passes.h"
@@ -338,12 +339,8 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region &region,
338339
oldTerm->getOperands());
339340
oldTerm->erase();
340341

341-
SmallVector<Value, 4> res;
342-
for (size_t i = 1; i < ifOp->getNumResults(); ++i) {
343-
res.push_back(ifOp->getResult(i));
344-
}
345-
builder.create<scf::ConditionOp>(builder.getUnknownLoc(),
346-
ifOp->getResult(0), res);
342+
builder.create<scf::YieldOp>(builder.getUnknownLoc(),
343+
ifOp->getResults());
347344
condBr->erase();
348345

349346
pseudoExit->erase();
@@ -386,14 +383,12 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
386383
mlir::OpBuilder builder(wrapper, wrapper->begin());
387384

388385
// Copy the arguments across
389-
SmallVector<Type, 4> headerArgumentTypes;
390-
for (auto arg : header->getArguments()) {
391-
headerArgumentTypes.push_back(arg.getType());
392-
}
386+
SmallVector<Type, 4> headerArgumentTypes(header->getArgumentTypes());
393387
wrapper->addArguments(headerArgumentTypes);
394388

395-
SmallVector<Value> valsCallingLoop(wrapper->getArguments().begin(),
396-
wrapper->getArguments().end());
389+
SmallVector<Value> valsCallingLoop;
390+
for (auto a : wrapper->getArguments())
391+
valsCallingLoop.push_back(a);
397392

398393
SmallVector<std::pair<Value, size_t>> preservedVals;
399394
for (auto B : L->getBlocks()) {
@@ -415,15 +410,9 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
415410
}
416411
}
417412

418-
// TODO values used outside loop should be wrapped.
419-
420-
SmallVector<Type, 4> combinedTypes(headerArgumentTypes.begin(),
421-
headerArgumentTypes.end());
422-
SmallVector<Type, 4> returns;
423-
for (auto arg : target->getArguments()) {
424-
returns.push_back(arg.getType());
425-
combinedTypes.push_back(arg.getType());
426-
}
413+
SmallVector<Type, 4> combinedTypes = headerArgumentTypes;
414+
SmallVector<Type, 4> returns(target->getArgumentTypes());
415+
combinedTypes.append(returns);
427416

428417
auto loop = builder.create<mlir::scf::WhileOp>(
429418
builder.getUnknownLoc(), combinedTypes, valsCallingLoop);
@@ -451,28 +440,45 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
451440
Preds.push_back(block);
452441
}
453442

454-
loop.getBefore().getBlocks().splice(loop.getBefore().getBlocks().begin(),
455-
region.getBlocks(), header);
443+
Block *loopEntry = new Block();
444+
loop.getBefore().push_back(loopEntry);
445+
builder.setInsertionPointToEnd(loopEntry);
446+
SmallVector<Type, 4> tys = {builder.getI1Type()};
447+
for (auto t : combinedTypes)
448+
tys.push_back(t);
449+
auto exec =
450+
builder.create<scf::ExecuteRegionOp>(builder.getUnknownLoc(), tys);
451+
452+
{
453+
SmallVector<Value> yields;
454+
for (auto a : exec.getResults())
455+
yields.push_back(a);
456+
yields.erase(yields.begin());
457+
builder.create<scf::ConditionOp>(builder.getUnknownLoc(),
458+
exec.getResult(0), yields);
459+
}
460+
461+
Region &insertRegion = exec.getRegion();
462+
463+
insertRegion.getBlocks().splice(insertRegion.getBlocks().begin(),
464+
region.getBlocks(), header);
465+
assert(header->getParent() == &insertRegion);
456466
for (auto *w : L->getBlocks()) {
457467
Block *b = &**w;
458468
if (b != header) {
459-
loop.getBefore().getBlocks().splice(
460-
loop.getBefore().getBlocks().end(), region.getBlocks(), b);
469+
insertRegion.getBlocks().splice(insertRegion.getBlocks().end(),
470+
region.getBlocks(), b);
461471
}
462472
}
463473

464474
Block *pseudoExit = new Block();
465-
auto i1Ty = builder.getI1Type();
466475
{
467-
loop.getBefore().push_back(pseudoExit);
468-
SmallVector<Type, 4> tys = {i1Ty};
469-
for (auto t : combinedTypes)
470-
tys.push_back(t);
476+
insertRegion.push_back(pseudoExit);
471477
pseudoExit->addArguments(tys);
472478
OpBuilder builder(pseudoExit, pseudoExit->begin());
473479
tys.clear();
474-
builder.create<scf::ConditionOp>(builder.getUnknownLoc(), tys,
475-
pseudoExit->getArguments());
480+
builder.create<scf::YieldOp>(builder.getUnknownLoc(), tys,
481+
pseudoExit->getArguments());
476482
}
477483

478484
for (auto *w : exitingBlocks) {
@@ -486,7 +492,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
486492
auto vfalse = builder.create<arith::ConstantIntOp>(
487493
builder.getUnknownLoc(), false, 1);
488494

489-
std::vector<Value> args = {vfalse};
495+
SmallVector<Value> args = {vfalse};
490496
for (auto arg : header->getArguments())
491497
args.push_back(arg);
492498
for (auto v : preservedVals)
@@ -541,11 +547,10 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
541547
builder.getUnknownLoc(), true, 1);
542548

543549
if (auto op = dyn_cast<BranchOp>(terminator)) {
544-
std::vector<Value> args(op.getOperands().begin(),
545-
op.getOperands().end());
550+
SmallVector<Value> args(op.getOperands());
546551
args.insert(args.begin(), vtrue);
547-
for (auto pair : preservedVals)
548-
args.push_back(pair.first);
552+
for (auto p : preservedVals)
553+
args.push_back(p.first);
549554
for (auto ty : returns) {
550555
args.push_back(builder.create<mlir::LLVM::UndefOp>(
551556
builder.getUnknownLoc(), ty));
@@ -618,41 +623,38 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
618623
});
619624
}
620625

626+
for (auto pair :
627+
llvm::zip(header->getArguments(),
628+
loopEntry->addArguments(header->getArgumentTypes()))) {
629+
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
630+
}
631+
header->eraseArguments([](BlockArgument) { return true; });
632+
621633
builder2.create<scf::YieldOp>(builder.getUnknownLoc(), yieldargs);
622-
domInfo.invalidate(&loop.getBefore());
623-
runOnRegion(domInfo, loop.getBefore());
624-
if (!removeIfFromRegion(domInfo, loop.getBefore(), pseudoExit)) {
634+
domInfo.invalidate(&insertRegion);
635+
636+
assert(header->getParent() == &insertRegion);
637+
638+
runOnRegion(domInfo, insertRegion);
639+
640+
if (!removeIfFromRegion(domInfo, insertRegion, pseudoExit)) {
625641
attemptToFoldIntoPredecessor(pseudoExit);
626642
}
627643

628644
attemptToFoldIntoPredecessor(wrapper);
629645
attemptToFoldIntoPredecessor(target);
630-
if (loop.getBefore().getBlocks().size() != 1) {
631-
Block *blk = new Block();
632-
OpBuilder B(loop.getContext());
633-
B.setInsertionPointToEnd(blk);
634-
auto cop =
635-
cast<scf::ConditionOp>(loop.getBefore().getBlocks().back().back());
636-
auto er = B.create<scf::ExecuteRegionOp>(loop.getLoc(),
637-
cop.getOperandTypes());
638-
er.getRegion().getBlocks().splice(er.getRegion().getBlocks().begin(),
639-
loop.getBefore().getBlocks());
640-
loop.getBefore().push_back(blk);
641-
SmallVector<Value> yields;
642-
for (auto a : er.getResults())
643-
yields.push_back(a);
644-
yields.erase(yields.begin());
645-
B.create<scf::ConditionOp>(cop.getLoc(), er.getResult(0), yields);
646-
B.setInsertionPoint(&*cop);
647-
for (auto arg : er.getRegion().front().getArguments()) {
648-
auto na = blk->addArgument(arg.getType());
649-
arg.replaceAllUsesWith(na);
650-
}
651-
er.getRegion().front().eraseArguments(
652-
[](BlockArgument) { return true; });
653-
B.create<scf::YieldOp>(cop.getLoc(), cop.getOperands());
654-
cop.erase();
646+
647+
if (llvm::hasSingleElement(insertRegion)) {
648+
Block *block = &insertRegion.front();
649+
IRRewriter B(exec->getContext());
650+
Operation *terminator = block->getTerminator();
651+
ValueRange results = terminator->getOperands();
652+
terminator->erase();
653+
B.mergeBlockBefore(block, exec);
654+
exec.replaceAllUsesWith(results);
655+
exec.erase();
655656
}
657+
656658
assert(loop.getBefore().getBlocks().size() == 1);
657659
runOnRegion(domInfo, loop.getAfter());
658660
assert(loop.getAfter().getBlocks().size() == 1);

lib/polygeist/Passes/Mem2Reg.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,17 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
579579
}
580580

581581
if (thenVal != nullptr && elseVal != nullptr) {
582+
if (ifOp.getElseRegion().getBlocks().size()) {
583+
for (auto tup : llvm::reverse(llvm::zip(
584+
ifOp.getResults(), ifOp.thenYield().getOperands(),
585+
ifOp.elseYield().getOperands()))) {
586+
if (std::get<1>(tup) == thenVal &&
587+
std::get<2>(tup) == elseVal) {
588+
lastVal = std::get<0>(tup);
589+
continue;
590+
}
591+
}
592+
}
582593
OpBuilder B(ifOp.getContext());
583594
B.setInsertionPoint(ifOp);
584595
SmallVector<mlir::Type, 4> tys(ifOp.getResultTypes().begin(),

lib/polygeist/Passes/ParallelLoopDistribute.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,8 @@ struct DistributeAroundBarrier : public OpRewritePattern<scf::ParallelOp> {
754754
u.set(buf);
755755
}
756756
} else if (auto ao = v.getDefiningOp<LLVM::AllocaOp>()) {
757+
llvm::errs() << ao->getParentOfType<FuncOp>() << "\n";
758+
llvm::errs() << ao << "\n";
757759
llvm_unreachable("split around llvm alloca unhandled\n");
758760
} else
759761
rewriter.create<memref::StoreOp>(v.getLoc(), v, alloc,

0 commit comments

Comments
 (0)