Skip to content

Commit be095b6

Browse files
committed
Extend OpenMP support
1 parent cdf36f5 commit be095b6

File tree

7 files changed

+315
-46
lines changed

7 files changed

+315
-46
lines changed

llvm-project

tools/mlir-clang/Lib/CGStmt.cc

Lines changed: 236 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "clang-mlir.h"
1111
#include "mlir/IR/Diagnostics.h"
1212
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
13+
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
1314
#include <mlir/Dialect/SCF/SCF.h>
1415

1516
#define DEBUG_TYPE "CGStmt"
@@ -283,67 +284,249 @@ ValueCategory MLIRScanner::VisitForStmt(clang::ForStmt *fors) {
283284
return nullptr;
284285
}
285286

287+
ValueCategory
288+
MLIRScanner::VisitOMPSingleDirective(clang::OMPSingleDirective *par) {
289+
IfScope scope(*this);
290+
291+
builder.create<omp::BarrierOp>(loc);
292+
auto affineOp = builder.create<omp::MasterOp>(loc);
293+
builder.create<omp::BarrierOp>(loc);
294+
295+
auto oldpoint = builder.getInsertionPoint();
296+
auto oldblock = builder.getInsertionBlock();
297+
298+
affineOp.getRegion().push_back(new Block());
299+
builder.setInsertionPointToStart(&affineOp.getRegion().front());
300+
301+
auto executeRegion =
302+
builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
303+
executeRegion.getRegion().push_back(new Block());
304+
builder.create<omp::TerminatorOp>(loc);
305+
builder.setInsertionPointToStart(&executeRegion.getRegion().back());
306+
307+
auto oldScope = allocationScope;
308+
allocationScope = &executeRegion.getRegion().back();
309+
310+
Visit(cast<CapturedStmt>(par->getAssociatedStmt())
311+
->getCapturedDecl()
312+
->getBody());
313+
314+
builder.create<scf::YieldOp>(loc);
315+
allocationScope = oldScope;
316+
builder.setInsertionPoint(oldblock, oldpoint);
317+
return nullptr;
318+
}
319+
320+
ValueCategory MLIRScanner::VisitOMPForDirective(clang::OMPForDirective *fors) {
321+
IfScope scope(*this);
322+
323+
if (fors->getPreInits()) {
324+
Visit(fors->getPreInits());
325+
}
326+
327+
SmallVector<mlir::Value> inits;
328+
for (auto f : fors->inits()) {
329+
assert(f);
330+
f = cast<clang::BinaryOperator>(f)->getRHS();
331+
inits.push_back(builder.create<IndexCastOp>(loc, Visit(f).getValue(builder),
332+
builder.getIndexType()));
333+
}
334+
335+
SmallVector<mlir::Value> finals;
336+
for (auto f : fors->finals()) {
337+
f = cast<clang::BinaryOperator>(f)->getRHS();
338+
finals.push_back(builder.create<IndexCastOp>(
339+
loc, Visit(f).getValue(builder), builder.getIndexType()));
340+
}
341+
342+
SmallVector<mlir::Value> incs;
343+
for (auto f : fors->updates()) {
344+
f = cast<clang::BinaryOperator>(f)->getRHS();
345+
while (auto ce = dyn_cast<clang::CastExpr>(f))
346+
f = ce->getSubExpr();
347+
auto bo = cast<clang::BinaryOperator>(f);
348+
assert(bo->getOpcode() == clang::BinaryOperator::Opcode::BO_Add);
349+
f = bo->getRHS();
350+
while (auto ce = dyn_cast<clang::CastExpr>(f))
351+
f = ce->getSubExpr();
352+
bo = cast<clang::BinaryOperator>(f);
353+
assert(bo->getOpcode() == clang::BinaryOperator::Opcode::BO_Mul);
354+
f = bo->getRHS();
355+
incs.push_back(builder.create<IndexCastOp>(loc, Visit(f).getValue(builder),
356+
builder.getIndexType()));
357+
}
358+
359+
auto affineOp = builder.create<omp::WsLoopOp>(loc, inits, finals, incs);
360+
affineOp.getRegion().push_back(new Block());
361+
for (auto init : inits)
362+
affineOp.getRegion().front().addArgument(init.getType());
363+
auto inds = affineOp.getRegion().front().getArguments();
364+
365+
auto oldpoint = builder.getInsertionPoint();
366+
auto oldblock = builder.getInsertionBlock();
367+
368+
builder.setInsertionPointToStart(&affineOp.getRegion().front());
369+
370+
auto executeRegion =
371+
builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
372+
builder.create<omp::YieldOp>(loc, ValueRange());
373+
executeRegion.getRegion().push_back(new Block());
374+
builder.setInsertionPointToStart(&executeRegion.getRegion().back());
375+
376+
auto oldScope = allocationScope;
377+
allocationScope = &executeRegion.getRegion().back();
378+
379+
std::map<VarDecl *, ValueCategory> prevInduction;
380+
for (auto zp : zip(inds, fors->counters())) {
381+
auto idx = builder.create<IndexCastOp>(
382+
loc, std::get<0>(zp),
383+
getMLIRType(fors->getIterationVariable()->getType()));
384+
VarDecl *name =
385+
cast<VarDecl>(cast<DeclRefExpr>(std::get<1>(zp))->getDecl());
386+
387+
if (params.find(name) != params.end()) {
388+
prevInduction[name] = params[name];
389+
params.erase(name);
390+
}
391+
392+
bool LLVMABI = false;
393+
bool isArray = false;
394+
if (Glob.getMLIRType(
395+
Glob.CGM.getContext().getLValueReferenceType(name->getType()))
396+
.isa<mlir::LLVM::LLVMPointerType>())
397+
LLVMABI = true;
398+
else
399+
Glob.getMLIRType(name->getType(), &isArray);
400+
401+
auto allocop = createAllocOp(idx.getType(), name, /*memtype*/ 0,
402+
/*isArray*/ isArray, /*LLVMABI*/ LLVMABI);
403+
params[name] = ValueCategory(allocop, true);
404+
params[name].store(builder, idx);
405+
}
406+
407+
// TODO: set loop context.
408+
Visit(fors->getBody());
409+
410+
builder.create<scf::YieldOp>(loc, ValueRange());
411+
412+
allocationScope = oldScope;
413+
414+
// TODO: set the value of the iteration value to the final bound at the
415+
// end of the loop.
416+
builder.setInsertionPoint(oldblock, oldpoint);
417+
418+
for (auto pair : prevInduction)
419+
params[pair.first] = pair.second;
420+
421+
return nullptr;
422+
}
423+
424+
ValueCategory
425+
MLIRScanner::VisitOMPParallelDirective(clang::OMPParallelDirective *par) {
426+
IfScope scope(*this);
427+
428+
auto affineOp = builder.create<omp::ParallelOp>(loc);
429+
430+
auto oldpoint = builder.getInsertionPoint();
431+
auto oldblock = builder.getInsertionBlock();
432+
433+
affineOp.getRegion().push_back(new Block());
434+
builder.setInsertionPointToStart(&affineOp.getRegion().front());
435+
436+
auto executeRegion =
437+
builder.create<scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
438+
executeRegion.getRegion().push_back(new Block());
439+
builder.create<omp::TerminatorOp>(loc);
440+
builder.setInsertionPointToStart(&executeRegion.getRegion().back());
441+
442+
auto oldScope = allocationScope;
443+
allocationScope = &executeRegion.getRegion().back();
444+
445+
std::map<VarDecl *, ValueCategory> prevInduction;
446+
for (auto f : par->clauses()) {
447+
switch (f->getClauseKind()) {
448+
case llvm::omp::OMPC_private:
449+
for (auto stmt : f->children()) {
450+
VarDecl *name = cast<VarDecl>(cast<DeclRefExpr>(stmt)->getDecl());
451+
452+
prevInduction[name] = params[name];
453+
params.erase(name);
454+
455+
bool LLVMABI = false;
456+
bool isArray = false;
457+
mlir::Type ty;
458+
if (Glob.getMLIRType(Glob.CGM.getContext().getLValueReferenceType(
459+
name->getType()))
460+
.isa<mlir::LLVM::LLVMPointerType>()) {
461+
LLVMABI = true;
462+
bool undef;
463+
ty = Glob.getMLIRType(name->getType(), &undef);
464+
} else
465+
ty = Glob.getMLIRType(name->getType(), &isArray);
466+
467+
auto allocop = createAllocOp(ty, name, /*memtype*/ 0,
468+
/*isArray*/ isArray, /*LLVMABI*/ LLVMABI);
469+
params[name] = ValueCategory(allocop, true);
470+
params[name].store(builder, prevInduction[name], isArray);
471+
}
472+
break;
473+
default:
474+
llvm::errs() << "may not handle omp clause " << (int)f->getClauseKind()
475+
<< "\n";
476+
}
477+
}
478+
479+
Visit(cast<CapturedStmt>(par->getAssociatedStmt())
480+
->getCapturedDecl()
481+
->getBody());
482+
483+
builder.create<scf::YieldOp>(loc);
484+
allocationScope = oldScope;
485+
builder.setInsertionPoint(oldblock, oldpoint);
486+
487+
for (auto pair : prevInduction)
488+
params[pair.first] = pair.second;
489+
return nullptr;
490+
}
491+
286492
ValueCategory MLIRScanner::VisitOMPParallelForDirective(
287493
clang::OMPParallelForDirective *fors) {
288494
IfScope scope(*this);
289495

290-
Visit(fors->getPreInits());
496+
if (fors->getPreInits()) {
497+
Visit(fors->getPreInits());
498+
}
291499

292500
SmallVector<mlir::Value> inits;
293501
for (auto f : fors->inits()) {
502+
assert(f);
294503
f = cast<clang::BinaryOperator>(f)->getRHS();
295-
if (auto ce = dyn_cast<CastExpr>(f))
296-
f = ce->getSubExpr();
297-
auto initV =
298-
cast<OMPCapturedExprDecl>(cast<DeclRefExpr>(f)->getDecl())->getInit();
299-
inits.push_back(builder.create<IndexCastOp>(
300-
loc, Visit(initV).getValue(builder), builder.getIndexType()));
504+
inits.push_back(builder.create<IndexCastOp>(loc, Visit(f).getValue(builder),
505+
builder.getIndexType()));
301506
}
302507

303508
SmallVector<mlir::Value> finals;
304509
for (auto f : fors->finals()) {
305510
f = cast<clang::BinaryOperator>(f)->getRHS();
306-
if (auto ce = dyn_cast<CastExpr>(f))
307-
f = ce->getSubExpr();
308-
auto bo =
309-
cast<clang::BinaryOperator>(cast<clang::BinaryOperator>(f)->getRHS());
310-
auto bo2 = cast<clang::BinaryOperator>(
311-
cast<clang::BinaryOperator>(
312-
cast<clang::BinaryOperator>(
313-
cast<ParenExpr>(cast<clang::BinaryOperator>(
314-
cast<ParenExpr>(bo->getLHS())->getSubExpr())
315-
->getLHS())
316-
->getSubExpr())
317-
->getLHS())
318-
->getLHS());
319-
auto rhs = cast<OMPCapturedExprDecl>(
320-
cast<DeclRefExpr>(
321-
cast<clang::CastExpr>(
322-
cast<ParenExpr>(
323-
cast<clang::CastExpr>(bo2->getLHS())->getSubExpr())
324-
->getSubExpr())
325-
->getSubExpr())
326-
->getDecl());
327511
finals.push_back(builder.create<IndexCastOp>(
328-
loc, Visit(rhs->getInit()).getValue(builder), builder.getIndexType()));
512+
loc, Visit(f).getValue(builder), builder.getIndexType()));
329513
}
330514

