LoopDistribute/LAA: Respect convergent

This case is slightly tricky, because loop distribution should be
allowed in some cases, and not others. As long as runtime dependency
checks don't need to be introduced, this should be OK. This is further
complicated by the fact that LoopDistribute partially ignores if LAA
says that vectorization is safe, and then does its own runtime pointer
legality checks.

Note this pass still does not handle noduplicate correctly, as this
should always be forbidden with it. I'm not going to bother trying to
fix it, as it would require more effort and I think noduplicate should
be removed.

https://reviews.llvm.org/D62607

llvm-svn: 363160
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index d6fbf6f..36bd9a8 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -1778,6 +1778,11 @@
   unsigned NumReads = 0;
   unsigned NumReadWrites = 0;
 
+  bool HasComplexMemInst = false;
+
+  // A runtime check is only legal to insert if there are no convergent calls.
+  HasConvergentOp = false;
+
   PtrRtChecking->Pointers.clear();
   PtrRtChecking->Need = false;
 
@@ -1785,8 +1790,25 @@
 
   // For each block.
   for (BasicBlock *BB : TheLoop->blocks()) {
-    // Scan the BB and collect legal loads and stores.
+    // Scan the BB and collect legal loads and stores. Also detect any
+    // convergent instructions.
     for (Instruction &I : *BB) {
+      if (auto *Call = dyn_cast<CallBase>(&I)) {
+        if (Call->isConvergent())
+          HasConvergentOp = true;
+      }
+
+      // With both a non-vectorizable memory instruction and a convergent
+      // operation, found in this loop, no reason to continue the search.
+      if (HasComplexMemInst && HasConvergentOp) {
+        CanVecMem = false;
+        return;
+      }
+
+      // Avoid hitting recordAnalysis multiple times.
+      if (HasComplexMemInst)
+        continue;
+
       // If this is a load, save it. If this instruction can read from memory
       // but is not a load, then we quit. Notice that we don't handle function
       // calls that read or write.
@@ -1805,12 +1827,18 @@
           continue;
 
         auto *Ld = dyn_cast<LoadInst>(&I);
-        if (!Ld || (!Ld->isSimple() && !IsAnnotatedParallel)) {
+        if (!Ld) {
+          recordAnalysis("CantVectorizeInstruction", Ld)
+            << "instruction cannot be vectorized";
+          HasComplexMemInst = true;
+          continue;
+        }
+        if (!Ld->isSimple() && !IsAnnotatedParallel) {
           recordAnalysis("NonSimpleLoad", Ld)
               << "read with atomic ordering or volatile read";
           LLVM_DEBUG(dbgs() << "LAA: Found a non-simple load.\n");
-          CanVecMem = false;
-          return;
+          HasComplexMemInst = true;
+          continue;
         }
         NumLoads++;
         Loads.push_back(Ld);
@@ -1826,15 +1854,15 @@
         if (!St) {
           recordAnalysis("CantVectorizeInstruction", St)
               << "instruction cannot be vectorized";
-          CanVecMem = false;
-          return;
+          HasComplexMemInst = true;
+          continue;
         }
         if (!St->isSimple() && !IsAnnotatedParallel) {
           recordAnalysis("NonSimpleStore", St)
               << "write with atomic ordering or volatile write";
           LLVM_DEBUG(dbgs() << "LAA: Found a non-simple store.\n");
-          CanVecMem = false;
-          return;
+          HasComplexMemInst = true;
+          continue;
         }
         NumStores++;
         Stores.push_back(St);
@@ -1845,6 +1873,11 @@
     } // Next instr.
   } // Next block.
 
+  if (HasComplexMemInst) {
+    CanVecMem = false;
+    return;
+  }
+
   // Now we have two lists that hold the loads and the stores.
   // Next, we find the pointers that they use.
 
@@ -1962,7 +1995,7 @@
   }
 
   LLVM_DEBUG(
-      dbgs() << "LAA: We can perform a memory runtime check if needed.\n");
+    dbgs() << "LAA: May be able to perform a memory runtime check if needed.\n");
 
   CanVecMem = true;
   if (Accesses.isDependencyCheckNeeded()) {
@@ -1997,6 +2030,15 @@
     }
   }
 
+  if (HasConvergentOp) {
+    recordAnalysis("CantInsertRuntimeCheckWithConvergent")
+      << "cannot add control dependency to convergent operation";
+    LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because a runtime check "
+                         "would be needed with a convergent operation\n");
+    CanVecMem = false;
+    return;
+  }
+
   if (CanVecMem)
     LLVM_DEBUG(
         dbgs() << "LAA: No unsafe dependent memory operations in loop.  We"
@@ -2285,6 +2327,7 @@
       PtrRtChecking(llvm::make_unique<RuntimePointerChecking>(SE)),
       DepChecker(llvm::make_unique<MemoryDepChecker>(*PSE, L)), TheLoop(L),
       NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), CanVecMem(false),
+      HasConvergentOp(false),
       HasDependenceInvolvingLoopInvariantAddress(false) {
   if (canAnalyzeLoop())
     analyzeLoop(AA, LI, TLI, DT);
@@ -2301,6 +2344,9 @@
     OS << "\n";
   }
 
+  if (HasConvergentOp)
+    OS.indent(Depth) << "Has convergent operation in loop\n";
+
   if (Report)
     OS.indent(Depth) << "Report: " << Report->getMsg() << "\n";
 
diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
index 4f5344b..f45e5fd 100644
--- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
@@ -766,8 +766,14 @@
                     "cannot isolate unsafe dependencies");
     }
 
-    // Don't distribute the loop if we need too many SCEV run-time checks.
+    // Don't distribute the loop if we need too many SCEV run-time checks, or
+    // any if it's illegal.
     const SCEVUnionPredicate &Pred = LAI->getPSE().getUnionPredicate();
+    if (LAI->hasConvergentOp() && !Pred.isAlwaysTrue()) {
+      return fail("RuntimeCheckWithConvergent",
+                  "may not insert runtime check with convergent operation");
+    }
+
     if (Pred.getComplexity() > (IsForced.getValueOr(false)
                                     ? PragmaDistributeSCEVCheckThreshold
                                     : DistributeSCEVCheckThreshold))
@@ -795,7 +801,14 @@
     auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition,
                                                   RtPtrChecking);
 
+    if (LAI->hasConvergentOp() && !Checks.empty()) {
+      return fail("RuntimeCheckWithConvergent",
+                  "may not insert runtime check with convergent operation");
+    }
+
     if (!Pred.isAlwaysTrue() || !Checks.empty()) {
+      assert(!LAI->hasConvergentOp() && "inserting illegal loop versioning");
+
       MDNode *OrigLoopID = L->getLoopID();
 
       LLVM_DEBUG(dbgs() << "\nPointers:\n");