Add a wrapper function to set branch weights metadata.

Summary:
This wrapper checks if there is at least one non-zero weight before
setting the metadata.

Reviewers: davidxl

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D39872

llvm-svn: 317845
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index e0045e9..e26f382 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -776,6 +776,31 @@
   return false;
 }
 
+// Set branch weights on SwitchInst. This sets the metadata if there is at
+// least one non-zero weight.
+static void setBranchWeights(SwitchInst *SI, ArrayRef<uint32_t> Weights) {
+  // Check that there is at least one non-zero weight. Otherwise, pass
+  // nullptr to setMetadata which will erase the existing metadata.
+  MDNode *N = nullptr;
+  if (llvm::any_of(Weights, [](uint32_t W) { return W != 0; }))
+    N = MDBuilder(SI->getParent()->getContext()).createBranchWeights(Weights);
+  SI->setMetadata(LLVMContext::MD_prof, N);
+}
+
+// Similar to the above, but for branch and select instructions that take
+// exactly 2 weights.
+static void setBranchWeights(Instruction *I, uint32_t TrueWeight,
+                             uint32_t FalseWeight) {
+  assert(isa<BranchInst>(I) || isa<SelectInst>(I));
+  // Check that there is at least one non-zero weight. Otherwise, pass
+  // nullptr to setMetadata which will erase the existing metadata.
+  MDNode *N = nullptr;
+  if (TrueWeight || FalseWeight)
+    N = MDBuilder(I->getParent()->getContext())
+            .createBranchWeights(TrueWeight, FalseWeight);
+  I->setMetadata(LLVMContext::MD_prof, N);
+}
+
 /// If TI is known to be a terminator instruction and its block is known to
 /// only have a single predecessor block, check to see if that predecessor is
 /// also a value comparison with the same value, and if that comparison
@@ -865,9 +890,7 @@
       }
     }
     if (HasWeight && Weights.size() >= 2)
-      SI->setMetadata(LLVMContext::MD_prof,
-                      MDBuilder(SI->getParent()->getContext())
-                          .createBranchWeights(Weights));
+      setBranchWeights(SI, Weights);
 
     DEBUG(dbgs() << "Leaving: " << *TI << "\n");
     return true;
@@ -1172,9 +1195,7 @@
 
         SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
 
-        NewSI->setMetadata(
-            LLVMContext::MD_prof,
-            MDBuilder(BB->getContext()).createBranchWeights(MDWeights));
+        setBranchWeights(NewSI, MDWeights);
       }
 
       EraseTerminatorInstAndDCECond(PTI);
@@ -2738,9 +2759,7 @@
 
         SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(),
                                            NewWeights.end());
-        PBI->setMetadata(
-            LLVMContext::MD_prof,
-            MDBuilder(BI->getContext()).createBranchWeights(MDWeights));
+        setBranchWeights(PBI, MDWeights[0], MDWeights[1]);
       } else
         PBI->setMetadata(LLVMContext::MD_prof, nullptr);
     } else {
@@ -3309,9 +3328,7 @@
     // Halve the weights if any of them cannot fit in an uint32_t
     FitWeights(NewWeights);
 
-    PBI->setMetadata(LLVMContext::MD_prof,
-                     MDBuilder(BI->getContext())
-                         .createBranchWeights(NewWeights[0], NewWeights[1]));
+    setBranchWeights(PBI, NewWeights[0], NewWeights[1]);
   }
 
   // OtherDest may have phi nodes.  If so, add an entry from PBI's
@@ -3349,9 +3366,7 @@
 
         FitWeights(NewWeights);
 
-        NV->setMetadata(LLVMContext::MD_prof,
-                        MDBuilder(BI->getContext())
-                            .createBranchWeights(NewWeights[0], NewWeights[1]));
+        setBranchWeights(NV, NewWeights[0], NewWeights[1]);
       }
     }
   }
@@ -3406,9 +3421,7 @@
       // Create a conditional branch sharing the condition of the select.
       BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB);
       if (TrueWeight != FalseWeight)
-        NewBI->setMetadata(LLVMContext::MD_prof,
-                           MDBuilder(OldTerm->getContext())
-                               .createBranchWeights(TrueWeight, FalseWeight));
+        setBranchWeights(NewBI, TrueWeight, FalseWeight);
     }
   } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) {
     // Neither of the selected blocks were successors, so this
@@ -3594,9 +3607,7 @@
       Weights.push_back(Weights[0]);
 
       SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
-      SI->setMetadata(
-          LLVMContext::MD_prof,
-          MDBuilder(SI->getContext()).createBranchWeights(MDWeights));
+      setBranchWeights(SI, MDWeights);
     }
   }
   SI->addCase(Cst, NewBB);
@@ -4323,10 +4334,7 @@
         TrueWeight /= 2;
         FalseWeight /= 2;
       }
-      NewBI->setMetadata(LLVMContext::MD_prof,
-                         MDBuilder(SI->getContext())
-                             .createBranchWeights((uint32_t)TrueWeight,
-                                                  (uint32_t)FalseWeight));
+      setBranchWeights(NewBI, TrueWeight, FalseWeight);
     }
   }
 
@@ -4423,9 +4431,7 @@
   }
   if (HasWeight && Weights.size() >= 2) {
     SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
-    SI->setMetadata(LLVMContext::MD_prof,
-                    MDBuilder(SI->getParent()->getContext())
-                        .createBranchWeights(MDWeights));
+    setBranchWeights(SI, MDWeights);
   }
 
   return !DeadCases.empty();