331515
SmallVector<mlir::Value> incs;
332516
for (auto f : fors->updates()) {
333-
auto bo = cast<clang::BinaryOperator>(
334-
cast<clang::CastExpr>(cast<clang::BinaryOperator>(f)->getRHS())
335-
->getSubExpr());
336-
auto rhs = cast<OMPCapturedExprDecl>(
337-
cast<DeclRefExpr>(
338-
cast<clang::CastExpr>(
339-
cast<clang::CastExpr>(
340-
cast<clang::BinaryOperator>(bo->getRHS())->getRHS())
341-
->getSubExpr())
342-
->getSubExpr())
343-
->getDecl());
344-
345-
incs.push_back(builder.create<IndexCastOp>(
346-
loc, Visit(rhs->getInit()).getValue(builder), builder.getIndexType()));
517+
f = cast<clang::BinaryOperator>(f)->getRHS();
518+
while (auto ce = dyn_cast<clang::CastExpr>(f))
519+
f = ce->getSubExpr();
520+
auto bo = cast<clang::BinaryOperator>(f);
521+
assert(bo->getOpcode() == clang::BinaryOperator::Opcode::BO_Add);
522+
f = bo->getRHS();
523+
while (auto ce = dyn_cast<clang::CastExpr>(f))
524+
f = ce->getSubExpr();
525+
bo = cast<clang::BinaryOperator>(f);
526+
assert(bo->getOpcode() == clang::BinaryOperator::Opcode::BO_Mul);
527+
f = bo->getRHS();
528+
incs.push_back(builder.create<IndexCastOp>(loc, Visit(f).getValue(builder),
529+
builder.getIndexType()));
347530
}
348531

