[NewPM] Port Msan

Summary:
Keeping msan a function pass requires replacing the module level initialization:
That means, don't define a ctor function which calls __msan_init, instead just
declare the init function at the first access, and add that to the global ctors
list.

Changes:
- Pull the actual sanitizer and the wrapper pass apart.
- Add a newpm msan pass. The function pass inserts calls to runtime
  library functions, for which it inserts declarations as necessary.
- Update tests.

Caveats:
- There is one test that I dropped, because it specifically tested the
  definition of the ctor.

Reviewers: chandlerc, fedor.sergeev, leonardchan, vitalybuka

Subscribers: sdardis, nemanjai, javed.absar, hiraditya, kbarton, bollu, atanasyan, jsji

Differential Revision: https://reviews.llvm.org/D55647

llvm-svn: 350305
diff --git a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
index eb6a373..5828019 100644
--- a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp
@@ -111,7 +111,7 @@
   initializePGOIndirectCallPromotionLegacyPassPass(Registry);
   initializePGOMemOPSizeOptLegacyPassPass(Registry);
   initializeInstrProfilingLegacyPassPass(Registry);
-  initializeMemorySanitizerPass(Registry);
+  initializeMemorySanitizerLegacyPassPass(Registry);
   initializeHWAddressSanitizerPass(Registry);
   initializeThreadSanitizerPass(Registry);
   initializeSanitizerCoverageModulePass(Registry);
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 1bac44c..493d22a 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -140,6 +140,7 @@
 ///
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Transforms/Instrumentation/MemorySanitizer.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DepthFirstIterator.h"
@@ -149,7 +150,6 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Triple.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/Transforms/Utils/Local.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
@@ -187,6 +187,7 @@
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Instrumentation.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 #include <algorithm>
 #include <cassert>
@@ -320,7 +321,6 @@
        cl::desc("Define custom MSan OriginBase"),
        cl::Hidden, cl::init(0));
 
-static const char *const kMsanModuleCtorName = "msan.module_ctor";
 static const char *const kMsanInitName = "__msan_init";
 
 namespace {
@@ -446,19 +446,16 @@
 
 namespace {
 
-/// An instrumentation pass implementing detection of uninitialized
-/// reads.
+/// Instrument functions of a module to detect uninitialized reads.
 ///
-/// MemorySanitizer: instrument the code in module to find
-/// uninitialized reads.
-class MemorySanitizer : public FunctionPass {
+/// Instantiating MemorySanitizer inserts the msan runtime library API function
+/// declarations into the module if they don't exist already. Instantiating
+/// ensures the __msan_init function is in the list of global constructors for
+/// the module.
+class MemorySanitizer {
 public:
-  // Pass identification, replacement for typeid.
-  static char ID;
-
-  MemorySanitizer(int TrackOrigins = 0, bool Recover = false,
-                  bool EnableKmsan = false)
-      : FunctionPass(ID) {
+  MemorySanitizer(Module &M, int TrackOrigins = 0, bool Recover = false,
+                  bool EnableKmsan = false) {
     this->CompileKernel =
         ClEnableKmsan.getNumOccurrences() > 0 ? ClEnableKmsan : EnableKmsan;
     if (ClTrackOrigins.getNumOccurrences() > 0)
@@ -468,15 +465,16 @@
     this->Recover = ClKeepGoing.getNumOccurrences() > 0
                         ? ClKeepGoing
                         : (this->CompileKernel | Recover);
-  }
-  StringRef getPassName() const override { return "MemorySanitizer"; }
-
-  void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addRequired<TargetLibraryInfoWrapperPass>();
+    initializeModule(M);
   }
 
-  bool runOnFunction(Function &F) override;
-  bool doInitialization(Module &M) override;
+  // MSan cannot be moved or copied because of MapParams.
+  MemorySanitizer(MemorySanitizer &&) = delete;
+  MemorySanitizer &operator=(MemorySanitizer &&) = delete;
+  MemorySanitizer(const MemorySanitizer &) = delete;
+  MemorySanitizer &operator=(const MemorySanitizer &) = delete;
+
+  bool sanitizeFunction(Function &F, TargetLibraryInfo &TLI);
 
 private:
   friend struct MemorySanitizerVisitor;
@@ -485,13 +483,13 @@
   friend struct VarArgAArch64Helper;
   friend struct VarArgPowerPC64Helper;
 
+  void initializeModule(Module &M);
   void initializeCallbacks(Module &M);
   void createKernelApi(Module &M);
   void createUserspaceApi(Module &M);
 
   /// True if we're compiling the Linux kernel.
   bool CompileKernel;
-
   /// Track origins (allocation points) of uninitialized values.
   int TrackOrigins;
   bool Recover;
@@ -588,25 +586,61 @@
 
   /// An empty volatile inline asm that prevents callback merge.
   InlineAsm *EmptyAsm;
+};
 
-  Function *MsanCtorFunction;
+/// A legacy function pass for msan instrumentation.
+///
+/// Instruments functions to detect unitialized reads.
+struct MemorySanitizerLegacyPass : public FunctionPass {
+  // Pass identification, replacement for typeid.
+  static char ID;
+
+  MemorySanitizerLegacyPass(int TrackOrigins = 0, bool Recover = false,
+                            bool EnableKmsan = false)
+      : FunctionPass(ID), TrackOrigins(TrackOrigins), Recover(Recover),
+        EnableKmsan(EnableKmsan) {}
+  StringRef getPassName() const override { return "MemorySanitizerLegacyPass"; }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<TargetLibraryInfoWrapperPass>();
+  }
+
+  bool runOnFunction(Function &F) override {
+    return MSan->sanitizeFunction(
+        F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI());
+  }
+  bool doInitialization(Module &M) override;
+
+  Optional<MemorySanitizer> MSan;
+  int TrackOrigins;
+  bool Recover;
+  bool EnableKmsan;
 };
 
 } // end anonymous namespace
 
-char MemorySanitizer::ID = 0;
+PreservedAnalyses MemorySanitizerPass::run(Function &F,
+                                           FunctionAnalysisManager &FAM) {
+  MemorySanitizer Msan(*F.getParent(), TrackOrigins, Recover, EnableKmsan);
+  if (Msan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F)))
+    return PreservedAnalyses::none();
+  return PreservedAnalyses::all();
+}
 
