Use uint64_t for branch weights instead of uint32_t

CallInst::updateProfWeight() creates branch_weights with i64 instead of i32.
To be more consistent everywhere and remove lots of casts from uint64_t
to uint32_t, use i64 for branch_weights.

Reviewed By: davidxl

Differential Revision: https://reviews.llvm.org/D88609
diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index 267d415..3249468 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -38,6 +38,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <cstdint>
@@ -434,6 +435,28 @@
   return true;
 }
 
+// Scales all values in Weights so that the total fits in 64 bits. Returns the
+// total.
+// FIXME: only scale by the minimum necessary to fit the total within 64 bits.
+static uint64_t ScaleWeights(MutableArrayRef<uint64_t> Weights) {
+  uint64_t Total = 0;
+  bool Overflowed = false;
+  for (uint64_t W : Weights) {
+    Total = SaturatingAdd(Total, W, &Overflowed);
+    if (Overflowed)
+      break;
+  }
+  if (Overflowed) {
+    uint64_t ScaledTotal = 0;
+    for (uint64_t &W : Weights) {
+      W /= UINT32_MAX;
+      ScaledTotal += W;
+    }
+    return ScaledTotal;
+  }
+  return Total;
+}
+
 // Propagate existing explicit probabilities from either profile data or
 // 'expect' intrinsic processing. Examine metadata against unreachable
 // heuristic. The probability of the edge coming to unreachable block is
@@ -458,10 +481,7 @@
     return false;
 
   // Build up the final weights that will be used in a temporary buffer.
-  // Compute the sum of all weights to later decide whether they need to
-  // be scaled to fit in 32 bits.
-  uint64_t WeightSum = 0;
-  SmallVector<uint32_t, 2> Weights;
+  SmallVector<uint64_t, 2> Weights;
   SmallVector<unsigned, 2> UnreachableIdxs;
   SmallVector<unsigned, 2> ReachableIdxs;
   Weights.reserve(TI->getNumSuccessors());
@@ -470,10 +490,10 @@
         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
     if (!Weight)
       return false;
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights.push_back(Weight->getZExtValue());
-    WeightSum += Weights.back();
+    // TODO: remove scaling by UINT32_MAX and use full uint64_t range.
+    uint64_t WeightVal = Weight->getZExtValue();
+    Weights.push_back(WeightVal);
+    // WeightSum += WeightVal;
     if (PostDominatedByUnreachable.count(TI->getSuccessor(I - 1)))
       UnreachableIdxs.push_back(I - 1);
     else
@@ -481,20 +501,7 @@
   }
   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
 
-  // If the sum of weights does not fit in 32 bits, scale every weight down
-  // accordingly.
-  uint64_t ScalingFactor =
-      (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
-
-  if (ScalingFactor > 1) {
-    WeightSum = 0;
-    for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
-      Weights[I] /= ScalingFactor;
-      WeightSum += Weights[I];
-    }
-  }
-  assert(WeightSum <= UINT32_MAX &&
-         "Expected weights to scale down to 32 bits");
+  uint64_t WeightSum = ScaleWeights(Weights);
 
   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
@@ -505,7 +512,8 @@
   // Set the probability.
   SmallVector<BranchProbability, 2> BP;
   for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
-    BP.push_back({ Weights[I], static_cast<uint32_t>(WeightSum) });
+    BP.push_back(
+        BranchProbability::getBranchProbability(Weights[I], WeightSum));
 
   // Examine the metadata against unreachable heuristic.
   // If the unreachable heuristic is more strong then we use it for this edge.
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index b8663cd..eb9100a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -4060,7 +4060,7 @@
          "num of prof branch_weights must accord with num of successors");
 
   bool AllZeroes =
-      all_of(Weights.getValue(), [](uint32_t W) { return W == 0; });
+      all_of(Weights.getValue(), [](uint64_t W) { return W == 0; });
 
   if (AllZeroes || Weights.getValue().size() < 2)
     return nullptr;
@@ -4078,10 +4078,10 @@
                      "not correspond to number of succesors");
   }
 
-  SmallVector<uint32_t, 8> Weights;
+  SmallVector<uint64_t, 8> Weights;
   for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
     ConstantInt *C = mdconst::extract<ConstantInt>(ProfileData->getOperand(CI));
