[Attributor] Use liveness during the creation of AAReturnedValues

Summary:
As one of the first attributes, and one of the complex ones,
AAReturnedValues was not using liveness but we filtered the result after
the fact. This change adds liveness usage during the creation. The
algorithm is also improved and shorter.

The new algorithm will collect returned values over time using the
generic facilities that work with liveness already, e.g.,
genericValueTraversal which does not look at dead PHI node predecessors.
A test to show how this leads to better results is included.

Note: Unresolved calls and resolved calls are now tracked explicitly.

Reviewers: uenoku, sstefan1

Subscribers: hiraditya, bollu, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D66120

llvm-svn: 368922
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 2f062bf..32bc88e 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -121,34 +121,28 @@
 }
 ///}
 
-template <typename StateTy>
-using followValueCB_t = std::function<bool(Value *, StateTy &State)>;
-template <typename StateTy>
-using visitValueCB_t = std::function<void(Value *, StateTy &State)>;
-
-/// Recursively visit all values that might become \p InitV at some point. This
+/// Recursively visit all values that might become \p IRP at some point. This
 /// will be done by looking through cast instructions, selects, phis, and calls
-/// with the "returned" attribute. The callback \p FollowValueCB is asked before
-/// a potential origin value is looked at. If no \p FollowValueCB is passed, a
-/// default one is used that will make sure we visit every value only once. Once
-/// we cannot look through the value any further, the callback \p VisitValueCB
-/// is invoked and passed the current value and the \p State. To limit how much
-/// effort is invested, we will never visit more than \p MaxValues values.
-template <typename StateTy>
-static bool genericValueTraversal(
-    Value *InitV, StateTy &State, visitValueCB_t<StateTy> &VisitValueCB,
-    followValueCB_t<StateTy> *FollowValueCB = nullptr, int MaxValues = 8) {
+/// with the "returned" attribute. Once we cannot look through the value any
+/// further, the callback \p VisitValueCB is invoked and passed the current
+/// value, the \p State, and a flag to indicate if we stripped anything. To
+/// limit how much effort is invested, we will never visit more values than
+/// specified by \p MaxValues.
+template <typename AAType, typename StateTy>
+bool genericValueTraversal(
+    Attributor &A, IRPosition IRP, const AAType &QueryingAA, StateTy &State,
+    const function_ref<void(Value &, StateTy &, bool)> &VisitValueCB,
+    int MaxValues = 8) {
 
+  const AAIsDead *LivenessAA = nullptr;
+  if (IRP.getAnchorScope())
+    LivenessAA = A.getAAFor<AAIsDead>(
+        QueryingAA, IRPosition::function(*IRP.getAnchorScope()));
+
+  // TODO: Use Positions here to allow context sensitivity in VisitValueCB
   SmallPtrSet<Value *, 16> Visited;
-  followValueCB_t<bool> DefaultFollowValueCB = [&](Value *Val, bool &) {
-    return Visited.insert(Val).second;
-  };
-
-  if (!FollowValueCB)
-    FollowValueCB = &DefaultFollowValueCB;
-
   SmallVector<Value *, 16> Worklist;
-  Worklist.push_back(InitV);
+  Worklist.push_back(&IRP.getAssociatedValue());
 
   int Iteration = 0;
   do {
@@ -156,7 +150,7 @@
 
     // Check if we should process the current value. To prevent endless
     // recursion keep a record of the values we followed!
-    if (!(*FollowValueCB)(V, State))
+    if (!Visited.insert(V).second)
       continue;
 
     // Make sure we limit the compile time for complex expressions.
@@ -165,23 +159,23 @@
 
     // Explicitly look through calls with a "returned" attribute if we do
     // not have a pointer as stripPointerCasts only works on them.
+    Value *NewV = nullptr;
     if (V->getType()->isPointerTy()) {
-      V = V->stripPointerCasts();
+      NewV = V->stripPointerCasts();
     } else {
       CallSite CS(V);
       if (CS && CS.getCalledFunction()) {
-        Value *NewV = nullptr;
         for (Argument &Arg : CS.getCalledFunction()->args())
           if (Arg.hasReturnedAttr()) {
             NewV = CS.getArgOperand(Arg.getArgNo());
             break;
           }
-        if (NewV) {
-          Worklist.push_back(NewV);
-          continue;
-        }
       }
     }
+    if (NewV && NewV != V) {
+      Worklist.push_back(NewV);
+      continue;
+    }
 
     // Look through select instructions, visit both potential values.
     if (auto *SI = dyn_cast<SelectInst>(V)) {
@@ -190,14 +184,19 @@
       continue;
     }
 
-    // Look through phi nodes, visit all operands.
+    // Look through phi nodes, visit all live operands.
     if (auto *PHI = dyn_cast<PHINode>(V)) {
-      Worklist.append(PHI->op_begin(), PHI->op_end());
+      for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) {
+        const BasicBlock *IncomingBB = PHI->getIncomingBlock(u);
+        if (!LivenessAA ||
+            !LivenessAA->isAssumedDead(IncomingBB->getTerminator()))
+          Worklist.push_back(PHI->getIncomingValue(u));
+      }
       continue;
     }
 
     // Once a leaf is reached we inform the user through the callback.
-    VisitValueCB(V, State);
+    VisitValueCB(*V, State, Iteration > 1);
   } while (!Worklist.empty());
 
   // All values have been visited.
@@ -494,45 +493,21 @@
 ///
 /// If there is a unique returned value R, the manifest method will:
 ///   - mark R with the "returned" attribute, if R is an argument.
-///
-/// TODO: We should use liveness during construction of the returned values map
-///       and before we set HasOverdefinedReturnedCalls.
 class AAReturnedValuesImpl : public AAReturnedValues, public AbstractState {
 
   /// Mapping of values potentially returned by the associated function to the
   /// return instructions that might return them.
   DenseMap<Value *, SmallPtrSet<ReturnInst *, 2>> ReturnedValues;
 
+  SmallPtrSet<CallBase *, 8> UnresolvedCalls;
+
   /// State flags
   ///
   ///{
   bool IsFixed;
   bool IsValidState;
-  bool HasOverdefinedReturnedCalls;
   ///}
 
-  /// Collect values that could become \p V in the set \p Values, each mapped to
-  /// \p ReturnInsts.
-  void collectValuesRecursively(
-      Attributor &A, Value *V, SmallPtrSetImpl<ReturnInst *> &ReturnInsts,
-      DenseMap<Value *, SmallPtrSet<ReturnInst *, 2>> &Values) {
-
-    visitValueCB_t<bool> VisitValueCB = [&](Value *Val, bool &) {
-      assert(!isa<Instruction>(Val) ||
-             &getAnchorScope() == cast<Instruction>(Val)->getFunction());
-      Values[Val].insert(ReturnInsts.begin(), ReturnInsts.end());
-    };
-
-    bool UnusedBool;
-    bool Success = genericValueTraversal(V, UnusedBool, VisitValueCB);
-
-    // If we did abort the above traversal we haven't see all the values.
-    // Consequently, we cannot know if the information we would derive is
-    // accurate so we give up early.
-    if (!Success)
-      indicatePessimisticFixpoint();
-  }
-
 public:
   AAReturnedValuesImpl(const IRPosition &IRP) : AAReturnedValues(IRP) {}
 
@@ -541,18 +516,20 @@
     // Reset the state.
     IsFixed = false;
     IsValidState = true;
-    HasOverdefinedReturnedCalls = false;
     ReturnedValues.clear();
 
-    Function &F = getAnchorScope();
+    Function *F = getAssociatedFunction();
+    if (!F || !F->hasExactDefinition()) {
+      indicatePessimisticFixpoint();
+      return;
+    }
 
     // The map from instruction opcodes to those instructions in the function.
-    auto &OpcodeInstMap = A.getInfoCache().getOpcodeInstMapForFunction(F);
+    auto &OpcodeInstMap = A.getInfoCache().getOpcodeInstMapForFunction(*F);
 
     // Look through all arguments, if one is marked as returned we are done.
-    for (Argument &Arg : F.args()) {
+    for (Argument &Arg : F->args()) {
       if (Arg.hasReturnedAttr()) {
-
         auto &ReturnInstSet = ReturnedValues[&Arg];
         for (Instruction *RI : OpcodeInstMap[Instruction::Ret])
           ReturnInstSet.insert(cast<ReturnInst>(RI));
@@ -561,14 +538,6 @@
         return;
       }
     }
-
-    // If no argument was marked as returned we look at all return instructions
-    // and collect potentially returned values.
-    for (Instruction *RI : OpcodeInstMap[Instruction::Ret]) {
-      SmallPtrSet<ReturnInst *, 1> RISet({cast<ReturnInst>(RI)});
-      collectValuesRecursively(A, cast<ReturnInst>(RI)->getReturnValue(), RISet,
-                               ReturnedValues);
-    }
   }
 
   /// See AbstractAttribute::manifest(...).
@@ -583,8 +552,20 @@
   /// See AbstractAttribute::updateImpl(Attributor &A).
   ChangeStatus updateImpl(Attributor &A) override;
 
+  llvm::iterator_range<iterator> returned_values() override {
+    return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end());
+  }
+
+  llvm::iterator_range<const_iterator> returned_values() const override {
+    return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end());
+  }
+
+  const SmallPtrSetImpl<CallBase *> &getUnresolvedCalls() const override {
+    return UnresolvedCalls;
+  }
+
   /// Return the number of potential return values, -1 if unknown.
