[RISCV] Custom-legalise i32 SDIV/UDIV/UREM on RV64M

Follow the same custom legalisation strategy as used in D57085 for
variable-length shifts (see that patch summary for more discussion). Although
we may lose out on some late-stage DAG combines, I think this custom
legalisation strategy is ultimately easier to reason about.

There are some codegen changes in rv64m-exhaustive-w-insts.ll but they are all
neutral in terms of the number of instructions.

Differential Revision: https://reviews.llvm.org/D57096

llvm-svn: 352171
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4bbdd86..79385de 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -80,7 +80,6 @@
     setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
 
   if (Subtarget.is64Bit()) {
-    setTargetDAGCombine(ISD::ANY_EXTEND);
     setOperationAction(ISD::SHL, MVT::i32, Custom);
     setOperationAction(ISD::SRA, MVT::i32, Custom);
     setOperationAction(ISD::SRL, MVT::i32, Custom);
@@ -96,6 +95,12 @@
     setOperationAction(ISD::UREM, XLenVT, Expand);
   }
 
+  if (Subtarget.is64Bit() && Subtarget.hasStdExtM()) {
+    setOperationAction(ISD::SDIV, MVT::i32, Custom);
+    setOperationAction(ISD::UDIV, MVT::i32, Custom);
+    setOperationAction(ISD::UREM, MVT::i32, Custom);
+  }
+
   setOperationAction(ISD::SDIVREM, XLenVT, Expand);
   setOperationAction(ISD::UDIVREM, XLenVT, Expand);
   setOperationAction(ISD::SMUL_LOHI, XLenVT, Expand);
@@ -524,6 +529,12 @@
     return RISCVISD::SRAW;
   case ISD::SRL:
     return RISCVISD::SRLW;
+  case ISD::SDIV:
+    return RISCVISD::DIVW;
+  case ISD::UDIV:
+    return RISCVISD::DIVUW;
+  case ISD::UREM:
+    return RISCVISD::REMUW;
   }
 }
 
@@ -558,46 +569,24 @@
       return;
     Results.push_back(customLegalizeToWOp(N, DAG));
     break;
-  }
-}
-
-// Returns true if the given node is an sdiv, udiv, or urem with non-constant
-// operands.
-static bool isVariableSDivUDivURem(SDValue Val) {
-  switch (Val.getOpcode()) {
-  default:
-    return false;
   case ISD::SDIV:
   case ISD::UDIV:
   case ISD::UREM:
-    return Val.getOperand(0).getOpcode() != ISD::Constant &&
-           Val.getOperand(1).getOpcode() != ISD::Constant;
+    assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
+           Subtarget.hasStdExtM() && "Unexpected custom legalisation");
+    if (N->getOperand(0).getOpcode() == ISD::Constant ||
+        N->getOperand(1).getOpcode() == ISD::Constant)
+      return;
+    Results.push_back(customLegalizeToWOp(N, DAG));
+    break;
   }
 }
 
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
-  SelectionDAG &DAG = DCI.DAG;
-
   switch (N->getOpcode()) {
   default:
     break;
-  case ISD::ANY_EXTEND: {
-    // If any-extending an i32 sdiv/udiv/urem to i64, then instead sign-extend
-    // in order to increase the chance of being able to select the
-    // divw/divuw/remuw instructions.
-    SDValue Src = N->getOperand(0);
-    if (N->getValueType(0) != MVT::i64 || Src.getValueType() != MVT::i32)
-      break;
-    if (!(Subtarget.hasStdExtM() && isVariableSDivUDivURem(Src)))
-      break;
-    SDLoc DL(N);
-    // Don't add the new node to the DAGCombiner worklist, in order to avoid
-    // an infinite cycle due to SimplifyDemandedBits converting the
-    // SIGN_EXTEND back to ANY_EXTEND.
-    return DCI.CombineTo(N, DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Src),
-                         false);
-  }
   case RISCVISD::SplitF64: {
     // If the input to SplitF64 is just BuildPairF64 then the operation is
     // redundant. Instead, use BuildPairF64's operands directly.
@@ -633,6 +622,9 @@
   case RISCVISD::SLLW:
   case RISCVISD::SRAW:
   case RISCVISD::SRLW:
+  case RISCVISD::DIVW:
+  case RISCVISD::DIVUW:
+  case RISCVISD::REMUW:
     // TODO: As the result is sign-extended, this is conservatively correct. A
     // more precise answer could be calculated for SRAW depending on known
     // bits in the shift amount.
@@ -1736,6 +1728,12 @@
     return "RISCVISD::SRAW";
   case RISCVISD::SRLW:
     return "RISCVISD::SRLW";
+  case RISCVISD::DIVW:
+    return "RISCVISD::DIVW";
+  case RISCVISD::DIVUW:
+    return "RISCVISD::DIVUW";
+  case RISCVISD::REMUW:
+    return "RISCVISD::REMUW";
   }
   return nullptr;
 }