Fix PR3401: when using large integers, the type
returned by getShiftAmountTy may be too small
to hold shift values (it is an i8 on x86-32).
Before and during type legalization, use a large
but legal type for shift amounts: getPointerTy;
afterwards use getShiftAmountTy, fixing up any
shift amounts with a big type during operation
legalization.  Thanks to Dan for writing the
original patch (which I shamelessly pillaged).


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@63482 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7240fb9..1dbbf4f 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -240,7 +240,13 @@
     /// FindBetterChain - Walk up chain skipping non-aliasing memory nodes,
     /// looking for a better chain (aliasing node.)
     SDValue FindBetterChain(SDNode *N, SDValue Chain);
-    
+
+    /// getShiftAmountTy - Returns a type large enough to hold any valid
+    /// shift amount - before type legalization these can be huge.
+    MVT getShiftAmountTy() {
+      return LegalTypes ?  TLI.getShiftAmountTy() : TLI.getPointerTy();
+    }
+
 public:
     DAGCombiner(SelectionDAG &D, AliasAnalysis &A, bool fast)
       : DAG(D),
@@ -1301,7 +1307,7 @@
   if (N1C && N1C->getAPIntValue().isPowerOf2())
     return DAG.getNode(ISD::SHL, N->getDebugLoc(), VT, N0,
                        DAG.getConstant(N1C->getAPIntValue().logBase2(),
-                                       TLI.getShiftAmountTy()));
+                                       getShiftAmountTy()));
   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
   if (N1C && isPowerOf2_64(-N1C->getSExtValue()))
     // FIXME: If the input is something that is easily negated (e.g. a 
@@ -1310,7 +1316,7 @@
                        DAG.getConstant(0, VT),
                        DAG.getNode(ISD::SHL, N->getDebugLoc(), VT, N0,
                             DAG.getConstant(Log2_64(-N1C->getSExtValue()),
-                                            TLI.getShiftAmountTy())));
+                                            getShiftAmountTy())));
   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
   if (N1C && N0.getOpcode() == ISD::SHL &&
       isa<ConstantSDNode>(N0.getOperand(1))) {
@@ -1406,18 +1412,18 @@
     // Splat the sign bit into the register
     SDValue SGN = DAG.getNode(ISD::SRA, N->getDebugLoc(), VT, N0,
                               DAG.getConstant(VT.getSizeInBits()-1,
-                                              TLI.getShiftAmountTy()));
+                                              getShiftAmountTy()));
     AddToWorkList(SGN.getNode());
 
     // Add (N0 < 0) ? abs2 - 1 : 0;
     SDValue SRL = DAG.getNode(ISD::SRL, N->getDebugLoc(), VT, SGN,
                               DAG.getConstant(VT.getSizeInBits() - lg2,
-                                              TLI.getShiftAmountTy()));
+                                              getShiftAmountTy()));
     SDValue ADD = DAG.getNode(ISD::ADD, N->getDebugLoc(), VT, N0, SRL);
     AddToWorkList(SRL.getNode());
     AddToWorkList(ADD.getNode());    // Divide by pow2
     SDValue SRA = DAG.getNode(ISD::SRA, N->getDebugLoc(), VT, ADD,
-                              DAG.getConstant(lg2, TLI.getShiftAmountTy()));
+                              DAG.getConstant(lg2, getShiftAmountTy()));
 
     // If we're dividing by a positive value, we're done.  Otherwise, we must
     // negate the result.
@@ -1467,7 +1473,7 @@
   if (N1C && N1C->getAPIntValue().isPowerOf2())
     return DAG.getNode(ISD::SRL, N->getDebugLoc(), VT, N0, 
                        DAG.getConstant(N1C->getAPIntValue().logBase2(),
-                                       TLI.getShiftAmountTy()));
+                                       getShiftAmountTy()));
   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
   if (N1.getOpcode() == ISD::SHL) {
     if (ConstantSDNode *SHC = dyn_cast<ConstantSDNode>(N1.getOperand(0))) {
@@ -1607,7 +1613,7 @@
   if (N1C && N1C->getAPIntValue() == 1)
     return DAG.getNode(ISD::SRA, N->getDebugLoc(), N0.getValueType(), N0,
                        DAG.getConstant(N0.getValueType().getSizeInBits() - 1,
-                                       TLI.getShiftAmountTy()));
+                                       getShiftAmountTy()));
   // fold (mulhs x, undef) -> 0
   if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)
     return DAG.getConstant(0, VT);
