[Attributor] Allow multiple uses of a casted function pointer

If a function pointer is casted into a different type the resulting
expression can be a constant. If so, it can be used multiple times which
cannot be handled by the AbstractCallSite constructor alone. Instead, we
follow the cast expression uses now explicitly during the call site
traversal.
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 0c806eb..aaa912b 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -4593,21 +4593,24 @@
     bool HasValueBefore = SimplifiedAssociatedValue.hasValue();
 
     auto PredForCallSite = [&](AbstractCallSite ACS) {
-      // Check if we have an associated argument or not (which can happen for
-      // callback calls).
-      Value *ArgOp = ACS.getCallArgOperand(getArgNo());
-      if (!ArgOp)
+      const IRPosition &ACSArgPos =
+          IRPosition::callsite_argument(ACS, getArgNo());
+      // Check if a coresponding argument was found or if it is on not
+      // associated (which can happen for callback calls).
+      if (ACSArgPos.getPositionKind() == IRPosition::IRP_INVALID)
         return false;
+
       // We can only propagate thread independent values through callbacks.
       // This is different to direct/indirect call sites because for them we
       // know the thread executing the caller and callee is the same. For
       // callbacks this is not guaranteed, thus a thread dependent value could
       // be different for the caller and callee, making it invalid to propagate.
+      Value &ArgOp = ACSArgPos.getAssociatedValue();
       if (ACS.isCallbackCall())
-        if (auto *C = dyn_cast<Constant>(ArgOp))
+        if (auto *C = dyn_cast<Constant>(&ArgOp))
           if (C->isThreadDependent())
             return false;
-      return checkAndUpdate(A, *this, *ArgOp, SimplifiedAssociatedValue);
+      return checkAndUpdate(A, *this, ArgOp, SimplifiedAssociatedValue);
     };
 
     bool AllCallSitesKnown;
@@ -7289,13 +7292,23 @@
   // If we do not require all call sites we might not see all.
   AllCallSitesKnown = RequireAllCallSites;
 
-  for (const Use &U : Fn.uses()) {
+  SmallVector<const Use *, 8> Uses(make_pointer_range(Fn.uses()));
+  for (unsigned u = 0; u < Uses.size(); ++u) {
+    const Use &U = *Uses[u];
     LLVM_DEBUG(dbgs() << "[Attributor] Check use: " << *U << " in "
                       << *U.getUser() << "\n");
     if (isAssumedDead(U, QueryingAA, nullptr, /* CheckBBLivenessOnly */ true)) {
       LLVM_DEBUG(dbgs() << "[Attributor] Dead use, skip!\n");
       continue;
     }
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U.getUser())) {
+      if (CE->isCast() && CE->getType()->isPointerTy() &&
+          CE->getType()->getPointerElementType()->isFunctionTy()) {
+        for (const Use &CEU : CE->uses())
+          Uses.push_back(&CEU);
+        continue;
+      }
+    }
 
     AbstractCallSite ACS(&U);
     if (!ACS) {