@@ -238,9 +238,13 @@ struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
238238
239239 Value difference = rewriter.create <SubIOp>(op.getLoc (), op.getUpperBound (),
240240 op.getLowerBound ());
241- Value tripCount = rewriter.create <AddIOp>(op.getLoc (), rewriter.create <DivUIOp>(op.getLoc (),
242- rewriter.create <SubIOp>(op.getLoc (), difference, one), op.getStep ()), one);
243- // rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
241+ Value tripCount = rewriter.create <AddIOp>(
242+ op.getLoc (),
243+ rewriter.create <DivUIOp>(
244+ op.getLoc (), rewriter.create <SubIOp>(op.getLoc (), difference, one),
245+ op.getStep ()),
246+ one);
247+ // rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
244248 auto newForOp =
245249 rewriter.create <scf::ForOp>(op.getLoc (), zero, tripCount, one);
246250 rewriter.setInsertionPointToStart (newForOp.getBody ());
@@ -455,34 +459,38 @@ static void moveBodies(PatternRewriter &rewriter, scf::ParallelOp op,
455459 scf::IfOp ifOp, scf::IfOp newIf) {
456460 rewriter.startRootUpdate (op);
457461 {
458- OpBuilder::InsertionGuard guard (rewriter);
459- rewriter.setInsertionPointToStart (newIf.thenBlock ());
460- auto newParallel = rewriter.create <scf::ParallelOp>(
461- op.getLoc (), op.getLowerBound (), op.getUpperBound (), op.getStep ());
462+ OpBuilder::InsertionGuard guard (rewriter);
463+ rewriter.setInsertionPointToStart (newIf.thenBlock ());
464+ auto newParallel = rewriter.create <scf::ParallelOp>(
465+ op.getLoc (), op.getLowerBound (), op.getUpperBound (), op.getStep ());
462466
463- for (auto tup : llvm::zip (newParallel.getInductionVars (), op.getInductionVars ())) {
467+ for (auto tup :
468+ llvm::zip (newParallel.getInductionVars (), op.getInductionVars ())) {
464469 std::get<1 >(tup).replaceUsesWithIf (std::get<0 >(tup), [&](OpOperand &op) {
465- return ifOp.getThenRegion ().isAncestor (op.getOwner ()->getParentRegion ());
470+ return ifOp.getThenRegion ().isAncestor (
471+ op.getOwner ()->getParentRegion ());
466472 });
467- }
473+ }
468474
469- rewriter.mergeBlockBefore (ifOp.thenBlock (), &newParallel.getBody ()->back ());
470- rewriter.eraseOp (&newParallel.getBody ()->back ());
475+ rewriter.mergeBlockBefore (ifOp.thenBlock (), &newParallel.getBody ()->back ());
476+ rewriter.eraseOp (&newParallel.getBody ()->back ());
471477 }
472478
473479 if (ifOp.getElseRegion ().getBlocks ().size () > 0 ) {
474- OpBuilder::InsertionGuard guard (rewriter);
475- rewriter.setInsertionPointToStart (newIf.elseBlock ());
476- auto newParallel = rewriter.create <scf::ParallelOp>(
477- op.getLoc (), op.getLowerBound (), op.getUpperBound (), op.getStep ());
480+ OpBuilder::InsertionGuard guard (rewriter);
481+ rewriter.setInsertionPointToStart (newIf.elseBlock ());
482+ auto newParallel = rewriter.create <scf::ParallelOp>(
483+ op.getLoc (), op.getLowerBound (), op.getUpperBound (), op.getStep ());
478484
479- for (auto tup : llvm::zip (newParallel.getInductionVars (), op.getInductionVars ())) {
485+ for (auto tup :
486+ llvm::zip (newParallel.getInductionVars (), op.getInductionVars ())) {
480487 std::get<1 >(tup).replaceUsesWithIf (std::get<0 >(tup), [&](OpOperand &op) {
481- return ifOp.getElseRegion ().isAncestor (op.getOwner ()->getParentRegion ());
488+ return ifOp.getElseRegion ().isAncestor (
489+ op.getOwner ()->getParentRegion ());
482490 });
483- }
484- rewriter.mergeBlockBefore (ifOp.elseBlock (), &newParallel.getBody ()->back ());
485- rewriter.eraseOp (&newParallel.getBody ()->back ());
491+ }
492+ rewriter.mergeBlockBefore (ifOp.elseBlock (), &newParallel.getBody ()->back ());
493+ rewriter.eraseOp (&newParallel.getBody ()->back ());
486494 }
487495
488496 rewriter.eraseOp (ifOp);
@@ -518,8 +526,9 @@ struct InterchangeIfPFor : public OpRewritePattern<scf::ParallelOp> {
518526 return failure ();
519527 }
520528
521- auto newIf =
522- rewriter.create <scf::IfOp>(ifOp.getLoc (), TypeRange (), ifOp.getCondition (), ifOp.getElseRegion ().getBlocks ().size () > 0 );
529+ auto newIf = rewriter.create <scf::IfOp>(
530+ ifOp.getLoc (), TypeRange (), ifOp.getCondition (),
531+ ifOp.getElseRegion ().getBlocks ().size () > 0 );
523532 moveBodies (rewriter, op, ifOp, newIf);
524533 return success ();
525534 }
@@ -563,9 +572,10 @@ struct InterchangeIfPForLoad : public OpRewritePattern<scf::ParallelOp> {
563572 Value condition = rewriter.create <memref::LoadOp>(
564573 loadOp.getLoc (), loadOp.getMemRef (),
565574 SmallVector<Value>(loadOp.getMemRefType ().getRank (), zero));
566-
575+
567576 auto newIf =
568- rewriter.create <scf::IfOp>(ifOp.getLoc (), TypeRange (), condition, ifOp.getElseRegion ().getBlocks ().size () > 0 );
577+ rewriter.create <scf::IfOp>(ifOp.getLoc (), TypeRange (), condition,
578+ ifOp.getElseRegion ().getBlocks ().size () > 0 );
569579 moveBodies (rewriter, op, ifOp, newIf);
570580 return success ();
571581 }
@@ -1072,9 +1082,11 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
10721082 allocated.reserve (op.getNumIterOperands ());
10731083 for (Value operand : op.getIterOperands ()) {
10741084 Value alloc = rewriter.create <memref::AllocaOp>(
1075- op.getLoc (), MemRefType::get (ArrayRef<int64_t >(), operand.getType ()), ValueRange ());
1085+ op.getLoc (), MemRefType::get (ArrayRef<int64_t >(), operand.getType ()),
1086+ ValueRange ());
10761087 allocated.push_back (alloc);
1077- rewriter.create <memref::StoreOp>(op.getLoc (), operand, alloc, ValueRange ());
1088+ rewriter.create <memref::StoreOp>(op.getLoc (), operand, alloc,
1089+ ValueRange ());
10781090 }
10791091
10801092 auto newOp = rewriter.create <scf::ForOp>(op.getLoc (), op.getLowerBound (),
@@ -1098,7 +1110,8 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
10981110 rewriter.setInsertionPointAfter (op);
10991111 SmallVector<Value> loaded;
11001112 for (Value alloc : allocated) {
1101- loaded.push_back (rewriter.create <memref::LoadOp>(op.getLoc (), alloc, ValueRange ()));
1113+ loaded.push_back (
1114+ rewriter.create <memref::LoadOp>(op.getLoc (), alloc, ValueRange ()));
11021115 }
11031116 rewriter.replaceOp (op, loaded);
11041117 return success ();
@@ -1112,25 +1125,26 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
11121125 PatternRewriter &rewriter) const override {
11131126 if (!op.getResults ().size () || !hasNestedBarrier (op))
11141127 return failure ();
1115-
11161128
11171129 SmallVector<Value> allocated;
11181130 allocated.reserve (op.getNumResults ());
11191131 for (Type opType : op.getResultTypes ()) {
11201132 Value alloc = rewriter.create <memref::AllocaOp>(
1121- op.getLoc (), MemRefType::get (ArrayRef<int64_t >(), opType), ValueRange ());
1133+ op.getLoc (), MemRefType::get (ArrayRef<int64_t >(), opType),
1134+ ValueRange ());
11221135 allocated.push_back (alloc);
11231136 }
1124-
1125- auto newOp = rewriter.create <scf::IfOp>(op.getLoc (), TypeRange (), op.getCondition (), true );
1137+
1138+ auto newOp = rewriter.create <scf::IfOp>(op.getLoc (), TypeRange (),
1139+ op.getCondition (), true );
11261140
11271141 rewriter.setInsertionPoint (op.thenYield ());
11281142 for (auto en : llvm::enumerate (op.thenYield ().getOperands ())) {
11291143 rewriter.create <memref::StoreOp>(op.getLoc (), en.value (),
11301144 allocated[en.index ()], ValueRange ());
11311145 }
11321146 op.thenYield ()->setOperands (ValueRange ());
1133-
1147+
11341148 rewriter.setInsertionPoint (op.elseYield ());
11351149 for (auto en : llvm::enumerate (op.elseYield ().getOperands ())) {
11361150 rewriter.create <memref::StoreOp>(op.getLoc (), en.value (),
@@ -1140,14 +1154,15 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
11401154
11411155 rewriter.eraseOp (&newOp.thenBlock ()->back ());
11421156 rewriter.mergeBlocks (op.thenBlock (), newOp.thenBlock ());
1143-
1157+
11441158 rewriter.eraseOp (&newOp.elseBlock ()->back ());
11451159 rewriter.mergeBlocks (op.elseBlock (), newOp.elseBlock ());
11461160
11471161 rewriter.setInsertionPointAfter (op);
11481162 SmallVector<Value> loaded;
11491163 for (Value alloc : allocated) {
1150- loaded.push_back (rewriter.create <memref::LoadOp>(op.getLoc (), alloc, ValueRange ()));
1164+ loaded.push_back (
1165+ rewriter.create <memref::LoadOp>(op.getLoc (), alloc, ValueRange ()));
11511166 }
11521167 rewriter.replaceOp (op, loaded);
11531168 return success ();
@@ -1157,16 +1172,19 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
11571172static void storeValues (Location loc, ValueRange values, ValueRange pointers,
11581173 PatternRewriter &rewriter) {
11591174 for (auto pair : llvm::zip (values, pointers)) {
1160- rewriter.create <memref::StoreOp>(loc, std::get<0 >(pair), std::get<1 >(pair), ValueRange ());
1175+ rewriter.create <memref::StoreOp>(loc, std::get<0 >(pair), std::get<1 >(pair),
1176+ ValueRange ());
11611177 }
11621178}
11631179
1164- static void allocaValues (Location loc, ValueRange values, PatternRewriter &rewriter,
1180+ static void allocaValues (Location loc, ValueRange values,
1181+ PatternRewriter &rewriter,
11651182 SmallVector<Value> &allocated) {
11661183 allocated.reserve (values.size ());
11671184 for (Value value : values) {
11681185 Value alloc = rewriter.create <memref::AllocaOp>(
1169- loc, MemRefType::get (ArrayRef<int64_t >(), value.getType ()), ValueRange ());
1186+ loc, MemRefType::get (ArrayRef<int64_t >(), value.getType ()),
1187+ ValueRange ());
11701188 allocated.push_back (alloc);
11711189 }
11721190}
@@ -1184,8 +1202,7 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
11841202 // Value stackPtr = rewriter.create<LLVM::StackSaveOp>(
11851203 // op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)));
11861204 SmallVector<Value> beforeAllocated, afterAllocated;
1187- allocaValues (op.getLoc (), op.getOperands (), rewriter,
1188- beforeAllocated);
1205+ allocaValues (op.getLoc (), op.getOperands (), rewriter, beforeAllocated);
11891206 storeValues (op.getLoc (), op.getOperands (), beforeAllocated, rewriter);
11901207 allocaValues (op.getLoc (), op.getResults (), rewriter, afterAllocated);
11911208
@@ -1194,8 +1211,7 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
11941211 Block *newBefore =
11951212 rewriter.createBlock (&newOp.getBefore (), newOp.getBefore ().begin ());
11961213 SmallVector<Value> newBeforeArguments;
1197- loadValues (op.getLoc (), beforeAllocated, rewriter,
1198- newBeforeArguments);
1214+ loadValues (op.getLoc (), beforeAllocated, rewriter, newBeforeArguments);
11991215 rewriter.mergeBlocks (&op.getBefore ().front (), newBefore,
12001216 newBeforeArguments);
12011217
@@ -1240,12 +1256,10 @@ struct CPUifyPass : public SCFCPUifyBase<CPUifyPass> {
12401256 OwningRewritePatternList patterns (&getContext ());
12411257 patterns
12421258 .insert <Reg2MemFor, Reg2MemWhile, Reg2MemIf,
1243- // ReplaceIfWithFors,
1244- WrapForWithBarrier, WrapIfWithBarrier,
1245- WrapWhileWithBarrier,
1246- InterchangeForPFor, InterchangeForPForLoad,
1247- InterchangeIfPFor, InterchangeIfPForLoad,
1248- InterchangeWhilePFor, NormalizeLoop,
1259+ // ReplaceIfWithFors,
1260+ WrapForWithBarrier, WrapIfWithBarrier, WrapWhileWithBarrier,
1261+ InterchangeForPFor, InterchangeForPForLoad, InterchangeIfPFor,
1262+ InterchangeIfPForLoad, InterchangeWhilePFor, NormalizeLoop,
12491263 NormalizeParallel, RotateWhile, DistributeAroundBarrier>(
12501264 &getContext ());
12511265 GreedyRewriteConfig config;
0 commit comments