Allow targets to specify a the type of the RHS of a shift parameterized on the type of the LHS.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@126518 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 911dbfd..9cc70a3 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -279,8 +279,8 @@
 
     /// getShiftAmountTy - Returns a type large enough to hold any valid
     /// shift amount - before type legalization these can be huge.
-    EVT getShiftAmountTy() {
-      return LegalTypes ? TLI.getShiftAmountTy() : TLI.getPointerTy();
+    EVT getShiftAmountTy(EVT LHSTy) {
+      return LegalTypes ? TLI.getShiftAmountTy(LHSTy) : TLI.getPointerTy();
     }
 
     /// isTypeLegal - This method returns true if we are running before type
@@ -670,7 +670,7 @@
   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Op)) {
     EVT MemVT = LD->getMemoryVT();
     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD)
-      ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, MemVT) ? ISD::ZEXTLOAD 
+      ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, MemVT) ? ISD::ZEXTLOAD
                                                   : ISD::EXTLOAD)
       : LD->getExtensionType();
     Replace = true;
@@ -894,7 +894,7 @@
     LoadSDNode *LD = cast<LoadSDNode>(N);
     EVT MemVT = LD->getMemoryVT();
     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD)
-      ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, MemVT) ? ISD::ZEXTLOAD 
+      ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, MemVT) ? ISD::ZEXTLOAD
                                                   : ISD::EXTLOAD)
       : LD->getExtensionType();
     SDValue NewLD = DAG.getExtLoad(ExtType, dl, PVT,
@@ -1521,7 +1521,7 @@
 // Since it may not be valid to emit a fold to zero for vector initializers
 // check if we can before folding.
 static SDValue tryFoldToZero(DebugLoc DL, const TargetLowering &TLI, EVT VT,
-                             SelectionDAG &DAG, bool LegalOperations) {                            
+                             SelectionDAG &DAG, bool LegalOperations) {
   if (!VT.isVector()) {
     return DAG.getConstant(0, VT);
   } else if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
@@ -1647,7 +1647,7 @@
   if (N1C && N1C->getAPIntValue().isPowerOf2())
     return DAG.getNode(ISD::SHL, N->getDebugLoc(), VT, N0,
                        DAG.getConstant(N1C->getAPIntValue().logBase2(),
-                                       getShiftAmountTy()));
+                                       getShiftAmountTy(N0.getValueType())));
   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
   if (N1C && (-N1C->getAPIntValue()).isPowerOf2()) {
     unsigned Log2Val = (-N1C->getAPIntValue()).logBase2();
@@ -1656,7 +1656,8 @@
     return DAG.getNode(ISD::SUB, N->getDebugLoc(), VT,
                        DAG.getConstant(0, VT),
                        DAG.getNode(ISD::SHL, N->getDebugLoc(), VT, N0,
-                            DAG.getConstant(Log2Val, getShiftAmountTy())));
+                            DAG.getConstant(Log2Val,
+                                      getShiftAmountTy(N0.getValueType()))));
   }
   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
   if (N1C && N0.getOpcode() == ISD::SHL &&
@@ -1753,18 +1754,18 @@
     // Splat the sign bit into the register
     SDValue SGN = DAG.getNode(ISD::SRA, N->getDebugLoc(), VT, N0,
                               DAG.getConstant(VT.getSizeInBits()-1,
-                                              getShiftAmountTy()));
+                                       getShiftAmountTy(N0.getValueType())));
     AddToWorkList(SGN.getNode());
 
     // Add (N0 < 0) ? abs2 - 1 : 0;
     SDValue SRL = DAG.getNode(ISD::SRL, N->getDebugLoc(), VT, SGN,
                               DAG.getConstant(VT.getSizeInBits() - lg2,
-                                              getShiftAmountTy()));
+                                       getShiftAmountTy(SGN.getValueType())));
     SDValue ADD = DAG.getNode(ISD::ADD, N->getDebugLoc(), VT, N0, SRL);
     AddToWorkList(SRL.getNode());
     AddToWorkList(ADD.getNode());    // Divide by pow2
     SDValue SRA = DAG.getNode(ISD::SRA, N->getDebugLoc(), VT, ADD,
-                              DAG.getConstant(lg2, getShiftAmountTy()));
+                  DAG.getConstant(lg2, getShiftAmountTy(ADD.getValueType())));
 
     // If we're dividing by a positive value, we're done.  Otherwise, we must
     // negate the result.