-  size_t getNumReturnValues() const {
+  size_t getNumReturnValues() const override {
     return isValidState() ? ReturnedValues.size() : -1;
   }
 
@@ -621,15 +602,6 @@
   }
 };
 
-struct AAReturnedValuesFunction final : public AAReturnedValuesImpl {
-  AAReturnedValuesFunction(const IRPosition &IRP) : AAReturnedValuesImpl(IRP) {}
-
-  /// See AbstractAttribute::trackStatistics()
-  void trackStatistics() const override {
-    STATS_DECL_AND_TRACK_ARG_ATTR(returned)
-  }
-};
-
 ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) {
   ChangeStatus Changed = ChangeStatus::UNCHANGED;
 
@@ -660,7 +632,7 @@
 const std::string AAReturnedValuesImpl::getAsStr() const {
   return (isAtFixpoint() ? "returns(#" : "may-return(#") +
          (isValidState() ? std::to_string(getNumReturnValues()) : "?") +
-         ")[OD: " + std::to_string(HasOverdefinedReturnedCalls) + "]";
+         ")[#UC: " + std::to_string(UnresolvedCalls.size()) + "]";
 }
 
 Optional<Value *>
@@ -707,8 +679,8 @@
     Value *RV = It.first;
     const SmallPtrSetImpl<ReturnInst *> &RetInsts = It.second;
 
-    ImmutableCallSite ICS(RV);
-    if (ICS && !HasOverdefinedReturnedCalls)
+    CallBase *CB = dyn_cast<CallBase>(RV);
+    if (CB && !UnresolvedCalls.count(CB))
       continue;
 
     if (!Pred(*RV, RetInsts))
@@ -719,127 +691,138 @@
 }
 
 ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
