[LoopPredication] Add profitability check based on BPI
Summary:
LoopPredication is not profitable when the loop is known to always exit
through some block other than the latch block.
A coarse grained latch check can cause loop predication to predicate the
loop, and unconditionally deoptimize.
However, without predicating the loop, the guard may never fail within the
loop during the dynamic execution because the non-latch loop termination
condition exits the loop before the latch condition causes the loop to
exit.
We teach LP about this using BranchProfileInfo pass.
Reviewers: apilipenko, skatkov, mkazantsev, reames
Reviewed by: skatkov
Subscribers: llvm-commits
Differential Revision: https://reviews.llvm.org/D44667
llvm-svn: 328210
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index 4d056d0..6102890 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -178,6 +178,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/LoopPredication.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/ScalarEvolution.h"
@@ -202,6 +203,20 @@
static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
cl::Hidden, cl::init(true));
+
+static cl::opt<bool>
+ SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
+ cl::Hidden, cl::init(false));
+
+// This is the scale factor for the latch probability. We use this during
+// profitability analysis to find other exiting blocks that have a much higher
+// probability of exiting the loop instead of loop exiting via latch.
+// This value should be greater than 1 for a sane profitability check.
+static cl::opt<float> LatchExitProbabilityScale(
+ "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
+ cl::desc("scale factor for the latch probability. Value should be greater "
+ "than 1. Lower values are ignored"));
+
namespace {
class LoopPredication {
/// Represents an induction variable check:
@@ -221,6 +236,7 @@
};
ScalarEvolution *SE;
+ BranchProbabilityInfo *BPI;
Loop *L;
const DataLayout *DL;
@@ -254,6 +270,12 @@
IRBuilder<> &Builder);
bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
+ // If the loop always exits through another block in the loop, we should not
+ // predicate based on the latch check. For example, the latch check can be a
+ // very coarse grained check and there can be more fine grained exit checks
+ // within the loop. We identify such unprofitable loops through BPI.
+ bool isLoopProfitableToPredicate();
+
// When the IV type is wider than the range operand type, we can still do loop
// predication, by generating SCEVs for the range and latch that are of the
// same type. We achieve this by generating a SCEV truncate expression for the
@@ -272,7 +294,8 @@
Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType);
public:
- LoopPredication(ScalarEvolution *SE) : SE(SE){};
+ LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI)
+ : SE(SE), BPI(BPI){};
bool runOnLoop(Loop *L);
};
@@ -284,6 +307,7 @@
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<BranchProbabilityInfoWrapperPass>();
getLoopAnalysisUsage(AU);
}
@@ -291,7 +315,9 @@
if (skipLoop(L))
return false;
auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- LoopPredication LP(SE);
+ BranchProbabilityInfo &BPI =
+ getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
+ LoopPredication LP(SE, &BPI);
return LP.runOnLoop(L);
}
};
@@ -301,6 +327,7 @@
INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
"Loop predication", false, false)
+INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopPass)
INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
"Loop predication", false, false)
@@ -312,7 +339,11 @@
PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
- LoopPredication LP(&AR.SE);
+ const auto &FAM =
+ AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
+ Function *F = L.getHeader()->getParent();
+ auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F);
+ LoopPredication LP(&AR.SE, BPI);
if (!LP.runOnLoop(&L))
return PreservedAnalyses::all();
@@ -690,6 +721,60 @@
Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
}
+bool LoopPredication::isLoopProfitableToPredicate() {
+ if (SkipProfitabilityChecks || !BPI)
+ return true;
+
+ SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 8> ExitEdges;
+ L->getExitEdges(ExitEdges);
+ // If there is only one exiting edge in the loop, it is always profitable to
+ // predicate the loop.
+ if (ExitEdges.size() == 1)
+ return true;
+
+ // Calculate the exiting probabilities of all exiting edges from the loop,
+ // starting with the LatchExitProbability.
+ // Heuristic for profitability: If any of the exiting blocks' probability of
+ // exiting the loop is larger than exiting through the latch block, it's not
+ // profitable to predicate the loop.
+ auto *LatchBlock = L->getLoopLatch();
+ assert(LatchBlock && "Should have a single latch at this point!");
+ auto *LatchTerm = LatchBlock->getTerminator();
+ assert(LatchTerm->getNumSuccessors() == 2 &&
+ "expected to be an exiting block with 2 succs!");
+ unsigned LatchBrExitIdx =
+ LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
+ BranchProbability LatchExitProbability =
+ BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx);
+
+ // Protect against degenerate inputs provided by the user. Providing a value
+ // less than one, can invert the definition of profitable loop predication.
+ float ScaleFactor = LatchExitProbabilityScale;
+ if (ScaleFactor < 1) {
+ DEBUG(
+ dbgs()
+ << "Ignored user setting for loop-predication-latch-probability-scale: "
+ << LatchExitProbabilityScale << "\n");
+ DEBUG(dbgs() << "The value is set to 1.0\n");
+ ScaleFactor = 1.0;
+ }
+ const auto LatchProbabilityThreshold =
+ LatchExitProbability * ScaleFactor;
+
+ for (const auto &ExitEdge : ExitEdges) {
+ BranchProbability ExitingBlockProbability =
+ BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second);
+ // Some exiting edge has higher probability than the latch exiting edge.
+ // No longer profitable to predicate.
+ if (ExitingBlockProbability > LatchProbabilityThreshold)
+ return false;
+ }
+ // Using BPI, we have concluded that the most probable way to exit from the
+ // loop is through the latch (or there's no profile information and all
+ // exits are equally likely).
+ return true;
+}
+
bool LoopPredication::runOnLoop(Loop *Loop) {
L = Loop;
@@ -718,6 +803,10 @@
DEBUG(dbgs() << "Latch check:\n");
DEBUG(LatchCheck.dump());
+ if (!isLoopProfitableToPredicate()) {
+ DEBUG(dbgs()<< "Loop not profitable to predicate!\n");
+ return false;
+ }
// Collect all the guards into a vector and process later, so as not
// to invalidate the instruction iterator.
SmallVector<IntrinsicInst *, 4> Guards;
diff --git a/llvm/test/Transforms/LoopPredication/profitability.ll b/llvm/test/Transforms/LoopPredication/profitability.ll
new file mode 100644
index 0000000..ce01d3c
--- /dev/null
+++ b/llvm/test/Transforms/LoopPredication/profitability.ll
@@ -0,0 +1,120 @@
+; RUN: opt -S -loop-predication -loop-predication-skip-profitability-checks=false < %s 2>&1 | FileCheck %s
+; RUN: opt -S -loop-predication-skip-profitability-checks=false -passes='require<scalar-evolution>,require<branch-prob>,loop(loop-predication)' < %s 2>&1 | FileCheck %s
+
+; latch block exits to a speculation block. BPI already knows (without prof
+; data) that deopt is very rarely
+; taken. So we do not predicate this loop using that coarse latch check.
+; LatchExitProbability: 0x04000000 / 0x80000000 = 3.12%
+; ExitingBlockProbability: 0x7ffa572a / 0x80000000 = 99.98%
+define i64 @donot_predicate(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) {
+; CHECK-LABEL: donot_predicate(
+entry:
+ %length.ext = zext i32 %length to i64
+ %n.pre = load i64, i64* %n_addr, align 4
+ br label %Header
+
+; CHECK-LABEL: Header:
+; CHECK: %within.bounds = icmp ult i64 %j2, %length.ext
+; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9)
+Header: ; preds = %entry, %Latch
+ %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ]
+ %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ]
+ %within.bounds = icmp ult i64 %j2, %length.ext
+ call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
+ %innercmp = icmp eq i64 %j2, %n.pre
+ %j.next = add nuw nsw i64 %j2, 1
+ br i1 %innercmp, label %Latch, label %exit, !prof !0
+
+Latch: ; preds = %Header
+ %speculate_trip_count = icmp ult i64 %j.next, 1048576
+ br i1 %speculate_trip_count, label %Header, label %deopt
+
+deopt: ; preds = %Latch
+ %counted_speculation_failed = call i64 (...) @llvm.experimental.deoptimize.i64(i64 30) [ "deopt"(i32 0) ]
+ ret i64 %counted_speculation_failed
+
+exit: ; preds = %Header
+ %result.in3.lcssa = phi i64* [ %result.in3, %Header ]
+ %result.le = load i64, i64* %result.in3.lcssa, align 8
+ ret i64 %result.le
+}
+!0 = !{!"branch_weights", i32 18, i32 104200}
+
+; predicate loop since there's no profile information and BPI concluded all
+; exiting blocks have same probability of exiting from loop.
+define i64 @predicate(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) {
+; CHECK-LABEL: predicate(
+; CHECK-LABEL: entry:
+; CHECK: [[limit_check:[^ ]+]] = icmp ule i64 1048576, %length.ext
+; CHECK-NEXT: [[first_iteration_check:[^ ]+]] = icmp ult i64 0, %length.ext
+; CHECK-NEXT: [[wide_cond:[^ ]+]] = and i1 [[first_iteration_check]], [[limit_check]]
+entry:
+ %length.ext = zext i32 %length to i64
+ %n.pre = load i64, i64* %n_addr, align 4
+ br label %Header
+
+; CHECK-LABEL: Header:
+; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ]
+Header: ; preds = %entry, %Latch
+ %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ]
+ %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ]
+ %within.bounds = icmp ult i64 %j2, %length.ext
+ call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
+ %innercmp = icmp eq i64 %j2, %n.pre
+ %j.next = add nuw nsw i64 %j2, 1
+ br i1 %innercmp, label %Latch, label %exit
+
+Latch: ; preds = %Header
+ %speculate_trip_count = icmp ult i64 %j.next, 1048576
+ br i1 %speculate_trip_count, label %Header, label %exitLatch
+
+exitLatch: ; preds = %Latch
+ ret i64 1
+
+exit: ; preds = %Header
+ %result.in3.lcssa = phi i64* [ %result.in3, %Header ]
+ %result.le = load i64, i64* %result.in3.lcssa, align 8
+ ret i64 %result.le
+}
+
+; Same as test above but with profiling data that the most probable exit from
+; the loop is the header exiting block (not the latch block). So do not predicate.
+; LatchExitProbability: 0x000020e1 / 0x80000000 = 0.00%
+; ExitingBlockProbability: 0x7ffcbb86 / 0x80000000 = 99.99%
+define i64 @donot_predicate_prof(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) {
+; CHECK-LABEL: donot_predicate_prof(
+; CHECK-LABEL: entry:
+entry:
+ %length.ext = zext i32 %length to i64
+ %n.pre = load i64, i64* %n_addr, align 4
+ br label %Header
+
+; CHECK-LABEL: Header:
+; CHECK: %within.bounds = icmp ult i64 %j2, %length.ext
+; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9)
+Header: ; preds = %entry, %Latch
+ %result.in3 = phi i64* [ %arg2, %entry ], [ %arg, %Latch ]
+ %j2 = phi i64 [ 0, %entry ], [ %j.next, %Latch ]
+ %within.bounds = icmp ult i64 %j2, %length.ext
+ call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
+ %innercmp = icmp eq i64 %j2, %n.pre
+ %j.next = add nuw nsw i64 %j2, 1
+ br i1 %innercmp, label %Latch, label %exit, !prof !1
+
+Latch: ; preds = %Header
+ %speculate_trip_count = icmp ult i64 %j.next, 1048576
+ br i1 %speculate_trip_count, label %Header, label %exitLatch, !prof !2
+
+exitLatch: ; preds = %Latch
+ ret i64 1
+
+exit: ; preds = %Header
+ %result.in3.lcssa = phi i64* [ %result.in3, %Header ]
+ %result.le = load i64, i64* %result.in3.lcssa, align 8
+ ret i64 %result.le
+}
+declare i64 @llvm.experimental.deoptimize.i64(...)
+declare void @llvm.experimental.guard(i1, ...)
+
+!1 = !{!"branch_weights", i32 104, i32 1042861}
+!2 = !{!"branch_weights", i32 255129, i32 1}