[GlobalISel][CallLowering] Add support for splitting types according to calling conventions.

On AArch64, s128 types have to be split into s64 GPRs when passed as arguments.
This change adds the generic support in call lowering for dealing with multiple
registers, for incoming and outgoing args.

Support for splitting for return types not yet implemented.

Differential Revision: https://reviews.llvm.org/D66180

llvm-svn: 370822
diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
index 5e08361..d433155 100644
--- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
@@ -11,8 +11,9 @@
 ///
 //===----------------------------------------------------------------------===//
 
-#include "llvm/CodeGen/GlobalISel/CallLowering.h"
 #include "llvm/CodeGen/Analysis.h"
+#include "llvm/CodeGen/GlobalISel/CallLowering.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/CodeGen/MachineOperand.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
@@ -71,29 +72,30 @@
 void CallLowering::setArgFlags(CallLowering::ArgInfo &Arg, unsigned OpIdx,
                                const DataLayout &DL,
                                const FuncInfoTy &FuncInfo) const {
+  auto &Flags = Arg.Flags[0];
   const AttributeList &Attrs = FuncInfo.getAttributes();
   if (Attrs.hasAttribute(OpIdx, Attribute::ZExt))
-    Arg.Flags.setZExt();
+    Flags.setZExt();
   if (Attrs.hasAttribute(OpIdx, Attribute::SExt))
-    Arg.Flags.setSExt();
+    Flags.setSExt();
   if (Attrs.hasAttribute(OpIdx, Attribute::InReg))
-    Arg.Flags.setInReg();
+    Flags.setInReg();
   if (Attrs.hasAttribute(OpIdx, Attribute::StructRet))
-    Arg.Flags.setSRet();
+    Flags.setSRet();
   if (Attrs.hasAttribute(OpIdx, Attribute::SwiftSelf))
-    Arg.Flags.setSwiftSelf();
+    Flags.setSwiftSelf();
   if (Attrs.hasAttribute(OpIdx, Attribute::SwiftError))
-    Arg.Flags.setSwiftError();
+    Flags.setSwiftError();
   if (Attrs.hasAttribute(OpIdx, Attribute::ByVal))
-    Arg.Flags.setByVal();
+    Flags.setByVal();
   if (Attrs.hasAttribute(OpIdx, Attribute::InAlloca))
-    Arg.Flags.setInAlloca();
+    Flags.setInAlloca();
 
-  if (Arg.Flags.isByVal() || Arg.Flags.isInAlloca()) {
+  if (Flags.isByVal() || Flags.isInAlloca()) {
     Type *ElementTy = cast<PointerType>(Arg.Ty)->getElementType();
 
     auto Ty = Attrs.getAttribute(OpIdx, Attribute::ByVal).getValueAsType();
-    Arg.Flags.setByValSize(DL.getTypeAllocSize(Ty ? Ty : ElementTy));
+    Flags.setByValSize(DL.getTypeAllocSize(Ty ? Ty : ElementTy));
 
     // For ByVal, alignment should be passed from FE.  BE will guess if
     // this info is not there but there are cases it cannot get right.
@@ -102,11 +104,11 @@
       FrameAlign = FuncInfo.getParamAlignment(OpIdx - 2);
     else
       FrameAlign = getTLI()->getByValTypeAlignment(ElementTy, DL);
-    Arg.Flags.setByValAlign(FrameAlign);
+    Flags.setByValAlign(FrameAlign);
   }
   if (Attrs.hasAttribute(OpIdx, Attribute::Nest))
-    Arg.Flags.setNest();
-  Arg.Flags.setOrigAlign(DL.getABITypeAlignment(Arg.Ty));
+    Flags.setNest();
+  Flags.setOrigAlign(DL.getABITypeAlignment(Arg.Ty));
 }
 
 template void
@@ -161,7 +163,7 @@
 }
 
 bool CallLowering::handleAssignments(MachineIRBuilder &MIRBuilder,
-                                     ArrayRef<ArgInfo> Args,
+                                     SmallVectorImpl<ArgInfo> &Args,
                                      ValueHandler &Handler) const {
   MachineFunction &MF = MIRBuilder.getMF();
   const Function &F = MF.getFunction();
@@ -173,7 +175,7 @@
 bool CallLowering::handleAssignments(CCState &CCInfo,
                                      SmallVectorImpl<CCValAssign> &ArgLocs,
                                      MachineIRBuilder &MIRBuilder,
-                                     ArrayRef<ArgInfo> Args,
+                                     SmallVectorImpl<ArgInfo> &Args,
                                      ValueHandler &Handler) const {
   MachineFunction &MF = MIRBuilder.getMF();
   const Function &F = MF.getFunction();
@@ -182,14 +184,101 @@
   unsigned NumArgs = Args.size();
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT CurVT = MVT::getVT(Args[i].Ty);
-    if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i], CCInfo)) {
-      // Try to use the register type if we couldn't assign the VT.
-      if (!Handler.isIncomingArgumentHandler() || !CurVT.isValid())
+    if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i],
+                          Args[i].Flags[0], CCInfo)) {
+      if (!CurVT.isValid())
         return false;
-      CurVT = TLI->getRegisterTypeForCallingConv(
+      MVT NewVT = TLI->getRegisterTypeForCallingConv(
           F.getContext(), F.getCallingConv(), EVT(CurVT));
-      if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i], CCInfo))
-        return false;
+
+      // If we need to split the type over multiple regs, check it's a scenario
+      // we currently support.
+      unsigned NumParts = TLI->getNumRegistersForCallingConv(
+          F.getContext(), F.getCallingConv(), CurVT);
+      if (NumParts > 1) {
+        if (CurVT.isVector())
+          return false;
+        // For now only handle exact splits.
+        if (NewVT.getSizeInBits() * NumParts != CurVT.getSizeInBits())
+          return false;
+      }
+
+      // For incoming arguments (return values), we could have values in
+      // physregs (or memlocs) which we want to extract and copy to vregs.
+      // During this, we might have to deal with the LLT being split across
+      // multiple regs, so we have to record this information for later.
+      //
+      // If we have outgoing args, then we have the opposite case. We have a
+      // vreg with an LLT which we want to assign to a physical location, and
+      // we might have to record that the value has to be split later.
+      if (Handler.isIncomingArgumentHandler()) {
+        if (NumParts == 1) {
+          // Try to use the register type if we couldn't assign the VT.
+          if (Handler.assignArg(i, NewVT, NewVT, CCValAssign::Full, Args[i],
+                                Args[i].Flags[0], CCInfo))
+            return false;
+        } else {
+          // We're handling an incoming arg which is split over multiple regs.
+          // E.g. returning an s128 on AArch64.
+          ISD::ArgFlagsTy OrigFlags = Args[i].Flags[0];
+          Args[i].OrigRegs.push_back(Args[i].Regs[0]);
+          Args[i].Regs.clear();
+          Args[i].Flags.clear();
+          LLT NewLLT = getLLTForMVT(NewVT);
+          // For each split register, create and assign a vreg that will store
+          // the incoming component of the larger value. These will later be
+          // merged to form the final vreg.
+          for (unsigned Part = 0; Part < NumParts; ++Part) {
+            Register Reg =
+                MIRBuilder.getMRI()->createGenericVirtualRegister(NewLLT);
+            ISD::ArgFlagsTy Flags = OrigFlags;
+            if (Part == 0) {
+              Flags.setSplit();
+            } else {
+              Flags.setOrigAlign(1);
+              if (Part == NumParts - 1)
+                Flags.setSplitEnd();
+            }
+            Args[i].Regs.push_back(Reg);
+            Args[i].Flags.push_back(Flags);
+            if (Handler.assignArg(i, NewVT, NewVT, CCValAssign::Full, Args[i],
+                                  Args[i].Flags[Part], CCInfo)) {
+              // Still couldn't assign this smaller part type for some reason.
+              return false;
+            }
+          }
+        }
+      } else {
+        // Handling an outgoing arg that might need to be split.
+        if (NumParts < 2)
+          return false; // Don't know how to deal with this type combination.
+
+        // This type is passed via multiple registers in the calling convention.
+        // We need to extract the individual parts.
+        Register LargeReg = Args[i].Regs[0];
+        LLT SmallTy = LLT::scalar(NewVT.getSizeInBits());
+        auto Unmerge = MIRBuilder.buildUnmerge(SmallTy, LargeReg);
+        assert(Unmerge->getNumOperands() == NumParts + 1);
+        ISD::ArgFlagsTy OrigFlags = Args[i].Flags[0];
+        // We're going to replace the regs and flags with the split ones.
+        Args[i].Regs.clear();
+        Args[i].Flags.clear();
+        for (unsigned PartIdx = 0; PartIdx < NumParts; ++PartIdx) {
+          ISD::ArgFlagsTy Flags = OrigFlags;
+          if (PartIdx == 0) {
+            Flags.setSplit();
+          } else {
+            Flags.setOrigAlign(1);
+            if (PartIdx == NumParts - 1)
+              Flags.setSplitEnd();
+          }
+          Args[i].Regs.push_back(Unmerge.getReg(PartIdx));
+          Args[i].Flags.push_back(Flags);
+          if (Handler.assignArg(i, NewVT, NewVT, CCValAssign::Full, Args[i],
+                                Args[i].Flags[PartIdx], CCInfo))
+            return false;
+        }
+      }
     }
   }
 
