[ExpandMemCmp] Properly constant-fold all compares.
Summary:
This gets rid of duplicated code and diverging behaviour w.r.t.
constants.
Fixes PR45086.
Subscribers: hiraditya, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D75519
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp
index d0dd538..589a6d3 100644
--- a/llvm/lib/CodeGen/ExpandMemCmp.cpp
+++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp
@@ -103,8 +103,12 @@
Value *getMemCmpExpansionZeroCase();
Value *getMemCmpEqZeroOneBlock();
Value *getMemCmpOneBlock();
- Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType,
- uint64_t OffsetBytes);
+ struct LoadPair {
+ Value *Lhs = nullptr;
+ Value *Rhs = nullptr;
+ };
+ LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
+ unsigned OffsetBytes);
static LoadEntryVector
computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
@@ -261,18 +265,52 @@
EndBlock->getParent(), EndBlock);
}
-/// Return a pointer to an element of type `LoadSizeType` at offset
-/// `OffsetBytes`.
-Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
- Type *LoadSizeType,
- uint64_t OffsetBytes) {
+MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
+ bool NeedsBSwap,
+ Type *CmpSizeType,
+ unsigned OffsetBytes) {
+ // Get the memory source at offset `OffsetBytes`.
+ Value *LhsSource = CI->getArgOperand(0);
+ Value *RhsSource = CI->getArgOperand(1);
if (OffsetBytes > 0) {
auto *ByteType = Type::getInt8Ty(CI->getContext());
- Source = Builder.CreateConstGEP1_64(
- ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()),
+ LhsSource = Builder.CreateConstGEP1_64(
+ ByteType, Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()),
+ OffsetBytes);
+ RhsSource = Builder.CreateConstGEP1_64(
+ ByteType, Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()),
OffsetBytes);
}
- return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo());
+ LhsSource = Builder.CreateBitCast(LhsSource, LoadSizeType->getPointerTo());
+ RhsSource = Builder.CreateBitCast(RhsSource, LoadSizeType->getPointerTo());
+
+ // Create a constant or a load from the source.
+ Value *Lhs = nullptr;
+ if (auto *C = dyn_cast<Constant>(LhsSource))
+ Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL);
+ if (!Lhs)
+ Lhs = Builder.CreateLoad(LoadSizeType, LhsSource);
+
+ Value *Rhs = nullptr;
+ if (auto *C = dyn_cast<Constant>(RhsSource))
+ Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL);
+ if (!Rhs)
+ Rhs = Builder.CreateLoad(LoadSizeType, RhsSource);
+
+ // Swap bytes if required.
+ if (NeedsBSwap) {
+ Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
+ Intrinsic::bswap, LoadSizeType);
+ Lhs = Builder.CreateCall(Bswap, Lhs);
+ Rhs = Builder.CreateCall(Bswap, Rhs);
+ }
+
+ // Zero extend if required.
+ if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
+ Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
+ Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
+ }
+ return {Lhs, Rhs};
}
// This function creates the IR instructions for loading and comparing 1 byte.
@@ -282,18 +320,10 @@
void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
unsigned OffsetBytes) {
Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
- Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
- Value *Source1 =
- getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes);
- Value *Source2 =
- getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes);
-
- Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
- Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
-
- LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
- LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
- Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
+ const LoadPair Loads =
+ getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
+ Type::getInt32Ty(CI->getContext()), OffsetBytes);
+ Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);
PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
@@ -340,41 +370,19 @@
: IntegerType::get(CI->getContext(), MaxLoadSize * 8);
for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
-
- IntegerType *LoadSizeType =
- IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
-
- Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
- CurLoadEntry.Offset);
- Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
- CurLoadEntry.Offset);
-
- // Get a constant or load a value for each source address.
- Value *LoadSrc1 = nullptr;
- if (auto *Source1C = dyn_cast<Constant>(Source1))
- LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
- if (!LoadSrc1)
- LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
-
- Value *LoadSrc2 = nullptr;
- if (auto *Source2C = dyn_cast<Constant>(Source2))
- LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
- if (!LoadSrc2)
- LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
+ const LoadPair Loads = getLoadPair(
+ IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
+ /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
if (NumLoads != 1) {
- if (LoadSizeType != MaxLoadType) {
- LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
- LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
- }
// If we have multiple loads per block, we need to generate a composite
// comparison using xor+or.
- Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
+ Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs);
Diff = Builder.CreateZExt(Diff, MaxLoadType);
XorList.push_back(Diff);
} else {
// If there's only one load per block, we just compare the loaded values.
- Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
+ Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs);
}
}
@@ -451,35 +459,18 @@
Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
- Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
- CurLoadEntry.Offset);
- Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
- CurLoadEntry.Offset);
-
- // Load LoadSizeType from the base address.
- Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
- Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
-
- if (DL.isLittleEndian()) {
- Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
- Intrinsic::bswap, LoadSizeType);
- LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
- LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
- }
-
- if (LoadSizeType != MaxLoadType) {
- LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
- LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
- }
+ const LoadPair Loads =
+ getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
+ CurLoadEntry.Offset);
// Add the loaded values to the phi nodes for calculating memcmp result only
// if result is not used in a zero equality.
if (!IsUsedForZeroCmp) {
- ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
- ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
+ ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]);
+ ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]);
}
- Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
+ Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs);
BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
? EndBlock
: LoadCmpBlocks[BlockIndex + 1];
@@ -568,42 +559,27 @@
/// the compare, branch, and phi IR that is required in the general case.
Value *MemCmpExpansion::getMemCmpOneBlock() {
Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
- Value *Source1 = CI->getArgOperand(0);
- Value *Source2 = CI->getArgOperand(1);
+ bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
- // Cast source to LoadSizeType*.
- if (Source1->getType() != LoadSizeType)
- Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
- if (Source2->getType() != LoadSizeType)
- Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
-
- // Load LoadSizeType from the base address.
- Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
- Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
-
- if (DL.isLittleEndian() && Size != 1) {
- Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
- Intrinsic::bswap, LoadSizeType);
- LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
- LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
- }
-
+ // The i8 and i16 cases don't need compares. We zext the loaded values and
+ // subtract them to get the suitable negative, zero, or positive i32 result.
if (Size < 4) {
- // The i8 and i16 cases don't need compares. We zext the loaded values and
- // subtract them to get the suitable negative, zero, or positive i32 result.
- LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
- LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
- return Builder.CreateSub(LoadSrc1, LoadSrc2);
+ const LoadPair Loads =
+ getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
+ /*Offset*/ 0);
+ return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
}
+ const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
+ /*Offset*/ 0);
// The result of memcmp is negative, zero, or positive, so produce that by
// subtracting 2 extended compare bits: sub (ugt, ult).
// If a target prefers to use selects to get -1/0/1, they should be able
// to transform this later. The inverse transform (going from selects to math)
// may not be possible in the DAG because the selects got converted into
// branches before we got there.
- Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
- Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
+ Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs);
+ Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs);
Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
return Builder.CreateSub(ZextUGT, ZextULT);