Generalize BuildVectorSDNode::isConstantSplat to use APInts and handle
arbitrary vector sizes. Add an optional MinSplatBits parameter to specify
a minimum for the splat element size. Update the PPC target to use the
revised interface.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@65899 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 9dfd6be..886e726 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5555,93 +5555,64 @@
return Val.ConstVal->getType();
}
-// If this is a splat (repetition) of a value across the whole vector, return
-// the smallest size that splats it. For example, "0x01010101010101..." is a
-// splat of 0x01, 0x0101, and 0x01010101. We return SplatBits = 0x01 and
-// SplatSize = 1 byte.
-bool BuildVectorSDNode::isConstantSplat(unsigned &SplatBits,
- unsigned &SplatUndef,
- unsigned &SplatSize,
- bool &HasAnyUndefs) {
- uint64_t Bits128[2];
- uint64_t Undef128[2];
+bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue,
+ APInt &SplatUndef,
+ unsigned &SplatBitSize,
+ bool &HasAnyUndefs,
+ unsigned MinSplatBits) {
+ MVT VT = getValueType(0);
+ assert(VT.isVector() && "Expected a vector type");
+ unsigned sz = VT.getSizeInBits();
+ if (MinSplatBits > sz)
+ return false;
- // If this is a vector of constants or undefs, get the bits. A bit in
- // UndefBits is set if the corresponding element of the vector is an
- // ISD::UNDEF value. For undefs, the corresponding VectorBits values are
- // zero.
+ SplatValue = APInt(sz, 0);
+ SplatUndef = APInt(sz, 0);
- // Start with zero'd results.
- Bits128[0] = Bits128[1] = Undef128[0] = Undef128[1] = 0;
-
- unsigned EltBitSize = getOperand(0).getValueType().getSizeInBits();
- for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
+ // Get the bits. Bits with undefined values (when the corresponding element
+ // of the vector is an ISD::UNDEF value) are set in SplatUndef and cleared
+ // in SplatValue. If any of the values are not constant, give up and return
+ // false.
+ unsigned int nOps = getNumOperands();
+ assert(nOps > 0 && "isConstantSplat has 0-size build vector");
+ unsigned EltBitSize = VT.getVectorElementType().getSizeInBits();
+ for (unsigned i = 0; i < nOps; ++i) {
SDValue OpVal = getOperand(i);
+ unsigned BitPos = i * EltBitSize;
- unsigned PartNo = i >= e/2; // In the upper 128 bits?
- unsigned SlotNo = e/2 - (i & (e/2-1))-1; // Which subpiece of the uint64_t.
-
- uint64_t EltBits = 0;
- if (OpVal.getOpcode() == ISD::UNDEF) {
- uint64_t EltUndefBits = ~0U >> (32-EltBitSize);
- Undef128[PartNo] |= EltUndefBits << (SlotNo*EltBitSize);
- continue;
- } else if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(OpVal)) {
- EltBits = CN->getZExtValue() & (~0U >> (32-EltBitSize));
- } else if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(OpVal)) {
- assert(CN->getValueType(0) == MVT::f32 &&
- "Only one legal FP vector type!");
- EltBits = FloatToBits(CN->getValueAPF().convertToFloat());
- } else {
- // Nonconstant element.
+ if (OpVal.getOpcode() == ISD::UNDEF)
+ SplatUndef |= APInt::getBitsSet(sz, BitPos, BitPos +EltBitSize);
+ else if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(OpVal))
+ SplatValue |= APInt(CN->getAPIntValue()).zext(sz) << BitPos;
+ else if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(OpVal))
+ SplatValue |= CN->getValueAPF().bitcastToAPInt().zext(sz) << BitPos;
+ else
return false;
- }
-
- Bits128[PartNo] |= EltBits << (SlotNo*EltBitSize);
}
- // Don't let undefs prevent splats from matching. See if the top 64-bits are
- // the same as the lower 64-bits, ignoring undefs.
- if ((Bits128[0] & ~Undef128[1]) != (Bits128[1] & ~Undef128[0]))
- return false; // Can't be a splat if two pieces don't match.
+ // The build_vector is all constants or undefs. Find the smallest element
+ // size that splats the vector.
- uint64_t Bits64 = Bits128[0] | Bits128[1];
- uint64_t Undef64 = Undef128[0] & Undef128[1];
+ HasAnyUndefs = (SplatUndef != 0);
+ while (sz > 8) {
- // Check that the top 32-bits are the same as the lower 32-bits, ignoring
- // undefs.
- if ((Bits64 & (~Undef64 >> 32)) != ((Bits64 >> 32) & ~Undef64))
- return false; // Can't be a splat if two pieces don't match.
+ unsigned HalfSize = sz / 2;
+ APInt HighValue = APInt(SplatValue).lshr(HalfSize).trunc(HalfSize);
+ APInt LowValue = APInt(SplatValue).trunc(HalfSize);
+ APInt HighUndef = APInt(SplatUndef).lshr(HalfSize).trunc(HalfSize);
+ APInt LowUndef = APInt(SplatUndef).trunc(HalfSize);
- HasAnyUndefs = (Undef128[0] | Undef128[1]) != 0;
+ // If the two halves do not match (ignoring undef bits), stop here.
+ if ((HighValue & ~LowUndef) != (LowValue & ~HighUndef) ||
+ MinSplatBits > HalfSize)
+ break;
- uint32_t Bits32 = uint32_t(Bits64) | uint32_t(Bits64 >> 32);
- uint32_t Undef32 = uint32_t(Undef64) & uint32_t(Undef64 >> 32);
-
- // If the top 16-bits are different than the lower 16-bits, ignoring
- // undefs, we have an i32 splat.
- if ((Bits32 & (~Undef32 >> 16)) != ((Bits32 >> 16) & ~Undef32)) {
- SplatBits = Bits32;
- SplatUndef = Undef32;
- SplatSize = 4;
- return true;
+ SplatValue = HighValue | LowValue;
+ SplatUndef = HighUndef & LowUndef;
+
+ sz = HalfSize;
}
- uint16_t Bits16 = uint16_t(Bits32) | uint16_t(Bits32 >> 16);
- uint16_t Undef16 = uint16_t(Undef32) & uint16_t(Undef32 >> 16);
-
- // If the top 8-bits are different than the lower 8-bits, ignoring
- // undefs, we have an i16 splat.
- if ((Bits16 & (uint16_t(~Undef16) >> 8)) != ((Bits16 >> 8) & ~Undef16)) {
- SplatBits = Bits16;
- SplatUndef = Undef16;
- SplatSize = 2;
- return true;
- }
-
- // Otherwise, we have an 8-bit splat.
- SplatBits = uint8_t(Bits16) | uint8_t(Bits16 >> 8);
- SplatUndef = uint8_t(Undef16) & uint8_t(Undef16 >> 8);
- SplatSize = 1;
+ SplatBitSize = sz;
return true;
}