@@ -1814,7 +1815,7 @@
   if (N1C && N1C->getAPIntValue().isPowerOf2())
     return DAG.getNode(ISD::SRL, N->getDebugLoc(), VT, N0,
                        DAG.getConstant(N1C->getAPIntValue().logBase2(),
-                                       getShiftAmountTy()));
+                                       getShiftAmountTy(N0.getValueType())));
   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
   if (N1.getOpcode() == ISD::SHL) {
     if (ConstantSDNode *SHC = dyn_cast<ConstantSDNode>(N1.getOperand(0))) {
@@ -1955,7 +1956,7 @@
   if (N1C && N1C->getAPIntValue() == 1)
     return DAG.getNode(ISD::SRA, N->getDebugLoc(), N0.getValueType(), N0,
                        DAG.getConstant(N0.getValueType().getSizeInBits() - 1,
-                                       getShiftAmountTy()));
+                                       getShiftAmountTy(N0.getValueType())));
   // fold (mulhs x, undef) -> 0
   if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)
     return DAG.getConstant(0, VT);
@@ -1971,11 +1972,11 @@
       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
-                       DAG.getConstant(SimpleSize, getShiftAmountTy()));
+            DAG.getConstant(SimpleSize, getShiftAmountTy(N1.getValueType())));
       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
     }
   }
-  
+
   return SDValue();
 }
 
@@ -2007,11 +2008,11 @@
       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
-                       DAG.getConstant(SimpleSize, getShiftAmountTy()));
+            DAG.getConstant(SimpleSize, getShiftAmountTy(N1.getValueType())));
       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
     }
   }
-  
+
   return SDValue();
 }
 
@@ -2090,14 +2091,14 @@
       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
       // Compute the high part as N1.
       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
-                       DAG.getConstant(SimpleSize, getShiftAmountTy()));
+            DAG.getConstant(SimpleSize, getShiftAmountTy(Lo.getValueType())));
       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
       // Compute the low part as N0.
       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
       return CombineTo(N, Lo, Hi);
     }
   }
-  
+
   return SDValue();
 }
 
@@ -2107,7 +2108,7 @@
 
   EVT VT = N->getValueType(0);
   DebugLoc DL = N->getDebugLoc();
-  
+
   // If the type twice as wide is legal, transform the mulhu to a wider multiply
   // plus a shift.
   if (VT.isSimple() && !VT.isVector()) {
@@ -2120,14 +2121,14 @@
       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
       // Compute the high part as N1.
       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
-                       DAG.getConstant(SimpleSize, getShiftAmountTy()));
+            DAG.getConstant(SimpleSize, getShiftAmountTy(Lo.getValueType())));
       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
       // Compute the low part as N0.
       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
       return CombineTo(N, Lo, Hi);
     }
   }
-  
+
   return SDValue();
 }
 
@@ -3004,7 +3005,7 @@
               N0.getOpcode() == ISD::SIGN_EXTEND) &&
       N0.getOperand(0).getOpcode() == ISD::SHL &&
       isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) {
-    uint64_t c1 = 
+    uint64_t c1 =
       cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue();
     uint64_t c2 = N1C->getZExtValue();
     EVT InnerShiftVT = N0.getOperand(0).getValueType();
@@ -3133,7 +3134,8 @@
           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
           TLI.isTruncateFree(VT, TruncVT)) {
 
-          SDValue Amt = DAG.getConstant(ShiftAmt, getShiftAmountTy());
+          SDValue Amt = DAG.getConstant(ShiftAmt,
+              getShiftAmountTy(N0.getOperand(0).getValueType()));
           SDValue Shift = DAG.getNode(ISD::SRL, N0.getDebugLoc(), VT,
                                       N0.getOperand(0), Amt);
           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, N0.getDebugLoc(), TruncVT,
@@ -3180,7 +3182,7 @@
         LargeShiftAmt->getZExtValue()) {
       SDValue Amt =
         DAG.getConstant(LargeShiftAmt->getZExtValue() + N1C->getZExtValue(),
-                        getShiftAmountTy());
+              getShiftAmountTy(N0.getOperand(0).getOperand(0).getValueType()));
       SDValue SRA = DAG.getNode(ISD::SRA, N->getDebugLoc(), LargeVT,
                                 N0.getOperand(0).getOperand(0), Amt);
       return DAG.getNode(ISD::TRUNCATE, N->getDebugLoc(), VT, SRA);
