[X86][SSE] Improve support for 128-bit vector sign extension
This patch improves support for sign extension of the lower lanes of vectors of integers by making use of the SSE41 pmovsx* sign extension instructions where possible, and optimizing the sign extension by shifts on pre-SSE41 targets (avoiding the use of i64 arithmetic shifts which require scalarization).
It converts SIGN_EXTEND nodes to SIGN_EXTEND_VECTOR_INREG where necessary, that more closely matches the pmovsx* instruction than the default approach of using SIGN_EXTEND_INREG which splits the operation (into an ANY_EXTEND lowered to a shuffle followed by shifts) making instruction matching difficult during lowering. Necessary support for SIGN_EXTEND_VECTOR_INREG has been added to the DAGCombiner.
Differential Revision: http://reviews.llvm.org/D9848
llvm-svn: 237885
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 6d75a7c..eaba9ca 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3429,12 +3429,35 @@
assert(EVT.bitsLE(VT) && "Not extending!");
if (EVT == VT) return N1; // Not actually extending
+ auto SignExtendInReg = [&](APInt Val) {
+ unsigned FromBits = EVT.getScalarType().getSizeInBits();
+ Val <<= Val.getBitWidth() - FromBits;
+ Val = Val.ashr(Val.getBitWidth() - FromBits);
+ return getConstant(Val, DL, VT.getScalarType());
+ };
+
if (N1C) {
APInt Val = N1C->getAPIntValue();
- unsigned FromBits = EVT.getScalarType().getSizeInBits();
- Val <<= Val.getBitWidth()-FromBits;
- Val = Val.ashr(Val.getBitWidth()-FromBits);
- return getConstant(Val, DL, VT);
+ return SignExtendInReg(Val);
+ }
+ if (ISD::isBuildVectorOfConstantSDNodes(N1.getNode())) {
+ SmallVector<SDValue, 8> Ops;
+ for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
+ SDValue Op = N1.getOperand(i);
+ if (Op.getValueType() != VT.getScalarType()) break;
+ if (Op.getOpcode() == ISD::UNDEF) {
+ Ops.push_back(Op);
+ continue;
+ }
+ if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getNode())) {
+ APInt Val = C->getAPIntValue();
+ Ops.push_back(SignExtendInReg(Val));
+ continue;
+ }
+ break;
+ }
+ if (Ops.size() == VT.getVectorNumElements())
+ return getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
}
break;
}