[ARM][ParallelDSP] Relax alias checks

When deciding the safety of generating smlad, we checked for any
writes within the block that may alias with any of the loads that
need to be widened. This is overly conservative because it only
matters when there's a potential aliasing write to a location
accessed by a pair of loads.

Now we check for aliasing writes only once, during setup. If two
loads are found to have an aliasing write between them, we don't add
these loads to LoadPairs. This means that later during the transform,
we can safely widened a pair without worrying about aliasing.

However, to maintain correctness, we also need to change the way that
wide loads are inserted because the order is now important.

The MatchSMLAD method has also been changed, absorbing
MatchReductions and AddMACCandidate to hopefully improve readability.

Differential Revision: https://reviews.llvm.org/D6102

llvm-svn: 360567
diff --git a/llvm/lib/Target/ARM/ARMParallelDSP.cpp b/llvm/lib/Target/ARM/ARMParallelDSP.cpp
index 9017537..beb44fb 100644
--- a/llvm/lib/Target/ARM/ARMParallelDSP.cpp
+++ b/llvm/lib/Target/ARM/ARMParallelDSP.cpp
@@ -63,21 +63,16 @@
     Instruction   *Root;
     ValueList     AllValues;
     MemInstList   VecLd;    // List of all load instructions.
-    MemLocList    MemLocs;  // All memory locations read by this tree.
+    MemInstList   Loads;
     bool          ReadOnly = true;
 
     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
     virtual ~OpChain() = default;
 
-    void SetMemoryLocations() {
-      const auto Size = LocationSize::unknown();
+    void PopulateLoads() {
       for (auto *V : AllValues) {
-        if (auto *I = dyn_cast<Instruction>(V)) {
-          if (I->mayWriteToMemory())
-            ReadOnly = false;
-          if (auto *Ld = dyn_cast<LoadInst>(V))
-            MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size));
-        }
+        if (auto *Ld = dyn_cast<LoadInst>(V))
+          Loads.push_back(Ld);
       }
     }
 
@@ -140,12 +135,11 @@
     std::map<LoadInst*, LoadInst*> LoadPairs;
     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
 
-    bool RecordSequentialLoads(BasicBlock *BB);
+    bool RecordMemoryOps(BasicBlock *BB);
     bool InsertParallelMACs(Reduction &Reduction);
     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
-    LoadInst* CreateLoadIns(IRBuilder<NoFolder> &IRB,
-                            SmallVectorImpl<LoadInst*> &Loads,
-                            IntegerType *LoadTy);
+    LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
+                             IntegerType *LoadTy);
     void CreateParallelMACPairs(Reduction &R);
     Instruction *CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
                                  SmallVectorImpl<LoadInst*> &VecLd1,
@@ -164,6 +158,12 @@
 
     ARMParallelDSP() : LoopPass(ID) { }
 
+    bool doInitialization(Loop *L, LPPassManager &LPM) override {
+      LoadPairs.clear();
+      WideLoads.clear();
+      return true;
+    }
+
     void getAnalysisUsage(AnalysisUsage &AU) const override {
       LoopPass::getAnalysisUsage(AU);
       AU.addRequired<AssumptionCacheTracker>();
@@ -228,7 +228,7 @@
 
       if (!ST->isLittle()) {
         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
-                             "ARMParallelDSP\n");
+                          << "ARMParallelDSP\n");
         return false;
       }
 
@@ -237,7 +237,7 @@
       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
 
-      if (!RecordSequentialLoads(Header)) {
+      if (!RecordMemoryOps(Header)) {
         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
         return false;
       }
@@ -314,11 +314,18 @@
   return true;
 }
 
