Use uint64_t for branch weights instead of uint32_t

CallInst::updateProfWeight() creates branch_weights with i64 instead of i32.
To be more consistent everywhere and remove lots of casts from uint64_t
to uint32_t, use i64 for branch_weights.

Reviewed By: davidxl

Differential Revision: https://reviews.llvm.org/D88609
diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index 267d415..3249468 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -38,6 +38,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <cstdint>
@@ -434,6 +435,28 @@
   return true;
 }
 
+// Scales all values in Weights so that the total fits in 64 bits. Returns the
+// total.
+// FIXME: only scale by the minimum necessary to fit the total within 64 bits.
+static uint64_t ScaleWeights(MutableArrayRef<uint64_t> Weights) {
+  uint64_t Total = 0;
+  bool Overflowed = false;
+  for (uint64_t W : Weights) {
+    Total = SaturatingAdd(Total, W, &Overflowed);
+    if (Overflowed)
+      break;
+  }
+  if (Overflowed) {
+    uint64_t ScaledTotal = 0;
+    for (uint64_t &W : Weights) {
+      W /= UINT32_MAX;
+      ScaledTotal += W;
+    }
+    return ScaledTotal;
+  }
+  return Total;
+}
+
 // Propagate existing explicit probabilities from either profile data or
 // 'expect' intrinsic processing. Examine metadata against unreachable
 // heuristic. The probability of the edge coming to unreachable block is
@@ -458,10 +481,7 @@
     return false;
 
   // Build up the final weights that will be used in a temporary buffer.
-  // Compute the sum of all weights to later decide whether they need to
-  // be scaled to fit in 32 bits.
-  uint64_t WeightSum = 0;
-  SmallVector<uint32_t, 2> Weights;
+  SmallVector<uint64_t, 2> Weights;
   SmallVector<unsigned, 2> UnreachableIdxs;
   SmallVector<unsigned, 2> ReachableIdxs;
   Weights.reserve(TI->getNumSuccessors());
@@ -470,10 +490,10 @@
         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
     if (!Weight)
       return false;
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights.push_back(Weight->getZExtValue());
-    WeightSum += Weights.back();
+    // TODO: remove scaling by UINT32_MAX and use full uint64_t range.
+    uint64_t WeightVal = Weight->getZExtValue();
+    Weights.push_back(WeightVal);
+    // WeightSum += WeightVal;
     if (PostDominatedByUnreachable.count(TI->getSuccessor(I - 1)))
       UnreachableIdxs.push_back(I - 1);
     else
@@ -481,20 +501,7 @@
   }
   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
 
-  // If the sum of weights does not fit in 32 bits, scale every weight down
-  // accordingly.
-  uint64_t ScalingFactor =
-      (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
-
-  if (ScalingFactor > 1) {
-    WeightSum = 0;
-    for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
-      Weights[I] /= ScalingFactor;
-      WeightSum += Weights[I];
-    }
-  }
-  assert(WeightSum <= UINT32_MAX &&
-         "Expected weights to scale down to 32 bits");
+  uint64_t WeightSum = ScaleWeights(Weights);
 
   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
@@ -505,7 +512,8 @@
   // Set the probability.
   SmallVector<BranchProbability, 2> BP;
   for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
-    BP.push_back({ Weights[I], static_cast<uint32_t>(WeightSum) });
+    BP.push_back(
+        BranchProbability::getBranchProbability(Weights[I], WeightSum));
 
   // Examine the metadata against unreachable heuristic.
   // If the unreachable heuristic is more strong then we use it for this edge.