Use uint64_t for branch weights instead of uint32_t
CallInst::updateProfWeight() creates branch_weights with i64 instead of i32.
To be more consistent everywhere and remove lots of casts from uint64_t
to uint32_t, use i64 for branch_weights.
Reviewed By: davidxl
Differential Revision: https://reviews.llvm.org/D88609
diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index 267d415..3249468 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -38,6 +38,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cstdint>
@@ -434,6 +435,28 @@
return true;
}
+// Scales all values in Weights so that the total fits in 64 bits. Returns the
+// total.
+// FIXME: only scale by the minimum necessary to fit the total within 64 bits.
+static uint64_t ScaleWeights(MutableArrayRef<uint64_t> Weights) {
+ uint64_t Total = 0;
+ bool Overflowed = false;
+ for (uint64_t W : Weights) {
+ Total = SaturatingAdd(Total, W, &Overflowed);
+ if (Overflowed)
+ break;
+ }
+ if (Overflowed) {
+ uint64_t ScaledTotal = 0;
+ for (uint64_t &W : Weights) {
+ W /= UINT32_MAX;
+ ScaledTotal += W;
+ }
+ return ScaledTotal;
+ }
+ return Total;
+}
+
// Propagate existing explicit probabilities from either profile data or
// 'expect' intrinsic processing. Examine metadata against unreachable
// heuristic. The probability of the edge coming to unreachable block is
@@ -458,10 +481,7 @@
return false;
// Build up the final weights that will be used in a temporary buffer.
- // Compute the sum of all weights to later decide whether they need to
- // be scaled to fit in 32 bits.
- uint64_t WeightSum = 0;
- SmallVector<uint32_t, 2> Weights;
+ SmallVector<uint64_t, 2> Weights;
SmallVector<unsigned, 2> UnreachableIdxs;
SmallVector<unsigned, 2> ReachableIdxs;
Weights.reserve(TI->getNumSuccessors());
@@ -470,10 +490,10 @@
mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
if (!Weight)
return false;
- assert(Weight->getValue().getActiveBits() <= 32 &&
- "Too many bits for uint32_t");
- Weights.push_back(Weight->getZExtValue());
- WeightSum += Weights.back();
+ // TODO: remove scaling by UINT32_MAX and use full uint64_t range.
+ uint64_t WeightVal = Weight->getZExtValue();
+ Weights.push_back(WeightVal);
+ // WeightSum += WeightVal;
if (PostDominatedByUnreachable.count(TI->getSuccessor(I - 1)))
UnreachableIdxs.push_back(I - 1);
else
@@ -481,20 +501,7 @@
}
assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
- // If the sum of weights does not fit in 32 bits, scale every weight down
- // accordingly.
- uint64_t ScalingFactor =
- (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
-
- if (ScalingFactor > 1) {
- WeightSum = 0;
- for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
- Weights[I] /= ScalingFactor;
- WeightSum += Weights[I];
- }
- }
- assert(WeightSum <= UINT32_MAX &&
- "Expected weights to scale down to 32 bits");
+ uint64_t WeightSum = ScaleWeights(Weights);
if (WeightSum == 0 || ReachableIdxs.size() == 0) {
for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
@@ -505,7 +512,8 @@
// Set the probability.
SmallVector<BranchProbability, 2> BP;
for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
- BP.push_back({ Weights[I], static_cast<uint32_t>(WeightSum) });
+ BP.push_back(
+ BranchProbability::getBranchProbability(Weights[I], WeightSum));
// Examine the metadata against unreachable heuristic.
// If the unreachable heuristic is more strong then we use it for this edge.