-    uint32_t CW = C->getValue().getZExtValue();
+    uint64_t CW = C->getValue().getZExtValue();
     Weights.push_back(CW);
   }
   this->Weights = std::move(Weights);
@@ -4109,7 +4109,7 @@
 
   if (!Weights && W && *W) {
     Changed = true;
-    Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
+    Weights = SmallVector<uint64_t, 8>(SI.getNumSuccessors(), 0);
     Weights.getValue()[SI.getNumSuccessors() - 1] = *W;
   } else if (Weights) {
     Changed = true;
@@ -4142,7 +4142,7 @@
     return;
 
   if (!Weights && *W)
-    Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
+    Weights = SmallVector<uint64_t, 8>(SI.getNumSuccessors(), 0);
 
   if (Weights) {
     auto &OldW = Weights.getValue()[idx];
diff --git a/llvm/lib/IR/MDBuilder.cpp b/llvm/lib/IR/MDBuilder.cpp
index 1f3bed3..871c620 100644
--- a/llvm/lib/IR/MDBuilder.cpp
+++ b/llvm/lib/IR/MDBuilder.cpp
@@ -34,20 +34,20 @@
   return MDNode::get(Context, Op);
 }
 
-MDNode *MDBuilder::createBranchWeights(uint32_t TrueWeight,
-                                       uint32_t FalseWeight) {
+MDNode *MDBuilder::createBranchWeights(uint64_t TrueWeight,
+                                       uint64_t FalseWeight) {
   return createBranchWeights({TrueWeight, FalseWeight});
 }
 
-MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights) {
+MDNode *MDBuilder::createBranchWeights(ArrayRef<uint64_t> Weights) {
   assert(Weights.size() >= 1 && "Need at least one branch weights!");
 
   SmallVector<Metadata *, 4> Vals(Weights.size() + 1);
   Vals[0] = createString("branch_weights");
 
-  Type *Int32Ty = Type::getInt32Ty(Context);
+  Type *Int64Ty = Type::getInt64Ty(Context);
   for (unsigned i = 0, e = Weights.size(); i != e; ++i)
-    Vals[i + 1] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
+    Vals[i + 1] = createConstant(ConstantInt::get(Int64Ty, Weights[i]));
 
   return MDNode::get(Context, Vals);
 }
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index dbcf58f..fc95e0e 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1584,8 +1584,7 @@
                             SortedCallTargets.size());
         } else if (!isa<IntrinsicInst>(&I)) {
           I.setMetadata(LLVMContext::MD_prof,
-                        MDB.createBranchWeights(
-                            {static_cast<uint32_t>(BlockWeights[BB])}));
+                        MDB.createBranchWeights({BlockWeights[BB]}));
         }
       }
     }
@@ -1600,24 +1599,17 @@
                       << ((BranchLoc) ? Twine(BranchLoc.getLine())
                                       : Twine("<UNKNOWN LOCATION>"))
                       << ".\n");
