More work to allow dag combiner to promote 16-bit ops to 32-bit.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@101621 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 9ba9bb5..5fe67b5 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -582,9 +582,8 @@
   return SDValue(N, 0);
 }
 
-void
-DAGCombiner::CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &
-                                                                          TLO) {
+void DAGCombiner::
+CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
   // Replace all uses.  If any nodes become isomorphic to other nodes and
   // are deleted, make sure to remove them from our worklist.
   WorkListRemover DeadNodes(*this);
@@ -614,7 +613,7 @@
 /// it can be simplified or if things it uses can be simplified by bit
 /// propagation.  If so, return true.
 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &Demanded) {
-  TargetLowering::TargetLoweringOpt TLO(DAG);
+  TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
   APInt KnownZero, KnownOne;
   if (!TLI.SimplifyDemandedBits(Op, Demanded, KnownZero, KnownOne, TLO))
     return false;
@@ -634,18 +633,50 @@
   return true;
 }
 
-static SDValue PromoteOperand(SDValue Op, EVT PVT, SelectionDAG &DAG) {
-  unsigned Opc = ISD::ZERO_EXTEND;
-  if (Op.getOpcode() == ISD::Constant) {
+static SDValue PromoteOperand(SDValue Op, EVT PVT, SelectionDAG &DAG,
+                              const TargetLowering &TLI) {
+  if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Op)) {
+    return DAG.getExtLoad(ISD::EXTLOAD, Op.getDebugLoc(), PVT,
+                          LD->getChain(), LD->getBasePtr(),
+                          LD->getSrcValue(), LD->getSrcValueOffset(),
+                          LD->getMemoryVT(), LD->isVolatile(),
+                          LD->isNonTemporal(), LD->getAlignment());
+  }
+
+  unsigned Opc = ISD::ANY_EXTEND;
+  if (Op.getOpcode() == ISD::Constant)
     // Zero extend things like i1, sign extend everything else.  It shouldn't
     // matter in theory which one we pick, but this tends to give better code?
     // See DAGTypeLegalizer::PromoteIntRes_Constant.
-    if (Op.getValueType().isByteSized())
-      Opc = ISD::SIGN_EXTEND;
-  }
+    Opc = Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+  if (!TLI.isOperationLegal(Opc, PVT))
+    return SDValue();
   return DAG.getNode(Opc, Op.getDebugLoc(), PVT, Op);
 }
 
+static SDValue SExtPromoteOperand(SDValue Op, EVT PVT, SelectionDAG &DAG,
+                                  const TargetLowering &TLI) {
+  if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
+    return SDValue();
+  EVT OldVT = Op.getValueType();
+  DebugLoc dl = Op.getDebugLoc();
+  Op = PromoteOperand(Op, PVT, DAG, TLI);
+  if (Op.getNode() == 0)
+    return SDValue();
+  return DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, Op.getValueType(), Op,
+                     DAG.getValueType(OldVT));
+}
+
+static SDValue ZExtPromoteOperand(SDValue Op, EVT PVT, SelectionDAG &DAG,
+                                  const TargetLowering &TLI) {
+  EVT OldVT = Op.getValueType();
+  DebugLoc dl = Op.getDebugLoc();
+  Op = PromoteOperand(Op, PVT, DAG, TLI);
+  if (Op.getNode() == 0)
+    return SDValue();
+  return DAG.getZeroExtendInReg(Op, dl, OldVT);
+}
+
 /// PromoteIntBinOp - Promote the specified integer binary operation if the
 /// target indicates it is beneficial. e.g. On x86, it's usually better to
 /// promote i16 operations to i32 since i16 instructions are longer.
@@ -657,15 +688,37 @@
   if (VT.isVector() || !VT.isInteger())
     return SDValue();
 
+  // If operation type is 'undesirable', e.g. i16 on x86, consider
+  // promoting it.
+  unsigned Opc = Op.getOpcode();
+  if (TLI.isTypeDesirableForOp(Opc, VT))
+    return SDValue();
+
   EVT PVT = VT;
