[ARM][ParallelDSP] Enable multiple uses of loads
    
When choosing whether a pair of loads can be combined into a single
wide load, we check that the load only has a sext user and that sext
also only has one user. But this can prevent the transformation in
the cases when parallel macs use the same loaded data multiple times.
    
To enable this, we need to fix up any other uses after creating the
wide load: generating a trunc and a shift + trunc pair to recreate
the narrow values. We also need to keep a record of which loads have
already been widened.

Differential Revision: https://reviews.llvm.org/D59215

llvm-svn: 356132
diff --git a/llvm/lib/Target/ARM/ARMParallelDSP.cpp b/llvm/lib/Target/ARM/ARMParallelDSP.cpp
index 9730c32..9b770dd 100644
--- a/llvm/lib/Target/ARM/ARMParallelDSP.cpp
+++ b/llvm/lib/Target/ARM/ARMParallelDSP.cpp
@@ -53,7 +53,7 @@
   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
   using ReductionList   = SmallVector<Reduction, 8>;
   using ValueList       = SmallVector<Value*, 8>;
-  using MemInstList     = SmallVector<Instruction*, 8>;
+  using MemInstList     = SmallVector<LoadInst*, 8>;
   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
   using PMACPairList    = SmallVector<PMACPair, 8>;
   using Instructions    = SmallVector<Instruction*,16>;
@@ -113,6 +113,21 @@
     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
   };
 
+  class WidenedLoad {
+    LoadInst *NewLd = nullptr;
+    SmallVector<LoadInst*, 4> Loads;
+
+  public:
+    WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
+      : NewLd(Wide) {
+      for (auto *I : Lds)
+        Loads.push_back(I);
+    }
+    LoadInst *getLoad() {
+      return NewLd;
+    }
+  };
+
   class ARMParallelDSP : public LoopPass {
     ScalarEvolution   *SE;
     AliasAnalysis     *AA;
@@ -123,13 +138,17 @@
     const DataLayout  *DL;
     Module            *M;
     std::map<LoadInst*, LoadInst*> LoadPairs;
-    std::map<LoadInst*, SmallVector<LoadInst*, 4>> SequentialLoads;
+    std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
 
-    bool RecordSequentialLoads(BasicBlock *Header);
+    bool RecordSequentialLoads(BasicBlock *BB);
     bool InsertParallelMACs(Reduction &Reduction);
     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
+    LoadInst* CreateLoadIns(IRBuilder<NoFolder> &IRB,
+                            SmallVectorImpl<LoadInst*> &Loads,
+                            IntegerType *LoadTy);
     void CreateParallelMACPairs(Reduction &R);
-    Instruction *CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
+    Instruction *CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
+                                 SmallVectorImpl<LoadInst*> &VecLd1,
                                  Instruction *Acc, bool Exchange,
                                  Instruction *InsertAfter);
 
@@ -202,7 +221,6 @@
       }
 
       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
-      bool Changes = false;
 
       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
@@ -212,7 +230,7 @@
         return false;
       }
 
-      Changes = MatchSMLAD(F);
+      bool Changes = MatchSMLAD(F);
       return Changes;
     }
   };
@@ -225,7 +243,6 @@
 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
 template<unsigned MaxBitWidth>
 static bool IsNarrowSequence(Value *V, ValueList &VL) {
-  LLVM_DEBUG(dbgs() << "Is narrow sequence? "; V->dump());
   ConstantInt *CInt;
 
   if (match(V, m_ConstantInt(CInt))) {
@@ -244,38 +261,25 @@
   } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
     // TODO: we need to implement sadd16/sadd8 for this, which enables to
     // also do the rewrite for smlad8.ll, but it is unsupported for now.
-    LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
     return false;
   } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
-    if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) {
-      LLVM_DEBUG(dbgs() << "No, wrong SrcTy size: " <<
-        cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() << "\n");
+    if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
       return false;
-    }
 
     if (match(Val, m_Load(m_Value()))) {
-      LLVM_DEBUG(dbgs() << "Yes, found narrow Load:\t"; Val->dump());
       VL.push_back(Val);
       VL.push_back(I);
       return true;
     }
   }
