[Attributor] Add "free"-based heap2stack deduction

Summary:
If there is a unique free of the allocated that has to be reached from
the malloc, we can apply the heap-2-stack transformation even if the
pointer escapes.

Reviewers: hfinkel, sstefan1, uenoku

Subscribers: hiraditya, bollu, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D68958
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 567ec78..11a4939 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -728,12 +728,10 @@
 
     SetVector<const Use *> NextUses;
 
+    auto EIt = Explorer.begin(CtxI), EEnd = Explorer.end(CtxI);
     for (const Use *U : Uses) {
       if (const Instruction *UserI = dyn_cast<Instruction>(U->getUser())) {
-        auto EIt = Explorer.begin(CtxI), EEnd = Explorer.end(CtxI);
-        bool Found = EIt.count(UserI);
-        while (!Found && ++EIt != EEnd)
-          Found = EIt.getCurrentInst() == UserI;
+        bool Found = Explorer.findInContextOf(UserI, EIt, EEnd);
         if (Found && Base::followUse(A, U, UserI))
           for (const Use &Us : UserI->uses())
             NextUses.insert(&Us);
@@ -3634,7 +3632,21 @@
   const Function *F = getAssociatedFunction();
   const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F);
 
+  MustBeExecutedContextExplorer &Explorer =
+      A.getInfoCache().getMustBeExecutedContextExplorer();
+
+  auto FreeCheck = [&](Instruction &I) {
+    const auto &Frees = FreesForMalloc.lookup(&I);
+    if (Frees.size() != 1)
+      return false;
+    Instruction *UniqueFree = *Frees.begin();
+    return Explorer.findInContextOf(UniqueFree, I.getNextNode());
+  };
+
   auto UsesCheck = [&](Instruction &I) {
+    bool ValidUsesOnly = true;
+    bool MustUse = true;
+
     SmallPtrSet<const Use *, 8> Visited;
     SmallVector<const Use *, 8> Worklist;
 
@@ -3652,10 +3664,12 @@
         continue;
       if (auto *SI = dyn_cast<StoreInst>(UserI)) {
         if (SI->getValueOperand() == U->get()) {
-          LLVM_DEBUG(dbgs() << "[H2S] escaping store to memory: " << *UserI << "\n");
-          return false;
+          LLVM_DEBUG(dbgs()
+                     << "[H2S] escaping store to memory: " << *UserI << "\n");
+          ValidUsesOnly = false;
+        } else {
+          // A store into the malloc'ed memory is fine.
         }
-        // A store into the malloc'ed memory is fine.
         continue;
       }
 
@@ -3673,8 +3687,14 @@
 
         // Record malloc.
         if (isFreeCall(UserI, TLI)) {
-          FreesForMalloc[&I].insert(
-              cast<Instruction>(const_cast<User *>(UserI)));
+          if (MustUse) {
+            FreesForMalloc[&I].insert(
+                cast<Instruction>(const_cast<User *>(UserI)));
+          } else {
+            LLVM_DEBUG(dbgs() << "[H2S] free potentially on different mallocs: "
+                              << *UserI << "\n");
+            ValidUsesOnly = false;
+          }
           continue;
         }
 
@@ -3688,22 +3708,25 @@
 
         if (!NoCaptureAA.isAssumedNoCapture() || !NoFreeAA.isAssumedNoFree()) {
           LLVM_DEBUG(dbgs() << "[H2S] Bad user: " << *UserI << "\n");
-          return false;
+          ValidUsesOnly = false;
         }
         continue;
       }
 
-      if (isa<GetElementPtrInst>(UserI) || isa<BitCastInst>(UserI)) {
+      if (isa<GetElementPtrInst>(UserI) || isa<BitCastInst>(UserI) ||
+          isa<PHINode>(UserI) || isa<SelectInst>(UserI)) {
+        MustUse &= !(isa<PHINode>(UserI) || isa<SelectInst>(UserI));
         for (Use &U : UserI->uses())
           Worklist.push_back(&U);
         continue;
       }
 
-      // Unknown user.
+      // Unknown user for which we can not track uses further (in a way that
+      // makes sense).
       LLVM_DEBUG(dbgs() << "[H2S] Unknown user: " << *UserI << "\n");
-      return false;
+      ValidUsesOnly = false;
     }
-    return true;
+    return ValidUsesOnly;
   };
 
   auto MallocCallocCheck = [&](Instruction &I) {
@@ -3720,7 +3743,7 @@
     if (IsMalloc) {
       if (auto *Size = dyn_cast<ConstantInt>(I.getOperand(0)))
         if (Size->getValue().sle(MaxHeapToStackSize))
-          if (UsesCheck(I)) {
+          if (UsesCheck(I) || FreeCheck(I)) {
             MallocCalls.insert(&I);
             return true;
           }
@@ -3730,7 +3753,7 @@
         if (auto *Size = dyn_cast<ConstantInt>(I.getOperand(1)))
           if ((Size->getValue().umul_ov(Num->getValue(), Overflow))
                    .sle(MaxHeapToStackSize))
-            if (!Overflow && UsesCheck(I)) {
+            if (!Overflow && (UsesCheck(I) || FreeCheck(I))) {
               MallocCalls.insert(&I);
               return true;
             }
@@ -3756,8 +3779,10 @@
   /// See AbstractAttribute::trackStatistics()
   void trackStatistics() const override {
     STATS_DECL(MallocCalls, Function,
-               "Number of MallocCalls converted to allocas");
-    BUILD_STAT_NAME(MallocCalls, Function) += MallocCalls.size();
+               "Number of malloc calls converted to allocas");
+    for (auto *C : MallocCalls)
+      if (!BadMallocCalls.count(C))
+        ++BUILD_STAT_NAME(MallocCalls, Function);
   }
 };