-/// Iterate through the block and record base, offset pairs of loads as well as
-/// maximal sequences of sequential loads.
-bool ARMParallelDSP::RecordSequentialLoads(BasicBlock *BB) {
+/// Iterate through the block and record base, offset pairs of loads which can
+/// be widened into a single load.
+bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
   SmallVector<LoadInst*, 8> Loads;
+  SmallVector<Instruction*, 8> Writes;
+
+  // Collect loads and instruction that may write to memory. For now we only
+  // record loads which are simple, sign-extended and have a single user.
+  // TODO: Allow zero-extended loads.
   for (auto &I : *BB) {
+    if (I.mayWriteToMemory())
+      Writes.push_back(&I);
     auto *Ld = dyn_cast<LoadInst>(&I);
     if (!Ld || !Ld->isSimple() ||
         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
@@ -326,13 +333,54 @@
     Loads.push_back(Ld);
   }
 
-  for (auto *Ld0 : Loads) {
-    for (auto *Ld1 : Loads) {
-      if (Ld0 == Ld1)
+  using InstSet = std::set<Instruction*>;
+  using DepMap = std::map<Instruction*, InstSet>;
+  DepMap RAWDeps;
+
+  // Record any writes that may alias a load.
+  const auto Size = LocationSize::unknown();
+  for (auto Read : Loads) {
+    for (auto Write : Writes) {
+      MemoryLocation ReadLoc =
+        MemoryLocation(Read->getPointerOperand(), Size);
+
+      if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
+          ModRefInfo::ModRef)))
+        continue;
+      if (DT->dominates(Write, Read))
+        RAWDeps[Read].insert(Write);
+    }
+  }
+
+  // Check whether there's not a write between the two loads which would
+  // prevent them from being safely merged.
+  auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
+    LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
+    LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
+
+    if (RAWDeps.count(Dominated)) {
+      InstSet &WritesBefore = RAWDeps[Dominated];
+
+      for (auto Before : WritesBefore) {
+
+        // We can't move the second load backward, past a write, to merge
+        // with the first load.
+        if (DT->dominates(Dominator, Before))
+          return false;
+      }
+    }
+    return true;
+  };
+
+  // Record base, offset load pairs.
+  for (auto *Base : Loads) {
+    for (auto *Offset : Loads) {
+      if (Base == Offset)
         continue;
 
-      if (AreSequentialAccesses<LoadInst>(Ld0, Ld1, *DL, *SE)) {
-        LoadPairs[Ld0] = Ld1;
+      if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
+          SafeToPair(Base, Offset)) {
+        LoadPairs[Base] = Offset;
         break;
       }
     }
@@ -442,9 +490,9 @@
   for (auto &Pair : Reduction.PMACPairs) {
     BinOpChain *PMul0 = Pair.first;
     BinOpChain *PMul1 = Pair.second;
-    LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
-               dbgs() << "- "; PMul0->Root->dump();
-               dbgs() << "- "; PMul1->Root->dump());
+    LLVM_DEBUG(dbgs() << "Found parallel MACs:\n"
+               << "- " << *PMul0->Root << "\n"
+               << "- " << *PMul1->Root << "\n");
 
     Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
                           InsertAfter);
@@ -459,54 +507,6 @@
   return false;
 }
 
-static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
-                            ReductionList &Reductions) {
-  RecurrenceDescriptor RecDesc;
-  const bool HasFnNoNaNAttr =
-    F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
-  const BasicBlock *Latch = TheLoop->getLoopLatch();
-
-  for (PHINode &Phi : Header->phis()) {
-    const auto *Ty = Phi.getType();
-    if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
-      continue;
-
-    const bool IsReduction =
-      RecurrenceDescriptor::AddReductionVar(&Phi,
-                                            RecurrenceDescriptor::RK_IntegerAdd,
-                                            TheLoop, HasFnNoNaNAttr, RecDesc);
-    if (!IsReduction)
-      continue;
-
-    Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
-    if (!Acc)
-      continue;
-
-    Reductions.push_back(Reduction(&Phi, Acc));
-  }
-
-  LLVM_DEBUG(
-    dbgs() << "\nAccumulating integer additions (reductions) found:\n";
-    for (auto &R : Reductions) {
-      dbgs() << "-  "; R.Phi->dump();
-      dbgs() << "-> "; R.AccIntAdd->dump();
-    }
-  );
-}
-
-static void AddMACCandidate(OpChainList &Candidates,
-                            Instruction *Mul,
-                            Value *MulOp0, Value *MulOp1) {
-  assert(Mul->getOpcode() == Instruction::Mul &&
-         "expected mul instruction");
-  ValueList LHS;
-  ValueList RHS;
-  if (IsNarrowSequence<16>(MulOp0, LHS) &&
-      IsNarrowSequence<16>(MulOp1, RHS)) {
-    Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
-  }
-}
-
 static void MatchParallelMACSequences(Reduction &R,
                                       OpChainList &Candidates) {
   Instruction *Acc = R.AccIntAdd;
@@ -528,8 +528,14 @@
     case Instruction::Mul: {
       Value *MulOp0 = I->getOperand(0);
       Value *MulOp1 = I->getOperand(1);
-      if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1))
-        AddMACCandidate(Candidates, I, MulOp0, MulOp1);
+      if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
+        ValueList LHS;
+        ValueList RHS;
+        if (IsNarrowSequence<16>(MulOp0, LHS) &&
+            IsNarrowSequence<16>(MulOp1, RHS)) {
+          Candidates.push_back(make_unique<BinOpChain>(I, LHS, RHS));
+        }
+      }
       return false;
     }
     case Instruction::SExt:
@@ -543,52 +549,6 @@
              << Candidates.size() << " candidates.\n");
 }
 
