[PM] Port Scalarizer to the new pass manager.
Patch by: markus (Markus Lavin)
Reviewers: chandlerc, fedor.sergeev
Reviewed By: fedor.sergeev
Subscribers: llvm-commits, Ka-Ka, bjope
Differential Revision: https://reviews.llvm.org/D54695
llvm-svn: 347392
diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp
index 1b140ac..89169c4 100644
--- a/llvm/lib/Transforms/Scalar/Scalar.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalar.cpp
@@ -25,6 +25,7 @@
 #include "llvm/IR/Verifier.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Transforms/Scalar/GVN.h"
+#include "llvm/Transforms/Scalar/Scalarizer.h"
 #include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h"
 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
 
@@ -43,7 +44,7 @@
   initializeDCELegacyPassPass(Registry);
   initializeDeadInstEliminationPass(Registry);
   initializeDivRemPairsLegacyPassPass(Registry);
-  initializeScalarizerPass(Registry);
+  initializeScalarizerLegacyPassPass(Registry);
   initializeDSELegacyPassPass(Registry);
   initializeGuardWideningLegacyPassPass(Registry);
   initializeLoopGuardWideningLegacyPassPass(Registry);
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 2f873ab..5816a52 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -39,6 +39,7 @@
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/Options.h"
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Scalar/Scalarizer.h"
 #include <cassert>
 #include <cstdint>
 #include <iterator>
@@ -49,6 +50,13 @@
 
 #define DEBUG_TYPE "scalarizer"
 
+// This is disabled by default because having separate loads and stores
+// makes it more likely that the -combiner-alias-analysis limits will be
+// reached.
+static cl::opt<bool>
+    ScalarizeLoadStore("scalarize-load-store", cl::init(false), cl::Hidden,
+                       cl::desc("Allow the scalarizer pass to scalarize loads and store"));
+
 namespace {
 
 // Used to store the scattered form of a vector.
@@ -152,17 +160,13 @@
   uint64_t ElemSize = 0;
 };
 
-class Scalarizer : public FunctionPass,
-                   public InstVisitor<Scalarizer, bool> {
+class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
 public:
-  static char ID;
-
-  Scalarizer() : FunctionPass(ID) {
-    initializeScalarizerPass(*PassRegistry::getPassRegistry());
+  ScalarizerVisitor(unsigned ParallelLoopAccessMDKind)
+    : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind) {
   }
 
-  bool doInitialization(Module &M) override;
-  bool runOnFunction(Function &F) override;
+  bool visit(Function &F);
 
   // InstVisitor methods.  They return true if the instruction was scalarized,
   // false if nothing changed.
@@ -180,16 +184,6 @@
   bool visitStoreInst(StoreInst &SI);
   bool visitCallInst(CallInst &ICI);
 
-  static void registerOptions() {
-    // This is disabled by default because having separate loads and stores
-    // makes it more likely that the -combiner-alias-analysis limits will be
-    // reached.
-    OptionRegistry::registerOption<bool, Scalarizer,
-                                 &Scalarizer::ScalarizeLoadStore>(
-        "scalarize-load-store",
-        "Allow the scalarizer pass to scalarize loads and store", false);
-  }
-
 private:
   Scatterer scatter(Instruction *Point, Value *V);
   void gather(Instruction *Op, const ValueVector &CV);
@@ -205,16 +199,28 @@
 
   ScatterMap Scattered;
   GatherList Gathered;
+
   unsigned ParallelLoopAccessMDKind;
-  bool ScalarizeLoadStore;
+};
+
+class ScalarizerLegacyPass : public FunctionPass {
+public:
+  static char ID;
+
+  ScalarizerLegacyPass() : FunctionPass(ID) {
+    initializeScalarizerLegacyPassPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override;
 };
 
 } // end anonymous namespace
 
-char Scalarizer::ID = 0;
-
-INITIALIZE_PASS_WITH_OPTIONS(Scalarizer, "scalarizer",
-                             "Scalarize vector operations", false, false)
+char ScalarizerLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
+                      "Scalarize vector operations", false, false)
+INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
+                    "Scalarize vector operations", false, false)
 
 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
                      ValueVector *cachePtr)
@@ -278,17 +284,22 @@
   return CV[I];
 }
 
-bool Scalarizer::doInitialization(Module &M) {
-  ParallelLoopAccessMDKind =
-      M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
-  ScalarizeLoadStore =
-      M.getContext().getOption<bool, Scalarizer, &Scalarizer::ScalarizeLoadStore>();
-  return false;
-}
-
-bool Scalarizer::runOnFunction(Function &F) {
+bool ScalarizerLegacyPass::runOnFunction(Function &F) {
   if (skipFunction(F))
     return false;
+
+  Module &M = *F.getParent();
+  unsigned ParallelLoopAccessMDKind =
+      M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
+  ScalarizerVisitor Impl(ParallelLoopAccessMDKind);
+  return Impl.visit(F);
+}
+
+FunctionPass *llvm::createScalarizerPass() {
+  return new ScalarizerLegacyPass();
+}
+
+bool ScalarizerVisitor::visit(Function &F) {
   assert(Gathered.empty() && Scattered.empty());
 
   // To ensure we replace gathered components correctly we need to do an ordered
@@ -297,7 +308,7 @@
   for (BasicBlock *BB : RPOT) {
     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
       Instruction *I = &*II;
-      bool Done = visit(I);
+      bool Done = InstVisitor::visit(I);
       ++II;
       if (Done && I->getType()->isVoidTy())
         I->eraseFromParent();
@@ -308,7 +319,7 @@
 
 // Return a scattered form of V that can be accessed by Point.  V must be a
 // vector or a pointer to a vector.
-Scatterer Scalarizer::scatter(Instruction *Point, Value *V) {
+Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) {
   if (Argument *VArg = dyn_cast<Argument>(V)) {
     // Put the scattered form of arguments in the entry block,
     // so that it can be used everywhere.
@@ -332,7 +343,7 @@
 // deletion of Op and creation of the gathered form to the end of the pass,
 // so that we can avoid creating the gathered form if all uses of Op are
 // replaced with uses of CV.
-void Scalarizer::gather(Instruction *Op, const ValueVector &CV) {
+void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) {
   // Since we're not deleting Op yet, stub out its operands, so that it
   // doesn't make anything live unnecessarily.
   for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I)
@@ -361,7 +372,7 @@
 
 // Return true if it is safe to transfer the given metadata tag from
 // vector to scalar instructions.
-bool Scalarizer::canTransferMetadata(unsigned Tag) {
+bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
   return (Tag == LLVMContext::MD_tbaa
           || Tag == LLVMContext::MD_fpmath
           || Tag == LLVMContext::MD_tbaa_struct
@@ -373,7 +384,7 @@
 
 // Transfer metadata from Op to the instructions in CV if it is known
 // to be safe to do so.
-void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) {
+void ScalarizerVisitor::transferMetadata(Instruction *Op, const ValueVector &CV) {
   SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
   Op->getAllMetadataOtherThanDebugLoc(MDs);
   for (unsigned I = 0, E = CV.size(); I != E; ++I) {
@@ -389,7 +400,7 @@
 
 // Try to fill in Layout from Ty, returning true on success.  Alignment is
 // the alignment of the vector, or 0 if the ABI default should be used.
-bool Scalarizer::getVectorLayout(Type *Ty, unsigned Alignment,
+bool ScalarizerVisitor::getVectorLayout(Type *Ty, unsigned Alignment,
                                  VectorLayout &Layout, const DataLayout &DL) {
   // Make sure we're dealing with a vector.
   Layout.VecTy = dyn_cast<VectorType>(Ty);
@@ -413,7 +424,7 @@
 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
 // to create an instruction like I with operands X and Y and name Name.
 template<typename Splitter>
-bool Scalarizer::splitBinary(Instruction &I, const Splitter &Split) {
+bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
   VectorType *VT = dyn_cast<VectorType>(I.getType());
   if (!VT)
     return false;
@@ -446,7 +457,7 @@
 
 /// If a call to a vector typed intrinsic function, split into a scalar call per
 /// element if possible for the intrinsic.
-bool Scalarizer::splitCall(CallInst &CI) {
+bool ScalarizerVisitor::splitCall(CallInst &CI) {
   VectorType *VT = dyn_cast<VectorType>(CI.getType());
   if (!VT)
     return false;
@@ -504,7 +515,7 @@
   return true;
 }
 
-bool Scalarizer::visitSelectInst(SelectInst &SI) {
+bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
   VectorType *VT = dyn_cast<VectorType>(SI.getType());
   if (!VT)
     return false;
@@ -534,19 +545,19 @@
   return true;
 }
 
-bool Scalarizer::visitICmpInst(ICmpInst &ICI) {
+bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) {
   return splitBinary(ICI, ICmpSplitter(ICI));
 }
 
-bool Scalarizer::visitFCmpInst(FCmpInst &FCI) {
+bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) {
   return splitBinary(FCI, FCmpSplitter(FCI));
 }
 
-bool Scalarizer::visitBinaryOperator(BinaryOperator &BO) {
+bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
   return splitBinary(BO, BinarySplitter(BO));
 }
 
-bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
   VectorType *VT = dyn_cast<VectorType>(GEPI.getType());
   if (!VT)
     return false;
@@ -592,7 +603,7 @@
   return true;
 }
 
-bool Scalarizer::visitCastInst(CastInst &CI) {
+bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
   VectorType *VT = dyn_cast<VectorType>(CI.getDestTy());
   if (!VT)
     return false;
@@ -610,7 +621,7 @@
   return true;
 }
 
-bool Scalarizer::visitBitCastInst(BitCastInst &BCI) {
+bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
   VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy());
   VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy());
   if (!DstVT || !SrcVT)
@@ -665,7 +676,7 @@
   return true;
 }
 
-bool Scalarizer::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
+bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
   VectorType *VT = dyn_cast<VectorType>(SVI.getType());
   if (!VT)
     return false;
@@ -689,7 +700,7 @@
   return true;
 }
 
-bool Scalarizer::visitPHINode(PHINode &PHI) {
+bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
   VectorType *VT = dyn_cast<VectorType>(PHI.getType());
   if (!VT)
     return false;
@@ -714,7 +725,7 @@
   return true;
 }
 
-bool Scalarizer::visitLoadInst(LoadInst &LI) {
+bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
   if (!ScalarizeLoadStore)
     return false;
   if (!LI.isSimple())
@@ -738,7 +749,7 @@
   return true;
 }
 
-bool Scalarizer::visitStoreInst(StoreInst &SI) {
+bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
   if (!ScalarizeLoadStore)
     return false;
   if (!SI.isSimple())
@@ -765,13 +776,13 @@
   return true;
 }
 
-bool Scalarizer::visitCallInst(CallInst &CI) {
+bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
   return splitCall(CI);
 }
 
 // Delete the instructions that we scalarized.  If a full vector result
 // is still needed, recreate it using InsertElements.
-bool Scalarizer::finish() {
+bool ScalarizerVisitor::finish() {
   // The presence of data in Gathered or Scattered indicates changes
   // made to the Function.
   if (Gathered.empty() && Scattered.empty())
@@ -802,6 +813,11 @@
   return true;
 }
 
-FunctionPass *llvm::createScalarizerPass() {
-  return new Scalarizer();
+PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
+  Module &M = *F.getParent();
+  unsigned ParallelLoopAccessMDKind =
+      M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
+  ScalarizerVisitor Impl(ParallelLoopAccessMDKind);
+  (void)Impl.visit(F);
+  return PreservedAnalyses::none();
 }