[Attributor] Manifest simplified (return) values properly

If we simplify a function return value we have to modify the return
instructions.
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 1c25969..166c9dd 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -4484,11 +4484,13 @@
     Value &QueryingValueSimplifiedUnwrapped =
         *QueryingValueSimplified.getValue();
 
-    if (isa<UndefValue>(QueryingValueSimplifiedUnwrapped))
-      return true;
-
-    if (AccumulatedSimplifiedValue.hasValue())
+    if (AccumulatedSimplifiedValue.hasValue() &&
+        !isa<UndefValue>(AccumulatedSimplifiedValue.getValue()) &&
+        !isa<UndefValue>(QueryingValueSimplifiedUnwrapped))
       return AccumulatedSimplifiedValue == QueryingValueSimplified;
+    if (AccumulatedSimplifiedValue.hasValue() &&
+        isa<UndefValue>(QueryingValueSimplifiedUnwrapped))
+      return true;
 
     LLVM_DEBUG(dbgs() << "[ValueSimplify] " << QueryingValue
                       << " is assumed to be "
@@ -4522,13 +4524,16 @@
   ChangeStatus manifest(Attributor &A) override {
     ChangeStatus Changed = ChangeStatus::UNCHANGED;
 
-    if (!SimplifiedAssociatedValue.hasValue() ||
+    if (SimplifiedAssociatedValue.hasValue() &&
         !SimplifiedAssociatedValue.getValue())
       return Changed;
 
-    if (auto *C = dyn_cast<Constant>(SimplifiedAssociatedValue.getValue())) {
+    Value &V = getAssociatedValue();
+    auto *C = SimplifiedAssociatedValue.hasValue()
+                  ? dyn_cast<Constant>(SimplifiedAssociatedValue.getValue())
+                  : UndefValue::get(V.getType());
+    if (C) {
       // We can replace the AssociatedValue with the constant.
-      Value &V = getAssociatedValue();
       if (!V.user_empty() && &V != C && V.getType() == C->getType()) {
         LLVM_DEBUG(dbgs() << "[ValueSimplify] " << V << " -> " << *C
                           << " :: " << *this << "\n");
@@ -4638,6 +4643,44 @@
                ? ChangeStatus::UNCHANGED
                : ChangeStatus ::CHANGED;
   }
+
+  ChangeStatus manifest(Attributor &A) override {
+    ChangeStatus Changed = ChangeStatus::UNCHANGED;
+
+    if (SimplifiedAssociatedValue.hasValue() &&
+        !SimplifiedAssociatedValue.getValue())
+      return Changed;
+
+    Value &V = getAssociatedValue();
+    auto *C = SimplifiedAssociatedValue.hasValue()
+                  ? dyn_cast<Constant>(SimplifiedAssociatedValue.getValue())
+                  : UndefValue::get(V.getType());
+    if (C) {
+      auto PredForReturned =
+          [&](Value &V, const SmallSetVector<ReturnInst *, 4> &RetInsts) {
+            // We can replace the AssociatedValue with the constant.
+            if (&V == C || V.getType() != C->getType() || isa<UndefValue>(V))
+              return true;
+            if (auto *CI = dyn_cast<CallInst>(&V))
+              if (CI->isMustTailCall())
+                return true;
+
+            for (ReturnInst *RI : RetInsts) {
+              if (RI->getFunction() != getAnchorScope())
+                continue;
+              LLVM_DEBUG(dbgs() << "[ValueSimplify] " << V << " -> " << *C
+                                << " in " << *RI << " :: " << *this << "\n");
+              if (A.changeUseAfterManifest(RI->getOperandUse(0), *C))
+                Changed = ChangeStatus::CHANGED;
+            }
+            return true;
+          };
+      A.checkForAllReturnedValuesAndReturnInsts(PredForReturned, *this);
+    }
+
+    return Changed | AAValueSimplify::manifest(A);
+  }
+
   /// See AbstractAttribute::trackStatistics()
   void trackStatistics() const override {
     STATS_DECLTRACK_FNRET_ATTR(value_simplify)
@@ -4728,7 +4771,7 @@
     if (auto *CI = dyn_cast<CallInst>(&getAssociatedValue()))
       if (CI->isMustTailCall())
         return ChangeStatus::UNCHANGED;
-    return AAValueSimplifyReturned::manifest(A);
+    return AAValueSimplifyImpl::manifest(A);
   }
 
   void trackStatistics() const override {