[WebAssembly] General vector shift lowering
Summary: Adds support for lowering non-splat shifts.
Reviewers: aheejin, dschuff
Subscribers: sbc100, jgravelle-google, sunfish, llvm-commits
Differential Revision: https://reviews.llvm.org/D53625
llvm-svn: 345916
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index d182bd9..578d235 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -146,10 +146,15 @@
}
}
- // Custom lowering to avoid having to emit a wrap for 2xi64 constant shifts
- if (Subtarget->hasSIMD128() && EnableUnimplementedWasmSIMDInstrs)
- for (auto Op : {ISD::SHL, ISD::SRA, ISD::SRL})
- setOperationAction(Op, MVT::v2i64, Custom);
+ // Custom lowering since wasm shifts must have a scalar shift amount
+ if (Subtarget->hasSIMD128()) {
+ for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
+ for (auto Op : {ISD::SHL, ISD::SRA, ISD::SRL})
+ setOperationAction(Op, T, Custom);
+ if (EnableUnimplementedWasmSIMDInstrs)
+ for (auto Op : {ISD::SHL, ISD::SRA, ISD::SRL})
+ setOperationAction(Op, MVT::v2i64, Custom);
+ }
// There is no select instruction for vectors
if (Subtarget->hasSIMD128()) {
@@ -1082,13 +1087,23 @@
SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
- auto *ShiftVec = dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode());
- APInt SplatValue, SplatUndef;
- unsigned SplatBitSize;
- bool HasAnyUndefs;
- if (!ShiftVec || !ShiftVec->isConstantSplat(SplatValue, SplatUndef,
- SplatBitSize, HasAnyUndefs))
+
+ // Only manually lower vector shifts
+ assert(Op.getSimpleValueType().isVector());
+
+ // Unroll non-splat vector shifts
+ BuildVectorSDNode *ShiftVec;
+ SDValue SplatVal;
+ if (!(ShiftVec = dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode())) ||
+ !(SplatVal = ShiftVec->getSplatValue()))
+ return DAG.UnrollVectorOp(Op.getNode());
+
+ // All splats except i64x2 const splats are handled by patterns
+ ConstantSDNode *SplatConst = dyn_cast<ConstantSDNode>(SplatVal);
+ if (!SplatConst || Op.getSimpleValueType() != MVT::v2i64)
return Op;
+
+ // i64x2 const splats are custom lowered to avoid unnecessary wraps
unsigned Opcode;
switch (Op.getOpcode()) {
case ISD::SHL:
@@ -1102,10 +1117,10 @@
break;
default:
llvm_unreachable("unexpected opcode");
- return Op;
}
+ APInt Shift = SplatConst->getAPIntValue().zextOrTrunc(32);
return DAG.getNode(Opcode, DL, Op.getValueType(), Op.getOperand(0),
- DAG.getConstant(SplatValue.trunc(32), DL, MVT::i32));
+ DAG.getConstant(Shift, DL, MVT::i32));
}
//===----------------------------------------------------------------------===//