-  if (TLI.PerformDAGCombinePromotion(Op, PVT)) {
+  // Consult target whether it is a good idea to promote this operation and
+  // what's the right type to promote it to.
+  if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
     assert(PVT != VT && "Don't know what type to promote to!");
 
-    SDValue N0 = PromoteOperand(Op.getOperand(0), PVT, DAG);
-    AddToWorkList(N0.getNode());
+    bool isShift = (Opc == ISD::SHL) || (Opc == ISD::SRA) || (Opc == ISD::SRL);
+    SDValue N0 = Op.getOperand(0);
+    if (Opc == ISD::SRA)
+      N0 = SExtPromoteOperand(Op.getOperand(0), PVT, DAG, TLI);
+    else if (Opc == ISD::SRL)
+      N0 = ZExtPromoteOperand(Op.getOperand(0), PVT, DAG, TLI);
+    else
+      N0 = PromoteOperand(N0, PVT, DAG, TLI);
+    if (N0.getNode() == 0)
+      return SDValue();
 
-    SDValue N1 = PromoteOperand(Op.getOperand(1), PVT, DAG);
-    AddToWorkList(N1.getNode());
+    SDValue N1 = Op.getOperand(1);
+    if (!isShift) {
+      N1 = PromoteOperand(N1, PVT, DAG, TLI);
+      if (N1.getNode() == 0)
+        return SDValue();
+      AddToWorkList(N1.getNode());
+    }
+    AddToWorkList(N0.getNode());
 
     DebugLoc dl = Op.getDebugLoc();
     return DAG.getNode(ISD::TRUNCATE, dl, VT,
@@ -674,6 +727,7 @@
   return SDValue();
 }
 
+
 //===----------------------------------------------------------------------===//
 //  Main DAG Combiner implementation
 //===----------------------------------------------------------------------===//
@@ -1765,8 +1819,10 @@
   // into a vsetcc.
   EVT Op0VT = N0.getOperand(0).getValueType();
   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
-       N0.getOpcode() == ISD::ANY_EXTEND  ||
        N0.getOpcode() == ISD::SIGN_EXTEND ||
+       // Avoid infinite looping with PromoteIntBinOp.
+       (N0.getOpcode() == ISD::ANY_EXTEND &&
+        (!LegalTypes || TLI.isTypeDesirableForOp(N->getOpcode(), Op0VT))) ||
        (N0.getOpcode() == ISD::TRUNCATE && TLI.isTypeLegal(Op0VT))) &&
       !VT.isVector() &&
       Op0VT == N1.getOperand(0).getValueType() &&
@@ -2624,7 +2680,13 @@
                        HiBitsMask);
   }
 
-  return N1C ? visitShiftByConstant(N, N1C->getZExtValue()) : SDValue();
+  if (N1C) {
+    SDValue NewSHL = visitShiftByConstant(N, N1C->getZExtValue());
+    if (NewSHL.getNode())
+      return NewSHL;
+  }
+
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 SDValue DAGCombiner::visitSRA(SDNode *N) {
@@ -2738,7 +2800,13 @@
   if (DAG.SignBitIsZero(N0))
     return DAG.getNode(ISD::SRL, N->getDebugLoc(), VT, N0, N1);
 
-  return N1C ? visitShiftByConstant(N, N1C->getZExtValue()) : SDValue();
+  if (N1C) {
+    SDValue NewSRA = visitShiftByConstant(N, N1C->getZExtValue());
+    if (NewSRA.getNode())
+      return NewSRA;
+  }
+
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 SDValue DAGCombiner::visitSRL(SDNode *N) {
@@ -2793,10 +2861,12 @@
     if (N1C->getZExtValue() >= SmallVT.getSizeInBits())
       return DAG.getUNDEF(VT);
 
-    SDValue SmallShift = DAG.getNode(ISD::SRL, N0.getDebugLoc(), SmallVT,
-                                     N0.getOperand(0), N1);
-    AddToWorkList(SmallShift.getNode());
-    return DAG.getNode(ISD::ANY_EXTEND, N->getDebugLoc(), VT, SmallShift);
+    if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
+      SDValue SmallShift = DAG.getNode(ISD::SRL, N0.getDebugLoc(), SmallVT,
+                                       N0.getOperand(0), N1);
+      AddToWorkList(SmallShift.getNode());
+      return DAG.getNode(ISD::ANY_EXTEND, N->getDebugLoc(), VT, SmallShift);
+    }
   }
 
   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