-// Collects all instructions that are not part of the MAC chains, which is the
-// set of instructions that can potentially alias with the MAC operands.
-static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
-                            Instructions &Writes) {
-  for (auto &I : *Header) {
-    if (I.mayReadFromMemory())
-      Reads.push_back(&I);
-    if (I.mayWriteToMemory())
-      Writes.push_back(&I);
-  }
-}
-
-// Check whether statements in the basic block that write to memory alias with
-// the memory locations accessed by the MAC-chains.
-// TODO: we need the read statements when we accept more complicated chains.
-static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
-                       Instructions &Writes, OpChainList &MACCandidates) {
-  LLVM_DEBUG(dbgs() << "Alias checks:\n");
-  for (auto &MAC : MACCandidates) {
-    LLVM_DEBUG(dbgs() << "mul: "; MAC->Root->dump());
-
-    // At the moment, we allow only simple chains that only consist of reads,
-    // accumulate their result with an integer add, and thus that don't write
-    // memory, and simply bail if they do.
-    if (!MAC->ReadOnly)
-      return true;
-
-    // Now for all writes in the basic block, check that they don't alias with
-    // the memory locations accessed by our MAC-chain:
-    for (auto *I : Writes) {
-      LLVM_DEBUG(dbgs() << "- "; I->dump());
-      assert(MAC->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
-      for (auto &MemLoc : MAC->MemLocs) {
-        if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc),
-                                          ModRefInfo::ModRef))) {
-          LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
-          return true;
-        }
-      }
-    }
-  }
-
-  LLVM_DEBUG(dbgs() << "OK: no aliases found!\n");
-  return false;
-}
-
 static bool CheckMACMemory(OpChainList &Candidates) {
   for (auto &C : Candidates) {
     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
@@ -597,7 +557,7 @@
       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
       return false;
     }
-    C->SetMemoryLocations();
+    C->PopulateLoads();
     ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
     ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
 
@@ -643,14 +603,36 @@
 // before the loop begins.
 //
 bool ARMParallelDSP::MatchSMLAD(Function &F) {
-  BasicBlock *Header = L->getHeader();
-  LLVM_DEBUG(dbgs() << "= Matching SMLAD =\n";
-             dbgs() << "Header block:\n"; Header->dump();
-             dbgs() << "Loop info:\n\n"; L->dump());
 
-  bool Changed = false;
+  auto FindReductions = [&](ReductionList &Reductions) {
+    RecurrenceDescriptor RecDesc;
+    const bool HasFnNoNaNAttr =
+      F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
+    BasicBlock *Latch = L->getLoopLatch();
+
+    for (PHINode &Phi : Latch->phis()) {
+      const auto *Ty = Phi.getType();
+      if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
+        continue;
+
+      const bool IsReduction = RecurrenceDescriptor::AddReductionVar(
+        &Phi, RecurrenceDescriptor::RK_IntegerAdd, L, HasFnNoNaNAttr, RecDesc);
+
+      if (!IsReduction)
+        continue;
+
+      Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
+      if (!Acc)
+        continue;
+
+      Reductions.push_back(Reduction(&Phi, Acc));
+    }
+    return !Reductions.empty();
+  };
+
   ReductionList Reductions;
-  MatchReductions(F, L, Header, Reductions);
+  if (!FindReductions(Reductions))
+    return false;
 
   for (auto &R : Reductions) {
     OpChainList MACCandidates;
@@ -666,72 +648,79 @@
       dbgs() << "\n";);
   }
 
-  // Collect all instructions that may read or write memory. Our alias
-  // analysis checks bail out if any of these instructions aliases with an
-  // instruction from the MAC-chain.
-  Instructions Reads, Writes;
-  AliasCandidates(Header, Reads, Writes);
-
+  bool Changed = false;
+  // Check whether statements in the basic block that write to memory alias
+  // with the memory locations accessed by the MAC-chains.
   for (auto &R : Reductions) {
-    if (AreAliased(AA, Reads, Writes, R.MACCandidates))
-      return false;
     CreateParallelMACPairs(R);
     Changed |= InsertParallelMACs(R);
   }
 
-  LLVM_DEBUG(if (Changed) dbgs() << "Header block:\n"; Header->dump(););
   return Changed;
 }
 
