Codegen failure for vmull with small vectors
Codegen was failing with an assertion because of unexpected vector
operands when legalizing the selection DAG for a MUL instruction.
The asserting code was legalizing multiplies for vectors of size 128
bits. It uses a custom lowering to try and detect cases where it can
use a VMULL instruction instead of a VMOVL + VMUL. The code was
looking for input operands to the MUL that had been sign or zero
extended. If it found the extended operands it would drop the
sign/zero extension and use the original vector size as input to a
VMULL instruction.
The code assumed that the original input vector was 64 bits so that
after dropping the extension it would fit directly into a D register
and could be used as an operand of a VMULL instruction. The input
code that trigger the failure used a vector of <4 x i8> that was
sign extended to <4 x i32>. It was not safe to drop the sign
extension in this case because the original vector is only 32 bits
wide. The fix is to insert a sign extension for the vector to reach
the required 64 bit size. In this particular example, the vector would
need to be sign extented to a <4 x i16>.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@169024 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp
index d139a56..8ccb3c3 100644
--- a/lib/Target/ARM/ARMISelLowering.cpp
+++ b/lib/Target/ARM/ARMISelLowering.cpp
@@ -4939,16 +4939,76 @@
return false;
}
-/// SkipExtension - For a node that is a SIGN_EXTEND, ZERO_EXTEND, extending
-/// load, or BUILD_VECTOR with extended elements, return the unextended value.
-static SDValue SkipExtension(SDNode *N, SelectionDAG &DAG) {
+/// AddRequiredExtensionForVMULL - Add a sign/zero extension to extend the total
+/// value size to 64 bits. We need a 64-bit D register as an operand to VMULL.
+/// We insert the required extension here to get the vector to fill a D register.
+static SDValue AddRequiredExtensionForVMULL(SDValue N, SelectionDAG &DAG,
+ const EVT &OrigTy,
+ const EVT &ExtTy,
+ unsigned ExtOpcode) {
+ // The vector originally had a size of OrigTy. It was then extended to ExtTy.
+ // We expect the ExtTy to be 128-bits total. If the OrigTy is less than
+ // 64-bits we need to insert a new extension so that it will be 64-bits.
+ assert(ExtTy.is128BitVector() && "Unexpected extension size");
+ if (OrigTy.getSizeInBits() >= 64)
+ return N;
+
+ // Must extend size to at least 64 bits to be used as an operand for VMULL.
+ MVT::SimpleValueType OrigSimpleTy = OrigTy.getSimpleVT().SimpleTy;
+ EVT NewVT;
+ switch (OrigSimpleTy) {
+ default: llvm_unreachable("Unexpected Orig Vector Type");
+ case MVT::v2i8:
+ case MVT::v2i16:
+ NewVT = MVT::v2i32;
+ break;
+ case MVT::v4i8:
+ NewVT = MVT::v4i16;
+ break;
+ }
+ return DAG.getNode(ExtOpcode, N->getDebugLoc(), NewVT, N);
+}
+
+/// SkipLoadExtensionForVMULL - return a load of the original vector size that
+/// does not do any sign/zero extension. If the original vector is less
+/// than 64 bits, an appropriate extension will be added after the load to
+/// reach a total size of 64 bits. We have to add the extension separately
+/// because ARM does not have a sign/zero extending load for vectors.
+static SDValue SkipLoadExtensionForVMULL(LoadSDNode *LD, SelectionDAG& DAG) {
+ SDValue NonExtendingLoad =
+ DAG.getLoad(LD->getMemoryVT(), LD->getDebugLoc(), LD->getChain(),
+ LD->getBasePtr(), LD->getPointerInfo(), LD->isVolatile(),
+ LD->isNonTemporal(), LD->isInvariant(),
+ LD->getAlignment());
+ unsigned ExtOp = 0;
+ switch (LD->getExtensionType()) {
+ default: llvm_unreachable("Unexpected LoadExtType");
+ case ISD::EXTLOAD:
+ case ISD::SEXTLOAD: ExtOp = ISD::SIGN_EXTEND; break;
+ case ISD::ZEXTLOAD: ExtOp = ISD::ZERO_EXTEND; break;
+ }
+ MVT::SimpleValueType MemType = LD->getMemoryVT().getSimpleVT().SimpleTy;
+ MVT::SimpleValueType ExtType = LD->getValueType(0).getSimpleVT().SimpleTy;
+ return AddRequiredExtensionForVMULL(NonExtendingLoad, DAG,
+ MemType, ExtType, ExtOp);
+}
+
+/// SkipExtensionForVMULL - For a node that is a SIGN_EXTEND, ZERO_EXTEND,
+/// extending load, or BUILD_VECTOR with extended elements, return the
+/// unextended value. The unextended vector should be 64 bits so that it can
+/// be used as an operand to a VMULL instruction. If the original vector size
+/// before extension is less than 64 bits we add a an extension to resize
+/// the vector to 64 bits.
+static SDValue SkipExtensionForVMULL(SDNode *N, SelectionDAG &DAG) {
if (N->getOpcode() == ISD::SIGN_EXTEND || N->getOpcode() == ISD::ZERO_EXTEND)
- return N->getOperand(0);
+ return AddRequiredExtensionForVMULL(N->getOperand(0), DAG,
+ N->getOperand(0)->getValueType(0),
+ N->getValueType(0),
+ N->getOpcode());
+
if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N))
- return DAG.getLoad(LD->getMemoryVT(), N->getDebugLoc(), LD->getChain(),
- LD->getBasePtr(), LD->getPointerInfo(), LD->isVolatile(),
- LD->isNonTemporal(), LD->isInvariant(),
- LD->getAlignment());
+ return SkipLoadExtensionForVMULL(LD, DAG);
+
// Otherwise, the value must be a BUILD_VECTOR. For v2i64, it will
// have been legalized as a BITCAST from v4i32.
if (N->getOpcode() == ISD::BITCAST) {
@@ -5003,7 +5063,8 @@
// Multiplications are only custom-lowered for 128-bit vectors so that
// VMULL can be detected. Otherwise v2i64 multiplications are not legal.
EVT VT = Op.getValueType();
- assert(VT.is128BitVector() && "unexpected type for custom-lowering ISD::MUL");
+ assert(VT.is128BitVector() && VT.isInteger() &&
+ "unexpected type for custom-lowering ISD::MUL");
SDNode *N0 = Op.getOperand(0).getNode();
SDNode *N1 = Op.getOperand(1).getNode();
unsigned NewOpc = 0;
@@ -5046,9 +5107,9 @@
// Legalize to a VMULL instruction.
DebugLoc DL = Op.getDebugLoc();
SDValue Op0;
- SDValue Op1 = SkipExtension(N1, DAG);
+ SDValue Op1 = SkipExtensionForVMULL(N1, DAG);
if (!isMLA) {
- Op0 = SkipExtension(N0, DAG);
+ Op0 = SkipExtensionForVMULL(N0, DAG);
assert(Op0.getValueType().is64BitVector() &&
Op1.getValueType().is64BitVector() &&
"unexpected types for extended operands to VMULL");
@@ -5063,8 +5124,8 @@
// vaddl q0, d4, d5
// vmovl q1, d6
// vmul q0, q0, q1
- SDValue N00 = SkipExtension(N0->getOperand(0).getNode(), DAG);
- SDValue N01 = SkipExtension(N0->getOperand(1).getNode(), DAG);
+ SDValue N00 = SkipExtensionForVMULL(N0->getOperand(0).getNode(), DAG);
+ SDValue N01 = SkipExtensionForVMULL(N0->getOperand(1).getNode(), DAG);
EVT Op1VT = Op1.getValueType();
return DAG.getNode(N0->getOpcode(), DL, VT,
DAG.getNode(NewOpc, DL, VT,