Skip to content

Commit 92b7f11

Browse files
committed
Fix mem2reg bug
1 parent 4fc9c7a commit 92b7f11

File tree

13 files changed

+454
-100
lines changed

13 files changed

+454
-100
lines changed

include/polygeist/BarrierUtils.h

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef MLIR_LIB_DIALECT_SCF_TRANSFORMS_BARRIERUTILS_H_
1010
#define MLIR_LIB_DIALECT_SCF_TRANSFORMS_BARRIERUTILS_H_
1111

12+
#include "mlir/Analysis/DataLayoutAnalysis.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1214
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1315
#include "mlir/Dialect/SCF/SCF.h"
1416
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -42,10 +44,14 @@ emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) {
4244
return iterationCounts;
4345
}
4446

47+
mlir::LLVM::LLVMFuncOp GetOrCreateMallocFunction(mlir::ModuleOp module);
48+
mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(mlir::ModuleOp module);
49+
4550
template <typename T>
4651
static T allocateTemporaryBuffer(mlir::OpBuilder &rewriter, mlir::Value value,
4752
mlir::ValueRange iterationCounts,
48-
bool alloca = true) {
53+
bool alloca = true,
54+
mlir::DataLayout *DLI = nullptr) {
4955
using namespace mlir;
5056
SmallVector<int64_t> bufferSize(iterationCounts.size(),
5157
ShapedType::kDynamicSize);
@@ -70,4 +76,30 @@ static T allocateTemporaryBuffer(mlir::OpBuilder &rewriter, mlir::Value value,
7076
auto type = MemRefType::get(bufferSize, ty);
7177
return rewriter.create<T>(value.getLoc(), type, iterationCounts);
7278
}
79+
80+
template <>
81+
mlir::LLVM::CallOp allocateTemporaryBuffer<mlir::LLVM::CallOp>(
82+
mlir::OpBuilder &rewriter, mlir::Value value,
83+
mlir::ValueRange iterationCounts, bool alloca, mlir::DataLayout *DLI) {
84+
using namespace mlir;
85+
auto val = value.getDefiningOp<LLVM::AllocaOp>();
86+
auto sz = val.getArraySize();
87+
assert(DLI);
88+
sz = rewriter.create<arith::MulIOp>(
89+
value.getLoc(), sz,
90+
rewriter.create<arith::ConstantIntOp>(
91+
value.getLoc(),
92+
DLI->getTypeSize(
93+
val.getType().cast<LLVM::LLVMPointerType>().getElementType()),
94+
sz.getType().cast<IntegerType>().getWidth()));
95+
for (auto iter : iterationCounts) {
96+
sz =
97+
rewriter.create<arith::MulIOp>(value.getLoc(), sz,
98+
rewriter.create<arith::IndexCastOp>(
99+
value.getLoc(), sz.getType(), iter));
100+
}
101+
auto m = val->getParentOfType<ModuleOp>();
102+
auto allocfn = GetOrCreateMallocFunction(m);
103+
return rewriter.create<LLVM::CallOp>(value.getLoc(), allocfn, sz);
104+
}
73105
#endif // MLIR_LIB_DIALECT_SCF_TRANSFORMS_BARRIERUTILS_H_

lib/polygeist/Ops.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,31 @@ class MetaPointer2Memref final : public OpRewritePattern<Op> {
810810
return failure();
811811

812812
auto mt = src.getType().cast<MemRefType>();
813+
814+
// Fantastic optimization, disabled for now to make a hard debug case easier
815+
// to find.
816+
if (auto before =
817+
src.source().getDefiningOp<polygeist::Memref2PointerOp>()) {
818+
auto mt0 = before.source().getType().cast<MemRefType>();
819+
if (mt0.getElementType() == mt.getElementType()) {
820+
auto sh0 = mt0.getShape();
821+
auto sh = mt.getShape();
822+
if (sh.size() == sh0.size()) {
823+
bool eq = true;
824+
for (size_t i = 1; i < sh.size(); i++) {
825+
if (sh[i] != sh0[i]) {
826+
eq = false;
827+
break;
828+
}
829+
}
830+
if (eq) {
831+
op.memrefMutable().assign(before.source());
832+
return success();
833+
}
834+
}
835+
}
836+
}
837+
813838
for (size_t i = 1; i < mt.getShape().size(); i++)
814839
if (mt.getShape()[i] == ShapedType::kDynamicSize)
815840
return failure();
@@ -905,10 +930,35 @@ void MetaPointer2Memref<AffineStoreOp>::rewrite(
905930
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, op.value(), ptr);
906931
}
907932

933+
// and(x, y) != 0 -> and(x != 0, y != 0)
934+
class CmpAnd final : public OpRewritePattern<arith::CmpIOp> {
935+
public:
936+
using OpRewritePattern<arith::CmpIOp>::OpRewritePattern;
937+
938+
LogicalResult matchAndRewrite(arith::CmpIOp op,
939+
PatternRewriter &rewriter) const override {
940+
auto src = op.getLhs().getDefiningOp<AndIOp>();
941+
if (!src)
942+
return failure();
943+
944+
if (!matchPattern(op.getRhs(), m_Zero()))
945+
return failure();
946+
if (op.getPredicate() != arith::CmpIPredicate::ne)
947+
return failure();
948+
949+
rewriter.replaceOpWithNewOp<arith::AndIOp>(
950+
op,
951+
rewriter.create<arith::CmpIOp>(op.getLoc(), CmpIPredicate::ne,
952+
src.getLhs(), op.getRhs()),
953+
rewriter.create<arith::CmpIOp>(op.getLoc(), CmpIPredicate::ne,
954+
src.getRhs(), op.getRhs()));
955+
return success();
956+
}
957+
};
908958
void Pointer2MemrefOp::getCanonicalizationPatterns(
909959
OwningRewritePatternList &results, MLIRContext *context) {
910960
results.insert<
911-
Pointer2MemrefCast, Pointer2Memref2PointerCast,
961+
CmpAnd, Pointer2MemrefCast, Pointer2Memref2PointerCast,
912962
MetaPointer2Memref<memref::LoadOp>, MetaPointer2Memref<memref::StoreOp>,
913963
MetaPointer2Memref<AffineLoadOp>, MetaPointer2Memref<AffineStoreOp>>(
914964
context);

lib/polygeist/Passes/CanonicalizeFor.cpp

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,36 +1100,67 @@ struct WhileLogicalNegation : public OpRewritePattern<WhileOp> {
11001100

11011101
LogicalResult matchAndRewrite(WhileOp op,
11021102
PatternRewriter &rewriter) const override {
1103-
SmallVector<BlockArgument, 2> origAfterArgs(op.getAfterArguments().begin(),
1104-
op.getAfterArguments().end());
11051103
bool changed = false;
11061104
scf::ConditionOp term =
11071105
cast<scf::ConditionOp>(op.getBefore().front().getTerminator());
1108-
assert(origAfterArgs.size() == op.getResults().size());
1109-
assert(origAfterArgs.size() == term.getArgs().size());
11101106

1111-
if (auto condCmp = term.getCondition().getDefiningOp<CmpIOp>()) {
1112-
for (auto pair :
1113-
llvm::zip(op.getResults(), term.getArgs(), origAfterArgs)) {
1114-
if (!std::get<0>(pair).use_empty()) {
1115-
if (auto termCmp = std::get<1>(pair).getDefiningOp<CmpIOp>()) {
1116-
if (termCmp.getLhs() == condCmp.getLhs() &&
1117-
termCmp.getRhs() == condCmp.getRhs()) {
1118-
// TODO generalize to logical negation of
1119-
if (condCmp.getPredicate() == CmpIPredicate::slt &&
1120-
termCmp.getPredicate() == CmpIPredicate::sge) {
1121-
1122-
rewriter.updateRootInPlace(op, [&] {
1123-
rewriter.setInsertionPoint(op);
1124-
auto truev =
1125-
rewriter.create<ConstantIntOp>(termCmp.getLoc(), true, 1);
1126-
std::get<0>(pair).replaceAllUsesWith(truev);
1127-
});
1128-
changed = true;
1107+
SmallPtrSet<Value, 1> condOps;
1108+
SmallVector<Value> todo = {term.getCondition()};
1109+
while (todo.size()) {
1110+
Value val = todo.back();
1111+
todo.pop_back();
1112+
condOps.insert(val);
1113+
if (auto ao = val.getDefiningOp<AndIOp>()) {
1114+
todo.push_back(ao.getLhs());
1115+
todo.push_back(ao.getRhs());
1116+
}
1117+
}
1118+
1119+
for (auto pair :
1120+
llvm::zip(op.getResults(), term.getArgs(), op.getAfterArguments())) {
1121+
auto termArg = std::get<1>(pair);
1122+
bool afterValue;
1123+
if (condOps.count(termArg)) {
1124+
afterValue = true;
1125+
} else {
1126+
bool found = false;
1127+
if (auto termCmp = termArg.getDefiningOp<arith::CmpIOp>()) {
1128+
for (auto cond : condOps) {
1129+
if (auto condCmp = cond.getDefiningOp<CmpIOp>()) {
1130+
if (termCmp.getLhs() == condCmp.getLhs() &&
1131+
termCmp.getRhs() == condCmp.getRhs()) {
1132+
// TODO generalize to logical negation of
1133+
if (condCmp.getPredicate() == CmpIPredicate::slt &&
1134+
termCmp.getPredicate() == CmpIPredicate::sge) {
1135+
found = true;
1136+
afterValue = false;
1137+
break;
1138+
}
11291139
}
11301140
}
11311141
}
11321142
}
1143+
if (!found)
1144+
continue;
1145+
}
1146+
1147+
if (!std::get<0>(pair).use_empty()) {
1148+
rewriter.updateRootInPlace(op, [&] {
1149+
rewriter.setInsertionPoint(op);
1150+
auto truev =
1151+
rewriter.create<ConstantIntOp>(op.getLoc(), !afterValue, 1);
1152+
std::get<0>(pair).replaceAllUsesWith(truev);
1153+
});
1154+
changed = true;
1155+
}
1156+
if (!std::get<2>(pair).use_empty()) {
1157+
rewriter.updateRootInPlace(op, [&] {
1158+
rewriter.setInsertionPointToStart(&op.getAfter().front());
1159+
auto truev =
1160+
rewriter.create<ConstantIntOp>(op.getLoc(), afterValue, 1);
1161+
std::get<2>(pair).replaceAllUsesWith(truev);
1162+
});
1163+
changed = true;
11331164
}
11341165
}
11351166

lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ struct SubIndexOpLowering : public ConvertOpToLLVMPattern<SubIndexOp> {
4343
ConversionPatternRewriter &rewriter) const override {
4444
auto loc = subViewOp.getLoc();
4545

46+
if (!subViewOp.source().getType().isa<MemRefType>()) {
47+
llvm::errs() << " func: " << subViewOp->getParentOfType<FuncOp>() << "\n";
48+
llvm::errs() << " sub: " << subViewOp << " - " << subViewOp.source()
49+
<< "\n";
50+
}
4651
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
4752

4853
auto viewMemRefType = subViewOp.getType().cast<MemRefType>();

0 commit comments

Comments
 (0)