[Hexagon] Implement HVX codegen for vector shifts

llvm-svn: 323914
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
index 9535d0f..b0781c1 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
@@ -2109,8 +2109,13 @@
       setOperationAction(ISD::ANY_EXTEND,         T, Custom);
       setOperationAction(ISD::SIGN_EXTEND,        T, Custom);
       setOperationAction(ISD::ZERO_EXTEND,        T, Custom);
-      if (T != ByteV)
+      if (T != ByteV) {
         setOperationAction(ISD::ANY_EXTEND_VECTOR_INREG, T, Custom);
+        // HVX only has shifts of words and halfwords.
+        setOperationAction(ISD::SRA,                     T, Custom);
+        setOperationAction(ISD::SHL,                     T, Custom);
+        setOperationAction(ISD::SRL,                     T, Custom);
+      }
     }
 
     for (MVT T : LegalV) {
@@ -2523,76 +2528,37 @@
   return SDValue();
 }
 
-// If BUILD_VECTOR has same base element repeated several times,
-// report true.
-static bool isCommonSplatElement(BuildVectorSDNode *BVN) {
-  unsigned NElts = BVN->getNumOperands();
-  SDValue V0 = BVN->getOperand(0);
-
-  for (unsigned i = 1, e = NElts; i != e; ++i) {
-    if (BVN->getOperand(i) != V0)
-      return false;
-  }
-  return true;
-}
-
 // Lower a vector shift. Try to convert
 // <VT> = SHL/SRA/SRL <VT> by <VT> to Hexagon specific
 // <VT> = SHL/SRA/SRL <VT> by <IT/i32>.
 SDValue
 HexagonTargetLowering::LowerVECTOR_SHIFT(SDValue Op, SelectionDAG &DAG) const {
-  BuildVectorSDNode *BVN = nullptr;
-  SDValue V1 = Op.getOperand(0);
-  SDValue V2 = Op.getOperand(1);
-  SDValue V3;
-  SDLoc dl(Op);
-  EVT VT = Op.getValueType();
+  const SDLoc dl(Op);
 
-  if ((BVN = dyn_cast<BuildVectorSDNode>(V1.getNode())) &&
-      isCommonSplatElement(BVN))
-    V3 = V2;
-  else if ((BVN = dyn_cast<BuildVectorSDNode>(V2.getNode())) &&
-           isCommonSplatElement(BVN))
-    V3 = V1;
-  else
-    return SDValue();
-
-  SDValue CommonSplat = BVN->getOperand(0);
-  SDValue Result;
-
-  if (VT.getSimpleVT() == MVT::v4i16) {
-    switch (Op.getOpcode()) {
-    case ISD::SRA:
-      Result = DAG.getNode(HexagonISD::VASR, dl, VT, V3, CommonSplat);
-      break;
-    case ISD::SHL:
-      Result = DAG.getNode(HexagonISD::VASL, dl, VT, V3, CommonSplat);
-      break;
-    case ISD::SRL:
-      Result = DAG.getNode(HexagonISD::VLSR, dl, VT, V3, CommonSplat);
-      break;
-    default:
-      return SDValue();
+  if (auto *BVN = dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode())) {
+    if (SDValue S = BVN->getSplatValue()) {
+      unsigned NewOpc;
+      switch (Op.getOpcode()) {
+        case ISD::SHL:
+          NewOpc = HexagonISD::VASL;
+          break;
+        case ISD::SRA:
+          NewOpc = HexagonISD::VASR;
+          break;
+        case ISD::SRL:
+          NewOpc = HexagonISD::VLSR;
+          break;
+        default:
+          llvm_unreachable("Unexpected shift opcode");
+      }
+      return DAG.getNode(NewOpc, dl, ty(Op), Op.getOperand(0), S);
     }
-  } else if (VT.getSimpleVT() == MVT::v2i32) {
-    switch (Op.getOpcode()) {
-    case ISD::SRA:
-      Result = DAG.getNode(HexagonISD::VASR, dl, VT, V3, CommonSplat);
-      break;
-    case ISD::SHL:
-      Result = DAG.getNode(HexagonISD::VASL, dl, VT, V3, CommonSplat);
-      break;
-    case ISD::SRL:
-      Result = DAG.getNode(HexagonISD::VLSR, dl, VT, V3, CommonSplat);
-      break;
-    default:
-      return SDValue();
-    }
-  } else {
-    return SDValue();
   }
 
-  return DAG.getNode(ISD::BITCAST, dl, VT, Result);
+  if (Subtarget.useHVXOps() && Subtarget.isHVXVectorType(ty(Op)))
+    return LowerHvxShift(Op, DAG);
+
+  return SDValue();
 }
 
 SDValue
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.h b/llvm/lib/Target/Hexagon/HexagonISelLowering.h
index 0bc2307..6a1105d 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.h
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.h
@@ -402,6 +402,7 @@
     SDValue LowerHvxMulh(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerHvxSetCC(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerHvxExtend(SDValue Op, SelectionDAG &DAG) const;
+    SDValue LowerHvxShift(SDValue Op, SelectionDAG &DAG) const;
 
     std::pair<const TargetRegisterClass*, uint8_t>
     findRepresentativeClass(const TargetRegisterInfo *TRI, MVT VT)
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
index 1be3ce8..b177473 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
@@ -1103,3 +1103,9 @@
   assert(Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG);
   return DAG.getZeroExtendVectorInReg(Op.getOperand(0), SDLoc(Op), ty(Op));
 }
+
+SDValue
+HexagonTargetLowering::LowerHvxShift(SDValue Op, SelectionDAG &DAG) const {
+  return Op;
+}
+
diff --git a/llvm/lib/Target/Hexagon/HexagonPatterns.td b/llvm/lib/Target/Hexagon/HexagonPatterns.td
index bcedf87..06e23e2 100644
--- a/llvm/lib/Target/Hexagon/HexagonPatterns.td
+++ b/llvm/lib/Target/Hexagon/HexagonPatterns.td
@@ -3095,19 +3095,33 @@
     def: Pat<(VecI16 (sext_inreg HVI16:$Vs, v32i8)),
              (V6_vasrh (V6_vaslh HVI16:$Vs, (A2_tfrsi 8)), (A2_tfrsi 8))>;
     def: Pat<(VecI32 (sext_inreg HVI32:$Vs, v16i8)),
-             (V6_vasrh (V6_vaslh HVI32:$Vs, (A2_tfrsi 24)), (A2_tfrsi 24))>;
+             (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 24)), (A2_tfrsi 24))>;
     def: Pat<(VecI32 (sext_inreg HVI32:$Vs, v16i16)),
