[X86][SSE] Improve support for 128-bit vector sign extension

This patch improves support for sign extension of the lower lanes of vectors of integers by making use of the SSE41 pmovsx* sign extension instructions where possible, and optimizing the sign extension by shifts on pre-SSE41 targets (avoiding the use of i64 arithmetic shifts which require scalarization).

It converts SIGN_EXTEND nodes to SIGN_EXTEND_VECTOR_INREG where necessary, that more closely matches the pmovsx* instruction than the default approach of using SIGN_EXTEND_INREG which splits the operation (into an ANY_EXTEND lowered to a shuffle followed by shifts) making instruction matching difficult during lowering. Necessary support for SIGN_EXTEND_VECTOR_INREG has been added to the DAGCombiner.

Differential Revision: http://reviews.llvm.org/D9848

llvm-svn: 237885
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 3ff4c40..77e648c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -268,6 +268,7 @@
     SDValue visitZERO_EXTEND(SDNode *N);
     SDValue visitANY_EXTEND(SDNode *N);
     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
+    SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N);
     SDValue visitTRUNCATE(SDNode *N);
     SDValue visitBITCAST(SDNode *N);
     SDValue visitBUILD_PAIR(SDNode *N);
@@ -1347,6 +1348,7 @@
   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
+  case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N);
   case ISD::TRUNCATE:           return visitTRUNCATE(N);
   case ISD::BITCAST:            return visitBITCAST(N);
   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
@@ -5541,7 +5543,8 @@
   EVT VT = N->getValueType(0);
 
   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
-         Opcode == ISD::ANY_EXTEND) && "Expected EXTEND dag node in input!");
+         Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
+         && "Expected EXTEND dag node in input!");
 
   // fold (sext c1) -> c1
   // fold (zext c1) -> c1
@@ -5563,7 +5566,7 @@
   unsigned EVTBits = N0->getValueType(0).getScalarType().getSizeInBits();
   unsigned ShAmt = VTBits - EVTBits;
   SmallVector<SDValue, 8> Elts;
-  unsigned NumElts = N0->getNumOperands();
+  unsigned NumElts = VT.getVectorNumElements();
   SDLoc DL(N);
 
   for (unsigned i=0; i != NumElts; ++i) {
@@ -5576,7 +5579,7 @@
     SDLoc DL(Op);
     ConstantSDNode *CurrentND = cast<ConstantSDNode>(Op);
     const APInt &C = APInt(VTBits, CurrentND->getAPIntValue().getZExtValue());
-    if (Opcode == ISD::SIGN_EXTEND)
+    if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
       Elts.push_back(DAG.getConstant(C.shl(ShAmt).ashr(ShAmt).getZExtValue(),
                                      DL, SVT));
     else
@@ -6805,6 +6808,20 @@
   return SDValue();
 }
 
+SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  EVT VT = N->getValueType(0);
+
+  if (N0.getOpcode() == ISD::UNDEF)
+    return DAG.getUNDEF(VT);
+
+  if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes,
+                                              LegalOperations))
+    return SDValue(Res, 0);
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 6d75a7c..eaba9ca 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3429,12 +3429,35 @@
     assert(EVT.bitsLE(VT) && "Not extending!");
     if (EVT == VT) return N1;  // Not actually extending
 
+    auto SignExtendInReg = [&](APInt Val) {
+      unsigned FromBits = EVT.getScalarType().getSizeInBits();
+      Val <<= Val.getBitWidth() - FromBits;
+      Val = Val.ashr(Val.getBitWidth() - FromBits);
+      return getConstant(Val, DL, VT.getScalarType());
+    };
+
     if (N1C) {
       APInt Val = N1C->getAPIntValue();
-      unsigned FromBits = EVT.getScalarType().getSizeInBits();
-      Val <<= Val.getBitWidth()-FromBits;
-      Val = Val.ashr(Val.getBitWidth()-FromBits);
-      return getConstant(Val, DL, VT);
+      return SignExtendInReg(Val);
+    }
+    if (ISD::isBuildVectorOfConstantSDNodes(N1.getNode())) {
+      SmallVector<SDValue, 8> Ops;
+      for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
+        SDValue Op = N1.getOperand(i);
+        if (Op.getValueType() != VT.getScalarType()) break;
+        if (Op.getOpcode() == ISD::UNDEF) {
+          Ops.push_back(Op);
+          continue;
+        }
+        if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getNode())) {
+          APInt Val = C->getAPIntValue();
+          Ops.push_back(SignExtendInReg(Val));
+          continue;
+        }
+        break;
+      }
+      if (Ops.size() == VT.getVectorNumElements())
+        return getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
     }
     break;
   }