[AMDGPU] Make printf lowering faster when there are no printfs

Summary:
Printf lowering unconditionally visited every instruction in the module.
To make it faster in the common case where there are no printfs, look up
the printf function (if any) and iterate over its users instead.

Reviewers: rampitec, kzhuravl, alex-t, arsenm

Subscribers: jvesely, wdng, nhaehnle, yaxunl, dstuttard, tpr, t-tye, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D68145

llvm-svn: 373433
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPrintfRuntimeBinding.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPrintfRuntimeBinding.cpp
index 261d628..5250bf4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUPrintfRuntimeBinding.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUPrintfRuntimeBinding.cpp
@@ -30,7 +30,6 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/InstVisitor.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
@@ -45,20 +44,13 @@
 
 namespace {
 class LLVM_LIBRARY_VISIBILITY AMDGPUPrintfRuntimeBinding final
-    : public ModulePass,
-      public InstVisitor<AMDGPUPrintfRuntimeBinding> {
+    : public ModulePass {
 
 public:
   static char ID;
 
   explicit AMDGPUPrintfRuntimeBinding();
 
-  void visitCallSite(CallSite CS) {
-    Function *F = CS.getCalledFunction();
-    if (F && F->hasName() && F->getName() == "printf")
-      Printfs.push_back(CS.getInstruction());
-  }
-
 private:
   bool runOnModule(Module &M) override;
   void getConversionSpecifiers(SmallVectorImpl<char> &OpConvSpecifiers,
@@ -80,7 +72,7 @@
 
   const DataLayout *TD;
   const DominatorTree *DT;
-  SmallVector<Value *, 32> Printfs;
+  SmallVector<CallInst *, 32> Printfs;
 };
 } // namespace
 
@@ -162,8 +154,7 @@
   // NB: This is important for this string size to be divizable by 4
   const char NonLiteralStr[4] = "???";
 
-  for (auto P : Printfs) {
-    auto CI = cast<CallInst>(P);
+  for (auto CI : Printfs) {
     unsigned NumOps = CI->getNumArgOperands();
 
     SmallString<16> OpConvSpecifiers;
@@ -564,10 +555,8 @@
   }
 
   // erase the printf calls
-  for (auto P : Printfs) {
-    auto CI = cast<CallInst>(P);
+  for (auto CI : Printfs)
     CI->eraseFromParent();
-  }
 
   Printfs.clear();
   return true;
@@ -578,7 +567,16 @@
   if (TT.getArch() == Triple::r600)
     return false;
 
-  visit(M);
+  auto PrintfFunction = M.getFunction("printf");
+  if (!PrintfFunction)
+    return false;
+
+  for (auto &U : PrintfFunction->uses()) {
+    if (auto *CI = dyn_cast<CallInst>(U.getUser())) {
+      if (CI->isCallee(&U))
+        Printfs.push_back(CI);
+    }
+  }
 
   if (Printfs.empty())
     return false;