@@ -2902,7 +2972,7 @@
     }
   }
 
-  return SDValue();
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
@@ -3861,7 +3931,9 @@
 
   // fold (truncate (load x)) -> (smaller load x)
   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
-  return ReduceLoadWidth(N);
+  if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT))
+    return ReduceLoadWidth(N);
+  return SDValue();
 }
 
 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 9082fb6..4b7fb86 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -355,7 +355,7 @@
     InWorklist.insert(I);
   }
 
-  TargetLowering::TargetLoweringOpt TLO(*CurDAG, true);
+  TargetLowering::TargetLoweringOpt TLO(*CurDAG, true, true, true);
   while (!Worklist.empty()) {
     SDNode *N = Worklist.pop_back_val();
     InWorklist.erase(N);
diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index df923c6..11dca39 100644
--- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1279,8 +1279,9 @@
     // variable.  The low bit of the shift cannot be an input sign bit unless
     // the shift amount is >= the size of the datatype, which is undefined.
     if (DemandedMask == 1)
-      return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, Op.getValueType(),
-                                               Op.getOperand(0), Op.getOperand(1)));
+      return TLO.CombineTo(Op,
+                           TLO.DAG.getNode(ISD::SRL, dl, Op.getValueType(),
+                                           Op.getOperand(0), Op.getOperand(1)));
 
     if (ConstantSDNode *SA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
       EVT VT = Op.getValueType();
@@ -1465,23 +1466,29 @@
       case ISD::SRL:
         // Shrink SRL by a constant if none of the high bits shifted in are
         // demanded.
-        if (ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(In.getOperand(1))){
-          APInt HighBits = APInt::getHighBitsSet(OperandBitWidth,
-                                                 OperandBitWidth - BitWidth);
-          HighBits = HighBits.lshr(ShAmt->getZExtValue());
-          HighBits.trunc(BitWidth);
-          
-          if (ShAmt->getZExtValue() < BitWidth && !(HighBits & NewMask)) {
-            // None of the shifted in bits are needed.  Add a truncate of the
-            // shift input, then shift it.
-            SDValue NewTrunc = TLO.DAG.getNode(ISD::TRUNCATE, dl,
-                                                 Op.getValueType(), 
-                                                 In.getOperand(0));
-            return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl,
-                                                     Op.getValueType(),
-                                                     NewTrunc, 
-                                                     In.getOperand(1)));
-          }
+        if (TLO.LegalTypes() &&
+            !isTypeDesirableForOp(ISD::SRL, Op.getValueType()))
+          // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is
+          // undesirable.
+          break;
+        ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(In.getOperand(1));
+        if (!ShAmt)
+          break;
+        APInt HighBits = APInt::getHighBitsSet(OperandBitWidth,
+                                               OperandBitWidth - BitWidth);
+        HighBits = HighBits.lshr(ShAmt->getZExtValue());
+        HighBits.trunc(BitWidth);
+
+        if (ShAmt->getZExtValue() < BitWidth && !(HighBits & NewMask)) {
+          // None of the shifted in bits are needed.  Add a truncate of the
+          // shift input, then shift it.
+          SDValue NewTrunc = TLO.DAG.getNode(ISD::TRUNCATE, dl,
+                                             Op.getValueType(), 
+                                             In.getOperand(0));
+          return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl,
+                                                   Op.getValueType(),
+                                                   NewTrunc, 
+                                                   In.getOperand(1)));
         }
         break;
       }