-INITIALIZE_PASS_BEGIN(
-    MemorySanitizer, "msan",
-    "MemorySanitizer: detects uninitialized reads.", false, false)
+char MemorySanitizerLegacyPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(MemorySanitizerLegacyPass, "msan",
+                      "MemorySanitizer: detects uninitialized reads.", false,
+                      false)
 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
-INITIALIZE_PASS_END(
-    MemorySanitizer, "msan",
-    "MemorySanitizer: detects uninitialized reads.", false, false)
+INITIALIZE_PASS_END(MemorySanitizerLegacyPass, "msan",
+                    "MemorySanitizer: detects uninitialized reads.", false,
+                    false)
 
-FunctionPass *llvm::createMemorySanitizerPass(int TrackOrigins, bool Recover,
-                                              bool CompileKernel) {
-  return new MemorySanitizer(TrackOrigins, Recover, CompileKernel);
+FunctionPass *llvm::createMemorySanitizerLegacyPassPass(int TrackOrigins,
+                                                        bool Recover,
+                                                        bool CompileKernel) {
+  return new MemorySanitizerLegacyPass(TrackOrigins, Recover, CompileKernel);
 }
 
 /// Create a non-const global initialized with the given string.
@@ -683,6 +717,14 @@
       "__msan_unpoison_alloca", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy);
 }
 
+static Constant *getOrInsertGlobal(Module &M, StringRef Name, Type *Ty) {
+  return M.getOrInsertGlobal(Name, Ty, [&] {
+    return new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage,
+                              nullptr, Name, nullptr,
+                              GlobalVariable::InitialExecTLSModel);
+  });
+}
+
 /// Insert declarations for userspace-specific functions and globals.
 void MemorySanitizer::createUserspaceApi(Module &M) {
   IRBuilder<> IRB(*C);
@@ -694,42 +736,31 @@
   WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy());
 
   // Create the global TLS variables.