-    SmallVector<uint32_t, 4> Weights;
-    uint32_t MaxWeight = 0;
+    SmallVector<uint64_t, 4> Weights;
+    uint64_t MaxWeight = 0;
     Instruction *MaxDestInst;
     for (unsigned I = 0; I < TI->getNumSuccessors(); ++I) {
       BasicBlock *Succ = TI->getSuccessor(I);
       Edge E = std::make_pair(BB, Succ);
       uint64_t Weight = EdgeWeights[E];
       LLVM_DEBUG(dbgs() << "\t"; printEdgeWeight(dbgs(), E));
-      // Use uint32_t saturated arithmetic to adjust the incoming weights,
-      // if needed. Sample counts in profiles are 64-bit unsigned values,
-      // but internally branch weights are expressed as 32-bit values.
-      if (Weight > std::numeric_limits<uint32_t>::max()) {
-        LLVM_DEBUG(dbgs() << " (saturated due to uint32_t overflow)");
-        Weight = std::numeric_limits<uint32_t>::max();
-      }
       // Weight is added by one to avoid propagation errors introduced by
       // 0 weights.
-      Weights.push_back(static_cast<uint32_t>(Weight + 1));
+      Weights.push_back(Weight + 1);
       if (Weight != 0) {
         if (Weight > MaxWeight) {
           MaxWeight = Weight;
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index a99c58b..1cb0171 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1865,9 +1865,9 @@
         << " branches or selects";
   });
   MergedBR->setCondition(MergedCondition);
-  uint32_t Weights[] = {
-      static_cast<uint32_t>(CHRBranchBias.scale(1000)),
-      static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
+  uint64_t Weights[] = {
+      CHRBranchBias.scale(1000),
+      CHRBranchBias.getCompl().scale(1000),
   };
   MDBuilder MDB(F.getContext());
   MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index dd70c1f..b9530ae 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -116,6 +116,7 @@
 #include <cstdint>
 #include <memory>
 #include <numeric>
+#include <stdint.h>
 #include <string>
 #include <unordered_map>
 #include <utility>
@@ -1830,7 +1831,7 @@
   MDBuilder MDB(M->getContext());
   assert(MaxCount > 0 && "Bad max count");
   uint64_t Scale = calculateCountScale(MaxCount);
-  SmallVector<unsigned, 4> Weights;
+  SmallVector<uint64_t, 4> Weights;
   for (const auto &ECI : EdgeCounts)
     Weights.push_back(scaleBranchCount(ECI, Scale));
 
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 12deaaa..a2bb496 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -293,7 +293,7 @@
     if (BP >= BranchProbability(50, 100))
       continue;
 
-    SmallVector<uint32_t, 2> Weights;
+    SmallVector<uint64_t, 2> Weights;
     if (PredBr->getSuccessor(0) == PredOutEdge.second) {
       Weights.push_back(BP.getNumerator());
       Weights.push_back(BP.getCompl().getNumerator());
@@ -2541,7 +2541,7 @@
   // shouldn't make edges extremely likely or unlikely based solely on static
   // estimation.
   if (BBSuccProbs.size() >= 2 && doesBlockHaveProfileData(BB)) {
-    SmallVector<uint32_t, 4> Weights;
+    SmallVector<uint64_t, 4> Weights;
     for (auto Prob : BBSuccProbs)
       Weights.push_back(Prob.getNumerator());
 
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index 33f73f6..b61ada1 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -47,14 +47,14 @@
 // 'select' instructions. It may be worthwhile to hoist these values to some
 // shared space, so they can be used directly by other passes.
 
-cl::opt<uint32_t> llvm::LikelyBranchWeight(
+cl::opt<uint64_t> llvm::LikelyBranchWeight(
     "likely-branch-weight", cl::Hidden, cl::init(2000),
     cl::desc("Weight of the branch likely to be taken (default = 2000)"));
-cl::opt<uint32_t> llvm::UnlikelyBranchWeight(
+cl::opt<uint64_t> llvm::UnlikelyBranchWeight(
     "unlikely-branch-weight", cl::Hidden, cl::init(1),
     cl::desc("Weight of the branch unlikely to be taken (default = 1)"));
 
-static std::tuple<uint32_t, uint32_t>
+static std::tuple<uint64_t, uint64_t>
 getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) {
   if (IntrinsicID == Intrinsic::expect) {
     // __builtin_expect
@@ -69,8 +69,8 @@
     assert((TrueProb >= 0.0 && TrueProb <= 1.0) &&
            "probability value must be in the range [0.0, 1.0]");
     double FalseProb = (1.0 - TrueProb) / (BranchCount - 1);
-    uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0);
-    uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0);
+    uint64_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0);
+    uint64_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0);
     return std::make_tuple(LikelyBW, UnlikelyBW);
   }
 }
@@ -92,11 +92,11 @@
 
   SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);
   unsigned n = SI.getNumCases(); // +1 for default case.
-  uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+  uint64_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
   std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
       getBranchWeight(Fn->getIntrinsicID(), CI, n + 1);
 
-  SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);
+  SmallVector<uint64_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);
 
   uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1;
   Weights[Index] = LikelyBranchWeightVal;
@@ -248,7 +248,7 @@
         return true;
       return false;
     };
-    uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+    uint64_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
     std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight(
         Expect->getCalledFunction()->getIntrinsicID(), Expect, 2);
 
@@ -318,7 +318,7 @@
   MDNode *Node;
   MDNode *ExpNode;
 
-  uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
+  uint64_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
   std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
       getBranchWeight(Fn->getIntrinsicID(), CI, 2);
 
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 34a5499..7d40c87 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -67,6 +67,7 @@
 #include <iterator>
 #include <map>
 #include <set>