@@ -204,9 +293,6 @@
       continue;
     }
 
-    assert(Args[i].Regs.size() == 1 &&
-           "Can't handle multiple virtual regs yet");
-
     // FIXME: Pack registers if we have more than one.
     Register ArgReg = Args[i].Regs[0];
 
@@ -214,8 +300,25 @@
       MVT OrigVT = MVT::getVT(Args[i].Ty);
       MVT VAVT = VA.getValVT();
       if (Handler.isIncomingArgumentHandler() && VAVT != OrigVT) {
-        if (VAVT.getSizeInBits() < OrigVT.getSizeInBits())
-          return false; // Can't handle this type of arg yet.
+        if (VAVT.getSizeInBits() < OrigVT.getSizeInBits()) {
+          // Expected to be multiple regs for a single incoming arg.
+          unsigned NumArgRegs = Args[i].Regs.size();
+          if (NumArgRegs < 2)
+            return false;
+
+          assert((j + (NumArgRegs - 1)) < ArgLocs.size() &&
+                 "Too many regs for number of args");
+          for (unsigned Part = 0; Part < NumArgRegs; ++Part) {
+            // There should be Regs.size() ArgLocs per argument.
+            VA = ArgLocs[j + Part];
+            Handler.assignValueToReg(Args[i].Regs[Part], VA.getLocReg(), VA);
+          }
+          j += NumArgRegs - 1;
+          // Merge the split registers into the expected larger result vreg
+          // of the original call.
+          MIRBuilder.buildMerge(Args[i].OrigRegs[0], Args[i].Regs);
+          continue;
+        }
         const LLT VATy(VAVT);
         Register NewReg =
             MIRBuilder.getMRI()->createGenericVirtualRegister(VATy);
@@ -236,6 +339,16 @@
         } else {
           MIRBuilder.buildTrunc(ArgReg, {NewReg}).getReg(0);
         }
