Lower multiply with overflow checking to __mulo<mode>
calls if we haven't been able to lower them any
other way.

Fixes rdar://9090077 and rdar://9210061


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@133288 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index da75b8a..f0aa9d6 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -19,6 +19,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "LegalizeTypes.h"
+#include "llvm/DerivedTypes.h"
 #include "llvm/CodeGen/PseudoSourceValue.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
@@ -1072,6 +1073,8 @@
   case ISD::SSUBO: ExpandIntRes_SADDSUBO(N, Lo, Hi); break;
   case ISD::UADDO:
   case ISD::USUBO: ExpandIntRes_UADDSUBO(N, Lo, Hi); break;
+  case ISD::UMULO:
+  case ISD::SMULO: ExpandIntRes_XMULO(N, Lo, Hi); break;
   }
 
   // If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -2149,6 +2152,66 @@
   ReplaceValueWith(SDValue(N, 1), Ofl);
 }
 
+void DAGTypeLegalizer::ExpandIntRes_XMULO(SDNode *N,
+                                          SDValue &Lo, SDValue &Hi) {
+  EVT VT = N->getValueType(0);
+  const Type *RetTy = VT.getTypeForEVT(*DAG.getContext());
+  EVT PtrVT = TLI.getPointerTy();
+  const Type *PtrTy = PtrVT.getTypeForEVT(*DAG.getContext());
+  DebugLoc dl = N->getDebugLoc();
+
+  // Expand the result by simply replacing it with the equivalent
+  // non-overflow-checking operation.
+  RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+  if (VT == MVT::i32)
+    LC = RTLIB::MULO_I32;
+  else if (VT == MVT::i64)
+    LC = RTLIB::MULO_I64;
+  else if (VT == MVT::i128)
+    LC = RTLIB::MULO_I128;
+  assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported XMULO!");
+
+  SDValue Temp = DAG.CreateStackTemporary(PtrVT);
+  // Temporary for the overflow value, default it to zero.
+  SDValue Chain = DAG.getStore(DAG.getEntryNode(), dl,
+			       DAG.getConstant(0, PtrVT), Temp,
+			       MachinePointerInfo(), false, false, 0);
+
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
+    EVT ArgVT = N->getOperand(i).getValueType();
+    const Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
+    Entry.Node = N->getOperand(i);
+    Entry.Ty = ArgTy;
+    Entry.isSExt = true;
+    Entry.isZExt = false;
+    Args.push_back(Entry);
+  }
+
+  // Also pass the address of the overflow check.
+  Entry.Node = Temp;
+  Entry.Ty = PtrTy->getPointerTo();
+  Entry.isSExt = true;
+  Entry.isZExt = false;
+  Args.push_back(Entry);
+
+  SDValue Func = DAG.getExternalSymbol(TLI.getLibcallName(LC), PtrVT);
+  std::pair<SDValue, SDValue> CallInfo =
+    TLI.LowerCallTo(Chain, RetTy, true, false, false, false,
+		    0, TLI.getLibcallCallingConv(LC), false,
+		    true, Func, Args, DAG, dl);
+
+  SplitInteger(CallInfo.first, Lo, Hi);
+  SDValue Temp2 = DAG.getLoad(PtrVT, dl, CallInfo.second, Temp,
+			      MachinePointerInfo(), false, false, 0);
+  SDValue Ofl = DAG.getSetCC(dl, N->getValueType(1), Temp2,
+                             DAG.getConstant(0, PtrVT),
+                             ISD::SETNE);
+  // Use the overflow from the libcall everywhere.
+  ReplaceValueWith(SDValue(N, 1), Ofl);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_UDIV(SDNode *N,
                                          SDValue &Lo, SDValue &Hi) {
   EVT VT = N->getValueType(0);
diff --git a/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 06dc40f..4597ec9 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -318,6 +318,7 @@
 
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_XMULO             (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandShiftByConstant(SDNode *N, unsigned Amt,
                              SDValue &Lo, SDValue &Hi);
diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index efbfaa4..474dd7a 100644
--- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -81,6 +81,9 @@
   Names[RTLIB::MUL_I32] = "__mulsi3";
   Names[RTLIB::MUL_I64] = "__muldi3";
   Names[RTLIB::MUL_I128] = "__multi3";
+  Names[RTLIB::MULO_I32] = "__mulosi4";
+  Names[RTLIB::MULO_I64] = "__mulodi4";
+  Names[RTLIB::MULO_I128] = "__muloti4";
   Names[RTLIB::SDIV_I8] = "__divqi3";
   Names[RTLIB::SDIV_I16] = "__divhi3";
   Names[RTLIB::SDIV_I32] = "__divsi3";
@@ -1914,7 +1917,7 @@
   // comparisons.
   if (isa<ConstantSDNode>(N0.getNode()))
     return DAG.getSetCC(dl, VT, N1, N0, ISD::getSetCCSwappedOperands(Cond));
-  
+
   if (ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode())) {
     const APInt &C1 = N1C->getAPIntValue();
 
@@ -2673,9 +2676,9 @@
                                                   std::string &Constraint,
                                                   std::vector<SDValue> &Ops,
                                                   SelectionDAG &DAG) const {
-  
+
   if (Constraint.length() > 1) return;
-  
+
   char ConstraintLetter = Constraint[0];
   switch (ConstraintLetter) {
   default: break;
@@ -2865,7 +2868,7 @@
           report_fatal_error("Indirect operand for inline asm not a pointer!");
         OpTy = PtrTy->getElementType();
       }
-      
+
       // Look for vector wrapped in a struct. e.g. { <16 x i8> }.
       if (const StructType *STy = dyn_cast<StructType>(OpTy))
         if (STy->getNumElements() == 1)