+  size_t NumUnresolvedCalls = UnresolvedCalls.size();
+  bool Changed = false;
 
-  // Check if we know of any values returned by the associated function,
-  // if not, we are done.
-  if (getNumReturnValues() == 0) {
-    indicateOptimisticFixpoint();
-    return ChangeStatus::UNCHANGED;
-  }
+  // State used in the value traversals starting in returned values.
+  struct RVState {
+    // The map in which we collect return values -> return instrs.
+    decltype(ReturnedValues) &RetValsMap;
+    // The flag to indicate a change.
+    bool Changed;
+    // The return instrs we come from.
+    SmallPtrSet<ReturnInst *, 2> RetInsts;
+  };
 
-  // Check if any of the returned values is a call site we can refine.
-  decltype(ReturnedValues) AddRVs;
-  bool HasCallSite = false;
+  // Callback for a leaf value returned by the associated function.
+  auto VisitValueCB = [](Value &Val, RVState &RVS, bool) {
+    auto Size = RVS.RetValsMap[&Val].size();
+    RVS.RetValsMap[&Val].insert(RVS.RetInsts.begin(), RVS.RetInsts.end());
+    bool Inserted = RVS.RetValsMap[&Val].size() != Size;
+    RVS.Changed |= Inserted;
+    LLVM_DEBUG({
+      if (Inserted)
+        dbgs() << "[AAReturnedValues] 1 Add new returned value " << Val
+               << " => " << RVS.RetInsts.size() << "\n";
+    });
+  };
 
