1414#include " mlir/Dialect/StandardOps/IR/Ops.h"
1515#include " mlir/IR/Builders.h"
1616#include " mlir/IR/Dominance.h"
17+ #include " mlir/IR/PatternMatch.h"
1718#include " mlir/IR/RegionGraphTraits.h"
1819#include " mlir/Pass/Pass.h"
1920#include " mlir/Transforms/Passes.h"
@@ -338,12 +339,8 @@ bool LoopRestructure::removeIfFromRegion(DominanceInfo &domInfo, Region ®ion,
338339 oldTerm->getOperands ());
339340 oldTerm->erase ();
340341
341- SmallVector<Value, 4 > res;
342- for (size_t i = 1 ; i < ifOp->getNumResults (); ++i) {
343- res.push_back (ifOp->getResult (i));
344- }
345- builder.create <scf::ConditionOp>(builder.getUnknownLoc (),
346- ifOp->getResult (0 ), res);
342+ builder.create <scf::YieldOp>(builder.getUnknownLoc (),
343+ ifOp->getResults ());
347344 condBr->erase ();
348345
349346 pseudoExit->erase ();
@@ -386,14 +383,12 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
386383 mlir::OpBuilder builder (wrapper, wrapper->begin ());
387384
388385 // Copy the arguments across
389- SmallVector<Type, 4 > headerArgumentTypes;
390- for (auto arg : header->getArguments ()) {
391- headerArgumentTypes.push_back (arg.getType ());
392- }
386+ SmallVector<Type, 4 > headerArgumentTypes (header->getArgumentTypes ());
393387 wrapper->addArguments (headerArgumentTypes);
394388
395- SmallVector<Value> valsCallingLoop (wrapper->getArguments ().begin (),
396- wrapper->getArguments ().end ());
389+ SmallVector<Value> valsCallingLoop;
390+ for (auto a : wrapper->getArguments ())
391+ valsCallingLoop.push_back (a);
397392
398393 SmallVector<std::pair<Value, size_t >> preservedVals;
399394 for (auto B : L->getBlocks ()) {
@@ -415,15 +410,9 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
415410 }
416411 }
417412
418- // TODO values used outside loop should be wrapped.
419-
420- SmallVector<Type, 4 > combinedTypes (headerArgumentTypes.begin (),
421- headerArgumentTypes.end ());
422- SmallVector<Type, 4 > returns;
423- for (auto arg : target->getArguments ()) {
424- returns.push_back (arg.getType ());
425- combinedTypes.push_back (arg.getType ());
426- }
413+ SmallVector<Type, 4 > combinedTypes = headerArgumentTypes;
414+ SmallVector<Type, 4 > returns (target->getArgumentTypes ());
415+ combinedTypes.append (returns);
427416
428417 auto loop = builder.create <mlir::scf::WhileOp>(
429418 builder.getUnknownLoc (), combinedTypes, valsCallingLoop);
@@ -451,28 +440,45 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
451440 Preds.push_back (block);
452441 }
453442
454- loop.getBefore ().getBlocks ().splice (loop.getBefore ().getBlocks ().begin (),
455- region.getBlocks (), header);
443+ Block *loopEntry = new Block ();
444+ loop.getBefore ().push_back (loopEntry);
445+ builder.setInsertionPointToEnd (loopEntry);
446+ SmallVector<Type, 4 > tys = {builder.getI1Type ()};
447+ for (auto t : combinedTypes)
448+ tys.push_back (t);
449+ auto exec =
450+ builder.create <scf::ExecuteRegionOp>(builder.getUnknownLoc (), tys);
451+
452+ {
453+ SmallVector<Value> yields;
454+ for (auto a : exec.getResults ())
455+ yields.push_back (a);
456+ yields.erase (yields.begin ());
457+ builder.create <scf::ConditionOp>(builder.getUnknownLoc (),
458+ exec.getResult (0 ), yields);
459+ }
460+
461+ Region &insertRegion = exec.getRegion ();
462+
463+ insertRegion.getBlocks ().splice (insertRegion.getBlocks ().begin (),
464+ region.getBlocks (), header);
465+ assert (header->getParent () == &insertRegion);
456466 for (auto *w : L->getBlocks ()) {
457467 Block *b = &**w;
458468 if (b != header) {
459- loop. getBefore ().getBlocks ().splice (
460- loop. getBefore (). getBlocks (). end (), region.getBlocks (), b);
469+ insertRegion. getBlocks ().splice (insertRegion. getBlocks ().end (),
470+ region.getBlocks (), b);
461471 }
462472 }
463473
464474 Block *pseudoExit = new Block ();
465- auto i1Ty = builder.getI1Type ();
466475 {
467- loop.getBefore ().push_back (pseudoExit);
468- SmallVector<Type, 4 > tys = {i1Ty};
469- for (auto t : combinedTypes)
470- tys.push_back (t);
476+ insertRegion.push_back (pseudoExit);
471477 pseudoExit->addArguments (tys);
472478 OpBuilder builder (pseudoExit, pseudoExit->begin ());
473479 tys.clear ();
474- builder.create <scf::ConditionOp >(builder.getUnknownLoc (), tys,
475- pseudoExit->getArguments ());
480+ builder.create <scf::YieldOp >(builder.getUnknownLoc (), tys,
481+ pseudoExit->getArguments ());
476482 }
477483
478484 for (auto *w : exitingBlocks) {
@@ -486,7 +492,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
486492 auto vfalse = builder.create <arith::ConstantIntOp>(
487493 builder.getUnknownLoc (), false , 1 );
488494
489- std::vector <Value> args = {vfalse};
495+ SmallVector <Value> args = {vfalse};
490496 for (auto arg : header->getArguments ())
491497 args.push_back (arg);
492498 for (auto v : preservedVals)
@@ -541,11 +547,10 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
541547 builder.getUnknownLoc (), true , 1 );
542548
543549 if (auto op = dyn_cast<BranchOp>(terminator)) {
544- std::vector<Value> args (op.getOperands ().begin (),
545- op.getOperands ().end ());
550+ SmallVector<Value> args (op.getOperands ());
546551 args.insert (args.begin (), vtrue);
547- for (auto pair : preservedVals)
548- args.push_back (pair .first );
552+ for (auto p : preservedVals)
553+ args.push_back (p .first );
549554 for (auto ty : returns) {
550555 args.push_back (builder.create <mlir::LLVM::UndefOp>(
551556 builder.getUnknownLoc (), ty));
@@ -618,41 +623,38 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region ®ion) {
618623 });
619624 }
620625
626+ for (auto pair :
627+ llvm::zip (header->getArguments (),
628+ loopEntry->addArguments (header->getArgumentTypes ()))) {
629+ std::get<0 >(pair).replaceAllUsesWith (std::get<1 >(pair));
630+ }
631+ header->eraseArguments ([](BlockArgument) { return true ; });
632+
621633 builder2.create <scf::YieldOp>(builder.getUnknownLoc (), yieldargs);
622- domInfo.invalidate (&loop.getBefore ());
623- runOnRegion (domInfo, loop.getBefore ());
624- if (!removeIfFromRegion (domInfo, loop.getBefore (), pseudoExit)) {
634+ domInfo.invalidate (&insertRegion);
635+
636+ assert (header->getParent () == &insertRegion);
637+
638+ runOnRegion (domInfo, insertRegion);
639+
640+ if (!removeIfFromRegion (domInfo, insertRegion, pseudoExit)) {
625641 attemptToFoldIntoPredecessor (pseudoExit);
626642 }
627643
628644 attemptToFoldIntoPredecessor (wrapper);
629645 attemptToFoldIntoPredecessor (target);
630- if (loop.getBefore ().getBlocks ().size () != 1 ) {
631- Block *blk = new Block ();
632- OpBuilder B (loop.getContext ());
633- B.setInsertionPointToEnd (blk);
634- auto cop =
635- cast<scf::ConditionOp>(loop.getBefore ().getBlocks ().back ().back ());
636- auto er = B.create <scf::ExecuteRegionOp>(loop.getLoc (),
637- cop.getOperandTypes ());
638- er.getRegion ().getBlocks ().splice (er.getRegion ().getBlocks ().begin (),
639- loop.getBefore ().getBlocks ());
640- loop.getBefore ().push_back (blk);
641- SmallVector<Value> yields;
642- for (auto a : er.getResults ())
643- yields.push_back (a);
644- yields.erase (yields.begin ());
645- B.create <scf::ConditionOp>(cop.getLoc (), er.getResult (0 ), yields);
646- B.setInsertionPoint (&*cop);
647- for (auto arg : er.getRegion ().front ().getArguments ()) {
648- auto na = blk->addArgument (arg.getType ());
649- arg.replaceAllUsesWith (na);
650- }
651- er.getRegion ().front ().eraseArguments (
652- [](BlockArgument) { return true ; });
653- B.create <scf::YieldOp>(cop.getLoc (), cop.getOperands ());
654- cop.erase ();
646+
647+ if (llvm::hasSingleElement (insertRegion)) {
648+ Block *block = &insertRegion.front ();
649+ IRRewriter B (exec->getContext ());
650+ Operation *terminator = block->getTerminator ();
651+ ValueRange results = terminator->getOperands ();
652+ terminator->erase ();
653+ B.mergeBlockBefore (block, exec);
654+ exec.replaceAllUsesWith (results);
655+ exec.erase ();
655656 }
657+
656658 assert (loop.getBefore ().getBlocks ().size () == 1 );
657659 runOnRegion (domInfo, loop.getAfter ());
658660 assert (loop.getAfter ().getBlocks ().size () == 1 );
0 commit comments