[Intrinsic] Signed Fixed Point Saturation Multiplication Intrinsic

Add an intrinsic that takes 2 signed integers with the scale of them provided
as the third argument and performs fixed point multiplication on them. The
result is saturated and clamped between the largest and smallest representable
values of the first 2 operands.

This is a part of implementing fixed point arithmetic in clang where some of
the more complex operations will be implemented as intrinsics.

Differential Revision: https://reviews.llvm.org/D55720

llvm-svn: 361289
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 63d407c..52ae1e0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1140,6 +1140,7 @@
     break;
   }
   case ISD::SMULFIX:
+  case ISD::SMULFIXSAT:
   case ISD::UMULFIX: {
     unsigned Scale = Node->getConstantOperandVal(2);
     Action = TLI.getFixedPointOperationAction(Node->getOpcode(),
@@ -3334,6 +3335,7 @@
     Results.push_back(TLI.expandAddSubSat(Node, DAG));
     break;
   case ISD::SMULFIX:
+  case ISD::SMULFIXSAT:
   case ISD::UMULFIX:
     Results.push_back(TLI.expandFixedPointMul(Node, DAG));
     break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 0930b63..357654f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -149,6 +149,7 @@
   case ISD::SSUBSAT:
   case ISD::USUBSAT:     Res = PromoteIntRes_ADDSUBSAT(N); break;
   case ISD::SMULFIX:
+  case ISD::SMULFIXSAT:
   case ISD::UMULFIX:     Res = PromoteIntRes_MULFIX(N); break;
   case ISD::ABS:         Res = PromoteIntRes_ABS(N); break;
 
@@ -670,14 +671,35 @@
   // Can just promote the operands then continue with operation.
   SDLoc dl(N);
   SDValue Op1Promoted, Op2Promoted;
-  if (N->getOpcode() == ISD::SMULFIX) {
+  bool Signed =
+      N->getOpcode() == ISD::SMULFIX || N->getOpcode() == ISD::SMULFIXSAT;
+  if (Signed) {
     Op1Promoted = SExtPromotedInteger(N->getOperand(0));
     Op2Promoted = SExtPromotedInteger(N->getOperand(1));
   } else {
     Op1Promoted = ZExtPromotedInteger(N->getOperand(0));
     Op2Promoted = ZExtPromotedInteger(N->getOperand(1));
   }
+  EVT OldType = N->getOperand(0).getValueType();
   EVT PromotedType = Op1Promoted.getValueType();
+  unsigned DiffSize =
+      PromotedType.getScalarSizeInBits() - OldType.getScalarSizeInBits();
+
+  bool Saturating = N->getOpcode() == ISD::SMULFIXSAT;
+  if (Saturating) {
+    // Promoting the operand and result values changes the saturation width,
+    // which is extends the values that we clamp to on saturation. This could be
+    // resolved by shifting one of the operands the same amount, which would
+    // also shift the result we compare against, then shifting back.
+    EVT ShiftTy = TLI.getShiftAmountTy(PromotedType, DAG.getDataLayout());
+    Op1Promoted = DAG.getNode(ISD::SHL, dl, PromotedType, Op1Promoted,
+                              DAG.getConstant(DiffSize, dl, ShiftTy));
+    SDValue Result = DAG.getNode(N->getOpcode(), dl, PromotedType, Op1Promoted,
+                                 Op2Promoted, N->getOperand(2));
+    unsigned ShiftOp = Signed ? ISD::SRA : ISD::SRL;
+    return DAG.getNode(ShiftOp, dl, PromotedType, Result,
+                       DAG.getConstant(DiffSize, dl, ShiftTy));
+  }
   return DAG.getNode(N->getOpcode(), dl, PromotedType, Op1Promoted, Op2Promoted,
                      N->getOperand(2));
 }
@@ -1125,6 +1147,7 @@
   case ISD::PREFETCH: Res = PromoteIntOp_PREFETCH(N, OpNo); break;
 
   case ISD::SMULFIX:
+  case ISD::SMULFIXSAT:
   case ISD::UMULFIX: Res = PromoteIntOp_MULFIX(N); break;
 
   case ISD::FPOWI: Res = PromoteIntOp_FPOWI(N); break;
@@ -1688,7 +1711,9 @@
   case ISD::UADDSAT:
   case ISD::SSUBSAT:
   case ISD::USUBSAT: ExpandIntRes_ADDSUBSAT(N, Lo, Hi); break;
+
   case ISD::SMULFIX:
+  case ISD::SMULFIXSAT:
   case ISD::UMULFIX: ExpandIntRes_MULFIX(N, Lo, Hi); break;
 
   case ISD::VECREDUCE_ADD:
@@ -2712,19 +2737,40 @@
   SplitInteger(Result, Lo, Hi);
 }
 
+/// This performs an expansion of the integer result for a fixed point
+/// multiplication. The default expansion performs rounding down towards
+/// negative infinity, though targets that do care about rounding should specify
+/// a target hook for rounding and provide their own expansion or lowering of
+/// fixed point multiplication to be consistent with rounding.
 void DAGTypeLegalizer::ExpandIntRes_MULFIX(SDNode *N, SDValue &Lo,
                                            SDValue &Hi) {
-  assert(
-      (N->getOpcode() == ISD::SMULFIX || N->getOpcode() == ISD::UMULFIX) &&
-      "Expected operand to be signed or unsigned fixed point multiplication");
-
   SDLoc dl(N);
   EVT VT = N->getValueType(0);
+  unsigned VTSize = VT.getScalarSizeInBits();
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
   uint64_t Scale = N->getConstantOperandVal(2);
+  bool Saturating = N->getOpcode() == ISD::SMULFIXSAT;
+  EVT BoolVT =
+      TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+  SDValue Zero = DAG.getConstant(0, dl, VT);
   if (!Scale) {
-    SDValue Result = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
+    SDValue Result;
+    if (!Saturating) {
+      Result = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
+    } else {
+      Result = DAG.getNode(ISD::SMULO, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
+      SDValue Product = Result.getValue(0);
+      SDValue Overflow = Result.getValue(1);
+
+      APInt MinVal = APInt::getSignedMinValue(VTSize);
+      APInt MaxVal = APInt::getSignedMaxValue(VTSize);
+      SDValue SatMin = DAG.getConstant(MinVal, dl, VT);
+      SDValue SatMax = DAG.getConstant(MaxVal, dl, VT);
+      SDValue ProdNeg = DAG.getSetCC(dl, BoolVT, Product, Zero, ISD::SETLT);
+      Result = DAG.getSelect(dl, VT, ProdNeg, SatMax, SatMin);
+      Result = DAG.getSelect(dl, VT, Overflow, Result, Product);
+    }
     SplitInteger(Result, Lo, Hi);
     return;
   }
@@ -2735,7 +2781,8 @@
   GetExpandedInteger(RHS, RL, RH);
   SmallVector<SDValue, 4> Result;
 
-  bool Signed = N->getOpcode() == ISD::SMULFIX;
+  bool Signed = (N->getOpcode() == ISD::SMULFIX ||
+                 N->getOpcode() == ISD::SMULFIXSAT);
   unsigned LoHiOp = Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
   if (!TLI.expandMUL_LOHI(LoHiOp, VT, dl, LHS, RHS, Result, NVT, DAG,
                           TargetLowering::MulExpansionKind::OnlyLegalOrCustom,
@@ -2744,8 +2791,9 @@
     return;
   }
 
-  unsigned VTSize = VT.getScalarSizeInBits();
   unsigned NVTSize = NVT.getScalarSizeInBits();
+  assert((VTSize == NVTSize * 2) && "Expected the new value type to be half "
+                                    "the size of the current value type");
   EVT ShiftTy = TLI.getShiftAmountTy(NVT, DAG.getDataLayout());
 
   // Shift whole amount by scale.
@@ -2754,6 +2802,12 @@
   SDValue ResultHL = Result[2];
   SDValue ResultHH = Result[3];
 
+  SDValue SatMax, SatMin;
+  SDValue NVTZero = DAG.getConstant(0, dl, NVT);
+  SDValue NVTNeg1 = DAG.getConstant(-1, dl, NVT);
+  EVT BoolNVT =
+      TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), NVT);
+
   // After getting the multplication result in 4 parts, we need to perform a
   // shift right by the amount of the scale to get the result in that scale.
   // Let's say we multiply 2 64 bit numbers. The resulting value can be held in
@@ -2782,11 +2836,60 @@
     Hi = DAG.getNode(ISD::SRL, dl, NVT, ResultLH, SRLAmnt);
     Hi = DAG.getNode(ISD::OR, dl, NVT, Hi,
                      DAG.getNode(ISD::SHL, dl, NVT, ResultHL, SHLAmnt));
+
+    // We cannot overflow past HH when multiplying 2 ints of size VTSize, so the
+    // highest bit of HH determines saturation direction in the event of
+    // saturation.
+    // The number of overflow bits we can check are VTSize - Scale + 1 (we
+    // include the sign bit). If these top bits are > 0, then we overflowed past
+    // the max value. If these top bits are < -1, then we overflowed past the
+    // min value. Otherwise, we did not overflow.
+    if (Saturating) {
+      unsigned OverflowBits = VTSize - Scale + 1;
+      assert(OverflowBits <= VTSize && OverflowBits > NVTSize &&
+             "Extent of overflow bits must start within HL");
+      SDValue HLHiMask = DAG.getConstant(
+          APInt::getHighBitsSet(NVTSize, OverflowBits - NVTSize), dl, NVT);
+      SDValue HLLoMask = DAG.getConstant(
+          APInt::getLowBitsSet(NVTSize, VTSize - OverflowBits), dl, NVT);
+
+      // HH > 0 or HH == 0 && HL > HLLoMask
+      SDValue HHPos = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETGT);
+      SDValue HHZero = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETEQ);
+      SDValue HLPos =
+          DAG.getSetCC(dl, BoolNVT, ResultHL, HLLoMask, ISD::SETUGT);
+      SatMax = DAG.getNode(ISD::OR, dl, BoolNVT, HHPos,
+                           DAG.getNode(ISD::AND, dl, BoolNVT, HHZero, HLPos));
+
+      // HH < -1 or HH == -1 && HL < HLHiMask
+      SDValue HHNeg = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETLT);
+      SDValue HHNeg1 = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETEQ);
+      SDValue HLNeg =
+          DAG.getSetCC(dl, BoolNVT, ResultHL, HLHiMask, ISD::SETULT);
+      SatMin = DAG.getNode(ISD::OR, dl, BoolNVT, HHNeg,
+                           DAG.getNode(ISD::AND, dl, BoolNVT, HHNeg1, HLNeg));
+    }
   } else if (Scale == NVTSize) {
     // If the scales are equal, Lo and Hi are ResultLH and Result HL,
     // respectively. Avoid shifting to prevent undefined behavior.
     Lo = ResultLH;
     Hi = ResultHL;
+
+    // We overflow max if HH > 0 or HH == 0 && HL sign is negative.
+    // We overflow min if HH < -1 or HH == -1 && HL sign is 0.
+    if (Saturating) {
+      SDValue HHPos = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETGT);
+      SDValue HHZero = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETEQ);
+      SDValue HLNeg = DAG.getSetCC(dl, BoolNVT, ResultHL, NVTZero, ISD::SETLT);
+      SatMax = DAG.getNode(ISD::OR, dl, BoolNVT, HHPos,
+                           DAG.getNode(ISD::AND, dl, BoolNVT, HHZero, HLNeg));
+
+      SDValue HHNeg = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETLT);
+      SDValue HHNeg1 = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETEQ);
+      SDValue HLPos = DAG.getSetCC(dl, BoolNVT, ResultHL, NVTZero, ISD::SETGT);
+      SatMin = DAG.getNode(ISD::OR, dl, BoolNVT, HHNeg,
+                           DAG.getNode(ISD::AND, dl, BoolNVT, HHNeg1, HLPos));
+    }
   } else if (Scale < VTSize) {
     // If the scale is instead less than the old VT size, but greater than or
     // equal to the expanded VT size, the first part of the result (ResultLL) is
@@ -2801,6 +2904,19 @@
     Hi = DAG.getNode(ISD::SRL, dl, NVT, ResultHL, SRLAmnt);
     Hi = DAG.getNode(ISD::OR, dl, NVT, Hi,
                      DAG.getNode(ISD::SHL, dl, NVT, ResultHH, SHLAmnt));
+
+    // This is similar to the case when we saturate if Scale < NVTSize, but we
+    // only need to chech HH.
+    if (Saturating) {
+      unsigned OverflowBits = VTSize - Scale + 1;
+      SDValue HHHiMask = DAG.getConstant(
+          APInt::getHighBitsSet(NVTSize, OverflowBits), dl, NVT);
+      SDValue HHLoMask = DAG.getConstant(
+          APInt::getLowBitsSet(NVTSize, NVTSize - OverflowBits), dl, NVT);
+
+      SatMax = DAG.getSetCC(dl, BoolNVT, ResultHH, HHLoMask, ISD::SETGT);
+      SatMin = DAG.getSetCC(dl, BoolNVT, ResultHH, HHHiMask, ISD::SETLT);
+    }
   } else if (Scale == VTSize) {
     assert(
         !Signed &&
@@ -2812,6 +2928,16 @@
     llvm_unreachable("Expected the scale to be less than or equal to the width "
                      "of the operands");
   }
+
+  if (Saturating) {
+    APInt LHMax = APInt::getSignedMaxValue(NVTSize);
+    APInt LLMax = APInt::getAllOnesValue(NVTSize);
+    APInt LHMin = APInt::getSignedMinValue(NVTSize);
+    Hi = DAG.getSelect(dl, NVT, SatMax, DAG.getConstant(LHMax, dl, NVT), Hi);
+    Hi = DAG.getSelect(dl, NVT, SatMin, DAG.getConstant(LHMin, dl, NVT), Hi);
+    Lo = DAG.getSelect(dl, NVT, SatMax, DAG.getConstant(LLMax, dl, NVT), Lo);
+    Lo = DAG.getSelect(dl, NVT, SatMin, NVTZero, Lo);
+  }
 }
 
 void DAGTypeLegalizer::ExpandIntRes_SADDSUBO(SDNode *Node,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index ad2e398..f77ccd9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -438,6 +438,7 @@
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
+  case ISD::SMULFIXSAT:
   case ISD::UMULFIX: {
     unsigned Scale = Node->getConstantOperandVal(2);
     Action = TLI.getFixedPointOperationAction(Node->getOpcode(),
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index add97ec..8570f57 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -183,6 +183,7 @@
     R = ScalarizeVecRes_OverflowOp(N, ResNo);
     break;
   case ISD::SMULFIX:
+  case ISD::SMULFIXSAT:
   case ISD::UMULFIX:
     R = ScalarizeVecRes_MULFIX(N);
     break;
@@ -971,6 +972,7 @@
     SplitVecRes_OverflowOp(N, ResNo, Lo, Hi);
     break;
   case ISD::SMULFIX:
+  case ISD::SMULFIXSAT:
   case ISD::UMULFIX:
     SplitVecRes_MULFIX(N, Lo, Hi);
     break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 31410f2..5ac9d79 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6298,6 +6298,14 @@
                              Op1.getValueType(), Op1, Op2, Op3));
     return;
   }
