Re-land r349731 "[CodeGen][ExpandMemcmp] Add an option for allowing overlapping loads.

Update PPC ir following GEP->bitcat to bitcat->GEP->bitcat change.

llvm-svn: 349747
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp
index d7562cb..ee7683a 100644
--- a/llvm/lib/CodeGen/ExpandMemCmp.cpp
+++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp
@@ -66,23 +66,18 @@
   // Represents the decomposition in blocks of the expansion. For example,
   // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
   // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
-  // TODO(courbet): Involve the target more in this computation. On X86, 7
-  // bytes can be done more efficiently with two overlaping 4-byte loads than
-  // covering the interval with [{4, 0},{2, 4},{1, 6}}.
   struct LoadEntry {
     LoadEntry(unsigned LoadSize, uint64_t Offset)
         : LoadSize(LoadSize), Offset(Offset) {
-      assert(Offset % LoadSize == 0 && "invalid load entry");
     }
 
-    uint64_t getGEPIndex() const { return Offset / LoadSize; }
-
     // The size of the load for this block, in bytes.
-    const unsigned LoadSize;
-    // The offset of this load WRT the base pointer, in bytes.
-    const uint64_t Offset;
+    unsigned LoadSize;
+    // The offset of this load from the base pointer, in bytes.
+    uint64_t Offset;
   };
-  SmallVector<LoadEntry, 8> LoadSequence;
+  using LoadEntryVector = SmallVector<LoadEntry, 8>;
+  LoadEntryVector LoadSequence;
 
   void createLoadCmpBlocks();
   void createResultBlock();
@@ -92,13 +87,23 @@
   void emitLoadCompareBlock(unsigned BlockIndex);
   void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
                                          unsigned &LoadIndex);
-  void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
+  void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes);
   void emitMemCmpResultBlock();
   Value *getMemCmpExpansionZeroCase();
   Value *getMemCmpEqZeroOneBlock();
   Value *getMemCmpOneBlock();
+  Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType,
+                                 uint64_t OffsetBytes);
 
- public:
+  static LoadEntryVector
+  computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
+                            unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte);
+  static LoadEntryVector
+  computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize,
+                                 unsigned MaxNumLoads,
+                                 unsigned &NumLoadsNonOneByte);
+
+public:
   MemCmpExpansion(CallInst *CI, uint64_t Size,
                   const TargetTransformInfo::MemCmpExpansionOptions &Options,
                   unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
@@ -110,6 +115,76 @@
   Value *getMemCmpExpansion();
 };
 
+MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence(
+    uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
+    const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) {
+  NumLoadsNonOneByte = 0;
+  LoadEntryVector LoadSequence;
+  uint64_t Offset = 0;
+  while (Size && !LoadSizes.empty()) {
+    const unsigned LoadSize = LoadSizes.front();
+    const uint64_t NumLoadsForThisSize = Size / LoadSize;
+    if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
+      // Do not expand if the total number of loads is larger than what the
+      // target allows. Note that it's important that we exit before completing
+      // the expansion to avoid using a ton of memory to store the expansion for
+      // large sizes.
+      return {};
+    }
+    if (NumLoadsForThisSize > 0) {
+      for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
+        LoadSequence.push_back({LoadSize, Offset});
+        Offset += LoadSize;
+      }
+      if (LoadSize > 1)
+        ++NumLoadsNonOneByte;
+      Size = Size % LoadSize;
+    }
+    LoadSizes = LoadSizes.drop_front();
+  }
+  return LoadSequence;
+}
+
+MemCmpExpansion::LoadEntryVector
+MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
+                                                const unsigned MaxLoadSize,
+                                                const unsigned MaxNumLoads,
+                                                unsigned &NumLoadsNonOneByte) {
+  // These are already handled by the greedy approach.
+  if (Size < 2 || MaxLoadSize < 2)
+    return {};
+
+  // We try to do as many non-overlapping loads as possible starting from the
+  // beginning.
+  const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize;
+  assert(NumNonOverlappingLoads && "there must be at least one load");
+  // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with
+  // an overlapping load.
+  Size = Size - NumNonOverlappingLoads * MaxLoadSize;
+  // Bail if we do not need an overloapping store, this is already handled by
+  // the greedy approach.
+  if (Size == 0)
+    return {};
+  // Bail if the number of loads (non-overlapping + potential overlapping one)
+  // is larger than the max allowed.
+  if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
+    return {};
+
+  // Add non-overlapping loads.
+  LoadEntryVector LoadSequence;
+  uint64_t Offset = 0;
+  for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) {
+    LoadSequence.push_back({MaxLoadSize, Offset});
+    Offset += MaxLoadSize;
+  }
+
+  // Add the last overlapping load.
+  assert(Size > 0 && Size < MaxLoadSize && "broken invariant");
+  LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)});
+  NumLoadsNonOneByte = 1;
+  return LoadSequence;
+}
+
 // Initialize the basic block structure required for expansion of memcmp call
 // with given maximum load size and memcmp size parameter.
 // This structure includes:
@@ -133,38 +208,31 @@
       Builder(CI) {
   assert(Size > 0 && "zero blocks");
   // Scale the max size down if the target can load more bytes than we need.
-  size_t LoadSizeIndex = 0;
-  while (LoadSizeIndex < Options.LoadSizes.size() &&
-         Options.LoadSizes[LoadSizeIndex] > Size) {
-    ++LoadSizeIndex;
+  llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes);
+  while (!LoadSizes.empty() && LoadSizes.front() > Size) {
+    LoadSizes = LoadSizes.drop_front();
   }
-  this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex];
+  assert(!LoadSizes.empty() && "cannot load Size bytes");
+  MaxLoadSize = LoadSizes.front();
   // Compute the decomposition.