-  // Keep track of any change to trigger updates on dependent attributes.
-  ChangeStatus Changed = ChangeStatus::UNCHANGED;
+  // Helper method to invoke the generic value traversal.
+  auto VisitReturnedValue = [&](Value &RV, RVState &RVS) {
+    IRPosition RetValPos = IRPosition::value(RV);
+    return genericValueTraversal<AAReturnedValues, RVState>(A, RetValPos, *this,
+                                                            RVS, VisitValueCB);
+  };
 
-  auto *LivenessAA = A.getAAFor<AAIsDead>(*this, getIRPosition());
+  // Callback for all "return intructions" live in the associated function.
+  auto CheckReturnInst = [this, &VisitReturnedValue, &Changed](Instruction &I) {
+    ReturnInst &Ret = cast<ReturnInst>(I);
+    RVState RVS({ReturnedValues, false});
+    RVS.RetInsts.insert(&Ret);
+    Changed |= RVS.Changed;
+    return VisitReturnedValue(*Ret.getReturnValue(), RVS);
+  };
 
-  // Look at all returned call sites.
+  // Start by discovering returned values from all live returned instructions in
+  // the associated function.
+  if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}))
+    return indicatePessimisticFixpoint();
+
+  // Once returned values "directly" present in the code are handled we try to
+  // resolve returned calls.
+  decltype(ReturnedValues) NewRVsMap;
   for (auto &It : ReturnedValues) {
-    SmallPtrSet<ReturnInst *, 2> &ReturnInsts = It.second;
-    Value *RV = It.first;
+    LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned value: " << *It.first
+                      << " by #" << It.second.size() << " RIs\n");
+    CallBase *CB = dyn_cast<CallBase>(It.first);
+    if (!CB || UnresolvedCalls.count(CB))
+      continue;
 
