Use uint64_t for branch weights instead of uint32_t
CallInst::updateProfWeight() creates branch_weights with i64 instead of i32.
To be more consistent everywhere and remove lots of casts from uint64_t
to uint32_t, use i64 for branch_weights.
Reviewed By: davidxl
Differential Revision: https://reviews.llvm.org/D88609
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index b8663cd..eb9100a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -4060,7 +4060,7 @@
"num of prof branch_weights must accord with num of successors");
bool AllZeroes =
- all_of(Weights.getValue(), [](uint32_t W) { return W == 0; });
+ all_of(Weights.getValue(), [](uint64_t W) { return W == 0; });
if (AllZeroes || Weights.getValue().size() < 2)
return nullptr;
@@ -4078,10 +4078,10 @@
"not correspond to number of succesors");
}
- SmallVector<uint32_t, 8> Weights;
+ SmallVector<uint64_t, 8> Weights;
for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
ConstantInt *C = mdconst::extract<ConstantInt>(ProfileData->getOperand(CI));
- uint32_t CW = C->getValue().getZExtValue();
+ uint64_t CW = C->getValue().getZExtValue();
Weights.push_back(CW);
}
this->Weights = std::move(Weights);
@@ -4109,7 +4109,7 @@
if (!Weights && W && *W) {
Changed = true;
- Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
+ Weights = SmallVector<uint64_t, 8>(SI.getNumSuccessors(), 0);
Weights.getValue()[SI.getNumSuccessors() - 1] = *W;
} else if (Weights) {
Changed = true;
@@ -4142,7 +4142,7 @@
return;
if (!Weights && *W)
- Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
+ Weights = SmallVector<uint64_t, 8>(SI.getNumSuccessors(), 0);
if (Weights) {
auto &OldW = Weights.getValue()[idx];
diff --git a/llvm/lib/IR/MDBuilder.cpp b/llvm/lib/IR/MDBuilder.cpp
index 1f3bed3..871c620 100644
--- a/llvm/lib/IR/MDBuilder.cpp
+++ b/llvm/lib/IR/MDBuilder.cpp
@@ -34,20 +34,20 @@
return MDNode::get(Context, Op);
}
-MDNode *MDBuilder::createBranchWeights(uint32_t TrueWeight,
- uint32_t FalseWeight) {
+MDNode *MDBuilder::createBranchWeights(uint64_t TrueWeight,
+ uint64_t FalseWeight) {
return createBranchWeights({TrueWeight, FalseWeight});
}
-MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights) {
+MDNode *MDBuilder::createBranchWeights(ArrayRef<uint64_t> Weights) {
assert(Weights.size() >= 1 && "Need at least one branch weights!");
SmallVector<Metadata *, 4> Vals(Weights.size() + 1);
Vals[0] = createString("branch_weights");
- Type *Int32Ty = Type::getInt32Ty(Context);
+ Type *Int64Ty = Type::getInt64Ty(Context);
for (unsigned i = 0, e = Weights.size(); i != e; ++i)
- Vals[i + 1] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
+ Vals[i + 1] = createConstant(ConstantInt::get(Int64Ty, Weights[i]));
return MDNode::get(Context, Vals);
}