[ExpandMemCmp] Properly constant-fold all compares.

Summary:
This gets rid of duplicated code and diverging behaviour w.r.t.
constants.
Fixes PR45086.

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75519
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp
index d0dd538..589a6d3 100644
--- a/llvm/lib/CodeGen/ExpandMemCmp.cpp
+++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp
@@ -103,8 +103,12 @@
   Value *getMemCmpExpansionZeroCase();
   Value *getMemCmpEqZeroOneBlock();
   Value *getMemCmpOneBlock();
-  Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType,
-                                 uint64_t OffsetBytes);
+  struct LoadPair {
+    Value *Lhs = nullptr;
+    Value *Rhs = nullptr;
+  };
+  LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
+                       unsigned OffsetBytes);
 
   static LoadEntryVector
   computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
@@ -261,18 +265,52 @@
                                    EndBlock->getParent(), EndBlock);
 }
 
-/// Return a pointer to an element of type `LoadSizeType` at offset
-/// `OffsetBytes`.
-Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
-                                                Type *LoadSizeType,
-                                                uint64_t OffsetBytes) {
+MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
+                                                       bool NeedsBSwap,
+                                                       Type *CmpSizeType,
+                                                       unsigned OffsetBytes) {
+  // Get the memory source at offset `OffsetBytes`.
+  Value *LhsSource = CI->getArgOperand(0);
+  Value *RhsSource = CI->getArgOperand(1);
   if (OffsetBytes > 0) {
     auto *ByteType = Type::getInt8Ty(CI->getContext());
-    Source = Builder.CreateConstGEP1_64(
-        ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()),
+    LhsSource = Builder.CreateConstGEP1_64(
+        ByteType, Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()),
+        OffsetBytes);
+    RhsSource = Builder.CreateConstGEP1_64(
+        ByteType, Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()),
         OffsetBytes);
   }
-  return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo());
+  LhsSource = Builder.CreateBitCast(LhsSource, LoadSizeType->getPointerTo());
+  RhsSource = Builder.CreateBitCast(RhsSource, LoadSizeType->getPointerTo());
+
+  // Create a constant or a load from the source.
+  Value *Lhs = nullptr;
+  if (auto *C = dyn_cast<Constant>(LhsSource))
+    Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL);
+  if (!Lhs)
+    Lhs = Builder.CreateLoad(LoadSizeType, LhsSource);
+
+  Value *Rhs = nullptr;
+  if (auto *C = dyn_cast<Constant>(RhsSource))
+    Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL);
+  if (!Rhs)
+    Rhs = Builder.CreateLoad(LoadSizeType, RhsSource);
+
+  // Swap bytes if required.
+  if (NeedsBSwap) {
+    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
+                                                Intrinsic::bswap, LoadSizeType);
+    Lhs = Builder.CreateCall(Bswap, Lhs);
+    Rhs = Builder.CreateCall(Bswap, Rhs);
+  }
+
+  // Zero extend if required.
+  if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
+    Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
+    Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
+  }
+  return {Lhs, Rhs};
 }
 
 // This function creates the IR instructions for loading and comparing 1 byte.
@@ -282,18 +320,10 @@
 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
                                                unsigned OffsetBytes) {
   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
-  Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
-  Value *Source1 =
-      getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes);
-  Value *Source2 =
-      getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes);
-
-  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
-  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
-
-  LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
-  LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
-  Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
+  const LoadPair Loads =
+      getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
+                  Type::getInt32Ty(CI->getContext()), OffsetBytes);
+  Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);
 
   PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
 
@@ -340,41 +370,19 @@
                     : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
   for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
     const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
-
-    IntegerType *LoadSizeType =
-        IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
-
-    Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
-                                             CurLoadEntry.Offset);
-    Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
-                                             CurLoadEntry.Offset);
-
-    // Get a constant or load a value for each source address.
-    Value *LoadSrc1 = nullptr;
-    if (auto *Source1C = dyn_cast<Constant>(Source1))
-      LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
-    if (!LoadSrc1)
-      LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
-
-    Value *LoadSrc2 = nullptr;
-    if (auto *Source2C = dyn_cast<Constant>(Source2))
-      LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
-    if (!LoadSrc2)
-      LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
+    const LoadPair Loads = getLoadPair(
+        IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
+        /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
 
     if (NumLoads != 1) {
-      if (LoadSizeType != MaxLoadType) {
-        LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
-        LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
-      }
       // If we have multiple loads per block, we need to generate a composite
       // comparison using xor+or.
-      Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
+      Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs);
       Diff = Builder.CreateZExt(Diff, MaxLoadType);
       XorList.push_back(Diff);
     } else {
       // If there's only one load per block, we just compare the loaded values.
-      Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
+      Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs);
     }
   }
 
@@ -451,35 +459,18 @@
 
   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
 
-  Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
-                                           CurLoadEntry.Offset);
-  Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
-                                           CurLoadEntry.Offset);
-
-  // Load LoadSizeType from the base address.
-  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
-  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
-
-  if (DL.isLittleEndian()) {
-    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
-                                                Intrinsic::bswap, LoadSizeType);
-    LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
-    LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
-  }
-
-  if (LoadSizeType != MaxLoadType) {
-    LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
-    LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
-  }
+  const LoadPair Loads =
+      getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
+                  CurLoadEntry.Offset);
 
   // Add the loaded values to the phi nodes for calculating memcmp result only
   // if result is not used in a zero equality.
   if (!IsUsedForZeroCmp) {
-    ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
-    ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
+    ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]);
+    ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]);
   }
 
-  Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
+  Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs);
   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
                            ? EndBlock
                            : LoadCmpBlocks[BlockIndex + 1];
@@ -568,42 +559,27 @@
 /// the compare, branch, and phi IR that is required in the general case.
 Value *MemCmpExpansion::getMemCmpOneBlock() {
   Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
-  Value *Source1 = CI->getArgOperand(0);
-  Value *Source2 = CI->getArgOperand(1);
+  bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
 
-  // Cast source to LoadSizeType*.
-  if (Source1->getType() != LoadSizeType)
-    Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
-  if (Source2->getType() != LoadSizeType)
-    Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
-
-  // Load LoadSizeType from the base address.
-  Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
-  Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
-
-  if (DL.isLittleEndian() && Size != 1) {
-    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
-                                                Intrinsic::bswap, LoadSizeType);
-    LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
-    LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
-  }
-
+  // The i8 and i16 cases don't need compares. We zext the loaded values and
+  // subtract them to get the suitable negative, zero, or positive i32 result.
   if (Size < 4) {
-    // The i8 and i16 cases don't need compares. We zext the loaded values and
-    // subtract them to get the suitable negative, zero, or positive i32 result.
-    LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
-    LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
-    return Builder.CreateSub(LoadSrc1, LoadSrc2);
+    const LoadPair Loads =
+        getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
+                    /*Offset*/ 0);
+    return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
   }
 
+  const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
+                                     /*Offset*/ 0);
   // The result of memcmp is negative, zero, or positive, so produce that by
   // subtracting 2 extended compare bits: sub (ugt, ult).
   // If a target prefers to use selects to get -1/0/1, they should be able
   // to transform this later. The inverse transform (going from selects to math)
   // may not be possible in the DAG because the selects got converted into
   // branches before we got there.
-  Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
-  Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
+  Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs);
+  Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs);
   Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
   Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
   return Builder.CreateSub(ZextUGT, ZextULT);