-    LLVM_DEBUG(dbgs() << "[AAReturnedValues] Potentially returned value " << *RV
+    const auto *RetValAAPtr =
+        A.getAAFor<AAReturnedValues>(*this, IRPosition::callsite_function(*CB));
+
+    // Skip dead ends, thus if we do not know anything about the returned
+    // call we mark it as unresolved and it will stay that way.
+    if (!RetValAAPtr || !RetValAAPtr->getState().isValidState()) {
+      LLVM_DEBUG(dbgs() << "[AAReturnedValues] Unresolved call: " << *CB
+                        << "\n");
+      UnresolvedCalls.insert(CB);
+      continue;
+    }
+
+    const auto &RetValAA = *RetValAAPtr;
+    LLVM_DEBUG(dbgs() << "[AAReturnedValues] Found another AAReturnedValues: "
+                      << static_cast<const AbstractAttribute &>(RetValAA)
                       << "\n");
 
-    // Only call sites can change during an update, ignore the rest.
-    CallSite RetCS(RV);
-    if (!RetCS)
-      continue;
+    // If we know something but not everyting about the returned values, keep
+    // track of that too. Hence, remember transitively unresolved calls.
+    UnresolvedCalls.insert(RetValAA.getUnresolvedCalls().begin(),
+                           RetValAA.getUnresolvedCalls().end());
 
-    // For now, any call site we see will prevent us from directly fixing the
-    // state. However, if the information on the callees is fixed, the call
-    // sites will be removed and we will fix the information for this state.
-    HasCallSite = true;
-
-    // Ignore dead ReturnValues.
-    if (LivenessAA &&
-        !LivenessAA->isLiveInstSet(ReturnInsts.begin(), ReturnInsts.end())) {
-      LLVM_DEBUG(dbgs() << "[AAReturnedValues] all returns are assumed dead, "
-                           "skip it for now\n");
-      continue;
+    // Now track transitively returned values.
+    for (auto &RetValAAIt : RetValAA.returned_values()) {
+      Value *RetVal = RetValAAIt.first;
+      if (Argument *Arg = dyn_cast<Argument>(RetVal)) {
+        // Arguments are mapped to call site operands and we begin the traversal
+        // again.
+        RVState RVS({NewRVsMap, false, RetValAAIt.second});
+        VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS);
+        continue;
+      } else if (isa<CallBase>(RetVal)) {
+        // Call sites are resolved by the callee attribute over time, no need to
+        // do anything for us.
+        continue;
+      } else if (isa<Constant>(RetVal)) {
+        // Constants are valid everywhere, we can simply take them.
+        NewRVsMap[RetVal].insert(It.second.begin(), It.second.end());
+        continue;
+      }
+      // Anything that did not fit in the above categories cannot be resolved,
+      // mark the call as unresolved.
+      LLVM_DEBUG(dbgs() << "[AAReturnedValues] transitively returned value "
+                           "cannot be translated: "
+                        << *RetVal << "\n");
+      UnresolvedCalls.insert(CB);
     }
-
-    // Try to find a assumed unique return value for the called function.
-    auto *RetCSAA = A.getAAFor<AAReturnedValuesImpl>(
-        *this, IRPosition::callsite_returned(RetCS));
-    if (!RetCSAA) {
-      if (!HasOverdefinedReturnedCalls)
-        Changed = ChangeStatus::CHANGED;
-      HasOverdefinedReturnedCalls = true;
-      LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned call site (" << *RV
-                        << ") with " << (RetCSAA ? "invalid" : "no")
-                        << " associated state\n");
-      continue;
-    }
-
-    // Try to find a assumed unique return value for the called function.
-    Optional<Value *> AssumedUniqueRV = RetCSAA->getAssumedUniqueReturnValue(A);
-
-    // If no assumed unique return value was found due to the lack of
-    // candidates, we may need to resolve more calls (through more update
-    // iterations) or the called function will not return. Either way, we
-    // simply stick with the call sites as return values. Because there were
-    // not multiple possibilities, we do not treat it as overdefined.
-    if (!AssumedUniqueRV.hasValue())
-      continue;
-
-    // If multiple, non-refinable values were found, there cannot be a unique
-    // return value for the called function. The returned call is overdefined!
-    if (!AssumedUniqueRV.getValue()) {
-      if (!HasOverdefinedReturnedCalls)
-        Changed = ChangeStatus::CHANGED;
-      HasOverdefinedReturnedCalls = true;
-      LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned call site has multiple "
-                           "potentially returned values\n");
-      continue;
-    }
-
-    LLVM_DEBUG({
-      bool UniqueRVIsKnown = RetCSAA->isAtFixpoint();
-      dbgs() << "[AAReturnedValues] Returned call site "
-             << (UniqueRVIsKnown ? "known" : "assumed")
-             << " unique return value: " << *AssumedUniqueRV << "\n";
-    });
-
-    // The assumed unique return value.
-    Value *AssumedRetVal = AssumedUniqueRV.getValue();
-
-    // If the assumed unique return value is an argument, lookup the matching
-    // call site operand and recursively collect new returned values.
-    // If it is not an argument, it is just put into the set of returned
-    // values as we would have already looked through casts, phis, and similar
-    // values.
-    if (Argument *AssumedRetArg = dyn_cast<Argument>(AssumedRetVal))
-      collectValuesRecursively(A,
-                               RetCS.getArgOperand(AssumedRetArg->getArgNo()),
-                               ReturnInsts, AddRVs);
-    else
-      AddRVs[AssumedRetVal].insert(ReturnInsts.begin(), ReturnInsts.end());
   }
 
