Reapply "AMDGPU: Fix handling of alignment padding in DAG argument lowering"

Reverts r337079 with fix for msan error.

llvm-svn: 337535
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 485927c..b201126 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -30,6 +30,7 @@
 #include "SIInstrInfo.h"
 #include "SIMachineFunctionInfo.h"
 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
@@ -40,18 +41,6 @@
 #include "llvm/Support/KnownBits.h"
 using namespace llvm;
 
-static bool allocateKernArg(unsigned ValNo, MVT ValVT, MVT LocVT,
-                            CCValAssign::LocInfo LocInfo,
-                            ISD::ArgFlagsTy ArgFlags, CCState &State) {
-  MachineFunction &MF = State.getMachineFunction();
-  AMDGPUMachineFunction *MFI = MF.getInfo<AMDGPUMachineFunction>();
-
-  uint64_t Offset = MFI->allocateKernArg(LocVT.getStoreSize(),
-                                         ArgFlags.getOrigAlign());
-  State.addLoc(CCValAssign::getCustomMem(ValNo, ValVT, Offset, LocVT, LocInfo));
-  return true;
-}
-
 static bool allocateCCRegs(unsigned ValNo, MVT ValVT, MVT LocVT,
                            CCValAssign::LocInfo LocInfo,
                            ISD::ArgFlagsTy ArgFlags, CCState &State,
@@ -910,74 +899,118 @@
 /// for each individual part is i8.  We pass the memory type as LocVT to the
 /// calling convention analysis function and the register type (Ins[x].VT) as
 /// the ValVT.
-void AMDGPUTargetLowering::analyzeFormalArgumentsCompute(CCState &State,
-                             const SmallVectorImpl<ISD::InputArg> &Ins) const {
-  for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
-    const ISD::InputArg &In = Ins[i];
-    EVT MemVT;
+void AMDGPUTargetLowering::analyzeFormalArgumentsCompute(
+  CCState &State,
+  const SmallVectorImpl<ISD::InputArg> &Ins) const {
+  const MachineFunction &MF = State.getMachineFunction();
+  const Function &Fn = MF.getFunction();
+  LLVMContext &Ctx = Fn.getParent()->getContext();
+  const AMDGPUSubtarget &ST = AMDGPUSubtarget::get(MF);
+  const unsigned ExplicitOffset = ST.getExplicitKernelArgOffset(Fn);
 
-    unsigned NumRegs = getNumRegisters(State.getContext(), In.ArgVT);
+  unsigned MaxAlign = 1;
+  uint64_t ExplicitArgOffset = 0;
+  const DataLayout &DL = Fn.getParent()->getDataLayout();
 
-    if (!Subtarget->isAmdHsaOS() &&
-        (In.ArgVT == MVT::i16 || In.ArgVT == MVT::i8 || In.ArgVT == MVT::f16)) {
-      // The ABI says the caller will extend these values to 32-bits.
-      MemVT = In.ArgVT.isInteger() ? MVT::i32 : MVT::f32;
-    } else if (NumRegs == 1) {
-      // This argument is not split, so the IR type is the memory type.
-      assert(!In.Flags.isSplit());
-      if (In.ArgVT.isExtended()) {
-        // We have an extended type, like i24, so we should just use the register type
-        MemVT = In.VT;
+  unsigned InIndex = 0;
+
+  for (const Argument &Arg : Fn.args()) {
+    Type *BaseArgTy = Arg.getType();
+    unsigned Align = DL.getABITypeAlignment(BaseArgTy);
+    MaxAlign = std::max(Align, MaxAlign);
+    unsigned AllocSize = DL.getTypeAllocSize(BaseArgTy);
+
+    uint64_t ArgOffset = alignTo(ExplicitArgOffset, Align) + ExplicitOffset;
+    ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize;
+
+    // We're basically throwing away everything passed into us and starting over
+    // to get accurate in-memory offsets. The "PartOffset" is completely useless
+    // to us as computed in Ins.
+    //
+    // We also need to figure out what type legalization is trying to do to get
+    // the correct memory offsets.
+
+    SmallVector<EVT, 16> ValueVTs;
+    SmallVector<uint64_t, 16> Offsets;
+    ComputeValueVTs(*this, DL, BaseArgTy, ValueVTs, &Offsets, ArgOffset);
+
+    for (unsigned Value = 0, NumValues = ValueVTs.size();
+         Value != NumValues; ++Value) {
+      uint64_t BasePartOffset = Offsets[Value];
+
+      EVT ArgVT = ValueVTs[Value];
+      EVT MemVT = ArgVT;
+      MVT RegisterVT =
+        getRegisterTypeForCallingConv(Ctx, ArgVT);
+      unsigned NumRegs =
+        getNumRegistersForCallingConv(Ctx, ArgVT);
+
+      if (!Subtarget->isAmdHsaOS() &&
+          (ArgVT == MVT::i16 || ArgVT == MVT::i8 || ArgVT == MVT::f16)) {
+        // The ABI says the caller will extend these values to 32-bits.
+        MemVT = ArgVT.isInteger() ? MVT::i32 : MVT::f32;
+      } else if (NumRegs == 1) {
+        // This argument is not split, so the IR type is the memory type.
+        if (ArgVT.isExtended()) {
+          // We have an extended type, like i24, so we should just use the
+          // register type.
+          MemVT = RegisterVT;
+        } else {
+          MemVT = ArgVT;
+        }
+      } else if (ArgVT.isVector() && RegisterVT.isVector() &&
+                 ArgVT.getScalarType() == RegisterVT.getScalarType()) {
+        assert(ArgVT.getVectorNumElements() > RegisterVT.getVectorNumElements());
+        // We have a vector value which has been split into a vector with
+        // the same scalar type, but fewer elements.  This should handle
+        // all the floating-point vector types.
+        MemVT = RegisterVT;
+      } else if (ArgVT.isVector() &&
+                 ArgVT.getVectorNumElements() == NumRegs) {
+        // This arg has been split so that each element is stored in a separate
+        // register.
+        MemVT = ArgVT.getScalarType();
+      } else if (ArgVT.isExtended()) {
+        // We have an extended type, like i65.
+        MemVT = RegisterVT;
       } else {
-        MemVT = In.ArgVT;
+        unsigned MemoryBits = ArgVT.getStoreSizeInBits() / NumRegs;
+        assert(ArgVT.getStoreSizeInBits() % NumRegs == 0);
+        if (RegisterVT.isInteger()) {
+          MemVT = EVT::getIntegerVT(State.getContext(), MemoryBits);
+        } else if (RegisterVT.isVector()) {
+          assert(!RegisterVT.getScalarType().isFloatingPoint());
+          unsigned NumElements = RegisterVT.getVectorNumElements();
+          assert(MemoryBits % NumElements == 0);
+          // This vector type has been split into another vector type with
+          // a different elements size.
+          EVT ScalarVT = EVT::getIntegerVT(State.getContext(),
+                                           MemoryBits / NumElements);
+          MemVT = EVT::getVectorVT(State.getContext(), ScalarVT, NumElements);
+        } else {
+          llvm_unreachable("cannot deduce memory type.");
+        }
       }
-    } else if (In.ArgVT.isVector() && In.VT.isVector() &&
-               In.ArgVT.getScalarType() == In.VT.getScalarType()) {
-      assert(In.ArgVT.getVectorNumElements() > In.VT.getVectorNumElements());
-      // We have a vector value which has been split into a vector with
-      // the same scalar type, but fewer elements.  This should handle
-      // all the floating-point vector types.
-      MemVT = In.VT;
-    } else if (In.ArgVT.isVector() &&
-               In.ArgVT.getVectorNumElements() == NumRegs) {
-      // This arg has been split so that each element is stored in a separate
-      // register.
-      MemVT = In.ArgVT.getScalarType();
-    } else if (In.ArgVT.isExtended()) {
-      // We have an extended type, like i65.
-      MemVT = In.VT;
-    } else {
-      unsigned MemoryBits = In.ArgVT.getStoreSizeInBits() / NumRegs;
-      assert(In.ArgVT.getStoreSizeInBits() % NumRegs == 0);
-      if (In.VT.isInteger()) {
-        MemVT = EVT::getIntegerVT(State.getContext(), MemoryBits);
-      } else if (In.VT.isVector()) {
-        assert(!In.VT.getScalarType().isFloatingPoint());
-        unsigned NumElements = In.VT.getVectorNumElements();
-        assert(MemoryBits % NumElements == 0);
-        // This vector type has been split into another vector type with
-        // a different elements size.
-        EVT ScalarVT = EVT::getIntegerVT(State.getContext(),
-                                         MemoryBits / NumElements);
-        MemVT = EVT::getVectorVT(State.getContext(), ScalarVT, NumElements);
-      } else {
-        llvm_unreachable("cannot deduce memory type.");
+
+      // Convert one element vectors to scalar.
+      if (MemVT.isVector() && MemVT.getVectorNumElements() == 1)
+        MemVT = MemVT.getScalarType();
+
+      if (MemVT.isExtended()) {
+        // This should really only happen if we have vec3 arguments
+        assert(MemVT.isVector() && MemVT.getVectorNumElements() == 3);
+        MemVT = MemVT.getPow2VectorType(State.getContext());
+      }
+
+      unsigned PartOffset = 0;
+      for (unsigned i = 0; i != NumRegs; ++i) {
+        State.addLoc(CCValAssign::getCustomMem(InIndex++, RegisterVT,
+                                               BasePartOffset + PartOffset,
+                                               MemVT.getSimpleVT(),
+                                               CCValAssign::Full));
+        PartOffset += MemVT.getStoreSize();
       }
     }
-
-    // Convert one element vectors to scalar.
-    if (MemVT.isVector() && MemVT.getVectorNumElements() == 1)
-      MemVT = MemVT.getScalarType();
-
-    if (MemVT.isExtended()) {
-      // This should really only happen if we have vec3 arguments
-      assert(MemVT.isVector() && MemVT.getVectorNumElements() == 3);
-      MemVT = MemVT.getPow2VectorType(State.getContext());
-    }
-
-    assert(MemVT.isSimple());
-    allocateKernArg(i, In.VT, MemVT.getSimpleVT(), CCValAssign::Full, In.Flags,
-                    State);
   }
 }