Avoid zero extend bit test operands to pointer type if all the masks fit in
the original type of the switch statement key.
rdar://8781238


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@122935 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 8162422..88490b3 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1615,12 +1615,28 @@
                                   Sub, DAG.getConstant(B.Range, VT),
                                   ISD::SETUGT);
 
-  SDValue ShiftOp = DAG.getZExtOrTrunc(Sub, getCurDebugLoc(),
-                                       TLI.getPointerTy());
+  // Determine the type of the test operands.
+  bool UsePtrType = false;
+  if (!TLI.isTypeLegal(VT))
+    UsePtrType = true;
+  else {
+    for (unsigned i = 0, e = B.Cases.size(); i != e; ++i)
+      if ((uint64_t)((int64_t)B.Cases[i].Mask >> VT.getSizeInBits()) + 1 >= 2) {
+        // Switch table case range are encoded into series of masks.
+        // Just use pointer type, it's guaranteed to fit.
+        UsePtrType = true;
+        break;
+      }
+  }
+  if (UsePtrType) {
+    VT = TLI.getPointerTy();
+    Sub = DAG.getZExtOrTrunc(Sub, getCurDebugLoc(), VT);
+  }
 
-  B.Reg = FuncInfo.CreateReg(TLI.getPointerTy());
+  B.RegVT = VT;
+  B.Reg = FuncInfo.CreateReg(VT);
   SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), getCurDebugLoc(),
-                                    B.Reg, ShiftOp);
+                                    B.Reg, Sub);
 
   // Set NextBlock to be the MBB immediately after the current one, if any.
   // This is used to avoid emitting unnecessary branches to the next block.
@@ -1646,36 +1662,34 @@
 }
 
 /// visitBitTestCase - this function produces one "bit test"
-void SelectionDAGBuilder::visitBitTestCase(MachineBasicBlock* NextMBB,
+void SelectionDAGBuilder::visitBitTestCase(BitTestBlock &BB,
+                                           MachineBasicBlock* NextMBB,
                                            unsigned Reg,
                                            BitTestCase &B,
                                            MachineBasicBlock *SwitchBB) {
-  SDValue ShiftOp = DAG.getCopyFromReg(getControlRoot(), getCurDebugLoc(), Reg,
-                                       TLI.getPointerTy());
+  EVT VT = BB.RegVT;
+  SDValue ShiftOp = DAG.getCopyFromReg(getControlRoot(), getCurDebugLoc(),
+                                       Reg, VT);
   SDValue Cmp;
   if (CountPopulation_64(B.Mask) == 1) {
     // Testing for a single bit; just compare the shift count with what it
     // would need to be to shift a 1 bit in that position.
     Cmp = DAG.getSetCC(getCurDebugLoc(),
-                       TLI.getSetCCResultType(ShiftOp.getValueType()),
+                       TLI.getSetCCResultType(VT),
                        ShiftOp,
-                       DAG.getConstant(CountTrailingZeros_64(B.Mask),
-                                       TLI.getPointerTy()),
+                       DAG.getConstant(CountTrailingZeros_64(B.Mask), VT),
                        ISD::SETEQ);
   } else {
     // Make desired shift
-    SDValue SwitchVal = DAG.getNode(ISD::SHL, getCurDebugLoc(),
-                                    TLI.getPointerTy(),
-                                    DAG.getConstant(1, TLI.getPointerTy()),
-                                    ShiftOp);
+    SDValue SwitchVal = DAG.getNode(ISD::SHL, getCurDebugLoc(), VT,
+                                    DAG.getConstant(1, VT), ShiftOp);
 
     // Emit bit tests and jumps
     SDValue AndOp = DAG.getNode(ISD::AND, getCurDebugLoc(),
-                                TLI.getPointerTy(), SwitchVal,
-                                DAG.getConstant(B.Mask, TLI.getPointerTy()));
+                                VT, SwitchVal, DAG.getConstant(B.Mask, VT));
     Cmp = DAG.getSetCC(getCurDebugLoc(),
-                       TLI.getSetCCResultType(AndOp.getValueType()),
-                       AndOp, DAG.getConstant(0, TLI.getPointerTy()),
+                       TLI.getSetCCResultType(VT),
+                       AndOp, DAG.getConstant(0, VT),
                        ISD::SETNE);
   }
 
@@ -2219,7 +2233,7 @@
   }
 
   BitTestBlock BTB(lowBound, cmpRange, SV,
-                   -1U, (CR.CaseBB == SwitchBB),
+                   -1U, MVT::Other, (CR.CaseBB == SwitchBB),
                    CR.CaseBB, Default, BTC);
 
   if (CR.CaseBB == SwitchBB)