[GlobalISel][AArch64] Allow CallLowering to handle types which are normally
required to be passed as different register types. E.g. <2 x i16> may need to
be passed as a larger <2 x i32> type, so formal arg lowering needs to be able
truncate it back. Likewise, when dealing with returns of these types, they need
to be widened in the appropriate way back.

Differential Revision: https://reviews.llvm.org/D60425

llvm-svn: 358032
diff --git a/llvm/lib/Target/AArch64/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/AArch64CallLowering.cpp
index 8a00a3f..83054ee 100644
--- a/llvm/lib/Target/AArch64/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64CallLowering.cpp
@@ -44,6 +44,8 @@
 #include <cstdint>
 #include <iterator>
 
+#define DEBUG_TYPE "aarch64-call-lowering"
+
 using namespace llvm;
 
 AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
@@ -97,6 +99,8 @@
   /// (it's an implicit-def of the BL).
   virtual void markPhysRegUsed(unsigned PhysReg) = 0;
 
+  bool isArgumentHandler() const override { return true; }
+
   uint64_t StackUsed;
 };
 
@@ -250,18 +254,63 @@
            "For each split Type there should be exactly one VReg.");
 
     SmallVector<ArgInfo, 8> SplitArgs;
+    CallingConv::ID CC = F.getCallingConv();
+
     for (unsigned i = 0; i < SplitEVTs.size(); ++i) {
-      // We zero-extend i1s to i8.
-      unsigned CurVReg = VRegs[i];
-      if (MRI.getType(VRegs[i]).getSizeInBits() == 1) {
-        CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg)
-                       ->getOperand(0)
-                       .getReg();
+      if (TLI.getNumRegistersForCallingConv(Ctx, CC, SplitEVTs[i]) > 1) {
+        LLVM_DEBUG(dbgs() << "Can't handle extended arg types which need split");
+        return false;
       }
 
+      unsigned CurVReg = VRegs[i];
       ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVTs[i].getTypeForEVT(Ctx)};
       setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, F);
-      splitToValueTypes(CurArgInfo, SplitArgs, DL, MRI, F.getCallingConv(),
+
+      // i1 is a special case because SDAG i1 true is naturally zero extended
+      // when widened using ANYEXT. We need to do it explicitly here.
+      if (MRI.getType(CurVReg).getSizeInBits() == 1) {
+        CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg).getReg(0);
+      } else {
+        // Some types will need extending as specified by the CC.
+        MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CC, SplitEVTs[i]);
+        if (EVT(NewVT) != SplitEVTs[i]) {
+          unsigned ExtendOp = TargetOpcode::G_ANYEXT;
+          if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex,
+                                             Attribute::SExt))
+            ExtendOp = TargetOpcode::G_SEXT;
+          else if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex,
+                                                  Attribute::ZExt))
+            ExtendOp = TargetOpcode::G_ZEXT;
+
+          LLT NewLLT(NewVT);
+          LLT OldLLT(MVT::getVT(CurArgInfo.Ty));
+          CurArgInfo.Ty = EVT(NewVT).getTypeForEVT(Ctx);
+          // Instead of an extend, we might have a vector type which needs
+          // padding with more elements, e.g. <2 x half> -> <4 x half>
+          if (NewVT.isVector() &&
+              NewLLT.getNumElements() > OldLLT.getNumElements()) {
+            // We don't handle VA types which are not exactly twice the size,
+            // but can easily be done in future.
+            if (NewLLT.getNumElements() != OldLLT.getNumElements() * 2) {
+              LLVM_DEBUG(dbgs() << "Outgoing vector ret has too many elts");
+              return false;
+            }
+            auto Undef = MIRBuilder.buildUndef({OldLLT});
+            CurVReg =
+                MIRBuilder.buildMerge({NewLLT}, {CurVReg, Undef.getReg(0)})
+                    .getReg(0);
+          } else {
+            CurVReg =
+                MIRBuilder.buildInstr(ExtendOp, {NewLLT}, {CurVReg}).getReg(0);
+          }
+        }
+      }
+      if (CurVReg != CurArgInfo.Reg) {
+        CurArgInfo.Reg = CurVReg;
+        // Reset the arg flags after modifying CurVReg.
+        setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, F);
+      }
+     splitToValueTypes(CurArgInfo, SplitArgs, DL, MRI, CC,
                         [&](unsigned Reg, uint64_t Offset) {
                           MIRBuilder.buildExtract(Reg, CurVReg, Offset);
                         });
diff --git a/llvm/lib/Target/ARM/ARMCallLowering.cpp b/llvm/lib/Target/ARM/ARMCallLowering.cpp
index def7c5c..b70d55f 100644
--- a/llvm/lib/Target/ARM/ARMCallLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMCallLowering.cpp
@@ -301,6 +301,8 @@
                        CCAssignFn AssignFn)
       : ValueHandler(MIRBuilder, MRI, AssignFn) {}
 
+  bool isArgumentHandler() const override { return true; }
+
   unsigned getStackAddress(uint64_t Size, int64_t Offset,
                            MachinePointerInfo &MPO) override {
     assert((Size == 1 || Size == 2 || Size == 4 || Size == 8) &&
diff --git a/llvm/lib/Target/X86/X86CallLowering.cpp b/llvm/lib/Target/X86/X86CallLowering.cpp
index 048e4ca..5a623db 100644
--- a/llvm/lib/Target/X86/X86CallLowering.cpp
+++ b/llvm/lib/Target/X86/X86CallLowering.cpp
@@ -228,6 +228,8 @@
       : ValueHandler(MIRBuilder, MRI, AssignFn),
         DL(MIRBuilder.getMF().getDataLayout()) {}
 
+  bool isArgumentHandler() const override { return true; }
+
   unsigned getStackAddress(uint64_t Size, int64_t Offset,
                            MachinePointerInfo &MPO) override {
     auto &MFI = MIRBuilder.getMF().getFrameInfo();