When checking for sret-demotion, it needs to use legal types.  When using the return value of an sret-demoted call, it needs to use possibly illegal types that match the declared Type of the callee.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@93667 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5e3a3b5..a3fb345 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -800,18 +800,19 @@
                              SDNodeOrder, Chain, NULL);
 }
 
-/// Get the EVTs and ArgFlags collections that represent the return type
-/// of the given function.  This does not require a DAG or a return value, and
-/// is suitable for use before any DAGs for the function are constructed.
+/// Get the EVTs and ArgFlags collections that represent the legalized return 
+/// type of the given function.  This does not require a DAG or a return value,
+/// and is suitable for use before any DAGs for the function are constructed.
 static void getReturnInfo(const Type* ReturnType,
                    Attributes attr, SmallVectorImpl<EVT> &OutVTs,
                    SmallVectorImpl<ISD::ArgFlagsTy> &OutFlags,
                    TargetLowering &TLI,
                    SmallVectorImpl<uint64_t> *Offsets = 0) {
   SmallVector<EVT, 4> ValueVTs;
-  ComputeValueVTs(TLI, ReturnType, ValueVTs, Offsets);
+  ComputeValueVTs(TLI, ReturnType, ValueVTs);
   unsigned NumValues = ValueVTs.size();
-  if ( NumValues == 0 ) return;
+  if (NumValues == 0) return;
+  unsigned Offset = 0;
 
   for (unsigned j = 0, f = NumValues; j != f; ++j) {
     EVT VT = ValueVTs[j];
@@ -834,6 +835,9 @@
 
     unsigned NumParts = TLI.getNumRegisters(ReturnType->getContext(), VT);
     EVT PartVT = TLI.getRegisterType(ReturnType->getContext(), VT);
+    unsigned PartSize = TLI.getTargetData()->getTypeAllocSize(
+                        PartVT.getTypeForEVT(ReturnType->getContext()));
+
     // 'inreg' on function refers to return value
     ISD::ArgFlagsTy Flags = ISD::ArgFlagsTy();
     if (attr & Attribute::InReg)
@@ -848,6 +852,11 @@
     for (unsigned i = 0; i < NumParts; ++i) {
       OutVTs.push_back(PartVT);
       OutFlags.push_back(Flags);
+      if (Offsets)
+      {
+        Offsets->push_back(Offset);
+        Offset += PartSize;
+      }
     }
   }
 }
@@ -5098,16 +5107,37 @@
     SDValue Chain = DAG.getNode(ISD::TokenFactor, getCurDebugLoc(),
                                 MVT::Other, &Chains[0], NumValues);
     PendingLoads.push_back(Chain);
+    
+    // Collect the legal value parts into potentially illegal values
+    // that correspond to the original function's return values.
+    SmallVector<EVT, 4> RetTys;
+    RetTy = FTy->getReturnType();
+    ComputeValueVTs(TLI, RetTy, RetTys);
+    ISD::NodeType AssertOp = ISD::DELETED_NODE;
+    SmallVector<SDValue, 4> ReturnValues;
+    unsigned CurReg = 0;
+    for (unsigned I = 0, E = RetTys.size(); I != E; ++I) {
+      EVT VT = RetTys[I];
+      EVT RegisterVT = TLI.getRegisterType(RetTy->getContext(), VT);
+      unsigned NumRegs = TLI.getNumRegisters(RetTy->getContext(), VT);
+  
+      SDValue ReturnValue =
+        getCopyFromParts(DAG, getCurDebugLoc(), SDNodeOrder, &Values[CurReg], NumRegs,
+                         RegisterVT, VT, AssertOp);
+      ReturnValues.push_back(ReturnValue);
+      if (DisableScheduling)
+        DAG.AssignOrdering(ReturnValue.getNode(), SDNodeOrder);
+      CurReg += NumRegs;
+    }
+    SDValue Res = DAG.getNode(ISD::MERGE_VALUES, getCurDebugLoc(),
+                              DAG.getVTList(&RetTys[0], RetTys.size()),
+                              &ReturnValues[0], ReturnValues.size());
 
-    SDValue MV = DAG.getNode(ISD::MERGE_VALUES,
-                             getCurDebugLoc(),
-                             DAG.getVTList(&OutVTs[0], NumValues),
-                             &Values[0], NumValues);
-    setValue(CS.getInstruction(), MV);
+    setValue(CS.getInstruction(), Res);
 
     if (DisableScheduling) {
       DAG.AssignOrdering(Chain.getNode(), SDNodeOrder);
-      DAG.AssignOrdering(MV.getNode(), SDNodeOrder);
+      DAG.AssignOrdering(Res.getNode(), SDNodeOrder);
     }
   }