Add AVX2 vpbroadcast support

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144967 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 4986aac..6a14f22 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -5115,9 +5115,9 @@
 /// 1. A splat BUILD_VECTOR which uses a single scalar load.
 /// 2. A splat shuffle which uses a scalar_to_vector node which comes from
 /// a scalar load.
-/// The scalar load node is returned when a pattern is found, 
-/// or SDValue() otherwise. 
-static SDValue isVectorBroadcast(SDValue &Op) {
+/// The scalar load node is returned when a pattern is found,
+/// or SDValue() otherwise.
+static SDValue isVectorBroadcast(SDValue &Op, bool hasAVX2) {
   EVT VT = Op.getValueType();
   SDValue V = Op;
 
@@ -5134,16 +5134,16 @@
 
     case ISD::BUILD_VECTOR: {
       // The BUILD_VECTOR node must be a splat.
-      if (!isSplatVector(V.getNode())) 
+      if (!isSplatVector(V.getNode()))
         return SDValue();
 
       Ld = V.getOperand(0);
-    
-      // The suspected load node has several users. Make sure that all 
+
+      // The suspected load node has several users. Make sure that all
       // of its users are from the BUILD_VECTOR node.
-      if (!Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0)) 
+      if (!Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0))
         return SDValue();
-      break; 
+      break;
     }
 
     case ISD::VECTOR_SHUFFLE: {
@@ -5151,11 +5151,11 @@
 
       // Shuffles must have a splat mask where the first element is
       // broadcasted.
-      if ((!SVOp->isSplat()) || SVOp->getMaskElt(0) != 0) 
+      if ((!SVOp->isSplat()) || SVOp->getMaskElt(0) != 0)
         return SDValue();
 
       SDValue Sc = Op.getOperand(0);
-      if (Sc.getOpcode() != ISD::SCALAR_TO_VECTOR) 
+      if (Sc.getOpcode() != ISD::SCALAR_TO_VECTOR)
         return SDValue();
 
       Ld = Sc.getOperand(0);
@@ -5167,15 +5167,27 @@
       break;
     }
   }
-  
+
   // The scalar source must be a normal load.
-  if (!ISD::isNormalLoad(Ld.getNode())) 
+  if (!ISD::isNormalLoad(Ld.getNode()))
     return SDValue();
-  
+
   bool Is256 = VT.getSizeInBits() == 256;
   bool Is128 = VT.getSizeInBits() == 128;
   unsigned ScalarSize = Ld.getValueType().getSizeInBits();
 
+  if (hasAVX2) {
+    // VBroadcast to YMM
+    if (Is256 && (ScalarSize == 8  || ScalarSize == 16 ||
+                  ScalarSize == 32 || ScalarSize == 64 ))
+      return Ld;
+
+    // VBroadcast to XMM
+    if (Is128 && (ScalarSize ==  8 || ScalarSize == 32 ||
+                  ScalarSize == 16 || ScalarSize == 64 ))
+      return Ld;
+  }
+
   // VBroadcast to YMM
   if (Is256 && (ScalarSize == 32 || ScalarSize == 64))
     return Ld;
@@ -5184,6 +5196,7 @@
   if (Is128 && (ScalarSize == 32))
     return Ld;
 
+
   // Unsupported broadcast.
   return SDValue();
 }
@@ -5216,7 +5229,7 @@
     return getOnesVector(Op.getValueType(), DAG, dl);
   }
 
-  SDValue LD = isVectorBroadcast(Op);
+  SDValue LD = isVectorBroadcast(Op, Subtarget->hasAVX2());
   if (Subtarget->hasAVX() && LD.getNode())
       return DAG.getNode(X86ISD::VBROADCAST, dl, VT, LD);
 
@@ -6613,7 +6626,7 @@
       return Op;
 
     // Use vbroadcast whenever the splat comes from a foldable load
-    SDValue LD = isVectorBroadcast(Op);
+    SDValue LD = isVectorBroadcast(Op, Subtarget->hasAVX2());
     if (Subtarget->hasAVX() && LD.getNode())
       return DAG.getNode(X86ISD::VBROADCAST, dl, VT, LD);
 
