InstCombine: Fold comparisons between unguessable allocas and other pointers

This will allow us to optimize code such as:

  int f(int *p) {
    int x;
    return p == &x;
  }

as well as:

  int *allocate(void);
  int f() {
    int x;
    int *p = allocate();
    return p == &x;
  }

The folding can only be done under certain circumstances. Even though p and &x
cannot alias, the comparison must still return true if the pointer
representations are equal. If a user successfully generates a p that's a
correct guess for &x, comparison should return true even though p is an invalid
pointer.

This patch argues that if the address of the alloca isn't observable outside the
function, the function can act as-if the address is impossible to guess from the
outside. The tricky part is keeping the act consistent: if we fold p == &x to
false in one place, we must make sure to fold any other comparisons based on
those pointers similarly. To ensure that, we only fold when &x is involved
exactly once in comparison instructions.

Differential Revision: http://reviews.llvm.org/D13358

llvm-svn: 249490
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 3afb5ed..a333c3e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -730,6 +730,83 @@
   return nullptr;
 }
 
+Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca,
+                                         Value *Other) {
+  assert(ICI.isEquality() && "Cannot fold non-equality comparison.");
+
+  // It would be tempting to fold away comparisons between allocas and any
+  // pointer not based on that alloca (e.g. an argument). However, even
+  // though such pointers cannot alias, they can still compare equal.
+  //
+  // But LLVM doesn't specify where allocas get their memory, so if the alloca
+  // doesn't escape we can argue that it's impossible to guess its value, and we
+  // can therefore act as if any such guesses are wrong.
+  //
+  // The code below checks that the alloca doesn't escape, and that it's only
+  // used in a comparison once (the current instruction). The
+  // single-comparison-use condition ensures that we're trivially folding all
+  // comparisons against the alloca consistently, and avoids the risk of
+  // erroneously folding a comparison of the pointer with itself.
+
+  unsigned MaxIter = 32; // Break cycles and bound to constant-time.
+
+  SmallVector<Use *, 32> Worklist;
+  for (Use &U : Alloca->uses()) {
+    if (Worklist.size() >= MaxIter)
+      return nullptr;
+    Worklist.push_back(&U);
+  }
+
+  unsigned NumCmps = 0;
+  while (!Worklist.empty()) {
+    assert(Worklist.size() <= MaxIter);
+    Use *U = Worklist.pop_back_val();
+    Value *V = U->getUser();
+    --MaxIter;
+
+    if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) ||
+        isa<SelectInst>(V)) {
+      // Track the uses.
+    } else if (isa<LoadInst>(V)) {
+      // Loading from the pointer doesn't escape it.
+      continue;
+    } else if (auto *SI = dyn_cast<StoreInst>(V)) {
+      // Storing *to* the pointer is fine, but storing the pointer escapes it.
+      if (SI->getValueOperand() == U->get())
+        return nullptr;
+      continue;
+    } else if (isa<ICmpInst>(V)) {
+      if (NumCmps++)
+        return nullptr; // Found more than one cmp.
+      continue;
+    } else if (auto *Intrin = dyn_cast<IntrinsicInst>(V)) {
+      switch (Intrin->getIntrinsicID()) {
+        // These intrinsics don't escape or compare the pointer. Memset is safe
+        // because we don't allow ptrtoint. Memcpy and memmove are safe because
+        // we don't allow stores, so src cannot point to V.
+        case Intrinsic::lifetime_start: case Intrinsic::lifetime_end:
+        case Intrinsic::dbg_declare: case Intrinsic::dbg_value:
+        case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset:
+          continue;
+        default:
+          return nullptr;
+      }
+    } else {
+      return nullptr;
+    }
+    for (Use &U : V->uses()) {
+      if (Worklist.size() >= MaxIter)
+        return nullptr;
+      Worklist.push_back(&U);
+    }
+  }
+
+  Type *CmpTy = CmpInst::makeCmpResultType(Other->getType());
+  return ReplaceInstUsesWith(
+      ICI,
+      ConstantInt::get(CmpTy, !CmpInst::isTrueWhenEqual(ICI.getPredicate())));
+}
+
 /// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X".
 Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI,
                                             Value *X, ConstantInt *CI,
@@ -3211,6 +3288,17 @@
                            ICmpInst::getSwappedPredicate(I.getPredicate()), I))
       return NI;
 
+  // Try to optimize equality comparisons against alloca-based pointers.
+  if (Op0->getType()->isPointerTy() && I.isEquality()) {
+    assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");
+    if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op0, DL)))
+      if (Instruction *New = FoldAllocaCmp(I, Alloca, Op1))
+        return New;
+    if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op1, DL)))
+      if (Instruction *New = FoldAllocaCmp(I, Alloca, Op0))
+        return New;
+  }
+
   // Test to see if the operands of the icmp are casted versions of other
   // values.  If the ptr->ptr cast can be stripped off both arguments, we do so
   // now.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 9e58c74..79cb5f2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -281,6 +281,7 @@
                                 ICmpInst::Predicate Pred);
   Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
                            ICmpInst::Predicate Cond, Instruction &I);
+  Instruction *FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, Value *Other);
   Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1,
                                    BinaryOperator &I);
   Instruction *commonCastTransforms(CastInst &CI);