[RISCV] Custom lower SHL_PARTS, SRA_PARTS, SRL_PARTS

When not optimizing for minimum size (-Oz) we custom lower wide shifts
(SHL_PARTS, SRA_PARTS, SRL_PARTS) instead of expanding to a libcall.

Differential Revision: https://reviews.llvm.org/D59477

llvm-svn: 358498
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index cc04c36..377933e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -125,9 +125,9 @@
   setOperationAction(ISD::SMUL_LOHI, XLenVT, Expand);
   setOperationAction(ISD::UMUL_LOHI, XLenVT, Expand);
 
-  setOperationAction(ISD::SHL_PARTS, XLenVT, Expand);
-  setOperationAction(ISD::SRL_PARTS, XLenVT, Expand);
-  setOperationAction(ISD::SRA_PARTS, XLenVT, Expand);
+  setOperationAction(ISD::SHL_PARTS, XLenVT, Custom);
+  setOperationAction(ISD::SRL_PARTS, XLenVT, Custom);
+  setOperationAction(ISD::SRA_PARTS, XLenVT, Custom);
 
   setOperationAction(ISD::ROTL, XLenVT, Expand);
   setOperationAction(ISD::ROTR, XLenVT, Expand);
@@ -360,6 +360,12 @@
     return lowerFRAMEADDR(Op, DAG);
   case ISD::RETURNADDR:
     return lowerRETURNADDR(Op, DAG);
+  case ISD::SHL_PARTS:
+    return lowerShiftLeftParts(Op, DAG);
+  case ISD::SRA_PARTS:
+    return lowerShiftRightParts(Op, DAG, true);
+  case ISD::SRL_PARTS:
+    return lowerShiftRightParts(Op, DAG, false);
   case ISD::BITCAST: {
     assert(Subtarget.is64Bit() && Subtarget.hasStdExtF() &&
            "Unexpected custom legalisation");
@@ -568,6 +574,97 @@
   return DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, XLenVT);
 }
 