-  LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
   return false;
 }
 
 template<typename MemInst>
 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
                                   const DataLayout &DL, ScalarEvolution &SE) {
-  if (!MemOp0->isSimple() || !MemOp1->isSimple()) {
-    LLVM_DEBUG(dbgs() << "No, not touching volatile access\n");
-    return false;
-  }
-  if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE)) {
-    LLVM_DEBUG(dbgs() << "OK: accesses are consecutive.\n");
+  if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
     return true;
-  }
-  LLVM_DEBUG(dbgs() << "No, accesses aren't consecutive.\n");
   return false;
 }
 
@@ -284,19 +288,14 @@
   if (!Ld0 || !Ld1)
     return false;
 
-  LLVM_DEBUG(dbgs() << "Are consecutive loads:\n";
+  if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
     dbgs() << "Ld0:"; Ld0->dump();
     dbgs() << "Ld1:"; Ld1->dump();
   );
 
-  if (!Ld0->hasOneUse() || !Ld1->hasOneUse()) {
-    LLVM_DEBUG(dbgs() << "No, load has more than one use.\n");
-    return false;
-  }
-
-  if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
-    return false;
-
   VecMem.clear();
   VecMem.push_back(Ld0);
   VecMem.push_back(Ld1);
@@ -305,17 +304,16 @@
 
 /// Iterate through the block and record base, offset pairs of loads as well as
 /// maximal sequences of sequential loads.
-bool ARMParallelDSP::RecordSequentialLoads(BasicBlock *Header) {
+bool ARMParallelDSP::RecordSequentialLoads(BasicBlock *BB) {
   SmallVector<LoadInst*, 8> Loads;
-  for (auto &I : *Header) {
+  for (auto &I : *BB) {
     auto *Ld = dyn_cast<LoadInst>(&I);
-    if (!Ld)
+    if (!Ld || !Ld->isSimple() ||
+        !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
       continue;
     Loads.push_back(Ld);
   }
 
-  std::map<LoadInst*, LoadInst*> BaseLoads;
-
   for (auto *Ld0 : Loads) {
     for (auto *Ld1 : Loads) {
       if (Ld0 == Ld1)
@@ -323,17 +321,18 @@
 
       if (AreSequentialAccesses<LoadInst>(Ld0, Ld1, *DL, *SE)) {
         LoadPairs[Ld0] = Ld1;
-        if (BaseLoads.count(Ld0)) {
-          LoadInst *Base = BaseLoads[Ld0];
-          BaseLoads[Ld1] = Base;
-          SequentialLoads[Base].push_back(Ld1);
-        } else {
-          BaseLoads[Ld1] = Ld0;
-          SequentialLoads[Ld0].push_back(Ld1);
-        }
+        break;
       }
     }
   }
+
+  LLVM_DEBUG(if (!LoadPairs.empty()) {
+               dbgs() << "Consecutive load pairs:\n";
+               for (auto &MapIt : LoadPairs) {
+                 LLVM_DEBUG(dbgs() << *MapIt.first << ", "
+                            << *MapIt.second << "\n");
+               }
+             });
   return LoadPairs.size() > 1;
 }
 
@@ -362,12 +361,11 @@
       if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
         return false;
 
-      LLVM_DEBUG(dbgs() << "Looking at operands " << x << ":\n"
-                 << "\t Ld0: " << *Ld0 << "\n"
-                 << "\t Ld1: " << *Ld1 << "\n"
-                 << "and operands " << x + 2 << ":\n"
-                 << "\t Ld2: " << *Ld2 << "\n"
-                 << "\t Ld3: " << *Ld3 << "\n");
+      LLVM_DEBUG(dbgs() << "Loads:\n"
+                 << " - " << *Ld0 << "\n"
+                 << " - " << *Ld1 << "\n"
+                 << " - " << *Ld2 << "\n"
+                 << " - " << *Ld3 << "\n");
 
       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
         if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
@@ -416,11 +414,6 @@
 
       assert(PMul0 != PMul1 && "expected different chains");
 
-      LLVM_DEBUG(dbgs() << "\nCheck parallel muls:\n";
-                 dbgs() << "- "; Mul0->dump();
-                 dbgs() << "- "; Mul1->dump());
-
-      LLVM_DEBUG(dbgs() << "OK: mul operands list match:\n");
       if (CanPair(PMul0, PMul1)) {
         Paired.insert(Mul0);
         Paired.insert(Mul1);
@@ -441,9 +434,8 @@
                dbgs() << "- "; PMul0->Root->dump();
                dbgs() << "- "; PMul1->Root->dump());
 
-    auto *VecLd0 = cast<LoadInst>(PMul0->VecLd[0]);
-    auto *VecLd1 = cast<LoadInst>(PMul1->VecLd[0]);
-    Acc = CreateSMLADCall(VecLd0, VecLd1, Acc, PMul1->Exchange, InsertAfter);
+    Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
+                          InsertAfter);
     InsertAfter = Acc;
   }
 
@@ -499,14 +491,12 @@
 static void AddMACCandidate(OpChainList &Candidates,
                             Instruction *Mul,
                             Value *MulOp0, Value *MulOp1) {
-  LLVM_DEBUG(dbgs() << "OK, found acc mul:\t"; Mul->dump());
   assert(Mul->getOpcode() == Instruction::Mul &&
          "expected mul instruction");
   ValueList LHS;
   ValueList RHS;
   if (IsNarrowSequence<16>(MulOp0, LHS) &&
       IsNarrowSequence<16>(MulOp1, RHS)) {
-    LLVM_DEBUG(dbgs() << "OK, found narrow mul: "; Mul->dump());
     Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
   }
 }
@@ -514,7 +504,7 @@
 static void MatchParallelMACSequences(Reduction &R,
                                       OpChainList &Candidates) {
   Instruction *Acc = R.AccIntAdd;
-  LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc);
+  LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc << "\n");
 
   // Returns false to signal the search should be stopped.
   std::function<bool(Value*)> Match =