-  RetvalTLS = new GlobalVariable(
-      M, ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8), false,
-      GlobalVariable::ExternalLinkage, nullptr, "__msan_retval_tls", nullptr,
-      GlobalVariable::InitialExecTLSModel);
+  RetvalTLS =
+      getOrInsertGlobal(M, "__msan_retval_tls",
+                        ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8));
 
-  RetvalOriginTLS = new GlobalVariable(
-      M, OriginTy, false, GlobalVariable::ExternalLinkage, nullptr,
-      "__msan_retval_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel);
+  RetvalOriginTLS = getOrInsertGlobal(M, "__msan_retval_origin_tls", OriginTy);
 
-  ParamTLS = new GlobalVariable(
-      M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false,
-      GlobalVariable::ExternalLinkage, nullptr, "__msan_param_tls", nullptr,
-      GlobalVariable::InitialExecTLSModel);
+  ParamTLS =
+      getOrInsertGlobal(M, "__msan_param_tls",
+                        ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8));
 
-  ParamOriginTLS = new GlobalVariable(
-      M, ArrayType::get(OriginTy, kParamTLSSize / 4), false,
-      GlobalVariable::ExternalLinkage, nullptr, "__msan_param_origin_tls",
-      nullptr, GlobalVariable::InitialExecTLSModel);
+  ParamOriginTLS =
+      getOrInsertGlobal(M, "__msan_param_origin_tls",
+                        ArrayType::get(OriginTy, kParamTLSSize / 4));
 
-  VAArgTLS = new GlobalVariable(
-      M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false,
-      GlobalVariable::ExternalLinkage, nullptr, "__msan_va_arg_tls", nullptr,
-      GlobalVariable::InitialExecTLSModel);
+  VAArgTLS =
+      getOrInsertGlobal(M, "__msan_va_arg_tls",
+                        ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8));
 
-  VAArgOriginTLS = new GlobalVariable(
-      M, ArrayType::get(OriginTy, kParamTLSSize / 4), false,
-      GlobalVariable::ExternalLinkage, nullptr, "__msan_va_arg_origin_tls",
-      nullptr, GlobalVariable::InitialExecTLSModel);
+  VAArgOriginTLS =
+      getOrInsertGlobal(M, "__msan_va_arg_origin_tls",
+                        ArrayType::get(OriginTy, kParamTLSSize / 4));
 
-  VAArgOverflowSizeTLS = new GlobalVariable(
-      M, IRB.getInt64Ty(), false, GlobalVariable::ExternalLinkage, nullptr,
-      "__msan_va_arg_overflow_size_tls", nullptr,
-      GlobalVariable::InitialExecTLSModel);
-  OriginTLS = new GlobalVariable(
-      M, IRB.getInt32Ty(), false, GlobalVariable::ExternalLinkage, nullptr,
-      "__msan_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel);
+  VAArgOverflowSizeTLS =
+      getOrInsertGlobal(M, "__msan_va_arg_overflow_size_tls", IRB.getInt64Ty());
+  OriginTLS = getOrInsertGlobal(M, "__msan_origin_tls", IRB.getInt32Ty());
 
   for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes;
        AccessSizeIndex++) {
@@ -808,9 +839,7 @@
 }
 
 /// Module-level initialization.