+SDValue RISCVTargetLowering::lowerShiftLeftParts(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Lo = Op.getOperand(0);
+  SDValue Hi = Op.getOperand(1);
+  SDValue Shamt = Op.getOperand(2);
+  EVT VT = Lo.getValueType();
+
+  // if Shamt-XLEN < 0: // Shamt < XLEN
+  //   Lo = Lo << Shamt
+  //   Hi = (Hi << Shamt) | ((Lo >>u 1) >>u (XLEN-1 - Shamt))
+  // else:
+  //   Lo = 0
+  //   Hi = Lo << (Shamt-XLEN)
+
+  SDValue Zero = DAG.getConstant(0, DL, VT);
+  SDValue One = DAG.getConstant(1, DL, VT);
+  SDValue MinusXLen = DAG.getConstant(-(int)Subtarget.getXLen(), DL, VT);
+  SDValue XLenMinus1 = DAG.getConstant(Subtarget.getXLen() - 1, DL, VT);
+  SDValue ShamtMinusXLen = DAG.getNode(ISD::ADD, DL, VT, Shamt, MinusXLen);
+  SDValue XLenMinus1Shamt = DAG.getNode(ISD::SUB, DL, VT, XLenMinus1, Shamt);
+
+  SDValue LoTrue = DAG.getNode(ISD::SHL, DL, VT, Lo, Shamt);
+  SDValue ShiftRight1Lo = DAG.getNode(ISD::SRL, DL, VT, Lo, One);
+  SDValue ShiftRightLo =
+      DAG.getNode(ISD::SRL, DL, VT, ShiftRight1Lo, XLenMinus1Shamt);
+  SDValue ShiftLeftHi = DAG.getNode(ISD::SHL, DL, VT, Hi, Shamt);
+  SDValue HiTrue = DAG.getNode(ISD::OR, DL, VT, ShiftLeftHi, ShiftRightLo);
+  SDValue HiFalse = DAG.getNode(ISD::SHL, DL, VT, Lo, ShamtMinusXLen);
+
+  SDValue CC = DAG.getSetCC(DL, VT, ShamtMinusXLen, Zero, ISD::SETLT);
+
+  Lo = DAG.getNode(ISD::SELECT, DL, VT, CC, LoTrue, Zero);
+  Hi = DAG.getNode(ISD::SELECT, DL, VT, CC, HiTrue, HiFalse);
+
+  SDValue Parts[2] = {Lo, Hi};
+  return DAG.getMergeValues(Parts, DL);
+}
+
+SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
+                                                  bool IsSRA) const {
+  SDLoc DL(Op);
+  SDValue Lo = Op.getOperand(0);
+  SDValue Hi = Op.getOperand(1);
+  SDValue Shamt = Op.getOperand(2);
+  EVT VT = Lo.getValueType();
+
+  // SRA expansion:
+  //   if Shamt-XLEN < 0: // Shamt < XLEN
+  //     Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - Shamt))
+  //     Hi = Hi >>s Shamt
+  //   else:
+  //     Lo = Hi >>s (Shamt-XLEN);
+  //     Hi = Hi >>s (XLEN-1)
+  //
+  // SRL expansion:
+  //   if Shamt-XLEN < 0: // Shamt < XLEN
+  //     Lo = (Lo >>u Shamt) | ((Hi << 1) << (XLEN-1 - Shamt))
+  //     Hi = Hi >>u Shamt
+  //   else:
+  //     Lo = Hi >>u (Shamt-XLEN);
+  //     Hi = 0;
+
+  unsigned ShiftRightOp = IsSRA ? ISD::SRA : ISD::SRL;
+
+  SDValue Zero = DAG.getConstant(0, DL, VT);
+  SDValue One = DAG.getConstant(1, DL, VT);
+  SDValue MinusXLen = DAG.getConstant(-(int)Subtarget.getXLen(), DL, VT);
+  SDValue XLenMinus1 = DAG.getConstant(Subtarget.getXLen() - 1, DL, VT);
+  SDValue ShamtMinusXLen = DAG.getNode(ISD::ADD, DL, VT, Shamt, MinusXLen);
+  SDValue XLenMinus1Shamt = DAG.getNode(ISD::SUB, DL, VT, XLenMinus1, Shamt);
+
+  SDValue ShiftRightLo = DAG.getNode(ISD::SRL, DL, VT, Lo, Shamt);
+  SDValue ShiftLeftHi1 = DAG.getNode(ISD::SHL, DL, VT, Hi, One);
+  SDValue ShiftLeftHi =
+      DAG.getNode(ISD::SHL, DL, VT, ShiftLeftHi1, XLenMinus1Shamt);
+  SDValue LoTrue = DAG.getNode(ISD::OR, DL, VT, ShiftRightLo, ShiftLeftHi);
+  SDValue HiTrue = DAG.getNode(ShiftRightOp, DL, VT, Hi, Shamt);
+  SDValue LoFalse = DAG.getNode(ShiftRightOp, DL, VT, Hi, ShamtMinusXLen);
+  SDValue HiFalse =
+      IsSRA ? DAG.getNode(ISD::SRA, DL, VT, Hi, XLenMinus1) : Zero;
+
+  SDValue CC = DAG.getSetCC(DL, VT, ShamtMinusXLen, Zero, ISD::SETLT);
+
+  Lo = DAG.getNode(ISD::SELECT, DL, VT, CC, LoTrue, LoFalse);
+  Hi = DAG.getNode(ISD::SELECT, DL, VT, CC, HiTrue, HiFalse);
+
+  SDValue Parts[2] = {Lo, Hi};
+  return DAG.getMergeValues(Parts, DL);
+}
+
 // Returns the opcode of the target-specific SDNode that implements the 32-bit
 // form of the given Opcode.
 static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {