[ARM] Lower (select_cc k k (select_cc ~k ~k x)) into (SSAT l_k x)

Summary:
SSAT saturates an integer, making sure that its value lies within
an interval [-k, k]. Since the constant is given to SSAT as the
number of bytes set to one, k + 1 must be a power of 2, otherwise
the optimization is not possible. Also, the select_cc must use <
and > respectively so that they define an interval.

Reviewers: mcrosier, jmolloy, rengolin

Subscribers: aemerson, rengolin, llvm-commits

Differential Revision: http://reviews.llvm.org/D21372

llvm-svn: 273581
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 4c575bd..d55a1dd 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -1136,6 +1136,8 @@
 
   case ARMISD::CMOV:          return "ARMISD::CMOV";
 
+  case ARMISD::SSAT:          return "ARMISD::SSAT";
+
   case ARMISD::SRL_FLAG:      return "ARMISD::SRL_FLAG";
   case ARMISD::SRA_FLAG:      return "ARMISD::SRA_FLAG";
   case ARMISD::RRX:           return "ARMISD::RRX";
@@ -3728,14 +3730,144 @@
   }
 }
 
+bool isGTorGE(ISD::CondCode CC) { return CC == ISD::SETGT || CC == ISD::SETGE; }
+
+bool isLTorLE(ISD::CondCode CC) { return CC == ISD::SETLT || CC == ISD::SETLE; }
+
+// See if a conditional (LHS CC RHS ? TrueVal : FalseVal) is lower-saturating.
+// All of these conditions (and their <= and >= counterparts) will do:
+//          x < k ? k : x
+//          x > k ? x : k
+//          k < x ? x : k
+//          k > x ? k : x
+bool isLowerSaturate(const SDValue LHS, const SDValue RHS,
+                     const SDValue TrueVal, const SDValue FalseVal,
+                     const ISD::CondCode CC, const SDValue K) {
+  return (isGTorGE(CC) &&
+          ((K == LHS && K == TrueVal) || (K == RHS && K == FalseVal))) ||
+         (isLTorLE(CC) &&
+          ((K == RHS && K == TrueVal) || (K == LHS && K == FalseVal)));
+}
+
+// Similar to isLowerSaturate(), but checks for upper-saturating conditions.
+bool isUpperSaturate(const SDValue LHS, const SDValue RHS,
+                     const SDValue TrueVal, const SDValue FalseVal,
+                     const ISD::CondCode CC, const SDValue K) {
+  return (isGTorGE(CC) &&
+          ((K == RHS && K == TrueVal) || (K == LHS && K == FalseVal))) ||
+         (isLTorLE(CC) &&
+          ((K == LHS && K == TrueVal) || (K == RHS && K == FalseVal)));
+}
+
+// Check if two chained conditionals could be converted into SSAT.
+//
+// SSAT can replace a set of two conditional selectors that bound a number to an
+// interval of type [k, ~k] when k + 1 is a power of 2. Here are some examples:
+//
+//     x < -k ? -k : (x > k ? k : x)
+//     x < -k ? -k : (x < k ? x : k)
+//     x > -k ? (x > k ? k : x) : -k
+//     x < k ? (x < -k ? -k : x) : k
+//     etc.
+//
+// It returns true if the conversion can be done, false otherwise.
+// Additionally, the variable is returned in parameter V and the constant in K.
+bool isSaturatingConditional(const SDValue &Op, SDValue &V, uint64_t &K) {
+
+  SDValue LHS1 = Op.getOperand(0);
+  SDValue RHS1 = Op.getOperand(1);
+  SDValue TrueVal1 = Op.getOperand(2);
+  SDValue FalseVal1 = Op.getOperand(3);
+  ISD::CondCode CC1 = cast<CondCodeSDNode>(Op.getOperand(4))->get();
+
+  const SDValue Op2 = isa<ConstantSDNode>(TrueVal1) ? FalseVal1 : TrueVal1;
+  if (Op2.getOpcode() != ISD::SELECT_CC)
+    return false;
+
+  SDValue LHS2 = Op2.getOperand(0);
+  SDValue RHS2 = Op2.getOperand(1);
+  SDValue TrueVal2 = Op2.getOperand(2);
+  SDValue FalseVal2 = Op2.getOperand(3);
+  ISD::CondCode CC2 = cast<CondCodeSDNode>(Op2.getOperand(4))->get();
+
+  // Find out which are the constants and which are the variables
+  // in each conditional
+  SDValue *K1 = isa<ConstantSDNode>(LHS1) ? &LHS1 : isa<ConstantSDNode>(RHS1)
+                                                        ? &RHS1
+                                                        : NULL;
+  SDValue *K2 = isa<ConstantSDNode>(LHS2) ? &LHS2 : isa<ConstantSDNode>(RHS2)
+                                                        ? &RHS2
+                                                        : NULL;
+  SDValue K2Tmp = isa<ConstantSDNode>(TrueVal2) ? TrueVal2 : FalseVal2;
+  SDValue V1Tmp = (K1 && *K1 == LHS1) ? RHS1 : LHS1;
+  SDValue V2Tmp = (K2 && *K2 == LHS2) ? RHS2 : LHS2;
+  SDValue V2 = (K2Tmp == TrueVal2) ? FalseVal2 : TrueVal2;
+
+  // We must detect cases where the original operations worked with 16- or
+  // 8-bit values. In such case, V2Tmp != V2 because the comparison operations
+  // must work with sign-extended values but the select operations return
+  // the original non-extended value.
+  SDValue V2TmpReg = V2Tmp;
+  if (V2Tmp->getOpcode() == ISD::SIGN_EXTEND_INREG)
+    V2TmpReg = V2Tmp->getOperand(0);
+
+  // Check that the registers and the constants have the correct values
+  // in both conditionals
+  if (!K1 || !K2 || *K1 == Op2 || *K2 != K2Tmp || V1Tmp != V2Tmp ||
+      V2TmpReg != V2)
+    return false;
+
+  // Figure out which conditional is saturating the lower/upper bound.
+  const SDValue *LowerCheckOp =
+      isLowerSaturate(LHS1, RHS1, TrueVal1, FalseVal1, CC1, *K1)
+          ? &Op
+          : isLowerSaturate(LHS2, RHS2, TrueVal2, FalseVal2, CC2, *K2) ? &Op2
+                                                                       : NULL;
+  const SDValue *UpperCheckOp =
+      isUpperSaturate(LHS1, RHS1, TrueVal1, FalseVal1, CC1, *K1)
+          ? &Op
+          : isUpperSaturate(LHS2, RHS2, TrueVal2, FalseVal2, CC2, *K2) ? &Op2
+                                                                       : NULL;
+
+  if (!UpperCheckOp || !LowerCheckOp || LowerCheckOp == UpperCheckOp)
+    return false;
+
+  // Check that the constant in the lower-bound check is
+  // the opposite of the constant in the upper-bound check
+  // in 1's complement.
+  int64_t Val1 = cast<ConstantSDNode>(*K1)->getSExtValue();
+  int64_t Val2 = cast<ConstantSDNode>(*K2)->getSExtValue();
+  int64_t PosVal = std::max(Val1, Val2);
+
+  if (((Val1 > Val2 && UpperCheckOp == &Op) ||
+       (Val1 < Val2 && UpperCheckOp == &Op2)) &&
+      Val1 == ~Val2 && isPowerOf2_64(PosVal + 1)) {
+
+    V = V2;
+    K = (uint64_t)PosVal; // At this point, PosVal is guaranteed to be positive
+    return true;
+  }
+
+  return false;
+}
+
 SDValue ARMTargetLowering::LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const {
+
   EVT VT = Op.getValueType();
+  SDLoc dl(Op);
+
+  // Try to convert two saturating conditional selects into a single SSAT
+  SDValue SatValue;
+  uint64_t SatConstant;
+  if (isSaturatingConditional(Op, SatValue, SatConstant))
+    return DAG.getNode(ARMISD::SSAT, dl, VT, SatValue,
+                       DAG.getConstant(countTrailingOnes(SatConstant), dl, VT));
+
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
   SDValue TrueVal = Op.getOperand(2);
   SDValue FalseVal = Op.getOperand(3);
-  SDLoc dl(Op);
 
   if (Subtarget->isFPOnlySP() && LHS.getValueType() == MVT::f64) {
     DAG.getTargetLoweringInfo().softenSetCCOperands(DAG, MVT::f64, LHS, RHS, CC,