-  uint64_t CurSize = Size;
-  uint64_t Offset = 0;
-  while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) {
-    const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex];
-    assert(LoadSize > 0 && "zero load size");
-    const uint64_t NumLoadsForThisSize = CurSize / LoadSize;
-    if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
-      // Do not expand if the total number of loads is larger than what the
-      // target allows. Note that it's important that we exit before completing
-      // the expansion to avoid using a ton of memory to store the expansion for
-      // large sizes.
-      LoadSequence.clear();
-      return;
+  unsigned GreedyNumLoadsNonOneByte = 0;
+  LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, MaxNumLoads,
+                                           GreedyNumLoadsNonOneByte);
+  NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
+  assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
+  // If we allow overlapping loads and the load sequence is not already optimal,
+  // use overlapping loads.
+  if (Options.AllowOverlappingLoads &&
+      (LoadSequence.empty() || LoadSequence.size() > 2)) {
+    unsigned OverlappingNumLoadsNonOneByte = 0;
+    auto OverlappingLoads = computeOverlappingLoadSequence(
+        Size, MaxLoadSize, MaxNumLoads, OverlappingNumLoadsNonOneByte);
+    if (!OverlappingLoads.empty() &&
+        (LoadSequence.empty() ||
+         OverlappingLoads.size() < LoadSequence.size())) {
+      LoadSequence = OverlappingLoads;
+      NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
     }
-    if (NumLoadsForThisSize > 0) {
-      for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
-        LoadSequence.push_back({LoadSize, Offset});
-        Offset += LoadSize;
-      }
-      if (LoadSize > 1) {
-        ++NumLoadsNonOneByte;
-      }
-      CurSize = CurSize % LoadSize;
-    }
-    ++LoadSizeIndex;
   }
   assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
 }
@@ -189,30 +257,32 @@
                                    EndBlock->getParent(), EndBlock);
 }
 
+/// Return a pointer to an element of type `LoadSizeType` at offset
+/// `OffsetBytes`.
+Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
+                                                Type *LoadSizeType,
+                                                uint64_t OffsetBytes) {
+  if (OffsetBytes > 0) {
+    auto *ByteType = Type::getInt8Ty(CI->getContext());
+    Source = Builder.CreateGEP(
+        ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()),
+        ConstantInt::get(ByteType, OffsetBytes));
+  }
+  return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo());
+}
+
 // This function creates the IR instructions for loading and comparing 1 byte.
 // It loads 1 byte from each source of the memcmp parameters with the given
 // GEPIndex. It then subtracts the two loaded values and adds this result to the
 // final phi node for selecting the memcmp result.
 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
-                                               unsigned GEPIndex) {
-  Value *Source1 = CI->getArgOperand(0);
-  Value *Source2 = CI->getArgOperand(1);
-
+                                               unsigned OffsetBytes) {
   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
   Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
-  // Cast source to LoadSizeType*.
-  if (Source1->getType() != LoadSizeType)
-    Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
-  if (Source2->getType() != LoadSizeType)
-    Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
-
-  // Get the base address using the GEPIndex.
-  if (GEPIndex != 0) {
-    Source1 = Builder.CreateGEP(LoadSizeType, Source1,
-                                ConstantInt::get(LoadSizeType, GEPIndex));
-    Source2 = Builder.CreateGEP(LoadSizeType, Source2,
-                                ConstantInt::get(LoadSizeType, GEPIndex));
-  }
+  Value *Source1 =
+      getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes);
+  Value *Source2 =
+      getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes);
 
   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
@@ -270,24 +340,10 @@
     IntegerType *LoadSizeType =
         IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
 
-    Value *Source1 = CI->getArgOperand(0);
-    Value *Source2 = CI->getArgOperand(1);
-
-    // Cast source to LoadSizeType*.
-    if (Source1->getType() != LoadSizeType)
-      Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
-    if (Source2->getType() != LoadSizeType)
-      Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
-
-    // Get the base address using a GEP.
-    if (CurLoadEntry.Offset != 0) {
-      Source1 = Builder.CreateGEP(
-          LoadSizeType, Source1,
-          ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
-      Source2 = Builder.CreateGEP(
-          LoadSizeType, Source2,
-          ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
-    }
+    Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
+                                             CurLoadEntry.Offset);
+    Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
+                                             CurLoadEntry.Offset);
 
     // Get a constant or load a value for each source address.
     Value *LoadSrc1 = nullptr;
@@ -378,8 +434,7 @@
   const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
 
   if (CurLoadEntry.LoadSize == 1) {
-    MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
-                                              CurLoadEntry.getGEPIndex());
+    MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
     return;
   }
 
@@ -388,25 +443,12 @@
   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
   assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
 
-  Value *Source1 = CI->getArgOperand(0);
-  Value *Source2 = CI->getArgOperand(1);
-
   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
-  // Cast source to LoadSizeType*.
-  if (Source1->getType() != LoadSizeType)
-    Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
-  if (Source2->getType() != LoadSizeType)
-    Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
 
-  // Get the base address using a GEP.
-  if (CurLoadEntry.Offset != 0) {
-    Source1 = Builder.CreateGEP(
-        LoadSizeType, Source1,
-        ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
-    Source2 = Builder.CreateGEP(
-        LoadSizeType, Source2,
-        ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
-  }
+  Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
+                                           CurLoadEntry.Offset);
+  Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
+                                           CurLoadEntry.Offset);
 
   // Load LoadSizeType from the base address.
   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
@@ -694,7 +736,6 @@
   if (SizeVal == 0) {
     return false;
   }
-
   // TTI call to check if target would like to expand memcmp. Also, get the
   // available load sizes.
   const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);