+  case Intrinsic::smul_fix_sat: {
+    SDValue Op1 = getValue(I.getArgOperand(0));
+    SDValue Op2 = getValue(I.getArgOperand(1));
+    SDValue Op3 = getValue(I.getArgOperand(2));
+    setValue(&I, DAG.getNode(ISD::SMULFIXSAT, sdl, Op1.getValueType(), Op1, Op2,
+                             Op3));
+    return;
+  }
   case Intrinsic::stacksave: {
     SDValue Op = getRoot();
     Res = DAG.getNode(
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index cbef6cc..2841633 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -301,7 +301,9 @@
   case ISD::UADDSAT:                    return "uaddsat";
   case ISD::SSUBSAT:                    return "ssubsat";
   case ISD::USUBSAT:                    return "usubsat";
+
   case ISD::SMULFIX:                    return "smulfix";
+  case ISD::SMULFIXSAT:                 return "smulfixsat";
   case ISD::UMULFIX:                    return "umulfix";
 
   // Conversion operators.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index f07180a..ac45f4e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5695,25 +5695,42 @@
 SDValue
 TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
   assert((Node->getOpcode() == ISD::SMULFIX ||
-          Node->getOpcode() == ISD::UMULFIX) &&
-         "Expected opcode to be SMULFIX or UMULFIX.");
+          Node->getOpcode() == ISD::UMULFIX ||
+          Node->getOpcode() == ISD::SMULFIXSAT) &&
+         "Expected a fixed point multiplication opcode");
 
   SDLoc dl(Node);
   SDValue LHS = Node->getOperand(0);
   SDValue RHS = Node->getOperand(1);
   EVT VT = LHS.getValueType();
   unsigned Scale = Node->getConstantOperandVal(2);
