[Attributor] AAUndefinedBehavior: Check for branches on undef value.

A branch is considered UB if it depends on an undefined / uninitialized value.
At this point this handles simple UB branches in the form: `br i1 undef, ...`
We query `AAValueSimplify` to get a value for the branch condition, so the branch
can be more complicated than just: `br i1 undef, ...`.

Patch By: Stefanos Baziotis (@baziotis)

Reviewers: jdoerfert, sstefan1, uenoku

Reviewed By: uenoku

Differential Revision: https://reviews.llvm.org/D71799
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 65f02a7..d909f38 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -332,30 +332,13 @@
 
   llvm_unreachable("Expected enum or string attribute!");
 }
-static const Value *getPointerOperand(const Instruction *I) {
-  if (auto *LI = dyn_cast<LoadInst>(I))
-    if (!LI->isVolatile())
-      return LI->getPointerOperand();
 
-  if (auto *SI = dyn_cast<StoreInst>(I))
-    if (!SI->isVolatile())
-      return SI->getPointerOperand();
-
-  if (auto *CXI = dyn_cast<AtomicCmpXchgInst>(I))
-    if (!CXI->isVolatile())
-      return CXI->getPointerOperand();
-
-  if (auto *RMWI = dyn_cast<AtomicRMWInst>(I))
-    if (!RMWI->isVolatile())
-      return RMWI->getPointerOperand();
-
-  return nullptr;
-}
 static const Value *
 getBasePointerOfAccessPointerOperand(const Instruction *I, int64_t &BytesOffset,
                                      const DataLayout &DL,
                                      bool AllowNonInbounds = false) {
-  const Value *Ptr = getPointerOperand(I);
+  const Value *Ptr =
+      Attributor::getPointerOperand(I, /* AllowVolatile */ false);
   if (!Ptr)
     return nullptr;
 
@@ -1734,7 +1717,8 @@
 
   int64_t Offset;
   if (const Value *Base = getBasePointerOfAccessPointerOperand(I, Offset, DL)) {
-    if (Base == &AssociatedValue && getPointerOperand(I) == UseV) {
+    if (Base == &AssociatedValue &&
+        Attributor::getPointerOperand(I, /* AllowVolatile */ false) == UseV) {
       int64_t DerefBytes =
           (int64_t)DL.getTypeStoreSize(PtrTy->getPointerElementType()) + Offset;
 
@@ -1747,7 +1731,7 @@
   if (const Value *Base = getBasePointerOfAccessPointerOperand(
           I, Offset, DL, /*AllowNonInbounds*/ true)) {
     if (Offset == 0 && Base == &AssociatedValue &&
-        getPointerOperand(I) == UseV) {
+        Attributor::getPointerOperand(I, /* AllowVolatile */ false) == UseV) {
       int64_t DerefBytes =
           (int64_t)DL.getTypeStoreSize(PtrTy->getPointerElementType());
       IsNonNull |= !NullPointerIsDefined;
@@ -1993,28 +1977,29 @@
   AAUndefinedBehaviorImpl(const IRPosition &IRP) : AAUndefinedBehavior(IRP) {}
 
   /// See AbstractAttribute::updateImpl(...).
-  // TODO: We should not only check instructions that access memory
   // through a pointer (i.e. also branches etc.)
   ChangeStatus updateImpl(Attributor &A) override {
-    const size_t PrevSize = NoUBMemAccessInsts.size();
+    const size_t UBPrevSize = KnownUBInsts.size();
+    const size_t NoUBPrevSize = AssumedNoUBInsts.size();
 
     auto InspectMemAccessInstForUB = [&](Instruction &I) {
       // Skip instructions that are already saved.
-      if (NoUBMemAccessInsts.count(&I) || UBMemAccessInsts.count(&I))
+      if (AssumedNoUBInsts.count(&I) || KnownUBInsts.count(&I))
         return true;
 
-      // `InspectMemAccessInstForUB` is only called on instructions
-      // for which getPointerOperand() should give us their
-      // pointer operand unless they're volatile.
-      const Value *PtrOp = getPointerOperand(&I);
-      if (!PtrOp)
-        return true;
+      // If we reach here, we know we have an instruction
+      // that accesses memory through a pointer operand,
+      // for which getPointerOperand() should give it to us.
+      const Value *PtrOp =
+          Attributor::getPointerOperand(&I, /* AllowVolatile */ true);
+      assert(PtrOp &&
+             "Expected pointer operand of memory accessing instruction");
 
       // A memory access through a pointer is considered UB
       // only if the pointer has constant null value.
       // TODO: Expand it to not only check constant values.
       if (!isa<ConstantPointerNull>(PtrOp)) {
-        NoUBMemAccessInsts.insert(&I);
+        AssumedNoUBInsts.insert(&I);
         return true;
       }
       const Type *PtrTy = PtrOp->getType();
@@ -2025,10 +2010,35 @@
 
       // A memory access using constant null pointer is only considered UB
       // if null pointer is _not_ defined for the target platform.
-      if (!llvm::NullPointerIsDefined(F, PtrTy->getPointerAddressSpace()))
-        UBMemAccessInsts.insert(&I);
+      if (llvm::NullPointerIsDefined(F, PtrTy->getPointerAddressSpace()))
+        AssumedNoUBInsts.insert(&I);
       else
-        NoUBMemAccessInsts.insert(&I);
+        KnownUBInsts.insert(&I);
+      return true;
+    };
+
+    auto InspectBrInstForUB = [&](Instruction &I) {
+      // A conditional branch instruction is considered UB if it has `undef`
+      // condition.
+
+      // Skip instructions that are already saved.
+      if (AssumedNoUBInsts.count(&I) || KnownUBInsts.count(&I))
+        return true;
+
+      // We know we have a branch instruction.
+      auto BrInst = cast<BranchInst>(&I);
+
+      // Unconditional branches are never considered UB.
+      if (BrInst->isUnconditional())
+        return true;
+
+      // Either we stopped and the appropriate action was taken,
+      // or we got back a simplified value to continue.
+      Optional<Value *> SimplifiedCond =
+          stopOnUndefOrAssumed(A, BrInst->getCondition(), BrInst);
+      if (!SimplifiedCond.hasValue())
+        return true;
+      AssumedNoUBInsts.insert(&I);
       return true;
     };
 
@@ -2036,19 +2046,46 @@
                               {Instruction::Load, Instruction::Store,
                                Instruction::AtomicCmpXchg,
                                Instruction::AtomicRMW});
-    if (PrevSize != NoUBMemAccessInsts.size())
+    A.checkForAllInstructions(InspectBrInstForUB, *this, {Instruction::Br});
+    if (NoUBPrevSize != AssumedNoUBInsts.size() ||
+        UBPrevSize != KnownUBInsts.size())
       return ChangeStatus::CHANGED;
     return ChangeStatus::UNCHANGED;
   }
 
+  bool isKnownToCauseUB(Instruction *I) const override {
+    return KnownUBInsts.count(I);
+  }
+
   bool isAssumedToCauseUB(Instruction *I) const override {
-    return UBMemAccessInsts.count(I);
+    // In simple words, if an instruction is not in the assumed to _not_
+    // cause UB, then it is assumed UB (that includes those
+    // in the KnownUBInsts set). The rest is boilerplate
+    // is to ensure that it is one of the instructions we test
+    // for UB.
+
+    switch (I->getOpcode()) {
+    case Instruction::Load:
+    case Instruction::Store:
+    case Instruction::AtomicCmpXchg:
+    case Instruction::AtomicRMW:
+      return !AssumedNoUBInsts.count(I);
+    case Instruction::Br: {
+      auto BrInst = cast<BranchInst>(I);
+      if (BrInst->isUnconditional())
+        return false;
+      return !AssumedNoUBInsts.count(I);
+    } break;
+    default:
+      return false;
+    }
+    return false;
   }
 
   ChangeStatus manifest(Attributor &A) override {
-    if (!UBMemAccessInsts.size())
+    if (KnownUBInsts.empty())
       return ChangeStatus::UNCHANGED;
-    for (Instruction *I : UBMemAccessInsts)
+    for (Instruction *I : KnownUBInsts)
       A.changeToUnreachableAfterManifest(I);
     return ChangeStatus::CHANGED;
   }
@@ -2058,22 +2095,69 @@
     return getAssumed() ? "undefined-behavior" : "no-ub";
   }
 
+  /// Note: The correctness of this analysis depends on the fact that the
+  /// following 2 sets will stop changing after some point.
+  /// "Change" here means that their size changes.
+  /// The size of each set is monotonically increasing
+  /// (we only add items to them) and it is upper bounded by the number of
+  /// instructions in the processed function (we can never save more
+  /// elements in either set than this number). Hence, at some point,
+  /// they will stop increasing.
+  /// Consequently, at some point, both sets will have stopped
+  /// changing, effectively making the analysis reach a fixpoint.
+
+  /// Note: These 2 sets are disjoint and an instruction can be considered
+  /// one of 3 things:
+  /// 1) Known to cause UB (AAUndefinedBehavior could prove it) and put it in
+  ///    the KnownUBInsts set.
+  /// 2) Assumed to cause UB (in every updateImpl, AAUndefinedBehavior
+  ///    has a reason to assume it).
+  /// 3) Assumed to not cause UB. very other instruction - AAUndefinedBehavior
+  ///    could not find a reason to assume or prove that it can cause UB,
+  ///    hence it assumes it doesn't. We have a set for these instructions
+  ///    so that we don't reprocess them in every update.
+  ///    Note however that instructions in this set may cause UB.
+
 protected:
-  // A set of all the (live) memory accessing instructions that _are_ assumed to
-  // cause UB.
-  SmallPtrSet<Instruction *, 8> UBMemAccessInsts;
+  /// A set of all live instructions _known_ to cause UB.
+  SmallPtrSet<Instruction *, 8> KnownUBInsts;
 
 private:
-  // A set of all the (live) memory accessing instructions
-  // that are _not_ assumed to cause UB.
-  //   Note: The correctness of the procedure depends on the fact that this
-  //   set stops changing after some point. "Change" here means that the size
-  //   of the set changes. The size of this set is monotonically increasing
-  //   (we only add items to it) and is upper bounded by the number of memory
-  //   accessing instructions in the processed function (we can never save more
-  //   elements in this set than this number). Hence, the size of this set, at
-  //   some point, will stop increasing, effectively reaching a fixpoint.
-  SmallPtrSet<Instruction *, 8> NoUBMemAccessInsts;
+  /// A set of all the (live) instructions that are assumed to _not_ cause UB.
+  SmallPtrSet<Instruction *, 8> AssumedNoUBInsts;
+
+  // Should be called on updates in which if we're processing an instruction
+  // \p I that depends on a value \p V, one of the following has to happen:
+  // - If the value is assumed, then stop.
+  // - If the value is known but undef, then consider it UB.
+  // - Otherwise, do specific processing with the simplified value.
+  // We return None in the first 2 cases to signify that an appropriate
+  // action was taken and the caller should stop.
+  // Otherwise, we return the simplified value that the caller should
+  // use for specific processing.
+  Optional<Value *> stopOnUndefOrAssumed(Attributor &A, const Value *V,
+                                         Instruction *I) {
+    const auto &ValueSimplifyAA =
+        A.getAAFor<AAValueSimplify>(*this, IRPosition::value(*V));
+    Optional<Value *> SimplifiedV =
+        ValueSimplifyAA.getAssumedSimplifiedValue(A);
+    if (!ValueSimplifyAA.isKnown()) {
+      // Don't depend on assumed values.
+      return llvm::None;
+    }
+    if (!SimplifiedV.hasValue()) {
+      // If it is known (which we tested above) but it doesn't have a value,
+      // then we can assume `undef` and hence the instruction is UB.
+      KnownUBInsts.insert(I);
+      return llvm::None;
+    }
+    Value *Val = SimplifiedV.getValue();
+    if (isa<UndefValue>(Val)) {
+      KnownUBInsts.insert(I);
+      return llvm::None;
+    }
+    return Val;
+  }
 };
 
 struct AAUndefinedBehaviorFunction final : AAUndefinedBehaviorImpl {
@@ -2085,7 +2169,7 @@
     STATS_DECL(UndefinedBehaviorInstruction, Instruction,
                "Number of instructions known to have UB");
     BUILD_STAT_NAME(UndefinedBehaviorInstruction, Instruction) +=
-        UBMemAccessInsts.size();
+        KnownUBInsts.size();
   }
 };
 
@@ -3101,7 +3185,8 @@
     int64_t Offset;
     if (const Value *Base = getBasePointerOfAccessPointerOperand(
             I, Offset, DL, /*AllowNonInbounds*/ true)) {
-      if (Base == &getAssociatedValue() && getPointerOperand(I) == UseV) {
+      if (Base == &getAssociatedValue() &&
+          Attributor::getPointerOperand(I, /* AllowVolatile */ false) == UseV) {
         uint64_t Size = DL.getTypeStoreSize(PtrTy->getPointerElementType());
         addAccessedBytes(Offset, Size);
       }
@@ -5592,6 +5677,7 @@
     case Instruction::CatchSwitch:
     case Instruction::AtomicRMW:
     case Instruction::AtomicCmpXchg:
+    case Instruction::Br:
     case Instruction::Resume:
     case Instruction::Ret:
       IsInterestingOpcode = true;