Teach dag combine to fold the following transformation more aggressively:
(OP (trunc x), (trunc y)) -> (trunc (OP x, y))
Unfortunately this simple change causes dag combine to infinite looping. The problem is the shrink demanded ops optimization tend to canonicalize expressions in the opposite manner. That is badness. This patch disable those optimizations in dag combine but instead it is done as a late pass in sdisel.
This also exposes some deficiencies in dag combine and x86 setcc / brcond lowering. Teach them to look pass ISD::TRUNCATE in various places.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@92849 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 759fa0e..5ab9280 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1688,18 +1688,18 @@
// fold (OP (zext x), (zext y)) -> (zext (OP x, y))
// fold (OP (sext x), (sext y)) -> (sext (OP x, y))
// fold (OP (aext x), (aext y)) -> (aext (OP x, y))
- // fold (OP (trunc x), (trunc y)) -> (trunc (OP x, y)) (if trunc isn't free)
+ // fold (OP (trunc x), (trunc y)) -> (trunc (OP x, y))
//
// do not sink logical op inside of a vector extend, since it may combine
// into a vsetcc.
- if ((N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND||
+ EVT Op0VT = N0.getOperand(0).getValueType();
+ if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
+ N0.getOpcode() == ISD::ANY_EXTEND ||
N0.getOpcode() == ISD::SIGN_EXTEND ||
- (N0.getOpcode() == ISD::TRUNCATE &&
- !TLI.isTruncateFree(N0.getOperand(0).getValueType(), VT))) &&
+ (N0.getOpcode() == ISD::TRUNCATE && TLI.isTypeLegal(Op0VT))) &&
!VT.isVector() &&
- N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
- (!LegalOperations ||
- TLI.isOperationLegal(N->getOpcode(), N0.getOperand(0).getValueType()))) {
+ Op0VT == N1.getOperand(0).getValueType() &&
+ (!LegalOperations || TLI.isOperationLegal(N->getOpcode(), Op0VT))) {
SDValue ORNode = DAG.getNode(N->getOpcode(), N0.getDebugLoc(),
N0.getOperand(0).getValueType(),
N0.getOperand(0), N1.getOperand(0));
@@ -1839,6 +1839,7 @@
if (!VT.isVector() &&
SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
+
// fold (zext_inreg (extload x)) -> (zextload x)
if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode())) {
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
@@ -1885,48 +1886,89 @@
// fold (and (load x), 255) -> (zextload x, i8)
// fold (and (extload x, i16), 255) -> (zextload x, i8)
- if (N1C && N0.getOpcode() == ISD::LOAD) {
- LoadSDNode *LN0 = cast<LoadSDNode>(N0);
+ // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8)
+ if (N1C && (N0.getOpcode() == ISD::LOAD ||
+ (N0.getOpcode() == ISD::ANY_EXTEND &&
+ N0.getOperand(0).getOpcode() == ISD::LOAD))) {
+ bool HasAnyExt = N0.getOpcode() == ISD::ANY_EXTEND;
+ LoadSDNode *LN0 = HasAnyExt
+ ? cast<LoadSDNode>(N0.getOperand(0))
+ : cast<LoadSDNode>(N0);
if (LN0->getExtensionType() != ISD::SEXTLOAD &&
- LN0->isUnindexed() && N0.hasOneUse() &&
- // Do not change the width of a volatile load.
- !LN0->isVolatile()) {
- EVT ExtVT = MVT::Other;
+ LN0->isUnindexed() && N0.hasOneUse()) {
uint32_t ActiveBits = N1C->getAPIntValue().getActiveBits();
- if (ActiveBits > 0 && APIntOps::isMask(ActiveBits, N1C->getAPIntValue()))
- ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
+ if (ActiveBits > 0 && APIntOps::isMask(ActiveBits, N1C->getAPIntValue())){
+ EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
+ EVT LoadedVT = LN0->getMemoryVT();
- EVT LoadedVT = LN0->getMemoryVT();
+ if (ExtVT == LoadedVT &&
+ (!LegalOperations || TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT))) {
+ if (HasAnyExt) {
+ SDValue Load =
+ DAG.getExtLoad(ISD::ZEXTLOAD, LN0->getDebugLoc(),
+ LN0->getValueType(0),
+ LN0->getChain(), LN0->getBasePtr(),
+ LN0->getSrcValue(), LN0->getSrcValueOffset(),
+ ExtVT, LN0->isVolatile(), LN0->getAlignment());
+ AddToWorkList(N);
+ CombineTo(N0.getOperand(0).getNode(), Load, Load.getValue(1));
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ } else {
+ SDValue Load =
+ DAG.getExtLoad(ISD::ZEXTLOAD, LN0->getDebugLoc(), VT,
+ LN0->getChain(), LN0->getBasePtr(),
+ LN0->getSrcValue(), LN0->getSrcValueOffset(),
+ ExtVT, LN0->isVolatile(), LN0->getAlignment());
+ AddToWorkList(N);
+ CombineTo(N0.getNode(), Load, Load.getValue(1));
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ }
+ } else if (!LN0->isVolatile()) {
+ // Do not change the width of a volatile load.
+ // Do not generate loads of non-round integer types since these can
+ // be expensive (and would be wrong if the type is not byte sized).
+ if (LoadedVT.bitsGT(ExtVT) && ExtVT.isRound() &&
+ (!LegalOperations || TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT))) {
+ EVT PtrType = LN0->getOperand(1).getValueType();
- // Do not generate loads of non-round integer types since these can
- // be expensive (and would be wrong if the type is not byte sized).
- if (ExtVT != MVT::Other && LoadedVT.bitsGT(ExtVT) && ExtVT.isRound() &&
- (!LegalOperations || TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT))) {
- EVT PtrType = N0.getOperand(1).getValueType();
+ // For big endian targets, we need to add an offset to the pointer
+ // to load the correct bytes. For little endian systems, we merely
+ // need to read fewer bytes from the same pointer.
+ unsigned LVTStoreBytes = LoadedVT.getStoreSize();
+ unsigned EVTStoreBytes = ExtVT.getStoreSize();
+ unsigned PtrOff = LVTStoreBytes - EVTStoreBytes;
+ unsigned Alignment = LN0->getAlignment();
+ SDValue NewPtr = LN0->getBasePtr();
- // For big endian targets, we need to add an offset to the pointer to
- // load the correct bytes. For little endian systems, we merely need to
- // read fewer bytes from the same pointer.
- unsigned LVTStoreBytes = LoadedVT.getStoreSize();
- unsigned EVTStoreBytes = ExtVT.getStoreSize();
- unsigned PtrOff = LVTStoreBytes - EVTStoreBytes;
- unsigned Alignment = LN0->getAlignment();
- SDValue NewPtr = LN0->getBasePtr();
+ if (TLI.isBigEndian()) {
+ NewPtr = DAG.getNode(ISD::ADD, LN0->getDebugLoc(), PtrType,
+ NewPtr, DAG.getConstant(PtrOff, PtrType));
+ Alignment = MinAlign(Alignment, PtrOff);
+ }
- if (TLI.isBigEndian()) {
- NewPtr = DAG.getNode(ISD::ADD, LN0->getDebugLoc(), PtrType,
- NewPtr, DAG.getConstant(PtrOff, PtrType));
- Alignment = MinAlign(Alignment, PtrOff);
+ AddToWorkList(NewPtr.getNode());
+ if (HasAnyExt) {
+ SDValue Load =
+ DAG.getExtLoad(ISD::ZEXTLOAD, LN0->getDebugLoc(),
+ LN0->getValueType(0),
+ LN0->getChain(), NewPtr,
+ LN0->getSrcValue(), LN0->getSrcValueOffset(),
+ ExtVT, LN0->isVolatile(), Alignment);
+ AddToWorkList(N);
+ CombineTo(N0.getOperand(0).getNode(), Load, Load.getValue(1));
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ } else {
+ SDValue Load =
+ DAG.getExtLoad(ISD::ZEXTLOAD, LN0->getDebugLoc(), VT,
+ LN0->getChain(), NewPtr,
+ LN0->getSrcValue(), LN0->getSrcValueOffset(),
+ ExtVT, LN0->isVolatile(), Alignment);
+ AddToWorkList(N);
+ CombineTo(N0.getNode(), Load, Load.getValue(1));
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ }
+ }
}
-
- AddToWorkList(NewPtr.getNode());
- SDValue Load =
- DAG.getExtLoad(ISD::ZEXTLOAD, LN0->getDebugLoc(), VT, LN0->getChain(),
- NewPtr, LN0->getSrcValue(), LN0->getSrcValueOffset(),
- ExtVT, LN0->isVolatile(), Alignment);
- AddToWorkList(N);
- CombineTo(N0.getNode(), Load, Load.getValue(1));
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
}
@@ -2778,9 +2820,17 @@
// However when after the source operand of SRL is optimized into AND, the SRL
// itself may not be optimized further. Look for it and add the BRCOND into
// the worklist.
- if (N->hasOneUse() &&
- N->use_begin()->getOpcode() == ISD::BRCOND)
- AddToWorkList(*N->use_begin());
+ if (N->hasOneUse()) {
+ SDNode *Use = *N->use_begin();
+ if (Use->getOpcode() == ISD::BRCOND)
+ AddToWorkList(Use);
+ else if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse()) {
+ // Also look pass the truncate.
+ Use = *Use->use_begin();
+ if (Use->getOpcode() == ISD::BRCOND)
+ AddToWorkList(Use);
+ }
+ }
return SDValue();
}
@@ -3198,7 +3248,10 @@
// fold (zext (truncate x)) -> (and x, mask)
if (N0.getOpcode() == ISD::TRUNCATE &&
- (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT))) {
+ (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) &&
+ (!TLI.isTruncateFree(N0.getOperand(0).getValueType(),
+ N0.getValueType()) ||
+ !TLI.isZExtFree(N0.getValueType(), VT))) {
SDValue Op = N0.getOperand(0);
if (Op.getValueType().bitsLT(VT)) {
Op = DAG.getNode(ISD::ANY_EXTEND, N->getDebugLoc(), VT, Op);
@@ -3704,7 +3757,7 @@
return DAG.getNode(ISD::TRUNCATE, N->getDebugLoc(), VT, N0.getOperand(0));
else
// if the source and dest are the same type, we can drop both the extend
- // and the truncate
+ // and the truncate.
return N0.getOperand(0);
}
@@ -4515,6 +4568,13 @@
N1.getOperand(0), N1.getOperand(1), N2);
}
+ SDNode *Trunc = 0;
+ if (N1.getOpcode() == ISD::TRUNCATE && N1.hasOneUse()) {
+ // Look pass truncate.
+ Trunc = N1.getNode();
+ N1 = N1.getOperand(0);
+ }
+
if (N1.hasOneUse() && N1.getOpcode() == ISD::SRL) {
// Match this pattern so that we can generate simpler code:
//
@@ -4526,7 +4586,7 @@
// into
//
// %a = ...
- // %b = and %a, 2
+ // %b = and i32 %a, 2
// %c = setcc eq %b, 0
// brcond %c ...
//
@@ -4537,7 +4597,6 @@
SDValue Op1 = N1.getOperand(1);
if (Op0.getOpcode() == ISD::AND &&
- Op0.hasOneUse() &&
Op1.getOpcode() == ISD::Constant) {
SDValue AndOp1 = Op0.getOperand(1);
@@ -4552,12 +4611,21 @@
Op0, DAG.getConstant(0, Op0.getValueType()),
ISD::SETNE);
+ SDValue NewBRCond = DAG.getNode(ISD::BRCOND, N->getDebugLoc(),
+ MVT::Other, Chain, SetCC, N2);
+ // Don't add the new BRCond into the worklist or else SimplifySelectCC
+ // will convert it back to (X & C1) >> C2.
+ CombineTo(N, NewBRCond, false);
+ // Truncate is dead.
+ if (Trunc) {
+ removeFromWorkList(Trunc);
+ DAG.DeleteNode(Trunc);
+ }
// Replace the uses of SRL with SETCC
DAG.ReplaceAllUsesOfValueWith(N1, SetCC);
removeFromWorkList(N1.getNode());
DAG.DeleteNode(N1.getNode());
- return DAG.getNode(ISD::BRCOND, N->getDebugLoc(),
- MVT::Other, Chain, SetCC, N2);
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
}
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f1c8650..ca8c17b 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2656,6 +2656,8 @@
// size of the value, the shift/rotate count is guaranteed to be zero.
if (VT == MVT::i1)
return N1;
+ if (N2C && N2C->isNullValue())
+ return N1;
break;
case ISD::FP_ROUND_INREG: {
EVT EVT = cast<VTSDNode>(N2)->getVT();
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 3073dfe..8ed24cc 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -438,6 +438,75 @@
SDB->clear();
}
+void SelectionDAGISel::ShrinkDemandedOps() {
+ SmallVector<SDNode*, 128> Worklist;
+
+ // Add all the dag nodes to the worklist.
+ Worklist.reserve(CurDAG->allnodes_size());
+ for (SelectionDAG::allnodes_iterator I = CurDAG->allnodes_begin(),
+ E = CurDAG->allnodes_end(); I != E; ++I)
+ Worklist.push_back(I);
+
+ APInt Mask;
+ APInt KnownZero;
+ APInt KnownOne;
+
+ TargetLowering::TargetLoweringOpt TLO(*CurDAG, true);
+ while (!Worklist.empty()) {
+ SDNode *N = Worklist.back();
+ Worklist.pop_back();
+
+ if (N->use_empty() && N != CurDAG->getRoot().getNode()) {
+ CurDAG->DeleteNode(N);
+ continue;
+ }
+
+ // Run ShrinkDemandedOp on scalar binary operations.
+ if (N->getNumValues() == 1 &&
+ N->getValueType(0).isSimple() && N->getValueType(0).isInteger()) {
+ DebugLoc dl = N->getDebugLoc();
+ unsigned BitWidth = N->getValueType(0).getScalarType().getSizeInBits();
+ APInt Demanded = APInt::getAllOnesValue(BitWidth);
+ APInt KnownZero, KnownOne;
+ if (TLI.SimplifyDemandedBits(SDValue(N, 0), Demanded,
+ KnownZero, KnownOne, TLO)) {
+ // Revisit the node.
+ Worklist.erase(std::remove(Worklist.begin(), Worklist.end(), N),
+ Worklist.end());
+ Worklist.push_back(N);
+
+ // Replace the old value with the new one.
+ DEBUG(errs() << "\nReplacing ";
+ TLO.Old.getNode()->dump(CurDAG);
+ errs() << "\nWith: ";
+ TLO.New.getNode()->dump(CurDAG);
+ errs() << '\n');
+
+ Worklist.push_back(TLO.New.getNode());
+ CurDAG->ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
+
+ if (TLO.Old.getNode()->use_empty()) {
+ for (unsigned i = 0, e = TLO.Old.getNode()->getNumOperands();
+ i != e; ++i) {
+ SDNode *OpNode = TLO.Old.getNode()->getOperand(i).getNode();
+ if (OpNode->hasOneUse()) {
+ Worklist.erase(std::remove(Worklist.begin(), Worklist.end(),
+ OpNode),
+ Worklist.end());
+ Worklist.push_back(TLO.Old.getNode()->getOperand(i).getNode());
+ }
+ }
+
+ Worklist.erase(std::remove(Worklist.begin(), Worklist.end(),
+ TLO.Old.getNode()),
+ Worklist.end());
+ CurDAG->DeleteNode(TLO.Old.getNode());
+ }
+ }
+ }
+ }
+}
+
void SelectionDAGISel::ComputeLiveOutVRegInfo() {
SmallPtrSet<SDNode*, 128> VisitedNodes;
SmallVector<SDNode*, 128> Worklist;
@@ -609,8 +678,10 @@
if (ViewISelDAGs) CurDAG->viewGraph("isel input for " + BlockName);
- if (OptLevel != CodeGenOpt::None)
+ if (OptLevel != CodeGenOpt::None) {
+ ShrinkDemandedOps();
ComputeLiveOutVRegInfo();
+ }
// Third, instruction select all of the operations to machine code, adding the
// code to the MachineBasicBlock.
diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index d9a5a13..f7694db 100644
--- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -990,7 +990,7 @@
if (TLO.ShrinkDemandedConstant(Op, ~KnownZero2 & NewMask))
return true;
// If the operation can be done in a smaller type, do so.
- if (TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
+ if (TLO.ShrinkOps && TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
return true;
// Output known-1 bits are only known if set in both the LHS & RHS.
@@ -1024,7 +1024,7 @@
if (TLO.ShrinkDemandedConstant(Op, NewMask))
return true;
// If the operation can be done in a smaller type, do so.
- if (TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
+ if (TLO.ShrinkOps && TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
return true;
// Output known-0 bits are only known if clear in both the LHS & RHS.
@@ -1049,7 +1049,7 @@
if ((KnownZero2 & NewMask) == NewMask)
return TLO.CombineTo(Op, Op.getOperand(1));
// If the operation can be done in a smaller type, do so.
- if (TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
+ if (TLO.ShrinkOps && TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
return true;
// If all of the unknown bits are known to be zero on one side or the other
@@ -1480,7 +1480,7 @@
KnownOne2, TLO, Depth+1))
return true;
// See if the operation should be performed at a smaller bit width.
- if (TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
+ if (TLO.ShrinkOps && TLO.ShrinkDemandedOp(Op, BitWidth, NewMask, dl))
return true;
}
// FALL THROUGH
@@ -1876,7 +1876,9 @@
// Fold bit comparisons when we can.
if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
- VT == N0.getValueType() && N0.getOpcode() == ISD::AND)
+ (VT == N0.getValueType() ||
+ (isTypeLegal(VT) && VT.bitsLE(N0.getValueType()))) &&
+ N0.getOpcode() == ISD::AND)
if (ConstantSDNode *AndRHS =
dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
EVT ShiftTy = DCI.isBeforeLegalize() ?
@@ -1884,16 +1886,18 @@
if (Cond == ISD::SETNE && C1 == 0) {// (X & 8) != 0 --> (X & 8) >> 3
// Perform the xform if the AND RHS is a single bit.
if (isPowerOf2_64(AndRHS->getZExtValue())) {
- return DAG.getNode(ISD::SRL, dl, VT, N0,
+ return DAG.getNode(ISD::TRUNCATE, dl, VT,
+ DAG.getNode(ISD::SRL, dl, N0.getValueType(), N0,
DAG.getConstant(Log2_64(AndRHS->getZExtValue()),
- ShiftTy));
+ ShiftTy)));
}
} else if (Cond == ISD::SETEQ && C1 == AndRHS->getZExtValue()) {
// (X & 8) == 8 --> (X & 8) >> 3
// Perform the xform if C1 is a single bit.
if (C1.isPowerOf2()) {
- return DAG.getNode(ISD::SRL, dl, VT, N0,
- DAG.getConstant(C1.logBase2(), ShiftTy));
+ return DAG.getNode(ISD::TRUNCATE, dl, VT,
+ DAG.getNode(ISD::SRL, dl, N0.getValueType(), N0,
+ DAG.getConstant(C1.logBase2(), ShiftTy)));
}
}
}