+      } else if (!Handler.isIncomingArgumentHandler()) {
+        assert((j + (Args[i].Regs.size() - 1)) < ArgLocs.size() &&
+               "Too many regs for number of args");
+        // This is an outgoing argument that might have been split.
+        for (unsigned Part = 0; Part < Args[i].Regs.size(); ++Part) {
+          // There should be Regs.size() ArgLocs per argument.
+          VA = ArgLocs[j + Part];
+          Handler.assignValueToReg(Args[i].Regs[Part], VA.getLocReg(), VA);
+        }
+        j += Args[i].Regs.size() - 1;
       } else {
         Handler.assignValueToReg(ArgReg, VA.getLocReg(), VA);
       }
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index b1bf825..da8898a 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -419,16 +419,6 @@
   return MF.getFunction().hasOptSize();
 }
 
-// Get a rough equivalent of an MVT for a given LLT.
-static MVT getMVTForLLT(LLT Ty) {
-  if (!Ty.isVector())
-    return MVT::getIntegerVT(Ty.getSizeInBits());
-
-  return MVT::getVectorVT(
-      MVT::getIntegerVT(Ty.getElementType().getSizeInBits()),
-      Ty.getNumElements());
-}
-
 // Returns a list of types to use for memory op lowering in MemOps. A partial
 // port of findOptimalMemOpLowering in TargetLowering.
 static bool findGISelOptimalMemOpLowering(
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index be09db1..a93e515 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -412,3 +412,20 @@
 void llvm::getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU) {
   AU.addPreserved<StackProtector>();
 }
+
+MVT llvm::getMVTForLLT(LLT Ty) {
+  if (!Ty.isVector())
+    return MVT::getIntegerVT(Ty.getSizeInBits());
+
+  return MVT::getVectorVT(
+      MVT::getIntegerVT(Ty.getElementType().getSizeInBits()),
+      Ty.getNumElements());
+}
+
+LLT llvm::getLLTForMVT(MVT Ty) {
+  if (!Ty.isVector())
+    return LLT::scalar(Ty.getSizeInBits());
+
+  return LLT::vector(Ty.getVectorNumElements(),
+                     Ty.getVectorElementType().getSizeInBits());
+}