-///
-/// inserts a call to __msan_init to the module's constructor list.
-bool MemorySanitizer::doInitialization(Module &M) {
+void MemorySanitizer::initializeModule(Module &M) {
   auto &DL = M.getDataLayout();
 
   bool ShadowPassed = ClShadowBase.getNumOccurrences() > 0;
@@ -884,27 +913,26 @@
   OriginStoreWeights = MDBuilder(*C).createBranchWeights(1, 1000);
 
   if (!CompileKernel) {
-    std::tie(MsanCtorFunction, std::ignore) =
-        createSanitizerCtorAndInitFunctions(M, kMsanModuleCtorName,
-                                            kMsanInitName,
-                                            /*InitArgTypes=*/{},
-                                            /*InitArgs=*/{});
-    if (ClWithComdat) {
-      Comdat *MsanCtorComdat = M.getOrInsertComdat(kMsanModuleCtorName);
-      MsanCtorFunction->setComdat(MsanCtorComdat);
-      appendToGlobalCtors(M, MsanCtorFunction, 0, MsanCtorFunction);
-    } else {
-      appendToGlobalCtors(M, MsanCtorFunction, 0);
-    }
+    getOrCreateInitFunction(M, kMsanInitName);
 
     if (TrackOrigins)
-      new GlobalVariable(M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage,
-                         IRB.getInt32(TrackOrigins), "__msan_track_origins");
+      M.getOrInsertGlobal("__msan_track_origins", IRB.getInt32Ty(), [&] {
+        return new GlobalVariable(
+            M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage,
+            IRB.getInt32(TrackOrigins), "__msan_track_origins");
+      });
 
     if (Recover)
-      new GlobalVariable(M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage,
-                         IRB.getInt32(Recover), "__msan_keep_going");
-  }
+      M.getOrInsertGlobal("__msan_keep_going", IRB.getInt32Ty(), [&] {
+        return new GlobalVariable(M, IRB.getInt32Ty(), true,
+                                  GlobalValue::WeakODRLinkage,
+                                  IRB.getInt32(Recover), "__msan_keep_going");
+      });
+}
+}
+
+bool MemorySanitizerLegacyPass::doInitialization(Module &M) {
+  MSan.emplace(M, TrackOrigins, Recover, EnableKmsan);
   return true;
 }
 
@@ -985,8 +1013,9 @@
   SmallVector<ShadowOriginAndInsertPoint, 16> InstrumentationList;
   SmallVector<StoreInst *, 16> StoreList;
 
-  MemorySanitizerVisitor(Function &F, MemorySanitizer &MS)
-      : F(F), MS(MS), VAHelper(CreateVarArgHelper(F, MS, *this)) {
+  MemorySanitizerVisitor(Function &F, MemorySanitizer &MS,
+                         const TargetLibraryInfo &TLI)
+      : F(F), MS(MS), VAHelper(CreateVarArgHelper(F, MS, *this)), TLI(&TLI) {
     bool SanitizeFunction = F.hasFnAttribute(Attribute::SanitizeMemory);
     InsertChecks = SanitizeFunction;
     PropagateShadow = SanitizeFunction;
@@ -995,7 +1024,6 @@
     // FIXME: Consider using SpecialCaseList to specify a list of functions that
     // must always return fully initialized values. For now, we hardcode "main".
     CheckReturnValue = SanitizeFunction && (F.getName() == "main");
-    TLI = &MS.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
 
     MS.initializeCallbacks(*F.getParent());
     if (MS.CompileKernel)
@@ -4430,10 +4458,8 @@
     return new VarArgNoOpHelper(Func, Msan, Visitor);
 }
 
-bool MemorySanitizer::runOnFunction(Function &F) {
-  if (!CompileKernel && (&F == MsanCtorFunction))
-    return false;
-  MemorySanitizerVisitor Visitor(F, *this);
+bool MemorySanitizer::sanitizeFunction(Function &F, TargetLibraryInfo &TLI) {
+  MemorySanitizerVisitor Visitor(F, *this, TLI);
 
   // Clear out readonly/readnone attributes.
   AttrBuilder B;
diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index ba4b7f3..8040cc7 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -174,6 +174,27 @@
   return std::make_pair(Ctor, InitFunction);
 }
 
+Function *llvm::getOrCreateInitFunction(Module &M, StringRef Name) {
+  assert(!Name.empty() && "Expected init function name");
+  if (Function *F = M.getFunction(Name)) {
+    if (F->arg_size() != 0 ||
+        F->getReturnType() != Type::getVoidTy(M.getContext())) {
+      std::string Err;
+      raw_string_ostream Stream(Err);
+      Stream << "Sanitizer interface function defined with wrong type: " << *F;
+      report_fatal_error(Err);
+    }
+    return F;
+  }
+  Function *F = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+      Name, AttributeList(), Type::getVoidTy(M.getContext())));
+  F->setLinkage(Function::ExternalLinkage);
+
+  appendToGlobalCtors(M, F, 0);
+
+  return F;
+}
+
 void llvm::filterDeadComdatFunctions(
     Module &M, SmallVectorImpl<Function *> &DeadComdatFunctions) {
   // Build a map from the comdat to the number of entries in that comdat we