2727#include " mlir/Transforms/Passes.h"
2828#include " polygeist/Ops.h"
2929#include " polygeist/Passes/Passes.h"
30+ #include " llvm/ADT/SetVector.h"
3031#include " llvm/ADT/SmallPtrSet.h"
3132#include < algorithm>
3233#include < mutex>
@@ -198,18 +199,23 @@ mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module) {
198199 lnk);
199200}
200201
202+ LogicalResult fixupGetFunc (LLVM::CallOp, OpBuilder &rewriter,
203+ SmallVectorImpl<Value> &);
204+
201205void ParallelLower::runOnOperation () {
202206 // The inliner should only be run on operations that define a symbol table,
203207 // as the callgraph will need to resolve references.
204208
205209 SymbolTableCollection symbolTable;
206210 symbolTable.getSymbolTable (getOperation ());
211+ SymbolUserMap symbolUserMap (symbolTable, getOperation ());
207212
208213 getOperation ()->walk ([&](CallOp bidx) {
209214 if (bidx.getCallee () == " cudaThreadSynchronize" )
210215 bidx.erase ();
211216 });
212217
218+ std::function<void (LLVM::CallOp)> LLVMcallInliner;
213219 std::function<void (CallOp)> callInliner = [&](CallOp caller) {
214220 // Build the inliner interface.
215221 AlwaysInlinerInterface interface (&getContext ());
@@ -230,10 +236,72 @@ void ParallelLower::runOnOperation() {
230236 return ;
231237 if (targetRegion->empty ())
232238 return ;
233- SmallVector<CallOp> ops;
234- callableOp.walk ([&](CallOp caller) { ops.push_back (caller); });
235- for (auto op : ops)
236- callInliner (op);
239+ {
240+ SmallVector<CallOp> ops;
241+ callableOp.walk ([&](CallOp caller) { ops.push_back (caller); });
242+ for (auto op : ops)
243+ callInliner (op);
244+ }
245+ {
246+ SmallVector<LLVM::CallOp> ops;
247+ callableOp.walk ([&](LLVM::CallOp caller) { ops.push_back (caller); });
248+ for (auto op : ops)
249+ LLVMcallInliner (op);
250+ }
251+ OpBuilder b (caller);
252+ auto allocScope = b.create <memref::AllocaScopeOp>(caller.getLoc (),
253+ caller.getResultTypes ());
254+ allocScope.getRegion ().push_back (new Block ());
255+ b.setInsertionPointToStart (&allocScope.getRegion ().front ());
256+ auto exOp = b.create <scf::ExecuteRegionOp>(caller.getLoc (),
257+ caller.getResultTypes ());
258+ Block *blk = new Block ();
259+ exOp.getRegion ().push_back (blk);
260+ caller->moveBefore (blk, blk->begin ());
261+ caller.replaceAllUsesWith (allocScope.getResults ());
262+ b.setInsertionPointToEnd (blk);
263+ b.create <scf::YieldOp>(caller.getLoc (), caller.getResults ());
264+ if (inlineCall (interface, caller, callableOp, targetRegion,
265+ /* shouldCloneInlinedRegion=*/ true )
266+ .succeeded ()) {
267+ caller.erase ();
268+ }
269+ b.setInsertionPointToEnd (&allocScope.getRegion ().front ());
270+ b.create <memref::AllocaScopeReturnOp>(allocScope.getLoc (),
271+ exOp.getResults ());
272+ };
273+ LLVMcallInliner = [&](LLVM::CallOp caller) {
274+ // Build the inliner interface.
275+ AlwaysInlinerInterface interface (&getContext ());
276+
277+ auto callable = caller.getCallableForCallee ();
278+ CallableOpInterface callableOp;
279+ if (SymbolRefAttr symRef = callable.dyn_cast <SymbolRefAttr>()) {
280+ if (!symRef.isa <FlatSymbolRefAttr>())
281+ return ;
282+ auto *symbolOp =
283+ symbolTable.lookupNearestSymbolFrom (getOperation (), symRef);
284+ callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
285+ } else {
286+ return ;
287+ }
288+ Region *targetRegion = callableOp.getCallableRegion ();
289+ if (!targetRegion)
290+ return ;
291+ if (targetRegion->empty ())
292+ return ;
293+ {
294+ SmallVector<CallOp> ops;
295+ callableOp.walk ([&](CallOp caller) { ops.push_back (caller); });
296+ for (auto op : ops)
297+ callInliner (op);
298+ }
299+ {
300+ SmallVector<LLVM::CallOp> ops;
301+ callableOp.walk ([&](LLVM::CallOp caller) { ops.push_back (caller); });
302+ for (auto op : ops)
303+ LLVMcallInliner (op);
304+ }
237305 OpBuilder b (caller);
238306 auto allocScope = b.create <memref::AllocaScopeOp>(caller.getLoc (),
239307 caller.getResultTypes ());
@@ -256,6 +324,7 @@ void ParallelLower::runOnOperation() {
256324 b.create <memref::AllocaScopeReturnOp>(allocScope.getLoc (),
257325 exOp.getResults ());
258326 };
327+
259328 {
260329 SmallVector<CallOp> dimsToInline;
261330 getOperation ()->walk ([&](CallOp bidx) {
@@ -268,15 +337,68 @@ void ParallelLower::runOnOperation() {
268337 }
269338
270339 // Only supports single block functions at the moment.
340+
341+ SmallVector<std::pair<Operation *, size_t >> outlineOps;
342+ getOperation ().walk ([&](gpu::LaunchOp launchOp) {
343+ launchOp.walk ([&](LLVM::CallOp caller) {
344+ if (!caller.getCallee ()) {
345+ outlineOps.push_back (std::make_pair (caller, (size_t )0 ));
346+ }
347+ });
348+ });
349+ SetVector<FunctionOpInterface> toinl;
350+ while (outlineOps.size ()) {
351+ auto opv = outlineOps.back ();
352+ auto op = std::get<0 >(opv);
353+ auto idx = std::get<1 >(opv);
354+ outlineOps.pop_back ();
355+ if (Value fn = op->getOperand (idx)) {
356+ if (auto fn2 = fn.getDefiningOp <polygeist::Memref2PointerOp>())
357+ fn = fn2.getOperand ();
358+ if (auto ba = fn.dyn_cast <BlockArgument>()) {
359+ if (auto F =
360+ dyn_cast<FunctionOpInterface>(ba.getOwner ()->getParentOp ())) {
361+ if (toinl.count (F))
362+ continue ;
363+ toinl.insert (F);
364+ for (Operation *m : symbolUserMap.getUsers (F)) {
365+ outlineOps.push_back (std::make_pair (m, (size_t )ba.getArgNumber ()));
366+ }
367+ }
368+ }
369+ }
370+ }
371+ for (auto F : toinl) {
372+ for (Operation *m : symbolUserMap.getUsers (F)) {
373+ callInliner (cast<CallOp>(m));
374+ }
375+ }
376+ getOperation ().walk ([&](LLVM::CallOp caller) {
377+ OpBuilder builder (caller);
378+ SmallVector<Value> vals;
379+ if (fixupGetFunc (caller, builder, vals).failed ())
380+ return ;
381+ if (vals.size ())
382+ caller.getResult ().replaceAllUsesWith (vals[0 ]);
383+ caller.erase ();
384+ });
385+
271386 SmallVector<gpu::LaunchOp> toHandle;
272387 getOperation ().walk (
273388 [&](gpu::LaunchOp launchOp) { toHandle.push_back (launchOp); });
274-
275389 for (gpu::LaunchOp launchOp : toHandle) {
276- SmallVector<CallOp> ops;
277- launchOp.walk ([&](CallOp caller) { ops.push_back (caller); });
278- for (auto op : ops)
279- callInliner (op);
390+ {
391+ SmallVector<CallOp> ops;
392+ launchOp.walk ([&](CallOp caller) { ops.push_back (caller); });
393+ for (auto op : ops)
394+ callInliner (op);
395+ }
396+ {
397+ SmallVector<LLVM::CallOp> lops;
398+ launchOp.walk ([&](LLVM::CallOp caller) { lops.push_back (caller); });
399+ for (auto op : lops)
400+ LLVMcallInliner (op);
401+ }
280402
281403 mlir::IRRewriter builder (launchOp.getContext ());
282404 auto loc = launchOp.getLoc ();
0 commit comments