[Hexagon] Implement HVX codegen for vector shifts

llvm-svn: 323914
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
index 9535d0f..b0781c1 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
@@ -2109,8 +2109,13 @@
       setOperationAction(ISD::ANY_EXTEND,         T, Custom);
       setOperationAction(ISD::SIGN_EXTEND,        T, Custom);
       setOperationAction(ISD::ZERO_EXTEND,        T, Custom);
-      if (T != ByteV)
+      if (T != ByteV) {
         setOperationAction(ISD::ANY_EXTEND_VECTOR_INREG, T, Custom);
+        // HVX only has shifts of words and halfwords.
+        setOperationAction(ISD::SRA,                     T, Custom);
+        setOperationAction(ISD::SHL,                     T, Custom);
+        setOperationAction(ISD::SRL,                     T, Custom);
+      }
     }
 
     for (MVT T : LegalV) {
@@ -2523,76 +2528,37 @@
   return SDValue();
 }
 
-// If BUILD_VECTOR has same base element repeated several times,
-// report true.
-static bool isCommonSplatElement(BuildVectorSDNode *BVN) {
-  unsigned NElts = BVN->getNumOperands();
-  SDValue V0 = BVN->getOperand(0);
-
-  for (unsigned i = 1, e = NElts; i != e; ++i) {
-    if (BVN->getOperand(i) != V0)
-      return false;
-  }
-  return true;
-}
-
 // Lower a vector shift. Try to convert
 // <VT> = SHL/SRA/SRL <VT> by <VT> to Hexagon specific
 // <VT> = SHL/SRA/SRL <VT> by <IT/i32>.
 SDValue
 HexagonTargetLowering::LowerVECTOR_SHIFT(SDValue Op, SelectionDAG &DAG) const {
-  BuildVectorSDNode *BVN = nullptr;
-  SDValue V1 = Op.getOperand(0);
-  SDValue V2 = Op.getOperand(1);
-  SDValue V3;
-  SDLoc dl(Op);
-  EVT VT = Op.getValueType();
+  const SDLoc dl(Op);
 
-  if ((BVN = dyn_cast<BuildVectorSDNode>(V1.getNode())) &&
-      isCommonSplatElement(BVN))
-    V3 = V2;
-  else if ((BVN = dyn_cast<BuildVectorSDNode>(V2.getNode())) &&
-           isCommonSplatElement(BVN))
-    V3 = V1;
-  else
-    return SDValue();
-
-  SDValue CommonSplat = BVN->getOperand(0);
-  SDValue Result;
-
-  if (VT.getSimpleVT() == MVT::v4i16) {
-    switch (Op.getOpcode()) {
-    case ISD::SRA:
-      Result = DAG.getNode(HexagonISD::VASR, dl, VT, V3, CommonSplat);
-      break;
-    case ISD::SHL:
-      Result = DAG.getNode(HexagonISD::VASL, dl, VT, V3, CommonSplat);
-      break;
-    case ISD::SRL:
-      Result = DAG.getNode(HexagonISD::VLSR, dl, VT, V3, CommonSplat);
-      break;
-    default:
-      return SDValue();
+  if (auto *BVN = dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode())) {
+    if (SDValue S = BVN->getSplatValue()) {
+      unsigned NewOpc;
+      switch (Op.getOpcode()) {
+        case ISD::SHL:
+          NewOpc = HexagonISD::VASL;
+          break;
+        case ISD::SRA:
+          NewOpc = HexagonISD::VASR;
+          break;
+        case ISD::SRL:
+          NewOpc = HexagonISD::VLSR;
+          break;
+        default:
+          llvm_unreachable("Unexpected shift opcode");
+      }
+      return DAG.getNode(NewOpc, dl, ty(Op), Op.getOperand(0), S);
     }
-  } else if (VT.getSimpleVT() == MVT::v2i32) {
-    switch (Op.getOpcode()) {
-    case ISD::SRA:
-      Result = DAG.getNode(HexagonISD::VASR, dl, VT, V3, CommonSplat);
-      break;
-    case ISD::SHL:
-      Result = DAG.getNode(HexagonISD::VASL, dl, VT, V3, CommonSplat);
-      break;
-    case ISD::SRL:
-      Result = DAG.getNode(HexagonISD::VLSR, dl, VT, V3, CommonSplat);
-      break;
-    default:
-      return SDValue();
-    }
-  } else {
-    return SDValue();
   }
 
-  return DAG.getNode(ISD::BITCAST, dl, VT, Result);
+  if (Subtarget.useHVXOps() && Subtarget.isHVXVectorType(ty(Op)))
+    return LowerHvxShift(Op, DAG);
+
+  return SDValue();
 }
 
 SDValue