-             (V6_vasrh (V6_vaslh HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
+             (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
   }
   let Predicates = [UseHVX,UseHVX128B] in {
     def: Pat<(VecI16 (sext_inreg HVI16:$Vs, v64i8)),
              (V6_vasrh (V6_vaslh HVI16:$Vs, (A2_tfrsi 8)), (A2_tfrsi 8))>;
     def: Pat<(VecI32 (sext_inreg HVI32:$Vs, v32i8)),
-             (V6_vasrh (V6_vaslh HVI32:$Vs, (A2_tfrsi 24)), (A2_tfrsi 24))>;
+             (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 24)), (A2_tfrsi 24))>;
     def: Pat<(VecI32 (sext_inreg HVI32:$Vs, v32i16)),
-             (V6_vasrh (V6_vaslh HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
+             (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
   }
 
+  def: Pat<(HexagonVASL HVI16:$Vs, I32:$Rt), (V6_vaslh HvxVR:$Vs, I32:$Rt)>;
+  def: Pat<(HexagonVASL HVI32:$Vs, I32:$Rt), (V6_vaslw HvxVR:$Vs, I32:$Rt)>;
+  def: Pat<(HexagonVASR HVI16:$Vs, I32:$Rt), (V6_vasrh HvxVR:$Vs, I32:$Rt)>;
+  def: Pat<(HexagonVASR HVI32:$Vs, I32:$Rt), (V6_vasrw HvxVR:$Vs, I32:$Rt)>;
+  def: Pat<(HexagonVLSR HVI16:$Vs, I32:$Rt), (V6_vlsrh HvxVR:$Vs, I32:$Rt)>;
+  def: Pat<(HexagonVLSR HVI32:$Vs, I32:$Rt), (V6_vlsrw HvxVR:$Vs, I32:$Rt)>;
+
+  def: Pat<(shl HVI16:$Vs, HVI16:$Vt), (V6_vaslhv HvxVR:$Vs, HvxVR:$Vt)>;
+  def: Pat<(shl HVI32:$Vs, HVI32:$Vt), (V6_vaslwv HvxVR:$Vs, HvxVR:$Vt)>;
+  def: Pat<(sra HVI16:$Vs, HVI16:$Vt), (V6_vasrhv HvxVR:$Vs, HvxVR:$Vt)>;
+  def: Pat<(sra HVI32:$Vs, HVI32:$Vt), (V6_vasrwv HvxVR:$Vs, HvxVR:$Vt)>;
+  def: Pat<(srl HVI16:$Vs, HVI16:$Vt), (V6_vlsrhv HvxVR:$Vs, HvxVR:$Vt)>;
+  def: Pat<(srl HVI32:$Vs, HVI32:$Vt), (V6_vlsrwv HvxVR:$Vs, HvxVR:$Vt)>;
+
   def: Pat<(VecI8 (trunc HWI16:$Vss)),
            (V6_vpackeb (HiVec $Vss), (LoVec $Vss))>;
   def: Pat<(VecI16 (trunc HWI32:$Vss)),