+  bool Saturating = Node->getOpcode() == ISD::SMULFIXSAT;
+  EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+  unsigned VTSize = VT.getScalarSizeInBits();
 
-  // [us]mul.fix(a, b, 0) -> mul(a, b)
   if (!Scale) {
-    if (VT.isVector() && !isOperationLegalOrCustom(ISD::MUL, VT))
-      return SDValue();
-    return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
+    // [us]mul.fix(a, b, 0) -> mul(a, b)
+    if (!Saturating && isOperationLegalOrCustom(ISD::MUL, VT)) {
+      return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
+    } else if (Saturating && isOperationLegalOrCustom(ISD::SMULO, VT)) {
+      SDValue Result =
+          DAG.getNode(ISD::SMULO, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
+      SDValue Product = Result.getValue(0);
+      SDValue Overflow = Result.getValue(1);
+      SDValue Zero = DAG.getConstant(0, dl, VT);
+
+      APInt MinVal = APInt::getSignedMinValue(VTSize);
+      APInt MaxVal = APInt::getSignedMaxValue(VTSize);
+      SDValue SatMin = DAG.getConstant(MinVal, dl, VT);
+      SDValue SatMax = DAG.getConstant(MaxVal, dl, VT);
+      SDValue ProdNeg = DAG.getSetCC(dl, BoolVT, Product, Zero, ISD::SETLT);
+      Result = DAG.getSelect(dl, VT, ProdNeg, SatMax, SatMin);
+      return DAG.getSelect(dl, VT, Overflow, Result, Product);
+    }
   }
 
-  unsigned VTSize = VT.getScalarSizeInBits();
-  bool Signed = Node->getOpcode() == ISD::SMULFIX;
-
+  bool Signed =
+      Node->getOpcode() == ISD::SMULFIX || Node->getOpcode() == ISD::SMULFIXSAT;
   assert(((Signed && Scale < VTSize) || (!Signed && Scale <= VTSize)) &&
          "Expected scale to be less than the number of bits if signed or at "
          "most the number of bits if unsigned.");
@@ -5746,8 +5763,25 @@
   // are scaled. The result is given to us in 2 halves, so we only want part of
   // both in the result.
   EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
-  return DAG.getNode(ISD::FSHR, dl, VT, Hi, Lo,
-                     DAG.getConstant(Scale, dl, ShiftTy));
+  SDValue Result = DAG.getNode(ISD::FSHR, dl, VT, Hi, Lo,
+                               DAG.getConstant(Scale, dl, ShiftTy));
+  if (!Saturating)
+    return Result;
+
+  unsigned OverflowBits = VTSize - Scale + 1; // +1 for the sign
+  SDValue HiMask =
+      DAG.getConstant(APInt::getHighBitsSet(VTSize, OverflowBits), dl, VT);
+  SDValue LoMask = DAG.getConstant(
+      APInt::getLowBitsSet(VTSize, VTSize - OverflowBits), dl, VT);
+  APInt MaxVal = APInt::getSignedMaxValue(VTSize);
+  APInt MinVal = APInt::getSignedMinValue(VTSize);
+
+  Result = DAG.getSelectCC(dl, Hi, LoMask,
+                           DAG.getConstant(MaxVal, dl, VT), Result,
+                           ISD::SETGT);
+  return DAG.getSelectCC(dl, Hi, HiMask,
+                         DAG.getConstant(MinVal, dl, VT), Result,
+                         ISD::SETLT);
 }
 
 void TargetLowering::expandUADDSUBO(