[GlobalISel][AArch64] Allow CallLowering to handle types which are normally
required to be passed as different register types. E.g. <2 x i16> may need to
be passed as a larger <2 x i32> type, so formal arg lowering needs to be able
truncate it back. Likewise, when dealing with returns of these types, they need
to be widened in the appropriate way back.

Differential Revision: https://reviews.llvm.org/D60425

llvm-svn: 358032
diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
index a7c2a2e..47fdeed 100644
--- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
@@ -20,6 +20,8 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 
+#define DEBUG_TYPE "call-lowering"
+
 using namespace llvm;
 
 void CallLowering::anchor() {}
@@ -121,8 +123,15 @@
   unsigned NumArgs = Args.size();
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT CurVT = MVT::getVT(Args[i].Ty);
-    if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i], CCInfo))
-      return false;
+    if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i], CCInfo)) {
+      // Try to use the register type if we couldn't assign the VT.
+      if (!Handler.isArgumentHandler())
+        return false; 
+      CurVT = TLI->getRegisterTypeForCallingConv(
+          F.getContext(), F.getCallingConv(), EVT(CurVT));
+      if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i], CCInfo))
+        return false;
+    }
   }
 
   for (unsigned i = 0, e = Args.size(), j = 0; i != e; ++i, ++j) {
@@ -136,12 +145,39 @@
       continue;
     }
 
-    if (VA.isRegLoc())
-      Handler.assignValueToReg(Args[i].Reg, VA.getLocReg(), VA);
-    else if (VA.isMemLoc()) {
-      unsigned Size = VA.getValVT() == MVT::iPTR
-                          ? DL.getPointerSize()
-                          : alignTo(VA.getValVT().getSizeInBits(), 8) / 8;
+    if (VA.isRegLoc()) {
+      MVT OrigVT = MVT::getVT(Args[i].Ty);
+      MVT VAVT = VA.getValVT();
+      if (Handler.isArgumentHandler() && VAVT != OrigVT) {
+        if (VAVT.getSizeInBits() < OrigVT.getSizeInBits())
+          return false; // Can't handle this type of arg yet.
+        const LLT VATy(VAVT);
+        unsigned NewReg =
+            MIRBuilder.getMRI()->createGenericVirtualRegister(VATy);
+        Handler.assignValueToReg(NewReg, VA.getLocReg(), VA);
+        // If it's a vector type, we either need to truncate the elements
+        // or do an unmerge to get the lower block of elements.
+        if (VATy.isVector() &&
+            VATy.getNumElements() > OrigVT.getVectorNumElements()) {
+          const LLT OrigTy(OrigVT);
+          // Just handle the case where the VA type is 2 * original type.
+          if (VATy.getNumElements() != OrigVT.getVectorNumElements() * 2) {
+            LLVM_DEBUG(dbgs()
+                       << "Incoming promoted vector arg has too many elts");
+            return false;
+          }
+          auto Unmerge = MIRBuilder.buildUnmerge({OrigTy, OrigTy}, {NewReg});
+          MIRBuilder.buildCopy(Args[i].Reg, Unmerge.getReg(0));
+        } else {
+          MIRBuilder.buildTrunc(Args[i].Reg, {NewReg}).getReg(0);
+        }
+      } else {
+        Handler.assignValueToReg(Args[i].Reg, VA.getLocReg(), VA);
+      }
+    } else if (VA.isMemLoc()) {
+      MVT VT = MVT::getVT(Args[i].Ty);
+      unsigned Size = VT == MVT::iPTR ? DL.getPointerSize()
+                                      : alignTo(VT.getSizeInBits(), 8) / 8;
       unsigned Offset = VA.getLocMemOffset();
       MachinePointerInfo MPO;
       unsigned StackAddr = Handler.getStackAddress(Size, Offset, MPO);
@@ -157,6 +193,8 @@
 unsigned CallLowering::ValueHandler::extendRegister(unsigned ValReg,
                                                     CCValAssign &VA) {
   LLT LocTy{VA.getLocVT()};
+  if (LocTy.getSizeInBits() == MRI.getType(ValReg).getSizeInBits())
+    return ValReg;
   switch (VA.getLocInfo()) {
   default: break;
   case CCValAssign::Full: