Make the case I just checked in stronger.  Now we compile this:

short test2(short X, short x) {
  int Y = (short)(X+x);
  return Y >> 1;
}

to:

_test2:
        add r2, r3, r4
        extsh r2, r2
        srawi r3, r2, 1
        blr

instead of:

_test2:
        add r2, r3, r4
        extsh r2, r2
        srwi r2, r2, 1
        extsh r3, r2
        blr


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@28175 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 1515df5..de23285 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1948,17 +1948,25 @@
       EVT < cast<VTSDNode>(N0.getOperand(1))->getVT()) {
     return DAG.getNode(ISD::SIGN_EXTEND_INREG, VT, N0.getOperand(0), N1);
   }
-  
-  // fold (sext_in_reg (srl X, 24), i8) -> sra X, 24
-  if (N0.getOpcode() == ISD::SRL) {
-    if (ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
-      if (ShAmt->getValue()+EVTBits == MVT::getSizeInBits(VT))
-        return DAG.getNode(ISD::SRA, VT, N0.getOperand(0), N0.getOperand(1));
-  }
-  
+
   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is zero
   if (TLI.MaskedValueIsZero(N0, 1ULL << (EVTBits-1)))
     return DAG.getZeroExtendInReg(N0, EVT);
+  
+  // fold (sext_in_reg (srl X, 24), i8) -> sra X, 24
+  // fold (sext_in_reg (srl X, 23), i8) -> sra X, 23 iff possible.
+  // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
+  if (N0.getOpcode() == ISD::SRL) {
+    if (ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
+      if (ShAmt->getValue()+EVTBits <= MVT::getSizeInBits(VT)) {
+        // We can turn this into an SRA iff the input to the SRL is already sign
+        // extended enough.
+        unsigned InSignBits = TLI.ComputeNumSignBits(N0.getOperand(0));
+        if (MVT::getSizeInBits(VT)-(ShAmt->getValue()+EVTBits) < InSignBits)
+          return DAG.getNode(ISD::SRA, VT, N0.getOperand(0), N0.getOperand(1));
+      }
+  }
+  
   // fold (sext_inreg (extload x)) -> (sextload x)
   if (N0.getOpcode() == ISD::EXTLOAD && 
       EVT == cast<VTSDNode>(N0.getOperand(3))->getVT() &&