Optimizing MCJIT module state tracking

Patch co-developed with Yaron Keren.




git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@193291 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/ExecutionEngine/MCJIT/MCJIT.h b/lib/ExecutionEngine/MCJIT/MCJIT.h
index 0d15b9a..8583a19 100644
--- a/lib/ExecutionEngine/MCJIT/MCJIT.h
+++ b/lib/ExecutionEngine/MCJIT/MCJIT.h
@@ -11,15 +11,15 @@
 #define LLVM_LIB_EXECUTIONENGINE_MCJIT_H
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ExecutionEngine/ExecutionEngine.h"
 #include "llvm/ExecutionEngine/ObjectCache.h"
 #include "llvm/ExecutionEngine/ObjectImage.h"
 #include "llvm/ExecutionEngine/RuntimeDyld.h"
-#include "llvm/PassManager.h"
+#include "llvm/IR/Module.h"
 
 namespace llvm {
-
 class MCJIT;
 
 // This is a helper class that the MCJIT execution engine uses for linking
@@ -70,7 +70,7 @@
   OwningPtr<RTDyldMemoryManager> ClientMM;
 };
 
-// About Module states:
+// About Module states: added->loaded->finalized.
 //
 // The purpose of the "added" state is having modules in standby. (added=known
 // but not compiled). The idea is that you can add a module to provide function
@@ -94,27 +94,108 @@
   MCJIT(Module *M, TargetMachine *tm, RTDyldMemoryManager *MemMgr,
         bool AllocateGVsWithCode);
 
-  enum ModuleState {
-    ModuleAdded,
-    ModuleEmitted,
-    ModuleLoading,
-    ModuleLoaded,
-    ModuleFinalizing,
-    ModuleFinalized
-  };
+  typedef llvm::SmallPtrSet<Module *, 4> ModulePtrSet;
 
-  class MCJITModuleState {
+  class OwningModuleContainer {
   public:
-    MCJITModuleState() : State(ModuleAdded) {}
+    OwningModuleContainer() {
+    }
+    ~OwningModuleContainer() {
+      freeModulePtrSet(AddedModules);
+      freeModulePtrSet(LoadedModules);
+      freeModulePtrSet(FinalizedModules);
+    }
 
-    MCJITModuleState & operator=(ModuleState s) { State = s; return *this; }
-    bool hasBeenEmitted() { return State != ModuleAdded; }
-    bool hasBeenLoaded() { return State != ModuleAdded &&
-                                  State != ModuleEmitted; }
-    bool hasBeenFinalized() { return State == ModuleFinalized; }
+    ModulePtrSet::iterator begin_added() { return AddedModules.begin(); }
+    ModulePtrSet::iterator end_added() { return AddedModules.end(); }
+
+    ModulePtrSet::iterator begin_loaded() { return LoadedModules.begin(); }
+    ModulePtrSet::iterator end_loaded() { return LoadedModules.end(); }
+
+    ModulePtrSet::iterator begin_finalized() { return FinalizedModules.begin(); }
+    ModulePtrSet::iterator end_finalized() { return FinalizedModules.end(); }
+
+    void addModule(Module *M) {
+      AddedModules.insert(M);
+    }
+
+    bool removeModule(Module *M) {
+      return AddedModules.erase(M) || LoadedModules.erase(M) ||
+             FinalizedModules.erase(M);
+    }
+
+    bool hasModuleBeenAddedButNotLoaded(Module *M) {
+      return AddedModules.count(M) != 0;
+    }
+
+    bool hasModuleBeenLoaded(Module *M) {
+      // If the module is in either the "loaded" or "finalized" sections it
+      // has been loaded.
+      return (LoadedModules.count(M) != 0 ) || (FinalizedModules.count(M) != 0);
+    }
+
+    bool hasModuleBeenFinalized(Module *M) {
+      return FinalizedModules.count(M) != 0;
+    }
+
+    bool ownsModule(Module* M) {
+      return (AddedModules.count(M) != 0) || (LoadedModules.count(M) != 0) ||
+             (FinalizedModules.count(M) != 0);
+    }
+
+    void markModuleAsLoaded(Module *M) {
+      // This checks against logic errors in the MCJIT implementation.
+      // This function should never be called with either a Module that MCJIT
+      // does not own or a Module that has already been loaded and/or finalized.
+      assert(AddedModules.count(M) &&
+             "markModuleAsLoaded: Module not found in AddedModules");
+
+      // Remove the module from the "Added" set.
+      AddedModules.erase(M);
+
+      // Add the Module to the "Loaded" set.
+      LoadedModules.insert(M);
+    }
+
+    void markModuleAsFinalized(Module *M) {
+      // This checks against logic errors in the MCJIT implementation.
+      // This function should never be called with either a Module that MCJIT
+      // does not own, a Module that has not been loaded or a Module that has
+      // already been finalized.
+      assert(LoadedModules.count(M) &&
+             "markModuleAsFinalized: Module not found in LoadedModules");
+
+      // Remove the module from the "Loaded" section of the list.
+      LoadedModules.erase(M);
+
+      // Add the Module to the "Finalized" section of the list by inserting it
+      // before the 'end' iterator.
+      FinalizedModules.insert(M);
+    }
+
+    void markAllLoadedModulesAsFinalized() {
+      for (ModulePtrSet::iterator I = LoadedModules.begin(),
+                                  E = LoadedModules.end();
+           I != E; ++I) {
+        Module *M = *I;
+        FinalizedModules.insert(M);
+      }
+      LoadedModules.clear();
+    }
 
   private:
-    ModuleState State;
+    ModulePtrSet AddedModules;
+    ModulePtrSet LoadedModules;
+    ModulePtrSet FinalizedModules;
+
+    void freeModulePtrSet(ModulePtrSet& MPS) {
+      // Go through the module set and delete everything.
+      for (ModulePtrSet::iterator I = MPS.begin(), E = MPS.end(); I != E; ++I) {
+        Module *M = *I;
+        delete M;
+      }
+      MPS.clear();
+    }
   };
 
   TargetMachine *TM;
@@ -123,8 +204,7 @@
   RuntimeDyld Dyld;
   SmallVector<JITEventListener*, 2> EventListeners;
 
-  typedef DenseMap<Module *, MCJITModuleState> ModuleStateMap;
-  ModuleStateMap  ModuleStates;
+  OwningModuleContainer OwnedModules;
 
   typedef DenseMap<Module *, ObjectImage *> LoadedObjectMap;
   LoadedObjectMap  LoadedObjects;
@@ -133,12 +213,26 @@
   // perform lookup of pre-compiled code to avoid re-compilation.
   ObjectCache *ObjCache;
 
+  Function *FindFunctionNamedInModulePtrSet(const char *FnName,
+                                            ModulePtrSet::iterator I,
+                                            ModulePtrSet::iterator E);
+
+  void runStaticConstructorsDestructorsInModulePtrSet(bool isDtors,
+                                                      ModulePtrSet::iterator I,
+                                                      ModulePtrSet::iterator E);
+
 public:
   ~MCJIT();
 
   /// @name ExecutionEngine interface implementation
   /// @{
   virtual void addModule(Module *M);
+  virtual bool removeModule(Module *M);
+
+  /// FindFunctionNamed - Search all of the active modules to find the one that
+  /// defines FnName.  This is very slow operation and shouldn't be used for
+  /// general code.
+  virtual Function *FindFunctionNamed(const char *FnName);
 
   /// Sets the object manager that MCJIT should use to avoid compilation.
   virtual void setObjectCache(ObjectCache *manager);
@@ -158,6 +252,12 @@
   virtual void finalizeModule(Module *);
   void finalizeLoadedModules();
 
+  /// runStaticConstructorsDestructors - This method is used to execute all of
+  /// the static constructors or destructors for a program.
+  ///
+  /// \param isDtors - Run the destructors instead of constructors.
+  void runStaticConstructorsDestructors(bool isDtors);
+
   virtual void *getPointerToBasicBlock(BasicBlock *BB);
 
   virtual void *getPointerToFunction(Function *F);