[Hexagon] Implement HVX codegen for vector shifts
llvm-svn: 323914
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
index 9535d0f..b0781c1 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
@@ -2109,8 +2109,13 @@
setOperationAction(ISD::ANY_EXTEND, T, Custom);
setOperationAction(ISD::SIGN_EXTEND, T, Custom);
setOperationAction(ISD::ZERO_EXTEND, T, Custom);
- if (T != ByteV)
+ if (T != ByteV) {
setOperationAction(ISD::ANY_EXTEND_VECTOR_INREG, T, Custom);
+ // HVX only has shifts of words and halfwords.
+ setOperationAction(ISD::SRA, T, Custom);
+ setOperationAction(ISD::SHL, T, Custom);
+ setOperationAction(ISD::SRL, T, Custom);
+ }
}
for (MVT T : LegalV) {
@@ -2523,76 +2528,37 @@
return SDValue();
}
-// If BUILD_VECTOR has same base element repeated several times,
-// report true.
-static bool isCommonSplatElement(BuildVectorSDNode *BVN) {
- unsigned NElts = BVN->getNumOperands();
- SDValue V0 = BVN->getOperand(0);
-
- for (unsigned i = 1, e = NElts; i != e; ++i) {
- if (BVN->getOperand(i) != V0)
- return false;
- }
- return true;
-}
-
// Lower a vector shift. Try to convert
// <VT> = SHL/SRA/SRL <VT> by <VT> to Hexagon specific
// <VT> = SHL/SRA/SRL <VT> by <IT/i32>.
SDValue
HexagonTargetLowering::LowerVECTOR_SHIFT(SDValue Op, SelectionDAG &DAG) const {
- BuildVectorSDNode *BVN = nullptr;
- SDValue V1 = Op.getOperand(0);
- SDValue V2 = Op.getOperand(1);
- SDValue V3;
- SDLoc dl(Op);
- EVT VT = Op.getValueType();
+ const SDLoc dl(Op);
- if ((BVN = dyn_cast<BuildVectorSDNode>(V1.getNode())) &&
- isCommonSplatElement(BVN))
- V3 = V2;
- else if ((BVN = dyn_cast<BuildVectorSDNode>(V2.getNode())) &&
- isCommonSplatElement(BVN))
- V3 = V1;
- else
- return SDValue();
-
- SDValue CommonSplat = BVN->getOperand(0);
- SDValue Result;
-
- if (VT.getSimpleVT() == MVT::v4i16) {
- switch (Op.getOpcode()) {
- case ISD::SRA:
- Result = DAG.getNode(HexagonISD::VASR, dl, VT, V3, CommonSplat);
- break;
- case ISD::SHL:
- Result = DAG.getNode(HexagonISD::VASL, dl, VT, V3, CommonSplat);
- break;
- case ISD::SRL:
- Result = DAG.getNode(HexagonISD::VLSR, dl, VT, V3, CommonSplat);
- break;
- default:
- return SDValue();
+ if (auto *BVN = dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode())) {
+ if (SDValue S = BVN->getSplatValue()) {
+ unsigned NewOpc;
+ switch (Op.getOpcode()) {
+ case ISD::SHL:
+ NewOpc = HexagonISD::VASL;
+ break;
+ case ISD::SRA:
+ NewOpc = HexagonISD::VASR;
+ break;
+ case ISD::SRL:
+ NewOpc = HexagonISD::VLSR;
+ break;
+ default:
+ llvm_unreachable("Unexpected shift opcode");
+ }
+ return DAG.getNode(NewOpc, dl, ty(Op), Op.getOperand(0), S);
}
- } else if (VT.getSimpleVT() == MVT::v2i32) {
- switch (Op.getOpcode()) {
- case ISD::SRA:
- Result = DAG.getNode(HexagonISD::VASR, dl, VT, V3, CommonSplat);
- break;
- case ISD::SHL:
- Result = DAG.getNode(HexagonISD::VASL, dl, VT, V3, CommonSplat);
- break;
- case ISD::SRL:
- Result = DAG.getNode(HexagonISD::VLSR, dl, VT, V3, CommonSplat);
- break;
- default:
- return SDValue();
- }
- } else {
- return SDValue();
}
- return DAG.getNode(ISD::BITCAST, dl, VT, Result);
+ if (Subtarget.useHVXOps() && Subtarget.isHVXVectorType(ty(Op)))
+ return LowerHvxShift(Op, DAG);
+
+ return SDValue();
}
SDValue
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.h b/llvm/lib/Target/Hexagon/HexagonISelLowering.h
index 0bc2307..6a1105d 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.h
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.h
@@ -402,6 +402,7 @@
SDValue LowerHvxMulh(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerHvxSetCC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerHvxExtend(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerHvxShift(SDValue Op, SelectionDAG &DAG) const;
std::pair<const TargetRegisterClass*, uint8_t>
findRepresentativeClass(const TargetRegisterInfo *TRI, MVT VT)
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
index 1be3ce8..b177473 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
@@ -1103,3 +1103,9 @@
assert(Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG);
return DAG.getZeroExtendVectorInReg(Op.getOperand(0), SDLoc(Op), ty(Op));
}
+
+SDValue
+HexagonTargetLowering::LowerHvxShift(SDValue Op, SelectionDAG &DAG) const {
+ return Op;
+}
+
diff --git a/llvm/lib/Target/Hexagon/HexagonPatterns.td b/llvm/lib/Target/Hexagon/HexagonPatterns.td
index bcedf87..06e23e2 100644
--- a/llvm/lib/Target/Hexagon/HexagonPatterns.td
+++ b/llvm/lib/Target/Hexagon/HexagonPatterns.td
@@ -3095,19 +3095,33 @@
def: Pat<(VecI16 (sext_inreg HVI16:$Vs, v32i8)),
(V6_vasrh (V6_vaslh HVI16:$Vs, (A2_tfrsi 8)), (A2_tfrsi 8))>;
def: Pat<(VecI32 (sext_inreg HVI32:$Vs, v16i8)),
- (V6_vasrh (V6_vaslh HVI32:$Vs, (A2_tfrsi 24)), (A2_tfrsi 24))>;
+ (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 24)), (A2_tfrsi 24))>;
def: Pat<(VecI32 (sext_inreg HVI32:$Vs, v16i16)),
- (V6_vasrh (V6_vaslh HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
+ (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
}
let Predicates = [UseHVX,UseHVX128B] in {
def: Pat<(VecI16 (sext_inreg HVI16:$Vs, v64i8)),
(V6_vasrh (V6_vaslh HVI16:$Vs, (A2_tfrsi 8)), (A2_tfrsi 8))>;
def: Pat<(VecI32 (sext_inreg HVI32:$Vs, v32i8)),
- (V6_vasrh (V6_vaslh HVI32:$Vs, (A2_tfrsi 24)), (A2_tfrsi 24))>;
+ (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 24)), (A2_tfrsi 24))>;
def: Pat<(VecI32 (sext_inreg HVI32:$Vs, v32i16)),
- (V6_vasrh (V6_vaslh HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
+ (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>;
}
+ def: Pat<(HexagonVASL HVI16:$Vs, I32:$Rt), (V6_vaslh HvxVR:$Vs, I32:$Rt)>;
+ def: Pat<(HexagonVASL HVI32:$Vs, I32:$Rt), (V6_vaslw HvxVR:$Vs, I32:$Rt)>;
+ def: Pat<(HexagonVASR HVI16:$Vs, I32:$Rt), (V6_vasrh HvxVR:$Vs, I32:$Rt)>;
+ def: Pat<(HexagonVASR HVI32:$Vs, I32:$Rt), (V6_vasrw HvxVR:$Vs, I32:$Rt)>;
+ def: Pat<(HexagonVLSR HVI16:$Vs, I32:$Rt), (V6_vlsrh HvxVR:$Vs, I32:$Rt)>;
+ def: Pat<(HexagonVLSR HVI32:$Vs, I32:$Rt), (V6_vlsrw HvxVR:$Vs, I32:$Rt)>;
+
+ def: Pat<(shl HVI16:$Vs, HVI16:$Vt), (V6_vaslhv HvxVR:$Vs, HvxVR:$Vt)>;
+ def: Pat<(shl HVI32:$Vs, HVI32:$Vt), (V6_vaslwv HvxVR:$Vs, HvxVR:$Vt)>;
+ def: Pat<(sra HVI16:$Vs, HVI16:$Vt), (V6_vasrhv HvxVR:$Vs, HvxVR:$Vt)>;
+ def: Pat<(sra HVI32:$Vs, HVI32:$Vt), (V6_vasrwv HvxVR:$Vs, HvxVR:$Vt)>;
+ def: Pat<(srl HVI16:$Vs, HVI16:$Vt), (V6_vlsrhv HvxVR:$Vs, HvxVR:$Vt)>;
+ def: Pat<(srl HVI32:$Vs, HVI32:$Vt), (V6_vlsrwv HvxVR:$Vs, HvxVR:$Vt)>;
+
def: Pat<(VecI8 (trunc HWI16:$Vss)),
(V6_vpackeb (HiVec $Vss), (LoVec $Vss))>;
def: Pat<(VecI16 (trunc HWI32:$Vss)),