1010#include " clang-mlir.h"
1111#include " mlir/IR/Diagnostics.h"
1212#include < mlir/Dialect/Arithmetic/IR/Arithmetic.h>
13+ #include < mlir/Dialect/OpenMP/OpenMPDialect.h>
1314#include < mlir/Dialect/SCF/SCF.h>
1415
1516#define DEBUG_TYPE " CGStmt"
@@ -283,67 +284,249 @@ ValueCategory MLIRScanner::VisitForStmt(clang::ForStmt *fors) {
283284 return nullptr ;
284285}
285286
287+ ValueCategory
288+ MLIRScanner::VisitOMPSingleDirective (clang::OMPSingleDirective *par) {
289+ IfScope scope (*this );
290+
291+ builder.create <omp::BarrierOp>(loc);
292+ auto affineOp = builder.create <omp::MasterOp>(loc);
293+ builder.create <omp::BarrierOp>(loc);
294+
295+ auto oldpoint = builder.getInsertionPoint ();
296+ auto oldblock = builder.getInsertionBlock ();
297+
298+ affineOp.getRegion ().push_back (new Block ());
299+ builder.setInsertionPointToStart (&affineOp.getRegion ().front ());
300+
301+ auto executeRegion =
302+ builder.create <scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
303+ executeRegion.getRegion ().push_back (new Block ());
304+ builder.create <omp::TerminatorOp>(loc);
305+ builder.setInsertionPointToStart (&executeRegion.getRegion ().back ());
306+
307+ auto oldScope = allocationScope;
308+ allocationScope = &executeRegion.getRegion ().back ();
309+
310+ Visit (cast<CapturedStmt>(par->getAssociatedStmt ())
311+ ->getCapturedDecl ()
312+ ->getBody ());
313+
314+ builder.create <scf::YieldOp>(loc);
315+ allocationScope = oldScope;
316+ builder.setInsertionPoint (oldblock, oldpoint);
317+ return nullptr ;
318+ }
319+
320+ ValueCategory MLIRScanner::VisitOMPForDirective (clang::OMPForDirective *fors) {
321+ IfScope scope (*this );
322+
323+ if (fors->getPreInits ()) {
324+ Visit (fors->getPreInits ());
325+ }
326+
327+ SmallVector<mlir::Value> inits;
328+ for (auto f : fors->inits ()) {
329+ assert (f);
330+ f = cast<clang::BinaryOperator>(f)->getRHS ();
331+ inits.push_back (builder.create <IndexCastOp>(loc, Visit (f).getValue (builder),
332+ builder.getIndexType ()));
333+ }
334+
335+ SmallVector<mlir::Value> finals;
336+ for (auto f : fors->finals ()) {
337+ f = cast<clang::BinaryOperator>(f)->getRHS ();
338+ finals.push_back (builder.create <IndexCastOp>(
339+ loc, Visit (f).getValue (builder), builder.getIndexType ()));
340+ }
341+
342+ SmallVector<mlir::Value> incs;
343+ for (auto f : fors->updates ()) {
344+ f = cast<clang::BinaryOperator>(f)->getRHS ();
345+ while (auto ce = dyn_cast<clang::CastExpr>(f))
346+ f = ce->getSubExpr ();
347+ auto bo = cast<clang::BinaryOperator>(f);
348+ assert (bo->getOpcode () == clang::BinaryOperator::Opcode::BO_Add);
349+ f = bo->getRHS ();
350+ while (auto ce = dyn_cast<clang::CastExpr>(f))
351+ f = ce->getSubExpr ();
352+ bo = cast<clang::BinaryOperator>(f);
353+ assert (bo->getOpcode () == clang::BinaryOperator::Opcode::BO_Mul);
354+ f = bo->getRHS ();
355+ incs.push_back (builder.create <IndexCastOp>(loc, Visit (f).getValue (builder),
356+ builder.getIndexType ()));
357+ }
358+
359+ auto affineOp = builder.create <omp::WsLoopOp>(loc, inits, finals, incs);
360+ affineOp.getRegion ().push_back (new Block ());
361+ for (auto init : inits)
362+ affineOp.getRegion ().front ().addArgument (init.getType ());
363+ auto inds = affineOp.getRegion ().front ().getArguments ();
364+
365+ auto oldpoint = builder.getInsertionPoint ();
366+ auto oldblock = builder.getInsertionBlock ();
367+
368+ builder.setInsertionPointToStart (&affineOp.getRegion ().front ());
369+
370+ auto executeRegion =
371+ builder.create <scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
372+ builder.create <omp::YieldOp>(loc, ValueRange ());
373+ executeRegion.getRegion ().push_back (new Block ());
374+ builder.setInsertionPointToStart (&executeRegion.getRegion ().back ());
375+
376+ auto oldScope = allocationScope;
377+ allocationScope = &executeRegion.getRegion ().back ();
378+
379+ std::map<VarDecl *, ValueCategory> prevInduction;
380+ for (auto zp : zip (inds, fors->counters ())) {
381+ auto idx = builder.create <IndexCastOp>(
382+ loc, std::get<0 >(zp),
383+ getMLIRType (fors->getIterationVariable ()->getType ()));
384+ VarDecl *name =
385+ cast<VarDecl>(cast<DeclRefExpr>(std::get<1 >(zp))->getDecl ());
386+
387+ if (params.find (name) != params.end ()) {
388+ prevInduction[name] = params[name];
389+ params.erase (name);
390+ }
391+
392+ bool LLVMABI = false ;
393+ bool isArray = false ;
394+ if (Glob.getMLIRType (
395+ Glob.CGM .getContext ().getLValueReferenceType (name->getType ()))
396+ .isa <mlir::LLVM::LLVMPointerType>())
397+ LLVMABI = true ;
398+ else
399+ Glob.getMLIRType (name->getType (), &isArray);
400+
401+ auto allocop = createAllocOp (idx.getType (), name, /* memtype*/ 0 ,
402+ /* isArray*/ isArray, /* LLVMABI*/ LLVMABI);
403+ params[name] = ValueCategory (allocop, true );
404+ params[name].store (builder, idx);
405+ }
406+
407+ // TODO: set loop context.
408+ Visit (fors->getBody ());
409+
410+ builder.create <scf::YieldOp>(loc, ValueRange ());
411+
412+ allocationScope = oldScope;
413+
414+ // TODO: set the value of the iteration value to the final bound at the
415+ // end of the loop.
416+ builder.setInsertionPoint (oldblock, oldpoint);
417+
418+ for (auto pair : prevInduction)
419+ params[pair.first ] = pair.second ;
420+
421+ return nullptr ;
422+ }
423+
424+ ValueCategory
425+ MLIRScanner::VisitOMPParallelDirective (clang::OMPParallelDirective *par) {
426+ IfScope scope (*this );
427+
428+ auto affineOp = builder.create <omp::ParallelOp>(loc);
429+
430+ auto oldpoint = builder.getInsertionPoint ();
431+ auto oldblock = builder.getInsertionBlock ();
432+
433+ affineOp.getRegion ().push_back (new Block ());
434+ builder.setInsertionPointToStart (&affineOp.getRegion ().front ());
435+
436+ auto executeRegion =
437+ builder.create <scf::ExecuteRegionOp>(loc, ArrayRef<mlir::Type>());
438+ executeRegion.getRegion ().push_back (new Block ());
439+ builder.create <omp::TerminatorOp>(loc);
440+ builder.setInsertionPointToStart (&executeRegion.getRegion ().back ());
441+
442+ auto oldScope = allocationScope;
443+ allocationScope = &executeRegion.getRegion ().back ();
444+
445+ std::map<VarDecl *, ValueCategory> prevInduction;
446+ for (auto f : par->clauses ()) {
447+ switch (f->getClauseKind ()) {
448+ case llvm::omp::OMPC_private:
449+ for (auto stmt : f->children ()) {
450+ VarDecl *name = cast<VarDecl>(cast<DeclRefExpr>(stmt)->getDecl ());
451+
452+ prevInduction[name] = params[name];
453+ params.erase (name);
454+
455+ bool LLVMABI = false ;
456+ bool isArray = false ;
457+ mlir::Type ty;
458+ if (Glob.getMLIRType (Glob.CGM .getContext ().getLValueReferenceType (
459+ name->getType ()))
460+ .isa <mlir::LLVM::LLVMPointerType>()) {
461+ LLVMABI = true ;
462+ bool undef;
463+ ty = Glob.getMLIRType (name->getType (), &undef);
464+ } else
465+ ty = Glob.getMLIRType (name->getType (), &isArray);
466+
467+ auto allocop = createAllocOp (ty, name, /* memtype*/ 0 ,
468+ /* isArray*/ isArray, /* LLVMABI*/ LLVMABI);
469+ params[name] = ValueCategory (allocop, true );
470+ params[name].store (builder, prevInduction[name], isArray);
471+ }
472+ break ;
473+ default :
474+ llvm::errs () << " may not handle omp clause " << (int )f->getClauseKind ()
475+ << " \n " ;
476+ }
477+ }
478+
479+ Visit (cast<CapturedStmt>(par->getAssociatedStmt ())
480+ ->getCapturedDecl ()
481+ ->getBody ());
482+
483+ builder.create <scf::YieldOp>(loc);
484+ allocationScope = oldScope;
485+ builder.setInsertionPoint (oldblock, oldpoint);
486+
487+ for (auto pair : prevInduction)
488+ params[pair.first ] = pair.second ;
489+ return nullptr ;
490+ }
491+
286492ValueCategory MLIRScanner::VisitOMPParallelForDirective (
287493 clang::OMPParallelForDirective *fors) {
288494 IfScope scope (*this );
289495
290- Visit (fors->getPreInits ());
496+ if (fors->getPreInits ()) {
497+ Visit (fors->getPreInits ());
498+ }
291499
292500 SmallVector<mlir::Value> inits;
293501 for (auto f : fors->inits ()) {
502+ assert (f);
294503 f = cast<clang::BinaryOperator>(f)->getRHS ();
295- if (auto ce = dyn_cast<CastExpr>(f))
296- f = ce->getSubExpr ();
297- auto initV =
298- cast<OMPCapturedExprDecl>(cast<DeclRefExpr>(f)->getDecl ())->getInit ();
299- inits.push_back (builder.create <IndexCastOp>(
300- loc, Visit (initV).getValue (builder), builder.getIndexType ()));
504+ inits.push_back (builder.create <IndexCastOp>(loc, Visit (f).getValue (builder),
505+ builder.getIndexType ()));
301506 }
302507
303508 SmallVector<mlir::Value> finals;
304509 for (auto f : fors->finals ()) {
305510 f = cast<clang::BinaryOperator>(f)->getRHS ();
306- if (auto ce = dyn_cast<CastExpr>(f))
307- f = ce->getSubExpr ();
308- auto bo =
309- cast<clang::BinaryOperator>(cast<clang::BinaryOperator>(f)->getRHS ());
310- auto bo2 = cast<clang::BinaryOperator>(
311- cast<clang::BinaryOperator>(
312- cast<clang::BinaryOperator>(
313- cast<ParenExpr>(cast<clang::BinaryOperator>(
314- cast<ParenExpr>(bo->getLHS ())->getSubExpr ())
315- ->getLHS ())
316- ->getSubExpr ())
317- ->getLHS ())
318- ->getLHS ());
319- auto rhs = cast<OMPCapturedExprDecl>(
320- cast<DeclRefExpr>(
321- cast<clang::CastExpr>(
322- cast<ParenExpr>(
323- cast<clang::CastExpr>(bo2->getLHS ())->getSubExpr ())
324- ->getSubExpr ())
325- ->getSubExpr ())
326- ->getDecl ());
327511 finals.push_back (builder.create <IndexCastOp>(
328- loc, Visit (rhs-> getInit () ).getValue (builder), builder.getIndexType ()));
512+ loc, Visit (f ).getValue (builder), builder.getIndexType ()));
329513 }
330514
331515 SmallVector<mlir::Value> incs;
332516 for (auto f : fors->updates ()) {
333- auto bo = cast<clang::BinaryOperator>(
334- cast<clang::CastExpr>(cast<clang::BinaryOperator>(f)->getRHS ())
335- ->getSubExpr ());
336- auto rhs = cast<OMPCapturedExprDecl>(
337- cast<DeclRefExpr>(
338- cast<clang::CastExpr>(
339- cast<clang::CastExpr>(
340- cast<clang::BinaryOperator>(bo->getRHS ())->getRHS ())
341- ->getSubExpr ())
342- ->getSubExpr ())
343- ->getDecl ());
344-
345- incs.push_back (builder.create <IndexCastOp>(
346- loc, Visit (rhs->getInit ()).getValue (builder), builder.getIndexType ()));
517+ f = cast<clang::BinaryOperator>(f)->getRHS ();
518+ while (auto ce = dyn_cast<clang::CastExpr>(f))
519+ f = ce->getSubExpr ();
520+ auto bo = cast<clang::BinaryOperator>(f);
521+ assert (bo->getOpcode () == clang::BinaryOperator::Opcode::BO_Add);
522+ f = bo->getRHS ();
523+ while (auto ce = dyn_cast<clang::CastExpr>(f))
524+ f = ce->getSubExpr ();
525+ bo = cast<clang::BinaryOperator>(f);
526+ assert (bo->getOpcode () == clang::BinaryOperator::Opcode::BO_Mul);
527+ f = bo->getRHS ();
528+ incs.push_back (builder.create <IndexCastOp>(loc, Visit (f).getValue (builder),
529+ builder.getIndexType ()));
347530 }
348531
349532 auto affineOp = builder.create <scf::ParallelOp>(loc, inits, finals, incs);
@@ -363,14 +546,18 @@ ValueCategory MLIRScanner::VisitOMPParallelForDirective(
363546 auto oldScope = allocationScope;
364547 allocationScope = &executeRegion.getRegion ().back ();
365548
549+ std::map<VarDecl *, ValueCategory> prevInduction;
366550 for (auto zp : zip (inds, fors->counters ())) {
367551 auto idx = builder.create <IndexCastOp>(
368552 loc, std::get<0 >(zp),
369553 getMLIRType (fors->getIterationVariable ()->getType ()));
370554 VarDecl *name =
371555 cast<VarDecl>(cast<DeclRefExpr>(std::get<1 >(zp))->getDecl ());
372- assert (params.find (name) == params.end () &&
373- " OpenMP induction variable is dual initialized" );
556+
557+ if (params.find (name) != params.end ()) {
558+ prevInduction[name] = params[name];
559+ params.erase (name);
560+ }
374561
375562 bool LLVMABI = false ;
376563 bool isArray = false ;
@@ -397,6 +584,10 @@ ValueCategory MLIRScanner::VisitOMPParallelForDirective(
397584 // TODO: set the value of the iteration value to the final bound at the
398585 // end of the loop.
399586 builder.setInsertionPoint (oldblock, oldpoint);
587+
588+ for (auto pair : prevInduction)
589+ params[pair.first ] = pair.second ;
590+
400591 return nullptr ;
401592}
402593
0 commit comments