Add AVX2 vpbroadcast support
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144967 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 4986aac..6a14f22 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -5115,9 +5115,9 @@
/// 1. A splat BUILD_VECTOR which uses a single scalar load.
/// 2. A splat shuffle which uses a scalar_to_vector node which comes from
/// a scalar load.
-/// The scalar load node is returned when a pattern is found,
-/// or SDValue() otherwise.
-static SDValue isVectorBroadcast(SDValue &Op) {
+/// The scalar load node is returned when a pattern is found,
+/// or SDValue() otherwise.
+static SDValue isVectorBroadcast(SDValue &Op, bool hasAVX2) {
EVT VT = Op.getValueType();
SDValue V = Op;
@@ -5134,16 +5134,16 @@
case ISD::BUILD_VECTOR: {
// The BUILD_VECTOR node must be a splat.
- if (!isSplatVector(V.getNode()))
+ if (!isSplatVector(V.getNode()))
return SDValue();
Ld = V.getOperand(0);
-
- // The suspected load node has several users. Make sure that all
+
+ // The suspected load node has several users. Make sure that all
// of its users are from the BUILD_VECTOR node.
- if (!Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0))
+ if (!Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0))
return SDValue();
- break;
+ break;
}
case ISD::VECTOR_SHUFFLE: {
@@ -5151,11 +5151,11 @@
// Shuffles must have a splat mask where the first element is
// broadcasted.
- if ((!SVOp->isSplat()) || SVOp->getMaskElt(0) != 0)
+ if ((!SVOp->isSplat()) || SVOp->getMaskElt(0) != 0)
return SDValue();
SDValue Sc = Op.getOperand(0);
- if (Sc.getOpcode() != ISD::SCALAR_TO_VECTOR)
+ if (Sc.getOpcode() != ISD::SCALAR_TO_VECTOR)
return SDValue();
Ld = Sc.getOperand(0);
@@ -5167,15 +5167,27 @@
break;
}
}
-
+
// The scalar source must be a normal load.
- if (!ISD::isNormalLoad(Ld.getNode()))
+ if (!ISD::isNormalLoad(Ld.getNode()))
return SDValue();
-
+
bool Is256 = VT.getSizeInBits() == 256;
bool Is128 = VT.getSizeInBits() == 128;
unsigned ScalarSize = Ld.getValueType().getSizeInBits();
+ if (hasAVX2) {
+ // VBroadcast to YMM
+ if (Is256 && (ScalarSize == 8 || ScalarSize == 16 ||
+ ScalarSize == 32 || ScalarSize == 64 ))
+ return Ld;
+
+ // VBroadcast to XMM
+ if (Is128 && (ScalarSize == 8 || ScalarSize == 32 ||
+ ScalarSize == 16 || ScalarSize == 64 ))
+ return Ld;
+ }
+
// VBroadcast to YMM
if (Is256 && (ScalarSize == 32 || ScalarSize == 64))
return Ld;
@@ -5184,6 +5196,7 @@
if (Is128 && (ScalarSize == 32))
return Ld;
+
// Unsupported broadcast.
return SDValue();
}
@@ -5216,7 +5229,7 @@
return getOnesVector(Op.getValueType(), DAG, dl);
}
- SDValue LD = isVectorBroadcast(Op);
+ SDValue LD = isVectorBroadcast(Op, Subtarget->hasAVX2());
if (Subtarget->hasAVX() && LD.getNode())
return DAG.getNode(X86ISD::VBROADCAST, dl, VT, LD);
@@ -6613,7 +6626,7 @@
return Op;
// Use vbroadcast whenever the splat comes from a foldable load
- SDValue LD = isVectorBroadcast(Op);
+ SDValue LD = isVectorBroadcast(Op, Subtarget->hasAVX2());
if (Subtarget->hasAVX() && LD.getNode())
return DAG.getNode(X86ISD::VBROADCAST, dl, VT, LD);