+#include <stdint.h>
 #include <utility>
 #include <vector>
 
@@ -1363,7 +1364,7 @@
 
   // Update the branch weights for the exit block.
   Instruction *TI = CodeReplacer->getTerminator();
-  SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
+  SmallVector<uint64_t, 8> BranchWeights(TI->getNumSuccessors(), 0);
 
   // Block Frequency distribution with dummy node.
   Distribution BranchDist;
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index afbba4c..c570035 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -207,7 +207,7 @@
         // left, unless the metadata doesn't match the switch.
         if (NCases > 1 && MD && MD->getNumOperands() == 2 + NCases) {
           // Collect branch weights into a vector.
-          SmallVector<uint32_t, 8> Weights;
+          SmallVector<uint64_t, 8> Weights;
           for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e;
                ++MD_i) {
             auto *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i));
@@ -2091,11 +2091,8 @@
   // If the invoke had profile metadata, try converting them for CallInst.
   uint64_t TotalWeight;
   if (NewCall->extractProfTotalWeight(TotalWeight)) {
-    // Set the total weight if it fits into i32, otherwise reset.
     MDBuilder MDB(NewCall->getContext());
-    auto NewWeights = uint32_t(TotalWeight) != TotalWeight
-                          ? nullptr
-                          : MDB.createBranchWeights({uint32_t(TotalWeight)});
+    auto NewWeights = MDB.createBranchWeights({TotalWeight});
     NewCall->setMetadata(LLVMContext::MD_prof, NewWeights);
   }
 
