[Hexagon] Add/fix patterns for 32/64-bit vector compares and logical ops

llvm-svn: 330330
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
index 91af9e2..40740f1 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
@@ -803,63 +803,62 @@
                       MachinePointerInfo(SV));
 }
 
-static bool isSExtFree(SDValue N) {
-  // A sign-extend of a truncate of a sign-extend is free.
-  if (N.getOpcode() == ISD::TRUNCATE &&
-      N.getOperand(0).getOpcode() == ISD::AssertSext)
-    return true;
-  // We have sign-extended loads.
-  if (N.getOpcode() == ISD::LOAD)
-    return true;
-  return false;
-}
-
 SDValue HexagonTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
-  SDLoc dl(Op);
+  const SDLoc &dl(Op);
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
-  SDValue Cmp = Op.getOperand(2);
-  ISD::CondCode CC = cast<CondCodeSDNode>(Cmp)->get();
+  ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
+  MVT ResTy = ty(Op);
+  MVT OpTy = ty(LHS);
 
-  EVT VT = Op.getValueType();
-  EVT LHSVT = LHS.getValueType();
-  EVT RHSVT = RHS.getValueType();
-
-  if (LHSVT == MVT::v2i16) {
-    assert(CC == ISD::SETEQ || CC == ISD::SETNE ||
-           ISD::isSignedIntSetCC(CC) || ISD::isUnsignedIntSetCC(CC));
-    unsigned ExtOpc = ISD::isSignedIntSetCC(CC) ? ISD::SIGN_EXTEND
-                                                : ISD::ZERO_EXTEND;
-    SDValue LX = DAG.getNode(ExtOpc, dl, MVT::v2i32, LHS);
-    SDValue RX = DAG.getNode(ExtOpc, dl, MVT::v2i32, RHS);
-    SDValue SC = DAG.getNode(ISD::SETCC, dl, MVT::v2i1, LX, RX, Cmp);
-    return SC;
+  if (OpTy == MVT::v2i16 || OpTy == MVT::v4i8) {
+    MVT ElemTy = OpTy.getVectorElementType();
+    assert(ElemTy.isScalarInteger());
+    MVT WideTy = MVT::getVectorVT(MVT::getIntegerVT(2*ElemTy.getSizeInBits()),
+                                  OpTy.getVectorNumElements());
+    return DAG.getSetCC(dl, ResTy,
+                        DAG.getSExtOrTrunc(LHS, SDLoc(LHS), WideTy),
+                        DAG.getSExtOrTrunc(RHS, SDLoc(RHS), WideTy), CC);
   }
 
   // Treat all other vector types as legal.
-  if (VT.isVector())
+  if (ResTy.isVector())
     return Op;
 
-  // Equals and not equals should use sign-extend, not zero-extend, since
-  // we can represent small negative values in the compare instructions.
+  // Comparisons of short integers should use sign-extend, not zero-extend,
+  // since we can represent small negative values in the compare instructions.
   // The LLVM default is to use zero-extend arbitrarily in these cases.
-  if ((CC == ISD::SETEQ || CC == ISD::SETNE) &&
-      (RHSVT == MVT::i8 || RHSVT == MVT::i16) &&
-      (LHSVT == MVT::i8 || LHSVT == MVT::i16)) {
+  auto isSExtFree = [this](SDValue N) {
+    switch (N.getOpcode()) {
+      case ISD::TRUNCATE: {
+        // A sign-extend of a truncate of a sign-extend is free.
+        SDValue Op = N.getOperand(0);
+        if (Op.getOpcode() != ISD::AssertSext)
+          return false;
+        MVT OrigTy = cast<VTSDNode>(Op.getOperand(1))->getVT().getSimpleVT();
+        unsigned ThisBW = ty(N).getSizeInBits();
+        unsigned OrigBW = OrigTy.getSizeInBits();
+        // The type that was sign-extended to get the AssertSext must be
+        // narrower than the type of N (so that N has still the same value
+        // as the original).
+        return ThisBW >= OrigBW;
+      }
+      case ISD::LOAD:
+        // We have sign-extended loads.
+        return true;
+    }
+    return false;
+  };
+
+  if (OpTy == MVT::i8 || OpTy == MVT::i16) {
     ConstantSDNode *C = dyn_cast<ConstantSDNode>(RHS);
-    if (C && C->getAPIntValue().isNegative()) {
-      LHS = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i32, LHS);
-      RHS = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i32, RHS);
-      return DAG.getNode(ISD::SETCC, dl, Op.getValueType(),
-                         LHS, RHS, Op.getOperand(2));
-    }
-    if (isSExtFree(LHS) || isSExtFree(RHS)) {
-      LHS = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i32, LHS);
-      RHS = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i32, RHS);
-      return DAG.getNode(ISD::SETCC, dl, Op.getValueType(),
-                         LHS, RHS, Op.getOperand(2));
-    }
+    bool IsNegative = C && C->getAPIntValue().isNegative();
+    if (IsNegative || isSExtFree(LHS) || isSExtFree(RHS))
+      return DAG.getSetCC(dl, ResTy,
+                          DAG.getSExtOrTrunc(LHS, SDLoc(LHS), MVT::i32),
+                          DAG.getSExtOrTrunc(RHS, SDLoc(RHS), MVT::i32), CC);
   }
+
   return SDValue();
 }
 
@@ -1306,8 +1305,10 @@
   setOperationAction(ISD::BlockAddress,  MVT::i32, Custom);
 
   // Hexagon needs to optimize cases with negative constants.
-  setOperationAction(ISD::SETCC, MVT::i8,  Custom);
-  setOperationAction(ISD::SETCC, MVT::i16, Custom);
+  setOperationAction(ISD::SETCC, MVT::i8,    Custom);
+  setOperationAction(ISD::SETCC, MVT::i16,   Custom);
+  setOperationAction(ISD::SETCC, MVT::v4i8,  Custom);
+  setOperationAction(ISD::SETCC, MVT::v2i16, Custom);
 
   // VASTART needs to be custom lowered to use the VarArgsFrameIndex.
   setOperationAction(ISD::VASTART, MVT::Other, Custom);