@@ -2613,7 +2619,7 @@
           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
           TLI.isTruncateFree(VT, TruncVT)) {
 
-          SDValue Amt = DAG.getConstant(ShiftAmt, TLI.getShiftAmountTy());
+          SDValue Amt = DAG.getConstant(ShiftAmt, getShiftAmountTy());
           SDValue Shift = DAG.getNode(ISD::SRL, N0.getDebugLoc(), VT,
                                       N0.getOperand(0), Amt);
           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, N0.getDebugLoc(), TruncVT,
@@ -2740,7 +2746,7 @@
 
       if (ShAmt) {
         Op = DAG.getNode(ISD::SRL, N0.getDebugLoc(), VT, Op,
-                         DAG.getConstant(ShAmt, TLI.getShiftAmountTy()));
+                         DAG.getConstant(ShAmt, getShiftAmountTy()));
         AddToWorkList(Op.getNode());
       }
 
@@ -5722,7 +5728,7 @@
       if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue()-1)) == 0)) {
         unsigned ShCtV = N2C->getAPIntValue().logBase2();
         ShCtV = XType.getSizeInBits()-ShCtV-1;
-        SDValue ShCt = DAG.getConstant(ShCtV, TLI.getShiftAmountTy());
+        SDValue ShCt = DAG.getConstant(ShCtV, getShiftAmountTy());
         SDValue Shift = DAG.getNode(ISD::SRL, N0.getDebugLoc(),
                                     XType, N0, ShCt);
         AddToWorkList(Shift.getNode());
@@ -5738,7 +5744,7 @@
       SDValue Shift = DAG.getNode(ISD::SRA, N0.getDebugLoc(),
                                   XType, N0,
                                   DAG.getConstant(XType.getSizeInBits()-1,
-                                                  TLI.getShiftAmountTy()));
+                                                  getShiftAmountTy()));
       AddToWorkList(Shift.getNode());
 
       if (XType.bitsGT(AType)) {
@@ -5787,7 +5793,7 @@
     // shl setcc result by log2 n2c
     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
                        DAG.getConstant(N2C->getAPIntValue().logBase2(),
-                                       TLI.getShiftAmountTy()));
+                                       getShiftAmountTy()));
   }
     
   // Check to see if this is the equivalent of setcc
@@ -5810,7 +5816,7 @@
       SDValue Ctlz = DAG.getNode(ISD::CTLZ, N0.getDebugLoc(), XType, N0);
       return DAG.getNode(ISD::SRL, DL, XType, Ctlz, 
                          DAG.getConstant(Log2_32(XType.getSizeInBits()),
-                                         TLI.getShiftAmountTy()));
+                                         getShiftAmountTy()));
     }
     // fold (setgt X, 0) -> (srl (and (-X, ~X), size(X)-1))
     if (N1C && N1C->isNullValue() && CC == ISD::SETGT) { 
@@ -5820,13 +5826,13 @@
       return DAG.getNode(ISD::SRL, DL, XType,
                          DAG.getNode(ISD::AND, XType, NegN0, NotN0),
                          DAG.getConstant(XType.getSizeInBits()-1,
-                                         TLI.getShiftAmountTy()));
+                                         getShiftAmountTy()));
     }
     // fold (setgt X, -1) -> (xor (srl (X, size(X)-1), 1))
     if (N1C && N1C->isAllOnesValue() && CC == ISD::SETGT) {
       SDValue Sign = DAG.getNode(ISD::SRL, N0.getDebugLoc(), XType, N0,
                                  DAG.getConstant(XType.getSizeInBits()-1,
-                                                 TLI.getShiftAmountTy()));
+                                                 getShiftAmountTy()));
       return DAG.getNode(ISD::XOR, DL, XType, Sign, DAG.getConstant(1, XType));
     }
   }
@@ -5839,7 +5845,7 @@
     MVT XType = N0.getValueType();
     SDValue Shift = DAG.getNode(ISD::SRA, N0.getDebugLoc(), XType, N0,
                                 DAG.getConstant(XType.getSizeInBits()-1,
-                                                TLI.getShiftAmountTy()));
+                                                getShiftAmountTy()));
     SDValue Add = DAG.getNode(ISD::ADD, N0.getDebugLoc(), XType,
                               N0, Shift);
     AddToWorkList(Shift.getNode());
@@ -5856,7 +5862,7 @@
         SDValue Shift = DAG.getNode(ISD::SRA, N0.getDebugLoc(), XType,
                                     N0,
                                     DAG.getConstant(XType.getSizeInBits()-1,
-                                                    TLI.getShiftAmountTy()));
+                                                    getShiftAmountTy()));
         SDValue Add = DAG.getNode(ISD::ADD, N0.getDebugLoc(),
                                   XType, N0, Shift);
         AddToWorkList(Shift.getNode());