diff --git a/lib/Target/X86/X86InstrSSE.td b/lib/Target/X86/X86InstrSSE.td
index 11f4785..e5957508 100644
--- a/lib/Target/X86/X86InstrSSE.td
+++ b/lib/Target/X86/X86InstrSSE.td
@@ -7189,19 +7189,6 @@
 def : Pat<(int_x86_avx_vbroadcastf128_ps_256 addr:$src),
           (VBROADCASTF128 addr:$src)>;
 
-def : Pat<(v8i32 (X86VBroadcast (loadi32 addr:$src))),
-          (VBROADCASTSSYrm addr:$src)>;
-def : Pat<(v4i64 (X86VBroadcast (loadi64 addr:$src))),
-          (VBROADCASTSDrm addr:$src)>;
-def : Pat<(v8f32 (X86VBroadcast (loadf32 addr:$src))),
-          (VBROADCASTSSYrm addr:$src)>;
-def : Pat<(v4f64 (X86VBroadcast (loadf64 addr:$src))),
-          (VBROADCASTSDrm addr:$src)>;
-
-def : Pat<(v4f32 (X86VBroadcast (loadf32 addr:$src))),
-          (VBROADCASTSSrm addr:$src)>;
-def : Pat<(v4i32 (X86VBroadcast (loadi32 addr:$src))),
-          (VBROADCASTSSrm addr:$src)>;
 
 //===----------------------------------------------------------------------===//
 // VINSERTF128 - Insert packed floating-point values
@@ -7557,6 +7544,40 @@
                                     int_x86_avx2_pbroadcastq_128,
                                     int_x86_avx2_pbroadcastq_256>;
 
+let Predicates = [HasAVX2] in {
+  def : Pat<(v16i8 (X86VBroadcast (loadi8 addr:$src))),
+          (VPBROADCASTBrm addr:$src)>;
+  def : Pat<(v32i8 (X86VBroadcast (loadi8 addr:$src))),
+          (VPBROADCASTBYrm addr:$src)>;
+  def : Pat<(v8i16 (X86VBroadcast (loadi16 addr:$src))),
+          (VPBROADCASTWrm addr:$src)>;
+  def : Pat<(v16i16 (X86VBroadcast (loadi16 addr:$src))),
+          (VPBROADCASTWYrm addr:$src)>;
+  def : Pat<(v4i32 (X86VBroadcast (loadi32 addr:$src))),
+          (VPBROADCASTDrm addr:$src)>;
+  def : Pat<(v8i32 (X86VBroadcast (loadi32 addr:$src))),
+          (VPBROADCASTDYrm addr:$src)>;
+  def : Pat<(v2i64 (X86VBroadcast (loadi64 addr:$src))),
+          (VPBROADCASTQrm addr:$src)>;
+  def : Pat<(v4i64 (X86VBroadcast (loadi64 addr:$src))),
+          (VPBROADCASTQYrm addr:$src)>;
+}
+
+// AVX1 broadcast patterns
+def : Pat<(v8i32 (X86VBroadcast (loadi32 addr:$src))),
+          (VBROADCASTSSYrm addr:$src)>;
+def : Pat<(v4i64 (X86VBroadcast (loadi64 addr:$src))),
+          (VBROADCASTSDrm addr:$src)>;
+def : Pat<(v8f32 (X86VBroadcast (loadf32 addr:$src))),
+          (VBROADCASTSSYrm addr:$src)>;
+def : Pat<(v4f64 (X86VBroadcast (loadf64 addr:$src))),
+          (VBROADCASTSDrm addr:$src)>;
+
+def : Pat<(v4f32 (X86VBroadcast (loadf32 addr:$src))),
+          (VBROADCASTSSrm addr:$src)>;
+def : Pat<(v4i32 (X86VBroadcast (loadi32 addr:$src))),
+          (VBROADCASTSSrm addr:$src)>;
+
 //===----------------------------------------------------------------------===//
 // VPERM - Permute instructions
 //