@@ -3245,7 +3247,7 @@
   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
       N0.getOperand(0).getOpcode() == ISD::SRL &&
       isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) {
-    uint64_t c1 = 
+    uint64_t c1 =
       cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue();
     uint64_t c2 = N1C->getZExtValue();
     EVT InnerShiftVT = N0.getOperand(0).getValueType();
@@ -3256,7 +3258,7 @@
       if (c1 + c2 >= InnerShiftSize)
         return DAG.getConstant(0, VT);
       return DAG.getNode(ISD::TRUNCATE, N0->getDebugLoc(), VT,
-                         DAG.getNode(ISD::SRL, N0->getDebugLoc(), InnerShiftVT, 
+                         DAG.getNode(ISD::SRL, N0->getDebugLoc(), InnerShiftVT,
                                      N0.getOperand(0)->getOperand(0),
                                      DAG.getConstant(c1 + c2, ShiftCountVT)));
     }
@@ -3320,7 +3322,7 @@
 
       if (ShAmt) {
         Op = DAG.getNode(ISD::SRL, N0.getDebugLoc(), VT, Op,
-                         DAG.getConstant(ShAmt, getShiftAmountTy()));
+                  DAG.getConstant(ShAmt, getShiftAmountTy(Op.getValueType())));
         AddToWorkList(Op.getNode());
       }
 
@@ -4025,11 +4027,11 @@
     }
 
     DebugLoc DL = N->getDebugLoc();
-    
-    // Ensure that the shift amount is wide enough for the shifted value. 
+
+    // Ensure that the shift amount is wide enough for the shifted value.
     if (VT.getSizeInBits() >= 256)
       ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
-    
+
     return DAG.getNode(N0.getOpcode(), DL, VT,
                        DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
                        ShAmt);
@@ -4278,12 +4280,12 @@
     return SDValue();
 
   unsigned EVTBits = ExtVT.getSizeInBits();
-  
+
   // Do not generate loads of non-round integer types since these can
   // be expensive (and would be wrong if the type is not byte sized).
   if (!ExtVT.isRound())
     return SDValue();
-  
+
   unsigned ShAmt = 0;
   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
@@ -4298,7 +4300,7 @@
 
       // At this point, we must have a load or else we can't do the transform.
       if (!isa<LoadSDNode>(N0)) return SDValue();
-      
+
       // If the shift amount is larger than the input type then we're not
       // accessing any of the loaded bytes.  If the load was a zextload/extload
       // then the result of the shift+trunc is zero/undef (handled elsewhere).
@@ -4319,18 +4321,18 @@
       N0 = N0.getOperand(0);
     }
   }
-  
+
   // If we haven't found a load, we can't narrow it.  Don't transform one with
   // multiple uses, this would require adding a new load.
   if (!isa<LoadSDNode>(N0) || !N0.hasOneUse() ||
       // Don't change the width of a volatile load.
       cast<LoadSDNode>(N0)->isVolatile())
     return SDValue();
-  
+
   // Verify that we are actually reducing a load width here.
   if (cast<LoadSDNode>(N0)->getMemoryVT().getSizeInBits() < EVTBits)
     return SDValue();
-  
+
   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
   EVT PtrType = N0.getOperand(1).getValueType();
 
