Introduce a new VTSDNode class with the ultimate goal of eliminating the
MVTSDNode class. This class is used to provide an operand to operators
that require an extra type. We start by converting FP_ROUND_INREG and
SIGN_EXTEND_INREG over to using it.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@22364 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index a9721e2..c6fc076 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -539,7 +539,7 @@
SDOperand ValRes;
if (Node->getOpcode() == ISD::SEXTLOAD)
ValRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, Result.getValueType(),
- Result, SrcVT);
+ Result, DAG.getValueType(SrcVT));
else
ValRes = DAG.getZeroExtendInReg(Result, SrcVT);
AddLegalizedOperand(SDOperand(Node, 0), ValRes);
@@ -808,8 +808,10 @@
case ISD::SETGT:
case ISD::SETLT:
case ISD::SETLE:
- Tmp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp1, VT);
- Tmp2 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp2, VT);
+ Tmp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp1,
+ DAG.getValueType(VT));
+ Tmp2 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp2,
+ DAG.getValueType(VT));
break;
}
@@ -1403,7 +1405,8 @@
// NOTE: Any extend would work here...
Result = DAG.getNode(ISD::ZERO_EXTEND, Op.getValueType(), Result);
Result = DAG.getNode(ISD::SIGN_EXTEND_INREG, Result.getValueType(),
- Result, Node->getOperand(0).getValueType());
+ Result,
+ DAG.getValueType(Node->getOperand(0).getValueType()));
break;
case ISD::TRUNCATE:
Result = PromoteOp(Node->getOperand(0));
@@ -1424,7 +1427,8 @@
case ISD::SINT_TO_FP:
Result = PromoteOp(Node->getOperand(0));
Result = DAG.getNode(ISD::SIGN_EXTEND_INREG, Result.getValueType(),
- Result, Node->getOperand(0).getValueType());
+ Result,
+ DAG.getValueType(Node->getOperand(0).getValueType()));
Result = DAG.getNode(ISD::SINT_TO_FP, Op.getValueType(), Result);
break;
case ISD::UINT_TO_FP:
@@ -1439,7 +1443,7 @@
case ISD::FP_ROUND_INREG:
case ISD::SIGN_EXTEND_INREG: {
Tmp1 = LegalizeOp(Node->getOperand(0));
- MVT::ValueType ExtraVT = cast<MVTSDNode>(Node)->getExtraValueType();
+ MVT::ValueType ExtraVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
// If this operation is not supported, convert it to a shl/shr or load/store
// pair.
@@ -1593,7 +1597,7 @@
// The high bits are not guaranteed to be anything. Insert an extend.
if (Node->getOpcode() == ISD::SIGN_EXTEND)
Result = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Result,
- Node->getOperand(0).getValueType());
+ DAG.getValueType(Node->getOperand(0).getValueType()));
else
Result = DAG.getZeroExtendInReg(Result,
Node->getOperand(0).getValueType());
@@ -1610,7 +1614,8 @@
case Legal:
// Input is legal? Do an FP_ROUND_INREG.
Result = LegalizeOp(Node->getOperand(0));
- Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result, VT);
+ Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result,
+ DAG.getValueType(VT));
break;
}
break;
@@ -1628,7 +1633,8 @@
Result = PromoteOp(Node->getOperand(0));
if (Node->getOpcode() == ISD::SINT_TO_FP)
Result = DAG.getNode(ISD::SIGN_EXTEND_INREG, Result.getValueType(),
- Result, Node->getOperand(0).getValueType());
+ Result,
+ DAG.getValueType(Node->getOperand(0).getValueType()));
else
Result = DAG.getZeroExtendInReg(Result,
Node->getOperand(0).getValueType());
@@ -1640,7 +1646,8 @@
Node->getOperand(0));
// Round if we cannot tolerate excess precision.
if (NoExcessFPPrecision)
- Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result, VT);
+ Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result,
+ DAG.getValueType(VT));
break;
}
break;
@@ -1679,7 +1686,8 @@
assert(Tmp1.getValueType() == NVT);
Result = DAG.getNode(Node->getOpcode(), NVT, Tmp1);
if(NoExcessFPPrecision)
- Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result, VT);
+ Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result,
+ DAG.getValueType(VT));
break;
case ISD::AND:
@@ -1702,7 +1710,8 @@
// FIXME: Why would we need to round FP ops more than integer ones?
// Is Round(Add(Add(A,B),C)) != Round(Add(Round(Add(A,B)), C))
if (MVT::isFloatingPoint(NVT) && NoExcessFPPrecision)
- Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result, VT);
+ Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result,
+ DAG.getValueType(VT));
break;
case ISD::SDIV:
@@ -1711,14 +1720,17 @@
Tmp1 = PromoteOp(Node->getOperand(0));
Tmp2 = PromoteOp(Node->getOperand(1));
if (MVT::isInteger(NVT)) {
- Tmp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp1, VT);
- Tmp2 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp2, VT);
+ Tmp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp1,
+ DAG.getValueType(VT));
+ Tmp2 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp2,
+ DAG.getValueType(VT));
}
Result = DAG.getNode(Node->getOpcode(), NVT, Tmp1, Tmp2);
// Perform FP_ROUND: this is probably overly pessimistic.
if (MVT::isFloatingPoint(NVT) && NoExcessFPPrecision)
- Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result, VT);
+ Result = DAG.getNode(ISD::FP_ROUND_INREG, NVT, Result,
+ DAG.getValueType(VT));
break;
case ISD::UDIV:
@@ -1740,7 +1752,8 @@
case ISD::SRA:
// The input value must be properly sign extended.
Tmp1 = PromoteOp(Node->getOperand(0));
- Tmp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp1, VT);
+ Tmp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, NVT, Tmp1,
+ DAG.getValueType(VT));
Tmp2 = LegalizeOp(Node->getOperand(1));
Result = DAG.getNode(ISD::SRA, NVT, Tmp1, Tmp2);
break;
@@ -2520,7 +2533,7 @@
In = PromoteOp(Node->getOperand(0));
// Emit the appropriate sign_extend_inreg to get the value we want.
In = DAG.getNode(ISD::SIGN_EXTEND_INREG, In.getValueType(), In,
- Node->getOperand(0).getValueType());
+ DAG.getValueType(Node->getOperand(0).getValueType()));
break;
}
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 8a34a59..9cf6e0f 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -220,7 +220,9 @@
case ISD::ExternalSymbol:
ExternalSymbols.erase(cast<ExternalSymbolSDNode>(N)->getSymbol());
break;
-
+ case ISD::VALUETYPE:
+ ValueTypeNodes[cast<VTSDNode>(N)->getVT()] = 0;
+ break;
case ISD::LOAD:
Loads.erase(std::make_pair(N->getOperand(1),
std::make_pair(N->getOperand(0),
@@ -234,8 +236,6 @@
N->getValueType(0))));
break;
case ISD::TRUNCSTORE:
- case ISD::SIGN_EXTEND_INREG:
- case ISD::FP_ROUND_INREG:
case ISD::EXTLOAD:
case ISD::SEXTLOAD:
case ISD::ZEXTLOAD: {
@@ -374,6 +374,17 @@
return SDOperand(N, 0);
}
+SDOperand SelectionDAG::getValueType(MVT::ValueType VT) {
+ if ((unsigned)VT >= ValueTypeNodes.size())
+ ValueTypeNodes.resize(VT+1);
+ if (ValueTypeNodes[VT] == 0) {
+ ValueTypeNodes[VT] = new VTSDNode(VT);
+ AllNodes.push_back(ValueTypeNodes[VT]);
+ }
+
+ return SDOperand(ValueTypeNodes[VT], 0);
+}
+
SDOperand SelectionDAG::getExternalSymbol(const char *Sym, MVT::ValueType VT) {
SDNode *&N = ExternalSymbols[Sym];
if (N) return SDOperand(N, 0);
@@ -864,6 +875,22 @@
assert(MVT::isInteger(VT) && MVT::isInteger(N2.getValueType()) &&
VT != MVT::i1 && "Shifts only work on integers");
break;
+ case ISD::FP_ROUND_INREG: {
+ MVT::ValueType EVT = cast<VTSDNode>(N2)->getVT();
+ assert(VT == N1.getValueType() && "Not an inreg round!");
+ assert(MVT::isFloatingPoint(VT) && MVT::isFloatingPoint(EVT) &&
+ "Cannot FP_ROUND_INREG integer types");
+ assert(EVT <= VT && "Not rounding down!");
+ break;
+ }
+ case ISD::SIGN_EXTEND_INREG: {
+ MVT::ValueType EVT = cast<VTSDNode>(N2)->getVT();
+ assert(VT == N1.getValueType() && "Not an inreg extend!");
+ assert(MVT::isInteger(VT) && MVT::isInteger(EVT) &&
+ "Cannot *_EXTEND_INREG FP types");
+ assert(EVT <= VT && "Not extending!");
+ }
+
default: break;
}
#endif
@@ -918,6 +945,10 @@
case ISD::SRA: // sra -1, X -> -1
if (N1C->isAllOnesValue()) return N1;
break;
+ case ISD::SIGN_EXTEND_INREG: // SIGN_EXTEND_INREG N1C, EVT
+ // Extending a constant? Just return the extended constant.
+ SDOperand Tmp = getNode(ISD::TRUNCATE, cast<VTSDNode>(N2)->getVT(), N1);
+ return getNode(ISD::SIGN_EXTEND, VT, Tmp);
}
}
@@ -1026,7 +1057,7 @@
// If we are masking out the part of our input that was extended, just
// mask the input to the extension directly.
unsigned ExtendBits =
- MVT::getSizeInBits(cast<MVTSDNode>(N1)->getExtraValueType());
+ MVT::getSizeInBits(cast<VTSDNode>(N1.getOperand(1))->getVT());
if ((C2 & (~0ULL << ExtendBits)) == 0)
return getNode(ISD::AND, VT, N1.getOperand(0), N2);
}
@@ -1072,7 +1103,7 @@
ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1.Val);
ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2.Val);
- if (N1CFP)
+ if (N1CFP) {
if (N2CFP) {
double C1 = N1CFP->getValue(), C2 = N2CFP->getValue();
switch (Opcode) {
@@ -1095,6 +1126,11 @@
}
}
+ if (Opcode == ISD::FP_ROUND_INREG)
+ return getNode(ISD::FP_EXTEND, VT,
+ getNode(ISD::FP_ROUND, cast<VTSDNode>(N2)->getVT(), N1));
+ }
+
// Finally, fold operations that do not require constants.
switch (Opcode) {
case ISD::TokenFactor:
@@ -1199,6 +1235,42 @@
if (N2.getOpcode() == ISD::FNEG) // (A- (-B) -> A+B
return getNode(ISD::ADD, VT, N1, N2.getOperand(0));
break;
+ case ISD::FP_ROUND_INREG:
+ if (cast<VTSDNode>(N2)->getVT() == VT) return N1; // Not actually rounding.
+ break;
+ case ISD::SIGN_EXTEND_INREG: {
+ MVT::ValueType EVT = cast<VTSDNode>(N2)->getVT();
+ if (EVT == VT) return N1; // Not actually extending
+
+ // If we are sign extending an extension, use the original source.
+ if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG)
+ if (cast<VTSDNode>(N1.getOperand(1))->getVT() <= EVT)
+ return N1;
+
+ // If we are sign extending a sextload, return just the load.
+ if (N1.getOpcode() == ISD::SEXTLOAD)
+ if (cast<MVTSDNode>(N1)->getExtraValueType() <= EVT)
+ return N1;
+
+ // If we are extending the result of a setcc, and we already know the
+ // contents of the top bits, eliminate the extension.
+ if (N1.getOpcode() == ISD::SETCC &&
+ TLI.getSetCCResultContents() ==
+ TargetLowering::ZeroOrNegativeOneSetCCResult)
+ return N1;
+
+ // If we are sign extending the result of an (and X, C) operation, and we
+ // know the extended bits are zeros already, don't do the extend.
+ if (N1.getOpcode() == ISD::AND)
+ if (ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
+ uint64_t Mask = N1C->getValue();
+ unsigned NumBits = MVT::getSizeInBits(EVT);
+ if ((Mask & (~0ULL << (NumBits-1))) == 0)
+ return N1;
+ }
+ break;
+ }
+
// FIXME: figure out how to safely handle things like
// int foo(int x) { return 1 << (x & 255); }
// int bar() { return foo(256); }
@@ -1207,7 +1279,7 @@
case ISD::SRL:
case ISD::SRA:
if (N2.getOpcode() == ISD::SIGN_EXTEND_INREG &&
- cast<MVTSDNode>(N2)->getExtraValueType() != MVT::i1)
+ cast<VTSDNode>(N2.getOperand(1))->getVT() != MVT::i1)
return getNode(Opcode, VT, N1, N2.getOperand(0));
else if (N2.getOpcode() == ISD::AND)
if (ConstantSDNode *AndRHS = dyn_cast<ConstantSDNode>(N2.getOperand(1))) {
@@ -1450,7 +1522,7 @@
case ISD::SRL_PARTS:
case ISD::SHL_PARTS:
if (N3.getOpcode() == ISD::SIGN_EXTEND_INREG &&
- cast<MVTSDNode>(N3)->getExtraValueType() != MVT::i1)
+ cast<VTSDNode>(N3.getOperand(1))->getVT() != MVT::i1)
return getNode(Opcode, VT, N1, N2, N3.getOperand(0));
else if (N3.getOpcode() == ISD::AND)
if (ConstantSDNode *AndRHS = dyn_cast<ConstantSDNode>(N3.getOperand(1))) {
@@ -1477,61 +1549,6 @@
SDOperand SelectionDAG::getNode(unsigned Opcode, MVT::ValueType VT,SDOperand N1,
MVT::ValueType EVT) {
-
- switch (Opcode) {
- default: assert(0 && "Bad opcode for this accessor!");
- case ISD::FP_ROUND_INREG:
- assert(VT == N1.getValueType() && "Not an inreg round!");
- assert(MVT::isFloatingPoint(VT) && MVT::isFloatingPoint(EVT) &&
- "Cannot FP_ROUND_INREG integer types");
- if (EVT == VT) return N1; // Not actually rounding
- assert(EVT < VT && "Not rounding down!");
-
- if (isa<ConstantFPSDNode>(N1))
- return getNode(ISD::FP_EXTEND, VT, getNode(ISD::FP_ROUND, EVT, N1));
- break;
- case ISD::SIGN_EXTEND_INREG:
- assert(VT == N1.getValueType() && "Not an inreg extend!");
- assert(MVT::isInteger(VT) && MVT::isInteger(EVT) &&
- "Cannot *_EXTEND_INREG FP types");
- if (EVT == VT) return N1; // Not actually extending
- assert(EVT < VT && "Not extending!");
-
- // Extending a constant? Just return the extended constant.
- if (ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.Val)) {
- SDOperand Tmp = getNode(ISD::TRUNCATE, EVT, N1);
- return getNode(ISD::SIGN_EXTEND, VT, Tmp);
- }
-
- // If we are sign extending an extension, use the original source.
- if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG)
- if (cast<MVTSDNode>(N1)->getExtraValueType() <= EVT)
- return N1;
-
- // If we are sign extending a sextload, return just the load.
- if (N1.getOpcode() == ISD::SEXTLOAD && Opcode == ISD::SIGN_EXTEND_INREG)
- if (cast<MVTSDNode>(N1)->getExtraValueType() <= EVT)
- return N1;
-
- // If we are extending the result of a setcc, and we already know the
- // contents of the top bits, eliminate the extension.
- if (N1.getOpcode() == ISD::SETCC &&
- TLI.getSetCCResultContents() ==
- TargetLowering::ZeroOrNegativeOneSetCCResult)
- return N1;
-
- // If we are sign extending the result of an (and X, C) operation, and we
- // know the extended bits are zeros already, don't do the extend.
- if (N1.getOpcode() == ISD::AND)
- if (ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
- uint64_t Mask = N1C->getValue();
- unsigned NumBits = MVT::getSizeInBits(EVT);
- if ((Mask & (~0ULL << (NumBits-1))) == 0)
- return N1;
- }
- break;
- }
-
EVTStruct NN;
NN.Opcode = Opcode;
NN.VT = VT;