Allow targets to specify a the type of the RHS of a shift parameterized on the type of the LHS.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@126518 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 691390e..35b847c 100644
--- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -563,7 +563,7 @@
   setOperationAction(ISD::TRAP, MVT::Other, Expand);
 
   IsLittleEndian = TD->isLittleEndian();
-  ShiftAmountTy = PointerTy = MVT::getIntegerVT(8*TD->getPointerSize());
+  PointerTy = MVT::getIntegerVT(8*TD->getPointerSize());
   memset(RegClassForVT, 0,MVT::LAST_VALUETYPE*sizeof(TargetRegisterClass*));
   memset(TargetDAGCombineArray, 0, array_lengthof(TargetDAGCombineArray));
   maxStoresPerMemset = maxStoresPerMemcpy = maxStoresPerMemmove = 8;
@@ -596,6 +596,10 @@
   delete &TLOF;
 }
 
+MVT TargetLowering::getShiftAmountTy(EVT LHSTy) const {
+  return MVT::getIntegerVT(8*TD->getPointerSize());
+}
+
 /// canOpTrap - Returns true if the operation can trap for the value type.
 /// VT must be a legal type.
 bool TargetLowering::canOpTrap(unsigned Op, EVT VT) const {
@@ -1401,7 +1405,7 @@
                                    BitWidth - InnerVT.getSizeInBits()) &
                DemandedMask) == 0 &&
             isTypeDesirableForOp(ISD::SHL, InnerVT)) {
-          EVT ShTy = getShiftAmountTy();
+          EVT ShTy = getShiftAmountTy(InnerVT);
           if (!APInt(BitWidth, ShAmt).isIntN(ShTy.getSizeInBits()))
             ShTy = InnerVT;
           SDValue NarrowShl =
@@ -2188,7 +2192,7 @@
       if (ConstantSDNode *AndRHS =
                   dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
         EVT ShiftTy = DCI.isBeforeLegalize() ?
-          getPointerTy() : getShiftAmountTy();
+          getPointerTy() : getShiftAmountTy(N0.getValueType());
         if (Cond == ISD::SETNE && C1 == 0) {// (X & 8) != 0  -->  (X & 8) >> 3
           // Perform the xform if the AND RHS is a single bit.
           if (AndRHS->getAPIntValue().isPowerOf2()) {
@@ -2359,7 +2363,7 @@
           // (Z-X) == X  --> Z == X<<1
           SDValue SH = DAG.getNode(ISD::SHL, dl, N1.getValueType(),
                                      N1,
-                                     DAG.getConstant(1, getShiftAmountTy()));
+                       DAG.getConstant(1, getShiftAmountTy(N1.getValueType())));
           if (!DCI.isCalledByLegalizer())
             DCI.AddToWorklist(SH.getNode());
           return DAG.getSetCC(dl, VT, N0.getOperand(0), SH, Cond);
@@ -2381,7 +2385,7 @@
           assert(N1.getOpcode() == ISD::SUB && "Unexpected operation!");
           // X == (Z-X)  --> X<<1 == Z
           SDValue SH = DAG.getNode(ISD::SHL, dl, N1.getValueType(), N0,
-                                     DAG.getConstant(1, getShiftAmountTy()));
+                       DAG.getConstant(1, getShiftAmountTy(N0.getValueType())));
           if (!DCI.isCalledByLegalizer())
             DCI.AddToWorklist(SH.getNode());
           return DAG.getSetCC(dl, VT, SH, N1.getOperand(0), Cond);
@@ -2493,7 +2497,7 @@
       }
     }
   }
-  
+
   return false;
 }
 
@@ -3141,14 +3145,14 @@
   // Shift right algebraic if shift value is nonzero
   if (magics.s > 0) {
     Q = DAG.getNode(ISD::SRA, dl, VT, Q,
-                    DAG.getConstant(magics.s, getShiftAmountTy()));
+                 DAG.getConstant(magics.s, getShiftAmountTy(Q.getValueType())));
     if (Created)
       Created->push_back(Q.getNode());
   }
   // Extract the sign bit and add it to the quotient
   SDValue T =
     DAG.getNode(ISD::SRL, dl, VT, Q, DAG.getConstant(VT.getSizeInBits()-1,
-                                                 getShiftAmountTy()));
+                                           getShiftAmountTy(Q.getValueType())));
   if (Created)
     Created->push_back(T.getNode());
   return DAG.getNode(ISD::ADD, dl, VT, Q, T);
@@ -3192,19 +3196,19 @@
     assert(magics.s < N1C->getAPIntValue().getBitWidth() &&
            "We shouldn't generate an undefined shift!");
     return DAG.getNode(ISD::SRL, dl, VT, Q,
-                       DAG.getConstant(magics.s, getShiftAmountTy()));
+                 DAG.getConstant(magics.s, getShiftAmountTy(Q.getValueType())));
   } else {
     SDValue NPQ = DAG.getNode(ISD::SUB, dl, VT, N->getOperand(0), Q);
     if (Created)
       Created->push_back(NPQ.getNode());
     NPQ = DAG.getNode(ISD::SRL, dl, VT, NPQ,
-                      DAG.getConstant(1, getShiftAmountTy()));
+                      DAG.getConstant(1, getShiftAmountTy(NPQ.getValueType())));
     if (Created)
       Created->push_back(NPQ.getNode());
     NPQ = DAG.getNode(ISD::ADD, dl, VT, NPQ, Q);
     if (Created)
       Created->push_back(NPQ.getNode());
     return DAG.getNode(ISD::SRL, dl, VT, NPQ,
-                       DAG.getConstant(magics.s-1, getShiftAmountTy()));
+             DAG.getConstant(magics.s-1, getShiftAmountTy(NPQ.getValueType())));
   }
 }