Switch lowering: use profile info to build weight-balanced binary search trees

This will cause hot nodes to appear closer to the root.

The literature says building the tree like this makes it a near-optimal (in
terms of search time given key frequencies) binary search tree. In LLVM's case,
we can do up to 3 comparisons in each leaf node, so it might be better to opt
for lower tree height in some cases; that's something to look into in the
future.

Differential Revision: http://reviews.llvm.org/D9318

llvm-svn: 236192
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index db41c40..1c14d4d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7986,14 +7986,41 @@
   unsigned NumClusters = W.LastCluster - W.FirstCluster + 1;
   assert(NumClusters >= 2 && "Too small to split!");
 
-  // FIXME: When we have profile info, we might want to balance the tree based
-  // on weights instead of node count.
+  // Balance the tree based on branch weights to create a near-optimal (in terms
+  // of search time given key frequency) binary search tree. See e.g. Kurt
+  // Mehlhorn "Nearly Optimal Binary Search Trees" (1975).
+  CaseClusterIt LastLeft = W.FirstCluster;
+  CaseClusterIt FirstRight = W.LastCluster;
+  uint32_t LeftWeight = LastLeft->Weight;
+  uint32_t RightWeight = FirstRight->Weight;
 
-  CaseClusterIt PivotCluster = W.FirstCluster + NumClusters / 2;
+  // Move LastLeft and FirstRight towards each other from opposite directions to
+  // find a partitioning of the clusters which balances the weight on both
+  // sides.
+  while (LastLeft + 1 < FirstRight) {
+    // Zero-weight nodes would cause skewed trees since they don't affect
+    // LeftWeight or RightWeight.
+    assert(LastLeft->Weight != 0);
+    assert(FirstRight->Weight != 0);
+
+    if (LeftWeight < RightWeight)
+      LeftWeight += (++LastLeft)->Weight;
+    else
+      RightWeight += (--FirstRight)->Weight;
+  }
+  assert(LastLeft + 1 == FirstRight);
+  assert(LastLeft >= W.FirstCluster);
+  assert(FirstRight <= W.LastCluster);
+
+  // Use the first element on the right as pivot since we will make less-than
+  // comparisons against it.
+  CaseClusterIt PivotCluster = FirstRight;
+  assert(PivotCluster > W.FirstCluster);
+  assert(PivotCluster <= W.LastCluster);
+
   CaseClusterIt FirstLeft = W.FirstCluster;
-  CaseClusterIt LastLeft = PivotCluster - 1;
-  CaseClusterIt FirstRight = PivotCluster;
   CaseClusterIt LastRight = W.LastCluster;
+
   const ConstantInt *Pivot = PivotCluster->Low;
 
   // New blocks will be inserted immediately after the current one.
@@ -8032,7 +8059,8 @@
   }
 
   // Create the CaseBlock record that will be used to lower the branch.
-  CaseBlock CB(ISD::SETLT, Cond, Pivot, nullptr, LeftMBB, RightMBB, W.MBB);
+  CaseBlock CB(ISD::SETLT, Cond, Pivot, nullptr, LeftMBB, RightMBB, W.MBB,
+               LeftWeight, RightWeight);
 
   if (W.MBB == SwitchMBB)
     visitSwitchCase(CB, SwitchMBB);
@@ -8048,7 +8076,7 @@
   for (auto I : SI.cases()) {
     MachineBasicBlock *Succ = FuncInfo.MBBMap[I.getCaseSuccessor()];
     const ConstantInt *CaseVal = I.getCaseValue();
-    uint32_t Weight = 0; // FIXME: Use 1 instead?
+    uint32_t Weight = 1;
     if (BPI) {
       Weight = BPI->getEdgeWeight(SI.getParent(), I.getSuccessorIndex());
       assert(Weight <= UINT32_MAX / SI.getNumSuccessors());