[Orc] Update the Orc indirection utils and refactor the CompileOnDemand layer.

This patch replaces most of the Orc indirection utils API with a new class:
JITCompileCallbackManager, which creates and manages JIT callbacks.
Exposing this functionality directly allows the user to create callbacks that
are associated with user supplied compilation actions. For example, you can
create a callback to lazyily IR-gen something from an AST. (A kaleidoscope
example demonstrating this will be committed shortly).

This patch also refactors the CompileOnDemand layer to use the
JITCompileCallbackManager API.

llvm-svn: 229461
diff --git a/llvm/lib/ExecutionEngine/Orc/CloneSubModule.cpp b/llvm/lib/ExecutionEngine/Orc/CloneSubModule.cpp
index 54acb78..64a33c8 100644
--- a/llvm/lib/ExecutionEngine/Orc/CloneSubModule.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/CloneSubModule.cpp
@@ -27,27 +27,20 @@
   }
 }
 
-std::unique_ptr<Module>
-llvm::CloneSubModule(const Module &M,
+void llvm::CloneSubModule(llvm::Module &Dst, const Module &Src,
                      HandleGlobalVariableFtor HandleGlobalVariable,
-                     HandleFunctionFtor HandleFunction, bool KeepInlineAsm) {
+                     HandleFunctionFtor HandleFunction, bool CloneInlineAsm) {
 
   ValueToValueMapTy VMap;
 
-  // First off, we need to create the new module.
-  std::unique_ptr<Module> New =
-      llvm::make_unique<Module>(M.getModuleIdentifier(), M.getContext());
-
-  New->setDataLayout(M.getDataLayout());
-  New->setTargetTriple(M.getTargetTriple());
-  if (KeepInlineAsm)
-    New->setModuleInlineAsm(M.getModuleInlineAsm());
+  if (CloneInlineAsm)
+    Dst.appendModuleInlineAsm(Src.getModuleInlineAsm());
 
   // Copy global variables (but not initializers, yet).
-  for (Module::const_global_iterator I = M.global_begin(), E = M.global_end();
+  for (Module::const_global_iterator I = Src.global_begin(), E = Src.global_end();
        I != E; ++I) {
     GlobalVariable *GV = new GlobalVariable(
-        *New, I->getType()->getElementType(), I->isConstant(), I->getLinkage(),
+        Dst, I->getType()->getElementType(), I->isConstant(), I->getLinkage(),
         (Constant *)nullptr, I->getName(), (GlobalVariable *)nullptr,
         I->getThreadLocalMode(), I->getType()->getAddressSpace());
     GV->copyAttributesFrom(I);
@@ -55,21 +48,21 @@
   }
 
   // Loop over the functions in the module, making external functions as before
-  for (Module::const_iterator I = M.begin(), E = M.end(); I != E; ++I) {
+  for (Module::const_iterator I = Src.begin(), E = Src.end(); I != E; ++I) {
     Function *NF =
         Function::Create(cast<FunctionType>(I->getType()->getElementType()),
-                         I->getLinkage(), I->getName(), &*New);
+                         I->getLinkage(), I->getName(), &Dst);
     NF->copyAttributesFrom(I);
     VMap[I] = NF;
   }
 
   // Loop over the aliases in the module
-  for (Module::const_alias_iterator I = M.alias_begin(), E = M.alias_end();
+  for (Module::const_alias_iterator I = Src.alias_begin(), E = Src.alias_end();
        I != E; ++I) {
     auto *PTy = cast<PointerType>(I->getType());
     auto *GA =
         GlobalAlias::create(PTy->getElementType(), PTy->getAddressSpace(),
-                            I->getLinkage(), I->getName(), &*New);
+                            I->getLinkage(), I->getName(), &Dst);
     GA->copyAttributesFrom(I);
     VMap[I] = GA;
   }
@@ -77,7 +70,7 @@
   // Now that all of the things that global variable initializer can refer to
   // have been created, loop through and copy the global variable referrers
   // over...  We also set the attributes on the global now.
-  for (Module::const_global_iterator I = M.global_begin(), E = M.global_end();
+  for (Module::const_global_iterator I = Src.global_begin(), E = Src.global_end();
        I != E; ++I) {
     GlobalVariable &GV = *cast<GlobalVariable>(VMap[I]);
     HandleGlobalVariable(GV, *I, VMap);
@@ -85,13 +78,13 @@
 
   // Similarly, copy over function bodies now...
   //
-  for (Module::const_iterator I = M.begin(), E = M.end(); I != E; ++I) {
+  for (Module::const_iterator I = Src.begin(), E = Src.end(); I != E; ++I) {
     Function &F = *cast<Function>(VMap[I]);
     HandleFunction(F, *I, VMap);
   }
 
   // And aliases
-  for (Module::const_alias_iterator I = M.alias_begin(), E = M.alias_end();
+  for (Module::const_alias_iterator I = Src.alias_begin(), E = Src.alias_end();
        I != E; ++I) {
     GlobalAlias *GA = cast<GlobalAlias>(VMap[I]);
     if (const Constant *C = I->getAliasee())
@@ -99,14 +92,13 @@
   }
 
   // And named metadata....
-  for (Module::const_named_metadata_iterator I = M.named_metadata_begin(),
-                                             E = M.named_metadata_end();
+  for (Module::const_named_metadata_iterator I = Src.named_metadata_begin(),
+                                             E = Src.named_metadata_end();
        I != E; ++I) {
     const NamedMDNode &NMD = *I;
-    NamedMDNode *NewNMD = New->getOrInsertNamedMetadata(NMD.getName());
+    NamedMDNode *NewNMD = Dst.getOrInsertNamedMetadata(NMD.getName());
     for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i)
       NewNMD->addOperand(MapMetadata(NMD.getOperand(i), VMap));
   }
 
-  return New;
 }
diff --git a/llvm/lib/ExecutionEngine/Orc/IndirectionUtils.cpp b/llvm/lib/ExecutionEngine/Orc/IndirectionUtils.cpp
index 2fcfb82..57616a5 100644
--- a/llvm/lib/ExecutionEngine/Orc/IndirectionUtils.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/IndirectionUtils.cpp
@@ -9,149 +9,101 @@
 
 namespace llvm {
 
-JITIndirections makeCallsSingleIndirect(
-    Module &M, const std::function<bool(const Function &)> &ShouldIndirect,
-    const char *JITImplSuffix, const char *JITAddrSuffix) {
-  std::vector<Function *> Worklist;
-  std::vector<std::string> FuncNames;
-
-  for (auto &F : M)
-    if (ShouldIndirect(F) && (F.user_begin() != F.user_end())) {
-      Worklist.push_back(&F);
-      FuncNames.push_back(F.getName());
-    }
-
-  for (auto *F : Worklist) {
-    GlobalVariable *FImplAddr = new GlobalVariable(
-        M, F->getType(), false, GlobalValue::ExternalLinkage,
-        Constant::getNullValue(F->getType()), F->getName() + JITAddrSuffix,
-        nullptr, GlobalValue::NotThreadLocal, 0, true);
-    FImplAddr->setVisibility(GlobalValue::HiddenVisibility);
-
-    for (auto *U : F->users()) {
-      assert(isa<Instruction>(U) && "Cannot indirect non-instruction use");
-      IRBuilder<> Builder(cast<Instruction>(U));
-      U->replaceUsesOfWith(F, Builder.CreateLoad(FImplAddr));
-    }
-  }
-
-  return JITIndirections(
-      FuncNames, [=](StringRef S) -> std::string { return std::string(S); },
-      [=](StringRef S)
-          -> std::string { return std::string(S) + JITAddrSuffix; });
+GlobalVariable* createImplPointer(Function &F, const Twine &Name,
+                                  Constant *Initializer) {
+  assert(F.getParent() && "Function isn't in a module.");
+  if (!Initializer)
+    Initializer = Constant::getNullValue(F.getType());
+  Module &M = *F.getParent();
+  return new GlobalVariable(M, F.getType(), false, GlobalValue::ExternalLinkage,
+                            Initializer, Name, nullptr,
+                            GlobalValue::NotThreadLocal, 0, true);
 }
 
-JITIndirections makeCallsDoubleIndirect(
-    Module &M, const std::function<bool(const Function &)> &ShouldIndirect,
-    const char *JITImplSuffix, const char *JITAddrSuffix) {
-
-  std::vector<Function *> Worklist;
-  std::vector<std::string> FuncNames;
-
-  for (auto &F : M)
-    if (!F.isDeclaration() && !F.hasAvailableExternallyLinkage() &&
-        ShouldIndirect(F))
-      Worklist.push_back(&F);
-
-  for (auto *F : Worklist) {
-    std::string OrigName = F->getName();
-    F->setName(OrigName + JITImplSuffix);
-    FuncNames.push_back(OrigName);
-
-    GlobalVariable *FImplAddr = new GlobalVariable(
-        M, F->getType(), false, GlobalValue::ExternalLinkage,
-        Constant::getNullValue(F->getType()), OrigName + JITAddrSuffix, nullptr,
-        GlobalValue::NotThreadLocal, 0, true);
-    FImplAddr->setVisibility(GlobalValue::HiddenVisibility);
-
-    Function *FRedirect =
-        Function::Create(F->getFunctionType(), F->getLinkage(), OrigName, &M);
-
-    F->replaceAllUsesWith(FRedirect);
-
-    BasicBlock *EntryBlock =
-        BasicBlock::Create(M.getContext(), "entry", FRedirect);
-
-    IRBuilder<> Builder(EntryBlock);
-    LoadInst *FImplLoadedAddr = Builder.CreateLoad(FImplAddr);
-
-    std::vector<Value *> CallArgs;
-    for (Value &Arg : FRedirect->args())
-      CallArgs.push_back(&Arg);
-    CallInst *Call = Builder.CreateCall(FImplLoadedAddr, CallArgs);
-    Call->setTailCall();
-    Builder.CreateRet(Call);
-  }
-
-  return JITIndirections(
-      FuncNames, [=](StringRef S)
-                     -> std::string { return std::string(S) + JITImplSuffix; },
-      [=](StringRef S)
-          -> std::string { return std::string(S) + JITAddrSuffix; });
+void makeStub(Function &F, GlobalVariable &ImplPointer) {
+  assert(F.isDeclaration() && "Can't turn a definition into a stub.");
+  assert(F.getParent() && "Function isn't in a module.");
+  Module &M = *F.getParent();
+  BasicBlock *EntryBlock = BasicBlock::Create(M.getContext(), "entry", &F);
+  IRBuilder<> Builder(EntryBlock);
+  LoadInst *ImplAddr = Builder.CreateLoad(&ImplPointer);
+  std::vector<Value*> CallArgs;
+  for (auto &A : F.args())
+    CallArgs.push_back(&A);
+  CallInst *Call = Builder.CreateCall(ImplAddr, CallArgs);
+  Call->setTailCall();
+  Builder.CreateRet(Call);
 }
 
-std::vector<std::unique_ptr<Module>>
-explode(const Module &OrigMod,
-        const std::function<bool(const Function &)> &ShouldExtract) {
+void partition(Module &M, const ModulePartitionMap &PMap) {
 
-  std::vector<std::unique_ptr<Module>> NewModules;
+  for (auto &KVPair : PMap) {
 
-  // Split all the globals, non-indirected functions, etc. into a single module.
-  auto ExtractGlobalVars = [&](GlobalVariable &New, const GlobalVariable &Orig,
-                               ValueToValueMapTy &VMap) {
-    copyGVInitializer(New, Orig, VMap);
-    if (New.getLinkage() == GlobalValue::PrivateLinkage) {
-      New.setLinkage(GlobalValue::ExternalLinkage);
-      New.setVisibility(GlobalValue::HiddenVisibility);
-    }
-  };
-
-  auto ExtractNonImplFunctions =
-      [&](Function &New, const Function &Orig, ValueToValueMapTy &VMap) {
-        if (!ShouldExtract(New))
-          copyFunctionBody(New, Orig, VMap);
+    auto ExtractGlobalVars =
+      [&](GlobalVariable &New, const GlobalVariable &Orig,
+          ValueToValueMapTy &VMap) {
+        if (KVPair.second.count(&Orig)) {
+          copyGVInitializer(New, Orig, VMap);
+        }
+        if (New.getLinkage() == GlobalValue::PrivateLinkage) {
+          New.setLinkage(GlobalValue::ExternalLinkage);
+          New.setVisibility(GlobalValue::HiddenVisibility);
+        }
       };
 
-  NewModules.push_back(CloneSubModule(OrigMod, ExtractGlobalVars,
-                                      ExtractNonImplFunctions, true));
+    auto ExtractFunctions =
+      [&](Function &New, const Function &Orig, ValueToValueMapTy &VMap) {
+        if (KVPair.second.count(&Orig))
+          copyFunctionBody(New, Orig, VMap);
+        if (New.getLinkage() == GlobalValue::InternalLinkage) {
+          New.setLinkage(GlobalValue::ExternalLinkage);
+          New.setVisibility(GlobalValue::HiddenVisibility);
+        }
+      };
 
-  // Preserve initializers for Common linkage vars, and make private linkage
-  // globals external: they are now provided by the globals module extracted
-  // above.
-  auto DropGlobalVars = [&](GlobalVariable &New, const GlobalVariable &Orig,
-                            ValueToValueMapTy &VMap) {
-    if (New.getLinkage() == GlobalValue::CommonLinkage)
-      copyGVInitializer(New, Orig, VMap);
-    else if (New.getLinkage() == GlobalValue::PrivateLinkage)
-      New.setLinkage(GlobalValue::ExternalLinkage);
-  };
+    CloneSubModule(*KVPair.first, M, ExtractGlobalVars, ExtractFunctions,
+                   false);
+  }
+}
 
-  // Split each 'impl' function out in to its own module.
-  for (const auto &Func : OrigMod) {
-    if (Func.isDeclaration() || !ShouldExtract(Func))
+FullyPartitionedModule fullyPartition(Module &M) {
+  FullyPartitionedModule MP;
+
+  ModulePartitionMap PMap;
+
+  for (auto &F : M) {
+
+    if (F.isDeclaration())
       continue;
 
-    auto ExtractNamedFunction =
-        [&](Function &New, const Function &Orig, ValueToValueMapTy &VMap) {
-          if (New.getName() == Func.getName())
-            copyFunctionBody(New, Orig, VMap);
-        };
-
-    NewModules.push_back(
-        CloneSubModule(OrigMod, DropGlobalVars, ExtractNamedFunction, false));
+    std::string NewModuleName = (M.getName() + "." + F.getName()).str();
+    MP.Functions.push_back(
+      llvm::make_unique<Module>(NewModuleName, M.getContext()));
+    MP.Functions.back()->setDataLayout(M.getDataLayout());
+    PMap[MP.Functions.back().get()].insert(&F);
   }
 
-  return NewModules;
+  MP.GlobalVars =
+    llvm::make_unique<Module>((M.getName() + ".globals_and_stubs").str(),
+                              M.getContext());
+  MP.GlobalVars->setDataLayout(M.getDataLayout());
+
+  MP.Commons =
+    llvm::make_unique<Module>((M.getName() + ".commons").str(), M.getContext());
+  MP.Commons->setDataLayout(M.getDataLayout());
+
+  // Make sure there's at least an empty set for the stubs map or we'll fail
+  // to clone anything for it (including the decls).
+  PMap[MP.GlobalVars.get()] = ModulePartitionMap::mapped_type();
+  for (auto &GV : M.globals())
+    if (GV.getLinkage() == GlobalValue::CommonLinkage)
+      PMap[MP.Commons.get()].insert(&GV);
+    else
+      PMap[MP.GlobalVars.get()].insert(&GV);
+
+  partition(M, PMap);
+
+  return MP;
 }
 
-std::vector<std::unique_ptr<Module>>
-explode(const Module &OrigMod, const JITIndirections &Indirections) {
-  std::set<std::string> ImplNames;
-
-  for (const auto &FuncName : Indirections.IndirectedNames)
-    ImplNames.insert(Indirections.GetImplName(FuncName));
-
-  return explode(
-      OrigMod, [&](const Function &F) { return ImplNames.count(F.getName()); });
-}
 }
diff --git a/llvm/lib/ExecutionEngine/Orc/OrcTargetSupport.cpp b/llvm/lib/ExecutionEngine/Orc/OrcTargetSupport.cpp
index 9f278f4..3f14645 100644
--- a/llvm/lib/ExecutionEngine/Orc/OrcTargetSupport.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/OrcTargetSupport.cpp
@@ -1,14 +1,11 @@
 #include "llvm/ADT/Triple.h"
-#include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
+#include "llvm/ExecutionEngine/Orc/OrcTargetSupport.h"
 #include <array>
 
 using namespace llvm;
 
 namespace {
 
-const char *JITCallbackFuncName = "call_jit_for_lazy_compile";
-const char *JITCallbackIndexLabelPrefix = "jit_resolve_";
-
 std::array<const char *, 12> X86GPRsToSave = {{
     "rbp", "rbx", "r12", "r13", "r14", "r15", // Callee saved.
     "rdi", "rsi", "rdx", "rcx", "r8", "r9",   // Int args.
@@ -41,61 +38,90 @@
     OS << "  popq    %" << X86GPRsToSave[X86GPRsToSave.size() - i - 1] << "\n";
 }
 
-uint64_t call_jit_for_fn(JITResolveCallbackHandler *J, uint64_t FuncIdx) {
-  return J->resolve(FuncIdx);
+template <typename TargetT>
+uint64_t executeCompileCallback(JITCompileCallbackManagerBase<TargetT> *JCBM,
+                                TargetAddress CallbackID) {
+  return JCBM->executeCompileCallback(CallbackID);
 }
+
 }
 
 namespace llvm {
 
-std::string getJITResolveCallbackIndexLabel(unsigned I) {
-  std::ostringstream LabelStream;
-  LabelStream << JITCallbackIndexLabelPrefix << I;
-  return LabelStream.str();
-}
+const char* OrcX86_64::ResolverBlockName = "orc_resolver_block";
 
-void insertX86CallbackAsm(Module &M, JITResolveCallbackHandler &J) {
+void OrcX86_64::insertResolverBlock(
+                               Module &M,
+                               JITCompileCallbackManagerBase<OrcX86_64> &JCBM) {
   uint64_t CallbackAddr =
-      static_cast<uint64_t>(reinterpret_cast<uintptr_t>(call_jit_for_fn));
+      static_cast<uint64_t>(
+        reinterpret_cast<uintptr_t>(executeCompileCallback<OrcX86_64>));
 
-  std::ostringstream JITCallbackAsm;
+  std::ostringstream AsmStream;
   Triple TT(M.getTargetTriple());
 
   if (TT.getOS() == Triple::Darwin)
-    JITCallbackAsm << ".section __TEXT,__text,regular,pure_instructions\n"
-                   << ".align 4, 0x90\n";
+    AsmStream << ".section __TEXT,__text,regular,pure_instructions\n"
+              << ".align 4, 0x90\n";
   else
-    JITCallbackAsm << ".text\n"
-                   << ".align 16, 0x90\n";
+    AsmStream << ".text\n"
+              << ".align 16, 0x90\n";
 
-  JITCallbackAsm << "jit_object_addr:\n"
-                 << "  .quad " << &J << "\n" << JITCallbackFuncName << ":\n";
+  AsmStream << "jit_callback_manager_addr:\n"
+            << "  .quad " << &JCBM << "\n"
+            << ResolverBlockName << ":\n";
 
-  uint64_t ReturnAddrOffset = saveX86Regs(JITCallbackAsm);
+  uint64_t ReturnAddrOffset = saveX86Regs(AsmStream);
 
   // Compute index, load object address, and call JIT.
-  JITCallbackAsm << "  movq    " << ReturnAddrOffset << "(%rsp), %rax\n"
-                 << "  leaq    (jit_indices_start+5)(%rip), %rbx\n"
-                 << "  subq    %rbx, %rax\n"
-                 << "  xorq    %rdx, %rdx\n"
-                 << "  movq    $5, %rbx\n"
-                 << "  divq    %rbx\n"
-                 << "  movq    %rax, %rsi\n"
-                 << "  leaq    jit_object_addr(%rip), %rdi\n"
-                 << "  movq    (%rdi), %rdi\n"
-                 << "  movabsq $" << CallbackAddr << ", %rax\n"
-                 << "  callq   *%rax\n"
-                 << "  movq    %rax, " << ReturnAddrOffset << "(%rsp)\n";
+  AsmStream << "  leaq    jit_callback_manager_addr(%rip), %rdi\n"
+            << "  movq    (%rdi), %rdi\n"
+            << "  movq    " << ReturnAddrOffset << "(%rsp), %rsi\n"
+            << "  movabsq $" << CallbackAddr << ", %rax\n"
+            << "  callq   *%rax\n"
+            << "  movq    %rax, " << ReturnAddrOffset << "(%rsp)\n";
 
-  restoreX86Regs(JITCallbackAsm);
+  restoreX86Regs(AsmStream);
 
-  JITCallbackAsm << "  retq\n"
-                 << "jit_indices_start:\n";
+  AsmStream << "  retq\n";
 
-  for (JITResolveCallbackHandler::StubIndex I = 0; I < J.getNumFuncs(); ++I)
-    JITCallbackAsm << getJITResolveCallbackIndexLabel(I) << ":\n"
-                   << "  callq " << JITCallbackFuncName << "\n";
-
-  M.appendModuleInlineAsm(JITCallbackAsm.str());
+  M.appendModuleInlineAsm(AsmStream.str());
 }
+
+OrcX86_64::LabelNameFtor
+OrcX86_64::insertCompileCallbackTrampolines(Module &M,
+                                            TargetAddress ResolverBlockAddr,
+                                            unsigned NumCalls,
+                                            unsigned StartIndex) {
+  const char *ResolverBlockPtrName = "Lorc_resolve_block_addr";
+
+  std::ostringstream AsmStream;
+  Triple TT(M.getTargetTriple());
+
+  if (TT.getOS() == Triple::Darwin)
+    AsmStream << ".section __TEXT,__text,regular,pure_instructions\n"
+              << ".align 4, 0x90\n";
+  else
+    AsmStream << ".text\n"
+              << ".align 16, 0x90\n";
+
+  AsmStream << ResolverBlockPtrName << ":\n"
+            << "  .quad " << ResolverBlockAddr << "\n";
+
+  auto GetLabelName =
+    [=](unsigned I) {
+      std::ostringstream LabelStream;
+      LabelStream << "orc_jcc_" << (StartIndex + I);
+      return LabelStream.str();
+  };
+
+  for (unsigned I = 0; I < NumCalls; ++I)
+    AsmStream << GetLabelName(I) << ":\n"
+              << "  callq *" << ResolverBlockPtrName << "(%rip)\n";
+
+  M.appendModuleInlineAsm(AsmStream.str());
+
+  return GetLabelName;
+}
+
 }