-LoadInst* ARMParallelDSP::CreateLoadIns(IRBuilder<NoFolder> &IRB,
-                                        SmallVectorImpl<LoadInst*> &Loads,
-                                        IntegerType *LoadTy) {
+LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
+                                         IntegerType *LoadTy) {
   assert(Loads.size() == 2 && "currently only support widening two loads");
- 
-  const unsigned AddrSpace = Loads[0]->getPointerAddressSpace();
-  Value *VecPtr = IRB.CreateBitCast(Loads[0]->getPointerOperand(),
+
+  LoadInst *Base = Loads[0];
+  LoadInst *Offset = Loads[1];
+
+  Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
+  Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
+
+  assert((BaseSExt && OffsetSExt)
+         && "Loads should have a single, extending, user");
+
+  std::function<void(Value*, Value*)> MoveBefore =
+    [&](Value *A, Value *B) -> void {
+      if (!isa<Instruction>(A) || !isa<Instruction>(B))
+        return;
+
+      auto *Source = cast<Instruction>(A);
+      auto *Sink = cast<Instruction>(B);
+
+      if (DT->dominates(Source, Sink) ||
+          Source->getParent() != Sink->getParent() ||
+          isa<PHINode>(Source) || isa<PHINode>(Sink))
+        return;
+
+      Source->moveBefore(Sink);
+      for (auto &U : Source->uses())
+        MoveBefore(Source, U.getUser());
+    };
+
+  // Insert the load at the point of the original dominating load.
+  LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
+  IRBuilder<NoFolder> IRB(DomLoad->getParent(),
+                          ++BasicBlock::iterator(DomLoad));
+
+  // Bitcast the pointer to a wider type and create the wide load, while making
+  // sure to maintain the original alignment as this prevents ldrd from being
+  // generated when it could be illegal due to memory alignment.
+  const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
+  Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
                                     LoadTy->getPointerTo(AddrSpace));
   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
-                                             Loads[0]->getAlignment());
-  // Fix up users, Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
-  Instruction *SExt0 = dyn_cast<SExtInst>(Loads[0]->user_back());
-  Instruction *SExt1 = dyn_cast<SExtInst>(Loads[1]->user_back());
+                                             Base->getAlignment());
 
-  assert((Loads[0]->hasOneUse() && Loads[1]->hasOneUse() && SExt0 && SExt1) &&
-         "Loads should have a single, extending, user");
-
-  std::function<void(Instruction*, Instruction*)> MoveAfter =
-    [&](Instruction* Source, Instruction* Sink) -> void {
-    if (DT->dominates(Source, Sink) ||
-        Source->getParent() != Sink->getParent() ||
-        isa<PHINode>(Source) || isa<PHINode>(Sink))
-      return;
-
-    Sink->moveAfter(Source);
-    for (auto &U : Sink->uses())
-      MoveAfter(Sink, cast<Instruction>(U.getUser()));
-  };
+  // Make sure everything is in the correct order in the basic block.
+  MoveBefore(Base->getPointerOperand(), VecPtr);
+  MoveBefore(VecPtr, WideLoad);
 
   // From the wide load, create two values that equal the original two loads.
-  Value *Bottom = IRB.CreateTrunc(WideLoad, Loads[0]->getType());
-  SExt0->setOperand(0, Bottom);
-  if (auto *I = dyn_cast<Instruction>(Bottom)) {
-    I->moveAfter(WideLoad);
-    MoveAfter(I, SExt0);
-  }
+  // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
+  // TODO: Support big-endian as well.
+  Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
+  BaseSExt->setOperand(0, Bottom);
 
-  IntegerType *Ld1Ty = cast<IntegerType>(Loads[1]->getType());
-  Value *ShiftVal = ConstantInt::get(LoadTy, Ld1Ty->getBitWidth());
+  IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
+  Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
-  if (auto *I = dyn_cast<Instruction>(Top))
-    MoveAfter(WideLoad, I);
+  Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
+  OffsetSExt->setOperand(0, Trunc);
 
-  Value *Trunc = IRB.CreateTrunc(Top, Ld1Ty);
-  SExt1->setOperand(0, Trunc);
-  if (auto *I = dyn_cast<Instruction>(Trunc))
-    MoveAfter(I, SExt1);
-
-  WideLoads.emplace(std::make_pair(Loads[0],
+  WideLoads.emplace(std::make_pair(Base,
                                    make_unique<WidenedLoad>(Loads, WideLoad)));
   return WideLoad;
 }
@@ -748,15 +737,13 @@
              << "- " << *Acc << "\n"
              << "- Exchange: " << Exchange << "\n");
 
-  IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
-                              ++BasicBlock::iterator(InsertAfter));
-
   // Replace the reduction chain with an intrinsic call
   IntegerType *Ty = IntegerType::get(M->getContext(), 32);
   LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
-    WideLoads[VecLd0[0]]->getLoad() : CreateLoadIns(Builder, VecLd0, Ty);
+    WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
   LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
-    WideLoads[VecLd1[0]]->getLoad() : CreateLoadIns(Builder, VecLd1, Ty);
+    WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
+
   Value* Args[] = { WideLd0, WideLd1, Acc };
   Function *SMLAD = nullptr;
   if (Exchange)
@@ -767,6 +754,9 @@
     SMLAD = Acc->getType()->isIntegerTy(32) ?
       Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
       Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
+
+  IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
+                              ++BasicBlock::iterator(InsertAfter));
   CallInst *Call = Builder.CreateCall(SMLAD, Args);
   NumSMLAD++;
   return Call;