@@ -687,32 +677,81 @@
   return Changed;
 }
 
-static LoadInst *CreateLoadIns(IRBuilder<NoFolder> &IRB, LoadInst &BaseLoad,
-                               Type *LoadTy) {
-  const unsigned AddrSpace = BaseLoad.getPointerAddressSpace();
-
-  Value *VecPtr = IRB.CreateBitCast(BaseLoad.getPointerOperand(),
+LoadInst* ARMParallelDSP::CreateLoadIns(IRBuilder<NoFolder> &IRB,
+                                        SmallVectorImpl<LoadInst*> &Loads,
+                                        IntegerType *LoadTy) {
+  assert(Loads.size() == 2 && "currently only support widening two loads");
+ 
+  const unsigned AddrSpace = Loads[0]->getPointerAddressSpace();
+  Value *VecPtr = IRB.CreateBitCast(Loads[0]->getPointerOperand(),
                                     LoadTy->getPointerTo(AddrSpace));
-  return IRB.CreateAlignedLoad(LoadTy, VecPtr, BaseLoad.getAlignment());
+  LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
+                                             Loads[0]->getAlignment());
+  // Fix up users, Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
+  Instruction *SExt0 = dyn_cast<SExtInst>(Loads[0]->user_back());
+  Instruction *SExt1 = dyn_cast<SExtInst>(Loads[1]->user_back());
+
+  assert((Loads[0]->hasOneUse() && Loads[1]->hasOneUse() && SExt0 && SExt1) &&
+         "Loads should have a single, extending, user");
+
+  std::function<void(Instruction*, Instruction*)> MoveAfter =
+    [&](Instruction* Source, Instruction* Sink) -> void {
+    if (DT->dominates(Source, Sink) ||
+        Source->getParent() != Sink->getParent() ||
+        isa<PHINode>(Source) || isa<PHINode>(Sink))
+      return;
+
+    Sink->moveAfter(Source);
+    for (auto &U : Sink->uses())
+      MoveAfter(Sink, cast<Instruction>(U.getUser()));
+  };
+
+  // From the wide load, create two values that equal the original two loads.
+  Value *Bottom = IRB.CreateTrunc(WideLoad, Loads[0]->getType());
+  SExt0->setOperand(0, Bottom);
+  if (auto *I = dyn_cast<Instruction>(Bottom)) {
+    I->moveAfter(WideLoad);
+    MoveAfter(I, SExt0);
+  }
+
+  IntegerType *Ld1Ty = cast<IntegerType>(Loads[1]->getType());
+  Value *ShiftVal = ConstantInt::get(LoadTy, Ld1Ty->getBitWidth());
+  Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
+  if (auto *I = dyn_cast<Instruction>(Top))
+    MoveAfter(WideLoad, I);
+
+  Value *Trunc = IRB.CreateTrunc(Top, Ld1Ty);
+  SExt1->setOperand(0, Trunc);
+  if (auto *I = dyn_cast<Instruction>(Trunc))
+    MoveAfter(I, SExt1);
+
+  WideLoads.emplace(std::make_pair(Loads[0],
+                                   make_unique<WidenedLoad>(Loads, WideLoad)));
+  return WideLoad;
 }
 
-Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
+Instruction *ARMParallelDSP::CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
+                                             SmallVectorImpl<LoadInst*> &VecLd1,
                                              Instruction *Acc, bool Exchange,
                                              Instruction *InsertAfter) {
   LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
-             << "- " << *VecLd0 << "\n"
-             << "- " << *VecLd1 << "\n"
+             << "- " << *VecLd0[0] << "\n"
+             << "- " << *VecLd0[1] << "\n"
+             << "- " << *VecLd1[0] << "\n"
+             << "- " << *VecLd1[1] << "\n"
              << "- " << *Acc << "\n"
-             << "Exchange: " << Exchange << "\n");
+             << "- Exchange: " << Exchange << "\n");
 
   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
                               ++BasicBlock::iterator(InsertAfter));
 
   // Replace the reduction chain with an intrinsic call
