@@ -534,7 +534,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
534534 }
535535 }
536536
537- auto getLLVM = [&](Expr *E) -> mlir::Value {
537+ auto getLLVM = [&](Expr *E, bool isRef = false ) -> mlir::Value {
538538 auto sub = Visit (E);
539539 if (!sub.val ) {
540540 expr->dump ();
@@ -564,23 +564,46 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
564564 auto shape = std::vector<int64_t >(mt.getShape ());
565565 assert (shape.size () == 2 );
566566
567- OpBuilder abuilder (builder.getContext ());
568- abuilder.setInsertionPointToStart (allocationScope);
569- auto one = abuilder.create <ConstantIntOp>(loc, 1 , 64 );
570- auto alloc = abuilder.create <mlir::LLVM::AllocaOp>(
571- loc,
567+ auto PT =
572568 LLVM::LLVMPointerType::get (Glob.typeTranslator .translateType (
573569 anonymize (getLLVMType (E->getType ()))),
574- 0 ),
575- one, 0 );
576- ValueCategory (alloc, /* isRef*/ true )
577- .store (loc, builder, sub, /* isArray*/ isArray);
578- sub = ValueCategory (alloc, /* isRef*/ true );
570+ 0 );
571+ if (true ) {
572+ sub = ValueCategory (
573+ builder.create <polygeist::Memref2PointerOp>(loc, PT, sub.val ),
574+ sub.isReference );
575+ } else {
576+ OpBuilder abuilder (builder.getContext ());
577+ abuilder.setInsertionPointToStart (allocationScope);
578+ auto one = abuilder.create <ConstantIntOp>(loc, 1 , 64 );
579+ auto alloc = abuilder.create <mlir::LLVM::AllocaOp>(loc, PT, one, 0 );
580+ ValueCategory (alloc, /* isRef*/ true )
581+ .store (loc, builder, sub, /* isArray*/ isArray);
582+ sub = ValueCategory (alloc, /* isRef*/ true );
583+ }
584+ }
585+ mlir::Value val;
586+ clang::QualType ct;
587+ if (!isRef) {
588+ val = sub.getValue (loc, builder);
589+ ct = E->getType ();
590+ } else {
591+ if (!sub.isReference ) {
592+ OpBuilder abuilder (builder.getContext ());
593+ abuilder.setInsertionPointToStart (allocationScope);
594+ auto one = abuilder.create <ConstantIntOp>(loc, 1 , 64 );
595+ auto alloc = abuilder.create <mlir::LLVM::AllocaOp>(
596+ loc, LLVM::LLVMPointerType::get (sub.val .getType ()), one, 0 );
597+ ValueCategory (alloc, /* isRef*/ true )
598+ .store (loc, builder, sub, /* isArray*/ isArray);
599+ sub = ValueCategory (alloc, /* isRef*/ true );
600+ }
601+ assert (sub.isReference );
602+ val = sub.val ;
603+ ct = Glob.CGM .getContext ().getLValueReferenceType (E->getType ());
579604 }
580- auto val = sub.getValue (loc, builder);
581605 if (auto mt = val.getType ().dyn_cast <MemRefType>()) {
582- auto nt = Glob.typeTranslator
583- .translateType (anonymize (getLLVMType (E->getType ())))
606+ auto nt = Glob.typeTranslator .translateType (anonymize (getLLVMType (ct)))
584607 .cast <LLVM::LLVMPointerType>();
585608 val = builder.create <polygeist::Memref2PointerOp>(loc, nt, val);
586609 }
@@ -1483,7 +1506,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
14831506
14841507 std::vector<mlir::Value> args;
14851508 for (auto *a : expr->arguments ()) {
1486- args.push_back (getLLVM (a));
1509+ args.push_back (getLLVM (a, /* isRef */ false ));
14871510 }
14881511 mlir::Value called;
14891512
@@ -1492,7 +1515,8 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
14921515 called = builder.create <mlir::LLVM::CallOp>(loc, strcmpF, args)
14931516 .getResult ();
14941517 } else {
1495- args.insert (args.begin (), getLLVM (expr->getCallee ()));
1518+ args.insert (args.begin (),
1519+ getLLVM (expr->getCallee (), /* isRef*/ false ));
14961520 SmallVector<mlir::Type> RTs = {Glob.typeTranslator .translateType (
14971521 anonymize (getLLVMType (expr->getType ())))};
14981522 if (RTs[0 ].isa <LLVM::LLVMVoidType>())
@@ -1509,31 +1533,154 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
15091533 if (!callee || callee->isVariadic ()) {
15101534 bool isReference = expr->isLValue () || expr->isXValue ();
15111535 std::vector<mlir::Value> args;
1512- for (auto *a : expr->arguments ()) {
1513- args.push_back (getLLVM (a));
1514- }
15151536 mlir::Value called;
15161537 if (callee) {
15171538 auto strcmpF = Glob.GetOrCreateLLVMFunction (callee);
1539+ std::vector<clang::QualType> types;
1540+ if (auto CC = dyn_cast<CXXMethodDecl>(callee)) {
1541+ types.push_back (CC->getThisType ());
1542+ }
1543+ for (auto parm : callee->parameters ()) {
1544+ types.push_back (parm->getOriginalType ());
1545+ }
1546+ int i = 0 ;
1547+ for (auto *a : expr->arguments ()) {
1548+ bool isRef = false ;
1549+ if (i < types.size ())
1550+ isRef = types[i]->isReferenceType ();
1551+ i++;
1552+ args.push_back (getLLVM (a, isRef));
1553+ }
15181554 called =
15191555 builder.create <mlir::LLVM::CallOp>(loc, strcmpF, args).getResult ();
15201556 } else {
1521- args.insert (args.begin (), getLLVM (expr->getCallee ()));
1557+ mlir::Value fn = Visit (expr->getCallee ()).getValue (loc, builder);
1558+ if (auto MT = fn.getType ().dyn_cast <MemRefType>()) {
1559+ fn = builder.create <polygeist::Memref2PointerOp>(
1560+ loc, LLVM::LLVMPointerType::get (MT.getElementType (), 0 ), fn);
1561+ }
1562+ auto PTF = fn.getType ()
1563+ .cast <LLVM::LLVMPointerType>()
1564+ .getElementType ()
1565+ .cast <LLVM::LLVMFunctionType>();
1566+ SmallVector<mlir::Type, 1 > argtys;
1567+ bool needsChange = false ;
1568+ for (auto FT : PTF.getParams ()) {
1569+ if (auto mt = FT.dyn_cast <MemRefType>()) {
1570+ argtys.push_back (LLVM::LLVMPointerType::get (mt.getElementType (), 0 ));
1571+ needsChange = true ;
1572+ } else
1573+ argtys.push_back (FT);
1574+ }
1575+ auto rt = PTF.getReturnType ();
1576+ if (auto mt = rt.dyn_cast <MemRefType>()) {
1577+ rt = LLVM::LLVMPointerType::get (mt.getElementType (), 0 );
1578+ needsChange = true ;
1579+ }
1580+ if (needsChange)
1581+ fn = builder.create <LLVM::BitcastOp>(
1582+ loc,
1583+ LLVM::LLVMPointerType::get (
1584+ LLVM::LLVMFunctionType::get (rt, argtys, PTF.isVarArg ()), 0 ),
1585+ fn);
1586+
1587+ args.push_back (fn);
15221588 auto CT = expr->getType ();
1523- if (isReference)
1524- CT = Glob.CGM .getContext ().getLValueReferenceType (CT);
1525- SmallVector<mlir::Type> RTs = {
1526- Glob. typeTranslator . translateType ( anonymize ( getLLVMType (CT)) )};
1589+ // if (isReference)
1590+ // CT = Glob.CGM.getContext().getLValueReferenceType(CT);
1591+ SmallVector<mlir::Type> RTs = {rt};
1592+ // getMLIRType(CT )};
15271593
15281594 auto ft = args[0 ]
15291595 .getType ()
15301596 .cast <LLVM::LLVMPointerType>()
15311597 .getElementType ()
15321598 .cast <LLVM::LLVMFunctionType>();
1533- assert (RTs[0 ] == ft.getReturnType ());
1534- if (RTs[0 ].isa <LLVM::LLVMVoidType>())
1599+ auto ETy = expr->getCallee ()->getType ()->getUnqualifiedDesugaredType ();
1600+ ETy = cast<clang::PointerType>(ETy)
1601+ ->getPointeeType ()
1602+ ->getUnqualifiedDesugaredType ();
1603+ auto CFT = dyn_cast<clang::FunctionProtoType>(ETy);
1604+ std::vector<clang::QualType> types;
1605+ if (CFT) {
1606+ for (auto t : CFT->getParamTypes ())
1607+ types.push_back (t);
1608+ } else {
1609+ assert (isa<clang::FunctionNoProtoType>(ETy));
1610+ }
1611+
1612+ auto ETy2 = ETy->getCanonicalTypeUnqualified ();
1613+
1614+ const clang::CodeGen::CGFunctionInfo *FI;
1615+ if (const FunctionProtoType *FPT = dyn_cast<FunctionProtoType>(ETy2)) {
1616+ FI = &Glob.CGM .getTypes ().arrangeFreeFunctionType (
1617+ CanQual<FunctionProtoType>::CreateUnsafe (QualType (FPT, 0 )));
1618+ } else {
1619+ const FunctionNoProtoType *FNPT = cast<FunctionNoProtoType>(ETy2);
1620+ FI = &Glob.CGM .getTypes ().arrangeFreeFunctionType (
1621+ CanQual<FunctionNoProtoType>::CreateUnsafe (QualType (FNPT, 0 )));
1622+ }
1623+
1624+ int i = 0 ;
1625+ for (auto *a : expr->arguments ()) {
1626+ bool isRef = false ;
1627+ bool isArray = false ;
1628+ if (i < types.size ()) {
1629+ isRef = types[i]->isReferenceType ();
1630+ // auto inf = FI->arguments()[i].info;
1631+ // isRef |= inf.isIndirect();
1632+ Glob.getMLIRType (types[i], &isArray);
1633+ isRef |= isArray;
1634+ }
1635+
1636+ auto sub = Visit (a);
1637+ mlir::Value v;
1638+ if (isRef) {
1639+ if (!sub.isReference ) {
1640+ OpBuilder abuilder (builder.getContext ());
1641+ abuilder.setInsertionPointToStart (allocationScope);
1642+ auto one = abuilder.create <ConstantIntOp>(loc, 1 , 64 );
1643+ auto alloc = abuilder.create <mlir::LLVM::AllocaOp>(
1644+ loc, LLVM::LLVMPointerType::get (sub.val .getType ()), one, 0 );
1645+ ValueCategory (alloc, /* isRef*/ true )
1646+ .store (loc, builder, sub, /* isArray*/ false );
1647+ sub = ValueCategory (alloc, /* isRef*/ true );
1648+ }
1649+ assert (sub.isReference );
1650+ v = sub.val ;
1651+ } else {
1652+ v = sub.getValue (loc, builder);
1653+ }
1654+ if (i < FI->arg_size ()) {
1655+ // TODO expand full calling conv
1656+ /*
1657+ auto inf = FI->arguments()[i].info;
1658+ if (inf.isIgnore() || inf.isInAlloca()) {
1659+ i++;
1660+ continue;
1661+ }
1662+ if (inf.isExpand()) {
1663+ i++;
1664+ continue;
1665+ }
1666+ */
1667+ }
1668+ i++;
1669+ if (auto mt = v.getType ().dyn_cast <MemRefType>()) {
1670+ v = builder.create <polygeist::Memref2PointerOp>(
1671+ loc, LLVM::LLVMPointerType::get (mt.getElementType (), 0 ), v);
1672+ }
1673+ args.push_back (v);
1674+ }
1675+ if (RTs[0 ].isa <mlir::NoneType>() || RTs[0 ].isa <LLVM::LLVMVoidType>())
15351676 RTs.clear ();
1677+ else
1678+ assert (RTs[0 ] == ft.getReturnType ());
15361679 called = builder.create <mlir::LLVM::CallOp>(loc, RTs, args).getResult ();
1680+ if (PTF.getReturnType () != ft.getReturnType ()) {
1681+ called = builder.create <polygeist::Pointer2MemrefOp>(
1682+ loc, PTF.getReturnType (), called);
1683+ }
15371684 }
15381685 if (isReference) {
15391686 if (!(called.getType ().isa <LLVM::LLVMPointerType>() ||
0 commit comments