349532
auto affineOp = builder.create<scf::ParallelOp>(loc, inits, finals, incs);
@@ -363,14 +546,18 @@ ValueCategory MLIRScanner::VisitOMPParallelForDirective(
363546
auto oldScope = allocationScope;
364547
allocationScope = &executeRegion.getRegion().back();
365548

549+
std::map<VarDecl *, ValueCategory> prevInduction;
366550
for (auto zp : zip(inds, fors->counters())) {
367551
auto idx = builder.create<IndexCastOp>(
368552
loc, std::get<0>(zp),
369553
getMLIRType(fors->getIterationVariable()->getType()));
370554
VarDecl *name =
371555
cast<VarDecl>(cast<DeclRefExpr>(std::get<1>(zp))->getDecl());
372-
assert(params.find(name) == params.end() &&
373-
"OpenMP induction variable is dual initialized");
556+
557+
if (params.find(name) != params.end()) {
558+
prevInduction[name] = params[name];
559+
params.erase(name);
560+
}
374561

375562
bool LLVMABI = false;
376563
bool isArray = false;
@@ -397,6 +584,10 @@ ValueCategory MLIRScanner::VisitOMPParallelForDirective(
397584
// TODO: set the value of the iteration value to the final bound at the
398585
// end of the loop.
399586
builder.setInsertionPoint(oldblock, oldpoint);
587+
588+
for (auto pair : prevInduction)
589+
params[pair.first] = pair.second;
590+
400591
return nullptr;
401592
}
402593

tools/mlir-clang/Lib/clang-mlir.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ class MLIRScanner : public StmtVisitor<MLIRScanner, ValueCategory> {
253253

254254
ValueCategory VisitForStmt(clang::ForStmt *fors);
255255

256+
ValueCategory VisitOMPSingleDirective(clang::OMPSingleDirective *);
257+
258+
ValueCategory VisitOMPForDirective(clang::OMPForDirective *);
259+
260+
ValueCategory VisitOMPParallelDirective(clang::OMPParallelDirective *);
261+
256262
ValueCategory
257263
VisitOMPParallelForDirective(clang::OMPParallelForDirective *fors);
258264

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-clang %s --function=* -fopenmp -S | FileCheck %s
2+
3+
void square(double* x) {
4+
int i;
5+
#pragma omp parallel for private(i)
6+
for(i=3; i < 10; i+= 2) {
7+
x[i] = i;
8+
i++;
9+
x[i] = i;
10+
}
11+
}
12+
13+
// CHECK: func @square(%arg0: memref<?xf64>, %arg1: i32, %arg2: i32, %arg3: i32) attributes {llvm.linkage = #llvm.linkage<external>} {
14+
// CHECK-NEXT: %0 = arith.index_cast %arg1 : i32 to index
15+
// CHECK-NEXT: %1 = arith.index_cast %arg2 : i32 to index
16+
// CHECK-NEXT: %2 = arith.index_cast %arg3 : i32 to index
17+
// CHECK-NEXT: scf.parallel (%arg4) = (%0) to (%1) step (%2) {
18+
// CHECK-NEXT: %3 = arith.index_cast %arg4 : index to i32
19+
// CHECK-NEXT: %4 = arith.sitofp %3 : i32 to f64
20+
// CHECK-NEXT: memref.store %4, %arg0[%arg4] : memref<?xf64>
21+
// CHECK-NEXT: scf.yield
22+
// CHECK-NEXT: }
23+
// CHECK-NEXT: return
24+
// CHECK-NEXT: }
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-clang %s --function=* -fopenmp -S | FileCheck %s
2+
3+
int get(int);
4+
void square(double* x, int ss) {
5+
int i=7;
6+
#pragma omp parallel for private(i)
7+
for(i=get(ss); i < 10; i+= 2) {
8+
x[i] = i;
9+
i++;
10+
x[i] = i;
11+
}
12+
}
13+
14+
// CHECK: func @square(%arg0: memref<?xf64>, %arg1: i32, %arg2: i32, %arg3: i32) attributes {llvm.linkage = #llvm.linkage<external>} {
15+
// CHECK-NEXT: %0 = arith.index_cast %arg1 : i32 to index
16+
// CHECK-NEXT: %1 = arith.index_cast %arg2 : i32 to index
17+
// CHECK-NEXT: %2 = arith.index_cast %arg3 : i32 to index
18+
// CHECK-NEXT: scf.parallel (%arg4) = (%0) to (%1) step (%2) {
19+
// CHECK-NEXT: %3 = arith.index_cast %arg4 : index to i32
20+
// CHECK-NEXT: %4 = arith.sitofp %3 : i32 to f64
21+
// CHECK-NEXT: memref.store %4, %arg0[%arg4] : memref<?xf64>
22+
// CHECK-NEXT: scf.yield
23+
// CHECK-NEXT: }
24+
// CHECK-NEXT: return
25+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)