@@ -208,7 +208,6 @@ void ParallelLower::runOnOperation() {
208208
209209 SymbolTableCollection symbolTable;
210210 symbolTable.getSymbolTable (getOperation ());
211- SymbolUserMap symbolUserMap (symbolTable, getOperation ());
212211
213212 getOperation ()->walk ([&](CallOp bidx) {
214213 if (bidx.getCallee () == " cudaThreadSynchronize" )
@@ -336,52 +335,94 @@ void ParallelLower::runOnOperation() {
336335 callInliner (op);
337336 }
338337
339- // Only supports single block functions at the moment.
338+ {
340339
341- SmallVector<std::pair<Operation *, size_t >> outlineOps;
342- getOperation ().walk ([&](gpu::LaunchOp launchOp) {
343- launchOp.walk ([&](LLVM::CallOp caller) {
344- if (!caller.getCallee ()) {
345- outlineOps.push_back (std::make_pair (caller, (size_t )0 ));
346- }
347- });
348- });
349- SetVector<FunctionOpInterface> toinl;
350- while (outlineOps.size ()) {
351- auto opv = outlineOps.back ();
352- auto op = std::get<0 >(opv);
353- auto idx = std::get<1 >(opv);
354- outlineOps.pop_back ();
355- if (Value fn = op->getOperand (idx)) {
356- if (auto fn2 = fn.getDefiningOp <polygeist::Memref2PointerOp>())
357- fn = fn2.getOperand ();
358- if (auto ba = fn.dyn_cast <BlockArgument>()) {
359- if (auto F =
360- dyn_cast<FunctionOpInterface>(ba.getOwner ()->getParentOp ())) {
361- if (toinl.count (F))
362- continue ;
363- toinl.insert (F);
364- for (Operation *m : symbolUserMap.getUsers (F)) {
365- outlineOps.push_back (std::make_pair (m, (size_t )ba.getArgNumber ()));
340+ SmallVector<Operation *> inlineOps;
341+ SmallVector<mlir::Value> toFollowOps;
342+ SetVector<FunctionOpInterface> toinl;
343+
344+ getOperation ().walk (
345+ [&](mlir::gpu::ThreadIdOp bidx) { inlineOps.push_back (bidx); });
346+ getOperation ().walk (
347+ [&](mlir::gpu::GridDimOp bidx) { inlineOps.push_back (bidx); });
348+ getOperation ().walk (
349+ [&](mlir::NVVM::Barrier0Op bidx) { inlineOps.push_back (bidx); });
350+
351+ SymbolUserMap symbolUserMap (symbolTable, getOperation ());
352+ while (inlineOps.size ()) {
353+ auto op = inlineOps.back ();
354+ inlineOps.pop_back ();
355+ auto lop = op->getParentOfType <gpu::LaunchOp>();
356+ auto fop = op->getParentOfType <FunctionOpInterface>();
357+ if (!lop || lop->isAncestor (fop)) {
358+ toinl.insert (fop);
359+ for (Operation *m : symbolUserMap.getUsers (fop)) {
360+ if (isa<LLVM::CallOp, func::CallOp>(m))
361+ inlineOps.push_back (m);
362+ else if (isa<polygeist::GetFuncOp>(m)) {
363+ toFollowOps.push_back (m->getResult (0 ));
366364 }
367365 }
368366 }
369367 }
370- }
371- for (auto F : toinl) {
372- for (Operation *m : symbolUserMap.getUsers (F)) {
373- callInliner (cast<CallOp>(m));
368+ for (auto F : toinl) {
369+ SmallVector<LLVM::CallOp> ltoinl;
370+ SmallVector<func::CallOp> mtoinl;
371+ SymbolUserMap symbolUserMap (symbolTable, getOperation ());
372+ for (Operation *m : symbolUserMap.getUsers (F)) {
373+ if (auto l = dyn_cast<LLVM::CallOp>(m))
374+ ltoinl.push_back (l);
375+ else if (auto mc = dyn_cast<func::CallOp>(m))
376+ mtoinl.push_back (mc);
377+ }
378+ for (auto l : ltoinl) {
379+ LLVMcallInliner (l);
380+ }
381+ for (auto m : mtoinl) {
382+ callInliner (m);
383+ }
384+ }
385+ while (toFollowOps.size ()) {
386+ auto op = toFollowOps.back ();
387+ toFollowOps.pop_back ();
388+ SmallVector<LLVM::CallOp> ltoinl;
389+ SmallVector<func::CallOp> mtoinl;
390+ bool inlined = false ;
391+ for (auto u : op.getUsers ()) {
392+ if (auto cop = dyn_cast<LLVM::CallOp>(u)) {
393+ if (!cop.getCallee () && cop->getOperand (0 ) == op) {
394+ OpBuilder builder (cop);
395+ SmallVector<Value> vals;
396+ if (fixupGetFunc (cop, builder, vals).succeeded ()) {
397+ if (vals.size ())
398+ cop.getResult ().replaceAllUsesWith (vals[0 ]);
399+ cop.erase ();
400+ inlined = true ;
401+ break ;
402+ }
403+ } else if (cop.getCallee ())
404+ ltoinl.push_back (cop);
405+ } else if (auto cop = dyn_cast<func::CallOp>(u)) {
406+ mtoinl.push_back (cop);
407+ } else {
408+ for (auto r : u->getResults ())
409+ toFollowOps.push_back (r);
410+ }
411+ }
412+ for (auto l : ltoinl) {
413+ LLVMcallInliner (l);
414+ inlined = true ;
415+ }
416+ for (auto m : mtoinl) {
417+ callInliner (m);
418+ inlined = true ;
419+ }
420+ if (inlined)
421+ toFollowOps.push_back (op);
374422 }
375423 }
376- getOperation ().walk ([&](LLVM::CallOp caller) {
377- OpBuilder builder (caller);
378- SmallVector<Value> vals;
379- if (fixupGetFunc (caller, builder, vals).failed ())
380- return ;
381- if (vals.size ())
382- caller.getResult ().replaceAllUsesWith (vals[0 ]);
383- caller.erase ();
384- });
424+
425+ // Only supports single block functions at the moment.
385426
386427 SmallVector<gpu::LaunchOp> toHandle;
387428 getOperation ().walk (
0 commit comments