@@ -4368,7 +4370,7 @@
   // Shift the result left, if we've swallowed a left shift.
   SDValue Result = Load;
   if (ShLeftAmt != 0) {
-    EVT ShImmTy = getShiftAmountTy();
+    EVT ShImmTy = getShiftAmountTy(Result.getValueType());
     if (!isUIntN(ShImmTy.getSizeInBits(), ShLeftAmt))
       ShImmTy = VT;
     Result = DAG.getNode(ISD::SHL, N0.getDebugLoc(), VT,
@@ -5984,7 +5986,8 @@
   // shifted by ByteShift and truncated down to NumBytes.
   if (ByteShift)
     IVal = DAG.getNode(ISD::SRL, IVal->getDebugLoc(), IVal.getValueType(), IVal,
-                       DAG.getConstant(ByteShift*8, DC->getShiftAmountTy()));
+                       DAG.getConstant(ByteShift*8,
+                                    DC->getShiftAmountTy(IVal.getValueType())));
 
   // Figure out the offset for the store and the alignment of the access.
   unsigned StOffset;
@@ -6399,7 +6402,7 @@
 
   EVT VT = InVec.getValueType();
 
-  // If we can't generate a legal BUILD_VECTOR, exit 
+  // If we can't generate a legal BUILD_VECTOR, exit
   if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
     return SDValue();
 
@@ -7107,7 +7110,8 @@
       if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue()-1)) == 0)) {
         unsigned ShCtV = N2C->getAPIntValue().logBase2();
         ShCtV = XType.getSizeInBits()-ShCtV-1;
-        SDValue ShCt = DAG.getConstant(ShCtV, getShiftAmountTy());
+        SDValue ShCt = DAG.getConstant(ShCtV,
+                                       getShiftAmountTy(N0.getValueType()));
         SDValue Shift = DAG.getNode(ISD::SRL, N0.getDebugLoc(),
                                     XType, N0, ShCt);
         AddToWorkList(Shift.getNode());
@@ -7123,7 +7127,7 @@
       SDValue Shift = DAG.getNode(ISD::SRA, N0.getDebugLoc(),
                                   XType, N0,
                                   DAG.getConstant(XType.getSizeInBits()-1,
-                                                  getShiftAmountTy()));
+                                         getShiftAmountTy(N0.getValueType())));
       AddToWorkList(Shift.getNode());
 
       if (XType.bitsGT(AType)) {
@@ -7151,13 +7155,15 @@
       // Shift the tested bit over the sign bit.
       APInt AndMask = ConstAndRHS->getAPIntValue();
       SDValue ShlAmt =
-        DAG.getConstant(AndMask.countLeadingZeros(), getShiftAmountTy());
+        DAG.getConstant(AndMask.countLeadingZeros(),
+                        getShiftAmountTy(AndLHS.getValueType()));
       SDValue Shl = DAG.getNode(ISD::SHL, N0.getDebugLoc(), VT, AndLHS, ShlAmt);
 
       // Now arithmetic right shift it all the way over, so the result is either
       // all-ones, or zero.
       SDValue ShrAmt =
-        DAG.getConstant(AndMask.getBitWidth()-1, getShiftAmountTy());
+        DAG.getConstant(AndMask.getBitWidth()-1,
+                        getShiftAmountTy(Shl.getValueType()));
       SDValue Shr = DAG.getNode(ISD::SRA, N0.getDebugLoc(), VT, Shl, ShrAmt);
 
       return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
@@ -7201,7 +7207,7 @@
     // shl setcc result by log2 n2c
     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
                        DAG.getConstant(N2C->getAPIntValue().logBase2(),
-                                       getShiftAmountTy()));
+                                       getShiftAmountTy(Temp.getValueType())));
   }
 
   // Check to see if this is the equivalent of setcc
@@ -7224,7 +7230,7 @@
       SDValue Ctlz = DAG.getNode(ISD::CTLZ, N0.getDebugLoc(), XType, N0);
       return DAG.getNode(ISD::SRL, DL, XType, Ctlz,
                          DAG.getConstant(Log2_32(XType.getSizeInBits()),
-                                         getShiftAmountTy()));
+                                       getShiftAmountTy(Ctlz.getValueType())));
     }
     // fold (setgt X, 0) -> (srl (and (-X, ~X), size(X)-1))
     if (N1C && N1C->isNullValue() && CC == ISD::SETGT) {
@@ -7234,13 +7240,13 @@
       return DAG.getNode(ISD::SRL, DL, XType,
                          DAG.getNode(ISD::AND, DL, XType, NegN0, NotN0),
                          DAG.getConstant(XType.getSizeInBits()-1,
-                                         getShiftAmountTy()));
+                                         getShiftAmountTy(XType)));
     }
     // fold (setgt X, -1) -> (xor (srl (X, size(X)-1), 1))
     if (N1C && N1C->isAllOnesValue() && CC == ISD::SETGT) {
       SDValue Sign = DAG.getNode(ISD::SRL, N0.getDebugLoc(), XType, N0,
                                  DAG.getConstant(XType.getSizeInBits()-1,
-                                                 getShiftAmountTy()));
+                                         getShiftAmountTy(N0.getValueType())));
       return DAG.getNode(ISD::XOR, DL, XType, Sign, DAG.getConstant(1, XType));
     }
   }
@@ -7267,7 +7273,7 @@
       SDValue Shift = DAG.getNode(ISD::SRA, N0.getDebugLoc(), XType,
                                   N0,
                                   DAG.getConstant(XType.getSizeInBits()-1,
-                                                  getShiftAmountTy()));
+                                         getShiftAmountTy(N0.getValueType())));
       SDValue Add = DAG.getNode(ISD::ADD, N0.getDebugLoc(),
                                 XType, N0, Shift);
       AddToWorkList(Shift.getNode());