Implement expansion in type legalization for add/sub with overflow.  The
expansion is the same as that used by LegalizeDAG.

The resulting code sucks in terms of performance/codesize on x86-32 for a
64-bit operation; I haven't looked into whether different expansions might be
better in general.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@105378 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 8b382bc..341fefb 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -990,6 +990,11 @@
   case ISD::SHL:
   case ISD::SRA:
   case ISD::SRL: ExpandIntRes_Shift(N, Lo, Hi); break;
+
+  case ISD::SADDO:
+  case ISD::SSUBO: ExpandIntRes_SADDSUBO(N, Lo, Hi); break;
+  case ISD::UADDO:
+  case ISD::USUBO: ExpandIntRes_UADDSUBO(N, Lo, Hi); break;
   }
 
   // If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -1716,6 +1721,48 @@
   SplitInteger(MakeLibCall(LC, VT, Ops, 2, true/*irrelevant*/, dl), Lo, Hi);
 }
 
+void DAGTypeLegalizer::ExpandIntRes_SADDSUBO(SDNode *Node,
+                                             SDValue &Lo, SDValue &Hi) {
+  SDValue LHS = Node->getOperand(0);
+  SDValue RHS = Node->getOperand(1);
+  DebugLoc dl = Node->getDebugLoc();
+
+  // Expand the result by simply replacing it with the equivalent
+  // non-overflow-checking operation.
+  SDValue Sum = DAG.getNode(Node->getOpcode() == ISD::SADDO ?
+                            ISD::ADD : ISD::SUB, dl, LHS.getValueType(),
+                            LHS, RHS);
+  SplitInteger(Sum, Lo, Hi);
+
+  // Compute the overflow.
+  //
+  //   LHSSign -> LHS >= 0
+  //   RHSSign -> RHS >= 0
+  //   SumSign -> Sum >= 0
+  //
+  //   Add:
+  //   Overflow -> (LHSSign == RHSSign) && (LHSSign != SumSign)
+  //   Sub:
+  //   Overflow -> (LHSSign != RHSSign) && (LHSSign != SumSign)
+  //
+  EVT OType = Node->getValueType(1);
+  SDValue Zero = DAG.getConstant(0, LHS.getValueType());
+
+  SDValue LHSSign = DAG.getSetCC(dl, OType, LHS, Zero, ISD::SETGE);
+  SDValue RHSSign = DAG.getSetCC(dl, OType, RHS, Zero, ISD::SETGE);
+  SDValue SignsMatch = DAG.getSetCC(dl, OType, LHSSign, RHSSign,
+                                    Node->getOpcode() == ISD::SADDO ?
+                                    ISD::SETEQ : ISD::SETNE);
+
+  SDValue SumSign = DAG.getSetCC(dl, OType, Sum, Zero, ISD::SETGE);
+  SDValue SumSignNE = DAG.getSetCC(dl, OType, LHSSign, SumSign, ISD::SETNE);
+
+  SDValue Cmp = DAG.getNode(ISD::AND, dl, OType, SignsMatch, SumSignNE);
+
+  // Use the calculated overflow everywhere.
+  ReplaceValueWith(SDValue(Node, 1), Cmp);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_SDIV(SDNode *N,
                                          SDValue &Lo, SDValue &Hi) {
   EVT VT = N->getValueType(0);
@@ -1912,6 +1959,29 @@
   Hi = DAG.getNode(ISD::TRUNCATE, dl, NVT, Hi);
 }
 
+void DAGTypeLegalizer::ExpandIntRes_UADDSUBO(SDNode *N,
+                                             SDValue &Lo, SDValue &Hi) {
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  DebugLoc dl = N->getDebugLoc();
+
+  // Expand the result by simply replacing it with the equivalent
+  // non-overflow-checking operation.
+  SDValue Sum = DAG.getNode(N->getOpcode() == ISD::UADDO ?
+                            ISD::ADD : ISD::SUB, dl, LHS.getValueType(),
+                            LHS, RHS);
+  SplitInteger(Sum, Lo, Hi);
+
+  // Calculate the overflow: addition overflows iff a + b < a, and subtraction
+  // overflows iff a - b > a.
+  SDValue Ofl = DAG.getSetCC(dl, N->getValueType(1), Sum, LHS,
+                             N->getOpcode () == ISD::UADDO ?
+                             ISD::SETULT : ISD::SETUGT);
+
+  // Use the calculated overflow everywhere.
+  ReplaceValueWith(SDValue(N, 1), Ofl);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_UDIV(SDNode *N,
                                          SDValue &Lo, SDValue &Hi) {
   EVT VT = N->getValueType(0);
diff --git a/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index db5e2a1..bd86694 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -345,6 +345,9 @@
   void ExpandIntRes_UREM              (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_Shift             (SDNode *N, SDValue &Lo, SDValue &Hi);
 
+  void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
+
   void ExpandShiftByConstant(SDNode *N, unsigned Amt,
                              SDValue &Lo, SDValue &Hi);
   bool ExpandShiftWithKnownAmountBit(SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/test/CodeGen/Generic/add-with-overflow-128.ll b/test/CodeGen/Generic/add-with-overflow-128.ll
new file mode 100644
index 0000000..c46c820
--- /dev/null
+++ b/test/CodeGen/Generic/add-with-overflow-128.ll
@@ -0,0 +1,42 @@
+; RUN: llc < %s
+
+@ok = internal constant [4 x i8] c"%d\0A\00"
+@no = internal constant [4 x i8] c"no\0A\00"
+
+define i1 @func1(i128 signext %v1, i128 signext %v2) nounwind {
+entry:
+  %t = call {i128, i1} @llvm.sadd.with.overflow.i128(i128 %v1, i128 %v2)
+  %sum = extractvalue {i128, i1} %t, 0
+  %sum32 = trunc i128 %sum to i32
+  %obit = extractvalue {i128, i1} %t, 1
+  br i1 %obit, label %overflow, label %normal
+
+normal:
+  %t1 = tail call i32 (i8*, ...)* @printf( i8* getelementptr ([4 x i8]* @ok, i32 0, i32 0), i32 %sum32 ) nounwind
+  ret i1 true
+
+overflow:
+  %t2 = tail call i32 (i8*, ...)* @printf( i8* getelementptr ([4 x i8]* @no, i32 0, i32 0) ) nounwind
+  ret i1 false
+}
+
+define i1 @func2(i128 zeroext %v1, i128 zeroext %v2) nounwind {
+entry:
+  %t = call {i128, i1} @llvm.uadd.with.overflow.i128(i128 %v1, i128 %v2)
+  %sum = extractvalue {i128, i1} %t, 0
+  %sum32 = trunc i128 %sum to i32
+  %obit = extractvalue {i128, i1} %t, 1
+  br i1 %obit, label %carry, label %normal
+
+normal:
+  %t1 = tail call i32 (i8*, ...)* @printf( i8* getelementptr ([4 x i8]* @ok, i32 0, i32 0), i32 %sum32 ) nounwind
+  ret i1 true
+
+carry:
+  %t2 = tail call i32 (i8*, ...)* @printf( i8* getelementptr ([4 x i8]* @no, i32 0, i32 0) ) nounwind
+  ret i1 false
+}
+
+declare i32 @printf(i8*, ...) nounwind
+declare {i128, i1} @llvm.sadd.with.overflow.i128(i128, i128)
+declare {i128, i1} @llvm.uadd.with.overflow.i128(i128, i128)