-  Type *Ty = IntegerType::get(M->getContext(), 32);
-  LoadInst *NewLd0 = CreateLoadIns(Builder, VecLd0[0], Ty);
-  LoadInst *NewLd1 = CreateLoadIns(Builder, VecLd1[0], Ty);
-  Value* Args[] = { NewLd0, NewLd1, Acc };
+  IntegerType *Ty = IntegerType::get(M->getContext(), 32);
+  LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
+    WideLoads[VecLd0[0]]->getLoad() : CreateLoadIns(Builder, VecLd0, Ty);
+  LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
+    WideLoads[VecLd1[0]]->getLoad() : CreateLoadIns(Builder, VecLd1, Ty);
+  Value* Args[] = { WideLd0, WideLd1, Acc };
   Function *SMLAD = nullptr;
   if (Exchange)
     SMLAD = Acc->getType()->isIntegerTy(32) ?
@@ -740,7 +779,6 @@
     }
 
     const unsigned Pairs = VL0.size();
-    LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs << "\n");
 
     for (unsigned i = 0; i < Pairs; ++i) {
       const Value *V0 = VL0[i];
@@ -748,24 +786,17 @@
       const auto *Inst0 = dyn_cast<Instruction>(V0);
       const auto *Inst1 = dyn_cast<Instruction>(V1);
 
-      LLVM_DEBUG(dbgs() << "Pair " << i << ":\n";
-                dbgs() << "mul1: "; V0->dump();
-                dbgs() << "mul2: "; V1->dump());
-
       if (!Inst0 || !Inst1)
         return false;
 
-      if (Inst0->isSameOperationAs(Inst1)) {
-        LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
+      if (Inst0->isSameOperationAs(Inst1))
         continue;
-      }
 
       const APInt *C0, *C1;
       if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
         return false;
     }
 
-    LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
     return true;
   };