[TargetLowering][AMDGPU] Remove the SimplifyDemandedBits function that takes a User and OpIdx. Stop using it in AMDGPU target for simplifyI24.

As we saw in D56057 when we tried to use this function on X86, it's unsafe. It allows the operand node to have multiple users, but doesn't prevent recursing past the first node when it does have multiple users. This can cause other simplifications earlier in the graph without regard to what bits are needed by the other users of the first node. Ideally all we should do to the first node if it has multiple uses is bypass it when its not needed by the user we started from. Doing any other transformation that SimplifyDemandedBits can do like turning ZEXT/SEXT into AEXT would result in an increase in instructions.

Fortunately, we already have a function that can do just that, GetDemandedBits. It will only make transformations that involve bypassing a node.

This patch changes AMDGPU's simplifyI24, to use a combination of GetDemandedBits to handle the multiple use simplifications. And then uses the regular SimplifyDemandedBits on each operand to handle simplifications allowed when the operand only has a single use. Unfortunately, GetDemandedBits simplifies constants more aggressively than SimplifyDemandedBits. This caused the -7 constant in the changed test to be simplified to remove the upper bits. I had to modify computeKnownBits to account for this by ignoring the upper 8 bits of the input.

Differential Revision: https://reviews.llvm.org/D56087

llvm-svn: 350560
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 2a05bf5..6951c91 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -2717,21 +2717,33 @@
     AMDGPUTargetLowering::numBitsSigned(Op, DAG) < 24;
 }
 
-static bool simplifyI24(SDNode *Node24, unsigned OpIdx,
-                        TargetLowering::DAGCombinerInfo &DCI) {
-
+static SDValue simplifyI24(SDNode *Node24,
+                           TargetLowering::DAGCombinerInfo &DCI) {
   SelectionDAG &DAG = DCI.DAG;
-  SDValue Op = Node24->getOperand(OpIdx);
+  SDValue LHS = Node24->getOperand(0);
+  SDValue RHS = Node24->getOperand(1);
+
+  APInt Demanded = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 24);
+
+  // First try to simplify using GetDemandedBits which allows the operands to
+  // have other uses, but will only perform simplifications that involve
+  // bypassing some nodes for this user.
+  SDValue DemandedLHS = DAG.GetDemandedBits(LHS, Demanded);
+  SDValue DemandedRHS = DAG.GetDemandedBits(RHS, Demanded);
+  if (DemandedLHS || DemandedRHS)
+    return DAG.getNode(Node24->getOpcode(), SDLoc(Node24), Node24->getVTList(),
+                       DemandedLHS ? DemandedLHS : LHS,
+                       DemandedRHS ? DemandedRHS : RHS);
+
+  // Now try SimplifyDemandedBits which can simplify the nodes used by our
+  // operands if this node is the only user.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  EVT VT = Op.getValueType();
+  if (TLI.SimplifyDemandedBits(LHS, Demanded, DCI))
+    return SDValue(Node24, 0);
+  if (TLI.SimplifyDemandedBits(RHS, Demanded, DCI))
+    return SDValue(Node24, 0);
 
-  APInt Demanded = APInt::getLowBitsSet(VT.getSizeInBits(), 24);
-  APInt KnownZero, KnownOne;
-  TargetLowering::TargetLoweringOpt TLO(DAG, true, true);
-  if (TLI.SimplifyDemandedBits(Node24, OpIdx, Demanded, DCI, TLO))
-    return true;
-
-  return false;
+  return SDValue();
 }
 
 template <typename IntTy>
@@ -3279,8 +3291,8 @@
   SelectionDAG &DAG = DCI.DAG;
 
   // Simplify demanded bits before splitting into multiple users.
-  if (simplifyI24(N, 0, DCI) || simplifyI24(N, 1, DCI))
-    return SDValue();
+  if (SDValue V = simplifyI24(N, DCI))
+    return V;
 
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -3866,9 +3878,8 @@
   case AMDGPUISD::MUL_U24:
   case AMDGPUISD::MULHI_I24:
   case AMDGPUISD::MULHI_U24: {
-    // If the first call to simplify is successfull, then N may end up being
-    // deleted, so we shouldn't call simplifyI24 again.
-    simplifyI24(N, 0, DCI) || simplifyI24(N, 1, DCI);
+    if (SDValue V = simplifyI24(N, DCI))
+      return V;
     return SDValue();
   }
   case AMDGPUISD::MUL_LOHI_I24:
@@ -4296,25 +4307,36 @@
                       RHSKnown.countMinTrailingZeros();
     Known.Zero.setLowBits(std::min(TrailZ, 32u));
 
-    unsigned LHSValBits = 32 - std::max(LHSKnown.countMinSignBits(), 8u);
-    unsigned RHSValBits = 32 - std::max(RHSKnown.countMinSignBits(), 8u);
-    unsigned MaxValBits = std::min(LHSValBits + RHSValBits, 32u);
-    if (MaxValBits >= 32)
-      break;
+    // Truncate to 24 bits.
+    LHSKnown = LHSKnown.trunc(24);
+    RHSKnown = RHSKnown.trunc(24);
+
     bool Negative = false;
     if (Opc == AMDGPUISD::MUL_I24) {
-      bool LHSNegative = !!(LHSKnown.One  & (1 << 23));
-      bool LHSPositive = !!(LHSKnown.Zero & (1 << 23));
-      bool RHSNegative = !!(RHSKnown.One  & (1 << 23));
-      bool RHSPositive = !!(RHSKnown.Zero & (1 << 23));
+      unsigned LHSValBits = 24 - LHSKnown.countMinSignBits();
+      unsigned RHSValBits = 24 - RHSKnown.countMinSignBits();
+      unsigned MaxValBits = std::min(LHSValBits + RHSValBits, 32u);
+      if (MaxValBits >= 32)
+        break;
+      bool LHSNegative = LHSKnown.isNegative();
+      bool LHSPositive = LHSKnown.isNonNegative();
+      bool RHSNegative = RHSKnown.isNegative();
+      bool RHSPositive = RHSKnown.isNonNegative();
       if ((!LHSNegative && !LHSPositive) || (!RHSNegative && !RHSPositive))
         break;
       Negative = (LHSNegative && RHSPositive) || (LHSPositive && RHSNegative);
-    }
-    if (Negative)
-      Known.One.setHighBits(32 - MaxValBits);
-    else
+      if (Negative)
+        Known.One.setHighBits(32 - MaxValBits);
+      else
+        Known.Zero.setHighBits(32 - MaxValBits);
+    } else {
+      unsigned LHSValBits = 24 - LHSKnown.countMinLeadingZeros();
+      unsigned RHSValBits = 24 - RHSKnown.countMinLeadingZeros();
+      unsigned MaxValBits = std::min(LHSValBits + RHSValBits, 32u);
+      if (MaxValBits >= 32)
+        break;
       Known.Zero.setHighBits(32 - MaxValBits);
+    }
     break;
   }
   case AMDGPUISD::PERM: {