-  for (auto &It : AddRVs) {
+  // To avoid modifications to the ReturnedValues map while we iterate over it
+  // we kept record of potential new entries in a copy map, NewRVsMap.
+  for (auto &It : NewRVsMap) {
     assert(!It.second.empty() && "Entry does not add anything.");
     auto &ReturnInsts = ReturnedValues[It.first];
     for (ReturnInst *RI : It.second)
       if (ReturnInsts.insert(RI).second) {
         LLVM_DEBUG(dbgs() << "[AAReturnedValues] Add new returned value "
                           << *It.first << " => " << *RI << "\n");
-        Changed = ChangeStatus::CHANGED;
+        Changed = true;
       }
   }
 
-  // If there is no call site in the returned values we are done.
-  if (!HasCallSite) {
-    indicateOptimisticFixpoint();
-    return ChangeStatus::CHANGED;
-  }
-
-  return Changed;
+  Changed |= (NumUnresolvedCalls != UnresolvedCalls.size());
+  return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
 }
 
+struct AAReturnedValuesFunction final : public AAReturnedValuesImpl {
+  AAReturnedValuesFunction(const IRPosition &IRP) : AAReturnedValuesImpl(IRP) {}
+
+  /// See AbstractAttribute::trackStatistics()
+  void trackStatistics() const override {
+    STATS_DECL_AND_TRACK_ARG_ATTR(returned)
+  }
+};
+
 /// ------------------------ NoSync Function Attribute -------------------------
 
 struct AANoSyncImpl : AANoSync {
@@ -1239,7 +1222,7 @@
 bool containsPossiblyEndlessLoop(Function &F) { return containsCycle(F); }
 
 void AAWillReturnFunction::initialize(Attributor &A) {
-  Function &F = getAnchorScope();
+  Function &F = *getAnchorScope();
 
   if (containsPossiblyEndlessLoop(F))
     indicatePessimisticFixpoint();
@@ -1288,7 +1271,7 @@
 
   /// See AbstractAttriubute::initialize(...).
   void initialize(Attributor &A) override {
-    Function &F = getAnchorScope();
+    Function &F = *getAnchorScope();
 
     // Already noalias.
     if (F.returnDoesNotAlias()) {
@@ -1348,7 +1331,7 @@
   AAIsDeadImpl(const IRPosition &IRP) : AAIsDead(IRP) {}
 
   void initialize(Attributor &A) override {
-    const Function &F = getAnchorScope();
+    const Function &F = *getAnchorScope();
 
     ToBeExploredPaths.insert(&(F.getEntryBlock().front()));
     AssumedLiveBlocks.insert(&(F.getEntryBlock()));
@@ -1371,7 +1354,7 @@
   /// See AbstractAttribute::getAsStr().
   const std::string getAsStr() const override {
     return "Live[#BB " + std::to_string(AssumedLiveBlocks.size()) + "/" +
-           std::to_string(getAnchorScope().size()) + "][#NRI " +
+           std::to_string(getAnchorScope()->size()) + "][#NRI " +
            std::to_string(NoReturnCalls.size()) + "]";
   }
 
@@ -1381,7 +1364,7 @@
            "Attempted to manifest an invalid state!");
 
     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
-    const Function &F = getAnchorScope();
+    const Function &F = *getAssociatedFunction();
 
     // Flag to determine if we can change an invoke to a call assuming the
     // callee is nounwind. This is not possible if the personality of the
@@ -1443,7 +1426,7 @@
 
   /// See AAIsDead::isAssumedDead(BasicBlock *).
   bool isAssumedDead(const BasicBlock *BB) const override {
-    assert(BB->getParent() == &getAnchorScope() &&
+    assert(BB->getParent() == getAnchorScope() &&
            "BB must be in the same anchor scope function.");
 
     if (!getAssumed())
@@ -1458,7 +1441,7 @@
 
   /// See AAIsDead::isAssumed(Instruction *I).
   bool isAssumedDead(const Instruction *I) const override {
-    assert(I->getParent()->getParent() == &getAnchorScope() &&
+    assert(I->getParent()->getParent() == getAnchorScope() &&
            "Instruction must be in the same anchor scope function.");
 
     if (!getAssumed())
@@ -1504,7 +1487,7 @@
     STATS_DECL(DeadBlocks, Function,
                "Number of basic blocks classified as dead");
     BUILD_STAT_NAME(DeadBlocks, Function) +=
-        getAnchorScope().size() - AssumedLiveBlocks.size();
+        getAnchorScope()->size() - AssumedLiveBlocks.size();
     STATS_DECL(PartiallyDeadBlocks, Function,
                "Number of basic blocks classified as partially dead");
     BUILD_STAT_NAME(PartiallyDeadBlocks, Function) += NoReturnCalls.size();
@@ -1602,13 +1585,13 @@
     }
   }
 
-  LLVM_DEBUG(
-      dbgs() << "[AAIsDead] AssumedLiveBlocks: " << AssumedLiveBlocks.size()
-             << " Total number of blocks: " << getAnchorScope().size() << "\n");
+  LLVM_DEBUG(dbgs() << "[AAIsDead] AssumedLiveBlocks: "
+                    << AssumedLiveBlocks.size() << " Total number of blocks: "
+                    << getAnchorScope()->size() << "\n");
 
   // If we know everything is live there is no need to query for liveness.
   if (NoReturnCalls.empty() &&
-      getAnchorScope().size() == AssumedLiveBlocks.size()) {
+      getAnchorScope()->size() == AssumedLiveBlocks.size()) {
     // Indicating a pessimistic fixpoint will cause the state to be "invalid"
     // which will cause the Attributor to not return the AAIsDead on request,
     // which will prevent us from querying isAssumedDead().
@@ -1824,7 +1807,7 @@
     return calcDifferenceIfBaseIsNonNull(
         DL.getTypeStoreSize(Base->getType()->getPointerElementType()),
         Offset.getSExtValue(),
-        !NullPointerIsDefined(&getAnchorScope(),
+        !NullPointerIsDefined(getAnchorScope(),
                               V.getType()->getPointerAddressSpace()));
 
   IsNonNull = false;
@@ -2230,23 +2213,7 @@
   if (!AARetVal || !AARetVal->getState().isValidState())
     return false;
 
-  auto *LivenessAA =
-      getAAFor<AAIsDead>(QueryingAA, IRPosition::function(*AssociatedFunction));
-  if (!LivenessAA)
-    return AARetVal->checkForAllReturnedValuesAndReturnInsts(Pred);
-
-  auto LivenessFilter = [&](Value &RV,
-                            const SmallPtrSetImpl<ReturnInst *> &ReturnInsts) {
-    SmallPtrSet<ReturnInst *, 4> FilteredReturnInsts;
-    for (ReturnInst *RI : ReturnInsts)
-      if (!LivenessAA->isAssumedDead(RI))
-        FilteredReturnInsts.insert(RI);
-    if (!FilteredReturnInsts.empty())
-      return Pred(RV, FilteredReturnInsts);
-    return true;
-  };
-
-  return AARetVal->checkForAllReturnedValuesAndReturnInsts(LivenessFilter);
+  return AARetVal->checkForAllReturnedValuesAndReturnInsts(Pred);
 }
 
 bool Attributor::checkForAllReturnedValues(
@@ -2263,22 +2230,10 @@
   if (!AARetVal || !AARetVal->getState().isValidState())
     return false;
 
-  auto *LivenessAA =
-      getAAFor<AAIsDead>(QueryingAA, IRPosition::function(*AssociatedFunction));
-  if (!LivenessAA)
-    return AARetVal->checkForAllReturnedValuesAndReturnInsts(
-        [&](Value &RV, const SmallPtrSetImpl<ReturnInst *> &) {
-          return Pred(RV);
-        });
-
-  auto LivenessFilter = [&](Value &RV,
-                            const SmallPtrSetImpl<ReturnInst *> &ReturnInsts) {
-    if (LivenessAA->isLiveInstSet(ReturnInsts.begin(), ReturnInsts.end()))
-      return Pred(RV);
-    return true;
-  };
-
-  return AARetVal->checkForAllReturnedValuesAndReturnInsts(LivenessFilter);
+  return AARetVal->checkForAllReturnedValuesAndReturnInsts(
+      [&](Value &RV, const SmallPtrSetImpl<ReturnInst *> &) {
+        return Pred(RV);
+      });
 }
 
 bool Attributor::checkForAllInstructions(