diff --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp
index a16ca1f..5b68f1e 100644
--- a/llvm/lib/Transforms/Utils/MisExpect.cpp
+++ b/llvm/lib/Transforms/Utils/MisExpect.cpp
@@ -93,7 +93,7 @@
 namespace llvm {
 namespace misexpect {
 
-void verifyMisExpect(Instruction *I, const SmallVector<uint32_t, 4> &Weights,
+void verifyMisExpect(Instruction *I, const SmallVector<uint64_t, 4> &Weights,
                      LLVMContext &Ctx) {
   if (auto *MisExpectData = I->getMetadata(LLVMContext::MD_misexpect)) {
     auto *MisExpectDataName = dyn_cast<MDString>(MisExpectData->getOperand(0));
@@ -161,7 +161,7 @@
     // Operand 0 is a string tag "branch_weights"
     if (MDString *Tag = cast<MDString>(MD->getOperand(0))) {
       if (Tag->getString().equals("branch_weights")) {
-        SmallVector<uint32_t, 4> RealWeights(NOps - 1);
+        SmallVector<uint64_t, 4> RealWeights(NOps - 1);
         for (unsigned i = 1; i < NOps; i++) {
           ConstantInt *Value =
               mdconst::dyn_extract<ConstantInt>(MD->getOperand(i));
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index a48a335..4a74277 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -238,7 +238,7 @@
                               const TargetTransformInfo &TTI);
   bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond,
                                   BasicBlock *TrueBB, BasicBlock *FalseBB,
-                                  uint32_t TrueWeight, uint32_t FalseWeight);
+                                  uint64_t TrueWeight, uint64_t FalseWeight);
   bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder,
                                  const DataLayout &DL);
   bool SimplifySwitchOnSelect(SwitchInst *SI, SelectInst *Select);
@@ -825,19 +825,19 @@
 
 // Set branch weights on SwitchInst. This sets the metadata if there is at
 // least one non-zero weight.
-static void setBranchWeights(SwitchInst *SI, ArrayRef<uint32_t> Weights) {
+static void setBranchWeights(SwitchInst *SI, ArrayRef<uint64_t> Weights) {
   // Check that there is at least one non-zero weight. Otherwise, pass
   // nullptr to setMetadata which will erase the existing metadata.
   MDNode *N = nullptr;
-  if (llvm::any_of(Weights, [](uint32_t W) { return W != 0; }))
+  if (llvm::any_of(Weights, [](uint64_t W) { return W != 0; }))
     N = MDBuilder(SI->getParent()->getContext()).createBranchWeights(Weights);
   SI->setMetadata(LLVMContext::MD_prof, N);
 }
 
 // Similar to the above, but for branch and select instructions that take
 // exactly 2 weights.
-static void setBranchWeights(Instruction *I, uint32_t TrueWeight,
-                             uint32_t FalseWeight) {
+static void setBranchWeights(Instruction *I, uint64_t TrueWeight,
+                             uint64_t FalseWeight) {
   assert(isa<BranchInst>(I) || isa<SelectInst>(I));
   // Check that there is at least one non-zero weight. Otherwise, pass
   // nullptr to setMetadata which will erase the existing metadata.
@@ -1025,16 +1025,6 @@
   }
 }
 
-/// Keep halving the weights until all can fit in uint32_t.
-static void FitWeights(MutableArrayRef<uint64_t> Weights) {
-  uint64_t Max = *std::max_element(Weights.begin(), Weights.end());
-  if (Max > UINT_MAX) {
-    unsigned Offset = 32 - countLeadingZeros(Max);
-    for (uint64_t &I : Weights)
-      I >>= Offset;
-  }
-}
-
 /// The specified terminator is a value equality comparison instruction
 /// (either a switch or a branch on "X == c").
 /// See if any of the predecessors of the terminator block are value comparisons
@@ -1220,10 +1210,7 @@
         NewSI->addCase(V.Value, V.Dest);
 
       if (PredHasWeights || SuccHasWeights) {
-        // Halve the weights if any of them cannot fit in an uint32_t
-        FitWeights(Weights);
-
-        SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
+        SmallVector<uint64_t, 8> MDWeights(Weights.begin(), Weights.end());
 
         setBranchWeights(NewSI, MDWeights);
       }
@@ -2954,10 +2941,7 @@
         PBI->setSuccessor(1, FalseDest);
       }
       if (NewWeights.size() == 2) {
-        // Halve the weights if any of them cannot fit in an uint32_t
-        FitWeights(NewWeights);
-
-        SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(),
+        SmallVector<uint64_t, 8> MDWeights(NewWeights.begin(),
                                            NewWeights.end());
         setBranchWeights(PBI, MDWeights[0], MDWeights[1]);
       } else
@@ -3585,8 +3569,6 @@
     uint64_t NewWeights[2] = {PredCommon * (SuccCommon + SuccOther) +
                                   PredOther * SuccCommon,
                               PredOther * SuccOther};
-    // Halve the weights if any of them cannot fit in an uint32_t
-    FitWeights(NewWeights);
 
     setBranchWeights(PBI, NewWeights[0], NewWeights[1]);
   }
@@ -3622,8 +3604,6 @@
         uint64_t NewWeights[2] = {PredCommon * (SuccCommon + SuccOther),
                                   PredOther * SuccCommon};
 
-        FitWeights(NewWeights);
-
         setBranchWeights(NV, NewWeights[0], NewWeights[1]);
       }
     }
@@ -3645,8 +3625,8 @@
 bool SimplifyCFGOpt::SimplifyTerminatorOnSelect(Instruction *OldTerm,
                                                 Value *Cond, BasicBlock *TrueBB,
                                                 BasicBlock *FalseBB,
-                                                uint32_t TrueWeight,
-                                                uint32_t FalseWeight) {
+                                                uint64_t TrueWeight,
+                                                uint64_t FalseWeight) {
   // Remove any superfluous successor edges from the CFG.
   // First, figure out which successors to preserve.
   // If TrueBB and FalseBB are equal, only try to preserve one copy of that
@@ -3720,16 +3700,16 @@
   BasicBlock *FalseBB = SI->findCaseValue(FalseVal)->getCaseSuccessor();
 
   // Get weight for TrueBB and FalseBB.
-  uint32_t TrueWeight = 0, FalseWeight = 0;
+  uint64_t TrueWeight = 0, FalseWeight = 0;
   SmallVector<uint64_t, 8> Weights;
   bool HasWeights = HasBranchWeights(SI);
   if (HasWeights) {
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {
       TrueWeight =
-          (uint32_t)Weights[SI->findCaseValue(TrueVal)->getSuccessorIndex()];
+          (uint64_t)Weights[SI->findCaseValue(TrueVal)->getSuccessorIndex()];
       FalseWeight =
-          (uint32_t)Weights[SI->findCaseValue(FalseVal)->getSuccessorIndex()];
+          (uint64_t)Weights[SI->findCaseValue(FalseVal)->getSuccessorIndex()];
     }
   }