[Intrinsic] Unsigned Fixed Point Multiplication Intrinsic

Add an intrinsic that takes 2 unsigned integers with the scale of them
provided as the third argument and performs fixed point multiplication on
them.

This is a part of implementing fixed point arithmetic in clang where some of
the more complex operations will be implemented as intrinsics.

Differential Revision: https://reviews.llvm.org/D55625

llvm-svn: 353059
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 9e7cb45..cba80e1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1127,7 +1127,8 @@
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   }
-  case ISD::SMULFIX: {
+  case ISD::SMULFIX:
+  case ISD::UMULFIX: {
     unsigned Scale = Node->getConstantOperandVal(2);
     Action = TLI.getFixedPointOperationAction(Node->getOpcode(),
                                               Node->getValueType(0), Scale);
@@ -3290,6 +3291,7 @@
     Results.push_back(TLI.expandAddSubSat(Node, DAG));
     break;
   case ISD::SMULFIX:
+  case ISD::UMULFIX:
     Results.push_back(TLI.expandFixedPointMul(Node, DAG));
     break;
   case ISD::SADDO:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index ae98935..9873b1e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -148,7 +148,8 @@
   case ISD::UADDSAT:
   case ISD::SSUBSAT:
   case ISD::USUBSAT:     Res = PromoteIntRes_ADDSUBSAT(N); break;
-  case ISD::SMULFIX:     Res = PromoteIntRes_SMULFIX(N); break;
+  case ISD::SMULFIX:
+  case ISD::UMULFIX:     Res = PromoteIntRes_MULFIX(N); break;
 
   case ISD::ATOMIC_LOAD:
     Res = PromoteIntRes_Atomic0(cast<AtomicSDNode>(N)); break;
@@ -645,11 +646,17 @@
   return DAG.getNode(ShiftOp, dl, PromotedType, Result, ShiftAmount);
 }
 
-SDValue DAGTypeLegalizer::PromoteIntRes_SMULFIX(SDNode *N) {
+SDValue DAGTypeLegalizer::PromoteIntRes_MULFIX(SDNode *N) {
   // Can just promote the operands then continue with operation.
   SDLoc dl(N);
-  SDValue Op1Promoted = SExtPromotedInteger(N->getOperand(0));
-  SDValue Op2Promoted = SExtPromotedInteger(N->getOperand(1));
+  SDValue Op1Promoted, Op2Promoted;
+  if (N->getOpcode() == ISD::SMULFIX) {
+    Op1Promoted = SExtPromotedInteger(N->getOperand(0));
+    Op2Promoted = SExtPromotedInteger(N->getOperand(1));
+  } else {
+    Op1Promoted = ZExtPromotedInteger(N->getOperand(0));
+    Op2Promoted = ZExtPromotedInteger(N->getOperand(1));
+  }
   EVT PromotedType = Op1Promoted.getValueType();
   return DAG.getNode(N->getOpcode(), dl, PromotedType, Op1Promoted, Op2Promoted,
                      N->getOperand(2));
@@ -1090,7 +1097,8 @@
 
   case ISD::PREFETCH: Res = PromoteIntOp_PREFETCH(N, OpNo); break;
 
-  case ISD::SMULFIX: Res = PromoteIntOp_SMULFIX(N); break;
+  case ISD::SMULFIX:
+  case ISD::UMULFIX: Res = PromoteIntOp_MULFIX(N); break;
 
   case ISD::FPOWI: Res = PromoteIntOp_FPOWI(N); break;
   }
@@ -1452,7 +1460,7 @@
   return SDValue(DAG.UpdateNodeOperands(N, LHS, RHS, Carry), 0);
 }
 
-SDValue DAGTypeLegalizer::PromoteIntOp_SMULFIX(SDNode *N) {
+SDValue DAGTypeLegalizer::PromoteIntOp_MULFIX(SDNode *N) {
   SDValue Op2 = ZExtPromotedInteger(N->getOperand(2));
   return SDValue(
       DAG.UpdateNodeOperands(N, N->getOperand(0), N->getOperand(1), Op2), 0);
@@ -1620,7 +1628,8 @@
   case ISD::UADDSAT:
   case ISD::SSUBSAT:
   case ISD::USUBSAT: ExpandIntRes_ADDSUBSAT(N, Lo, Hi); break;
-  case ISD::SMULFIX: ExpandIntRes_SMULFIX(N, Lo, Hi); break;
+  case ISD::SMULFIX:
+  case ISD::UMULFIX: ExpandIntRes_MULFIX(N, Lo, Hi); break;
   }
 
   // If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -2588,8 +2597,12 @@
   SplitInteger(Result, Lo, Hi);
 }
 
-void DAGTypeLegalizer::ExpandIntRes_SMULFIX(SDNode *N, SDValue &Lo,
-                                            SDValue &Hi) {
+void DAGTypeLegalizer::ExpandIntRes_MULFIX(SDNode *N, SDValue &Lo,
+                                           SDValue &Hi) {
+  assert(
+      (N->getOpcode() == ISD::SMULFIX || N->getOpcode() == ISD::UMULFIX) &&
+      "Expected operand to be signed or unsigned fixed point multiplication");
+
   SDLoc dl(N);
   EVT VT = N->getValueType(0);
   SDValue LHS = N->getOperand(0);
@@ -2607,10 +2620,12 @@
   GetExpandedInteger(RHS, RL, RH);
   SmallVector<SDValue, 4> Result;
 
-  if (!TLI.expandMUL_LOHI(ISD::SMUL_LOHI, VT, dl, LHS, RHS, Result, NVT, DAG,
+  bool Signed = N->getOpcode() == ISD::SMULFIX;
+  unsigned LoHiOp = Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
+  if (!TLI.expandMUL_LOHI(LoHiOp, VT, dl, LHS, RHS, Result, NVT, DAG,
                           TargetLowering::MulExpansionKind::OnlyLegalOrCustom,
                           LL, LH, RL, RH)) {
-    report_fatal_error("Unable to expand SMUL_FIX using SMUL_LOHI.");
+    report_fatal_error("Unable to expand MUL_FIX using MUL_LOHI.");
     return;
   }
 
@@ -2671,9 +2686,16 @@
     Hi = DAG.getNode(ISD::SRL, dl, NVT, ResultHL, SRLAmnt);
     Hi = DAG.getNode(ISD::OR, dl, NVT, Hi,
                      DAG.getNode(ISD::SHL, dl, NVT, ResultHH, SHLAmnt));
+  } else if (Scale == VTSize) {
+    assert(
+        !Signed &&
+        "Only unsigned types can have a scale equal to the operand bit width");
+
+    Lo = ResultHL;
+    Hi = ResultHH;
   } else {
-    llvm_unreachable(
-        "Expected the scale to be less than the width of the operands");
+    llvm_unreachable("Expected the scale to be less than or equal to the width "
+                     "of the operands");
   }
 }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index eb14c63..a2a8d16 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -344,7 +344,7 @@
   SDValue PromoteIntRes_VAARG(SDNode *N);
   SDValue PromoteIntRes_XMULO(SDNode *N, unsigned ResNo);
   SDValue PromoteIntRes_ADDSUBSAT(SDNode *N);
-  SDValue PromoteIntRes_SMULFIX(SDNode *N);
+  SDValue PromoteIntRes_MULFIX(SDNode *N);
   SDValue PromoteIntRes_FLT_ROUNDS(SDNode *N);
 
   // Integer Operand Promotion.
@@ -378,7 +378,7 @@
   SDValue PromoteIntOp_ADDSUBCARRY(SDNode *N, unsigned OpNo);
   SDValue PromoteIntOp_FRAMERETURNADDR(SDNode *N);
   SDValue PromoteIntOp_PREFETCH(SDNode *N, unsigned OpNo);
-  SDValue PromoteIntOp_SMULFIX(SDNode *N);
+  SDValue PromoteIntOp_MULFIX(SDNode *N);
   SDValue PromoteIntOp_FPOWI(SDNode *N);
 
   void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -435,7 +435,7 @@
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_XMULO             (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_ADDSUBSAT         (SDNode *N, SDValue &Lo, SDValue &Hi);
-  void ExpandIntRes_SMULFIX           (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_MULFIX            (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_ATOMIC_LOAD       (SDNode *N, SDValue &Lo, SDValue &Hi);
 
@@ -692,7 +692,7 @@
   SDValue ScalarizeVecRes_UNDEF(SDNode *N);
   SDValue ScalarizeVecRes_VECTOR_SHUFFLE(SDNode *N);
 
-  SDValue ScalarizeVecRes_SMULFIX(SDNode *N);
+  SDValue ScalarizeVecRes_MULFIX(SDNode *N);
 
   // Vector Operand Scalarization: <1 x ty> -> ty.
   bool ScalarizeVectorOperand(SDNode *N, unsigned OpNo);
@@ -729,7 +729,7 @@
   void SplitVecRes_ExtVecInRegOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_StrictFPOp(SDNode *N, SDValue &Lo, SDValue &Hi);
 
-  void SplitVecRes_SMULFIX(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_MULFIX(SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void SplitVecRes_BITCAST(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_BUILD_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6ff288c..5d080e0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -425,7 +425,8 @@
   case ISD::USUBSAT:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
-  case ISD::SMULFIX: {
+  case ISD::SMULFIX:
+  case ISD::UMULFIX: {
     unsigned Scale = Node->getConstantOperandVal(2);
     Action = TLI.getFixedPointOperationAction(Node->getOpcode(),
                                               Node->getValueType(0), Scale);
@@ -784,6 +785,7 @@
   case ISD::SADDSAT:
     return ExpandAddSubSat(Op);
   case ISD::SMULFIX:
+  case ISD::UMULFIX:
     return ExpandFixedPointMul(Op);
   case ISD::STRICT_FADD:
   case ISD::STRICT_FSUB:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 8252f8a..32876bf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -172,7 +172,8 @@
     R = ScalarizeVecRes_StrictFPOp(N);
     break;
   case ISD::SMULFIX:
-    R = ScalarizeVecRes_SMULFIX(N);
+  case ISD::UMULFIX:
+    R = ScalarizeVecRes_MULFIX(N);
     break;
   }
 
@@ -196,7 +197,7 @@
                      Op0.getValueType(), Op0, Op1, Op2);
 }
 
-SDValue DAGTypeLegalizer::ScalarizeVecRes_SMULFIX(SDNode *N) {
+SDValue DAGTypeLegalizer::ScalarizeVecRes_MULFIX(SDNode *N) {
   SDValue Op0 = GetScalarizedVector(N->getOperand(0));
   SDValue Op1 = GetScalarizedVector(N->getOperand(1));
   SDValue Op2 = N->getOperand(2);
@@ -859,7 +860,8 @@
     SplitVecRes_StrictFPOp(N, Lo, Hi);
     break;
   case ISD::SMULFIX:
-    SplitVecRes_SMULFIX(N, Lo, Hi);
+  case ISD::UMULFIX:
+    SplitVecRes_MULFIX(N, Lo, Hi);
     break;
   }
 
@@ -898,8 +900,7 @@
                    Op0Hi, Op1Hi, Op2Hi);
 }
 
-void DAGTypeLegalizer::SplitVecRes_SMULFIX(SDNode *N, SDValue &Lo,
-                                           SDValue &Hi) {
+void DAGTypeLegalizer::SplitVecRes_MULFIX(SDNode *N, SDValue &Lo, SDValue &Hi) {
   SDValue LHSLo, LHSHi;
   GetSplitVector(N->getOperand(0), LHSLo, LHSHi);
   SDValue RHSLo, RHSHi;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1111a3e..58b69dc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5077,6 +5077,17 @@
 #  define setjmp_undefined_for_msvc
 #endif
 
+static unsigned FixedPointIntrinsicToOpcode(unsigned Intrinsic) {
+  switch (Intrinsic) {
+  case Intrinsic::smul_fix:
+    return ISD::SMULFIX;
+  case Intrinsic::umul_fix:
+    return ISD::UMULFIX;
+  default:
+    llvm_unreachable("Unhandled fixed point intrinsic");
+  }
+}
+
 /// Lower the call to the specified intrinsic function. If we want to emit this
 /// as a call to a named external function, return the name. Otherwise, lower it
 /// and return null.
@@ -5880,12 +5891,13 @@
     setValue(&I, DAG.getNode(ISD::USUBSAT, sdl, Op1.getValueType(), Op1, Op2));
     return nullptr;
   }
-  case Intrinsic::smul_fix: {
+  case Intrinsic::smul_fix:
+  case Intrinsic::umul_fix: {
     SDValue Op1 = getValue(I.getArgOperand(0));
     SDValue Op2 = getValue(I.getArgOperand(1));
     SDValue Op3 = getValue(I.getArgOperand(2));
-    setValue(&I,
-             DAG.getNode(ISD::SMULFIX, sdl, Op1.getValueType(), Op1, Op2, Op3));
+    setValue(&I, DAG.getNode(FixedPointIntrinsicToOpcode(Intrinsic), sdl,
+                             Op1.getValueType(), Op1, Op2, Op3));
     return nullptr;
   }
   case Intrinsic::stacksave: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 4b0481c..c14b94e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -300,6 +300,7 @@
   case ISD::SSUBSAT:                    return "ssubsat";
   case ISD::USUBSAT:                    return "usubsat";
   case ISD::SMULFIX:                    return "smulfix";
+  case ISD::UMULFIX:                    return "umulfix";
 
   // Conversion operators.
   case ISD::SIGN_EXTEND:                return "sign_extend";
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 7c96813..2257b44 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5415,7 +5415,9 @@
 
 SDValue
 TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
-  assert(Node->getOpcode() == ISD::SMULFIX && "Expected opcode to be SMULFIX.");
+  assert((Node->getOpcode() == ISD::SMULFIX ||
+          Node->getOpcode() == ISD::UMULFIX) &&
+         "Expected opcode to be SMULFIX or UMULFIX.");
 
   SDLoc dl(Node);
   SDValue LHS = Node->getOperand(0);
@@ -5430,27 +5432,37 @@
     return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
   }
 
+  unsigned VTSize = VT.getScalarSizeInBits();
+  bool Signed = Node->getOpcode() == ISD::SMULFIX;
+
+  assert(((Signed && Scale < VTSize) || (!Signed && Scale <= VTSize)) &&
+         "Expected scale to be less than the number of bits if signed or at "
+         "most the number of bits if unsigned.");
   assert(LHS.getValueType() == RHS.getValueType() &&
          "Expected both operands to be the same type");
-  assert(Scale < VT.getScalarSizeInBits() &&
-         "Expected scale to be less than the number of bits.");
 
   // Get the upper and lower bits of the result.
   SDValue Lo, Hi;
-  if (isOperationLegalOrCustom(ISD::SMUL_LOHI, VT)) {
-    SDValue Result =
-        DAG.getNode(ISD::SMUL_LOHI, dl, DAG.getVTList(VT, VT), LHS, RHS);
+  unsigned LoHiOp = Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
+  unsigned HiOp = Signed ? ISD::MULHS : ISD::MULHU;
+  if (isOperationLegalOrCustom(LoHiOp, VT)) {
+    SDValue Result = DAG.getNode(LoHiOp, dl, DAG.getVTList(VT, VT), LHS, RHS);
     Lo = Result.getValue(0);
     Hi = Result.getValue(1);
-  } else if (isOperationLegalOrCustom(ISD::MULHS, VT)) {
+  } else if (isOperationLegalOrCustom(HiOp, VT)) {
     Lo = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
-    Hi = DAG.getNode(ISD::MULHS, dl, VT, LHS, RHS);
+    Hi = DAG.getNode(HiOp, dl, VT, LHS, RHS);
   } else if (VT.isVector()) {
     return SDValue();
   } else {
-    report_fatal_error("Unable to expand signed fixed point multiplication.");
+    report_fatal_error("Unable to expand fixed point multiplication.");
   }
 
+  if (Scale == VTSize)
+    // Result is just the top half since we'd be shifting by the width of the
+    // operand.
+    return Hi;
+
   // The result will need to be shifted right by the scale since both operands
   // are scaled. The result is given to us in 2 halves, so we only want part of
   // both in the result.
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index b87aa5d..280305d 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -624,6 +624,7 @@
     setOperationAction(ISD::SSUBSAT, VT, Expand);
     setOperationAction(ISD::USUBSAT, VT, Expand);
     setOperationAction(ISD::SMULFIX, VT, Expand);
+    setOperationAction(ISD::UMULFIX, VT, Expand);
 
     // Overflow operations default to expand
     setOperationAction(ISD::SADDO, VT, Expand);