[ARM][ParallelDSP] Change smlad insertion order

Instead of inserting everything after the 'root' of the reduction,
insert all instructions as close to their operands as possible. This
can help reduce register pressure.

Differential Revision: https://reviews.llvm.org/D67392

llvm-svn: 374981
diff --git a/llvm/lib/Target/ARM/ARMParallelDSP.cpp b/llvm/lib/Target/ARM/ARMParallelDSP.cpp
index 8bda733..ae5657a 100644
--- a/llvm/lib/Target/ARM/ARMParallelDSP.cpp
+++ b/llvm/lib/Target/ARM/ARMParallelDSP.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
+#include "llvm/Analysis/OrderedBasicBlock.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/NoFolder.h"
 #include "llvm/Transforms/Scalar.h"
@@ -42,6 +43,10 @@
 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
                    cl::desc("Disable the ARM Parallel DSP pass"));
 
+static cl::opt<unsigned>
+NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16),
+             cl::desc("Limit the number of loads analysed"));
+
 namespace {
   struct MulCandidate;
   class Reduction;
@@ -346,6 +351,7 @@
   SmallVector<Instruction*, 8> Writes;
   LoadPairs.clear();
   WideLoads.clear();
+  OrderedBasicBlock OrderedBB(BB);
 
   // Collect loads and instruction that may write to memory. For now we only
   // record loads which are simple, sign-extended and have a single user.
@@ -360,21 +366,24 @@
     Loads.push_back(Ld);
   }
 
+  if (Loads.empty() || Loads.size() > NumLoadLimit)
+    return false;
+
   using InstSet = std::set<Instruction*>;
   using DepMap = std::map<Instruction*, InstSet>;
   DepMap RAWDeps;
 
   // Record any writes that may alias a load.
   const auto Size = LocationSize::unknown();
-  for (auto Read : Loads) {
-    for (auto Write : Writes) {
+  for (auto Write : Writes) {
+    for (auto Read : Loads) {
       MemoryLocation ReadLoc =
         MemoryLocation(Read->getPointerOperand(), Size);
 
       if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
           ModRefInfo::ModRef)))
         continue;
-      if (DT->dominates(Write, Read))
+      if (OrderedBB.dominates(Write, Read))
         RAWDeps[Read].insert(Write);
     }
   }
@@ -382,8 +391,8 @@
   // Check whether there's not a write between the two loads which would
   // prevent them from being safely merged.
   auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
-    LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
-    LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
+    LoadInst *Dominator = OrderedBB.dominates(Base, Offset) ? Base : Offset;
+    LoadInst *Dominated = OrderedBB.dominates(Base, Offset) ? Offset : Base;
 
     if (RAWDeps.count(Dominated)) {
       InstSet &WritesBefore = RAWDeps[Dominated];
@@ -391,7 +400,7 @@
       for (auto Before : WritesBefore) {
         // We can't move the second load backward, past a write, to merge
         // with the first load.
-        if (DT->dominates(Dominator, Before))
+        if (OrderedBB.dominates(Dominator, Before))
           return false;
       }
     }
@@ -401,7 +410,7 @@
   // Record base, offset load pairs.
   for (auto *Base : Loads) {
     for (auto *Offset : Loads) {
-      if (Base == Offset)
+      if (Base == Offset || OffsetLoads.count(Offset))
         continue;
 
       if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
@@ -613,7 +622,6 @@
   return !R.getMulPairs().empty();
 }
 
-
 void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
 
   auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
@@ -633,30 +641,45 @@
         Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
 
     IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
-                                ++BasicBlock::iterator(InsertAfter));
+                                BasicBlock::iterator(InsertAfter));
     Instruction *Call = Builder.CreateCall(SMLAD, Args);
     NumSMLAD++;
     return Call;
   };
 
-  Instruction *InsertAfter = R.getRoot();
+  // Return the instruction after the dominated instruction.
+  auto GetInsertPoint = [this](Value *A, Value *B) {
+    assert((isa<Instruction>(A) || isa<Instruction>(B)) &&
+           "expected at least one instruction");
+
+    Value *V = nullptr;
+    if (!isa<Instruction>(A))
+      V = B;
+    else if (!isa<Instruction>(B))
+      V = A;
+    else
+      V = DT->dominates(cast<Instruction>(A), cast<Instruction>(B)) ? B : A;
+
+    return &*++BasicBlock::iterator(cast<Instruction>(V));
+  };
+
   Value *Acc = R.getAccumulator();
 
   // For any muls that were discovered but not paired, accumulate their values
   // as before.
-  IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
-                              ++BasicBlock::iterator(InsertAfter));
+  IRBuilder<NoFolder> Builder(R.getRoot()->getParent());
   MulCandList &MulCands = R.getMuls();
   for (auto &MulCand : MulCands) {
     if (MulCand->Paired)
       continue;
 
-    Value *Mul = MulCand->Root;
+    Instruction *Mul = cast<Instruction>(MulCand->Root);
     LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n");
 
     if (R.getType() != Mul->getType()) {
       assert(R.is64Bit() && "expected 64-bit result");
-      Mul = Builder.CreateSExt(Mul, R.getType());
+      Builder.SetInsertPoint(&*++BasicBlock::iterator(Mul));
+      Mul = cast<Instruction>(Builder.CreateSExt(Mul, R.getRoot()->getType()));
     }
 
     if (!Acc) {
@@ -664,8 +687,11 @@
       continue;
     }
 
+    // If Acc is the original incoming value to the reduction, it could be a
+    // phi. But the phi will dominate Mul, meaning that Mul will be the
+    // insertion point.
+    Builder.SetInsertPoint(GetInsertPoint(Mul, Acc));
     Acc = Builder.CreateAdd(Mul, Acc);
-    InsertAfter = cast<Instruction>(Acc);
   }
 
   if (!Acc) {
@@ -677,6 +703,14 @@
     Acc = Builder.CreateSExt(Acc, R.getType());
   }
 
+  // Roughly sort the mul pairs in their program order.
+  OrderedBasicBlock OrderedBB(R.getRoot()->getParent());
+  llvm::sort(R.getMulPairs(), [&OrderedBB](auto &PairA, auto &PairB) {
+               const Instruction *A = PairA.first->Root;
+               const Instruction *B = PairB.first->Root;
+               return OrderedBB.dominates(A, B);
+             });
+
   IntegerType *Ty = IntegerType::get(M->getContext(), 32);
   for (auto &Pair : R.getMulPairs()) {
     MulCandidate *LHSMul = Pair.first;
@@ -688,8 +722,9 @@
     LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
       WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
 
+    Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
+    InsertAfter = GetInsertPoint(InsertAfter, Acc);
     Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
-    InsertAfter = cast<Instruction>(Acc);
   }
   R.UpdateRoot(cast<Instruction>(Acc));
 }