Add constrained intrinsics for some libm-equivalent operations

Differential revision: https://reviews.llvm.org/D32319

llvm-svn: 303922
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 9a47a91..d0a8b34 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -899,6 +899,39 @@
   }
 }
 
+static TargetLowering::LegalizeAction
+getStrictFPOpcodeAction(const TargetLowering &TLI, unsigned Opcode, EVT VT) {
+  unsigned EqOpc;
+  switch (Opcode) {
+    default: llvm_unreachable("Unexpected FP pseudo-opcode");
+    case ISD::STRICT_FSQRT: EqOpc = ISD::FSQRT; break;
+    case ISD::STRICT_FPOW: EqOpc = ISD::FPOW; break;
+    case ISD::STRICT_FPOWI: EqOpc = ISD::FPOWI; break;
+    case ISD::STRICT_FSIN: EqOpc = ISD::FSIN; break;
+    case ISD::STRICT_FCOS: EqOpc = ISD::FCOS; break;
+    case ISD::STRICT_FEXP: EqOpc = ISD::FEXP; break;
+    case ISD::STRICT_FEXP2: EqOpc = ISD::FEXP2; break;
+    case ISD::STRICT_FLOG: EqOpc = ISD::FLOG; break;
+    case ISD::STRICT_FLOG10: EqOpc = ISD::FLOG10; break;
+    case ISD::STRICT_FLOG2: EqOpc = ISD::FLOG2; break;
+    case ISD::STRICT_FRINT: EqOpc = ISD::FRINT; break;
+    case ISD::STRICT_FNEARBYINT: EqOpc = ISD::FNEARBYINT; break;
+  }
+
+  auto Action = TLI.getOperationAction(EqOpc, VT);
+
+  // We don't currently handle Custom or Promote for strict FP pseudo-ops.
+  // For now, we just expand for those cases.
+  if (Action != TargetLowering::Legal)
+    Action = TargetLowering::Expand;
+
+  // ISD::FPOWI returns 'Legal' even though it should be expanded.
+  if (Opcode == ISD::STRICT_FPOWI && Action == TargetLowering::Legal)
+    Action = TargetLowering::Expand;
+
+  return Action;
+}
+
 /// Return a legal replacement for the given operation, with all legal operands.
 void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
   DEBUG(dbgs() << "\nLegalizing: "; Node->dump(&DAG));
@@ -1043,6 +1076,25 @@
       return;
     }
     break;
+  case ISD::STRICT_FSQRT:
+  case ISD::STRICT_FPOW:
+  case ISD::STRICT_FPOWI:
+  case ISD::STRICT_FSIN:
+  case ISD::STRICT_FCOS:
+  case ISD::STRICT_FEXP:
+  case ISD::STRICT_FEXP2:
+  case ISD::STRICT_FLOG:
+  case ISD::STRICT_FLOG10:
+  case ISD::STRICT_FLOG2:
+  case ISD::STRICT_FRINT:
+  case ISD::STRICT_FNEARBYINT:
+    // These pseudo-ops get legalized as if they were their non-strict
+    // equivalent.  For instance, if ISD::FSQRT is legal then ISD::STRICT_FSQRT
+    // is also legal, but if ISD::FSQRT requires expansion then so does
+    // ISD::STRICT_FSQRT.
+    Action = getStrictFPOpcodeAction(TLI, Node->getOpcode(),
+                                     Node->getValueType(0));
+    break;
 
   default:
     if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
@@ -2032,6 +2084,9 @@
                                               RTLIB::Libcall Call_F80,
                                               RTLIB::Libcall Call_F128,
                                               RTLIB::Libcall Call_PPCF128) {
+  if (Node->isStrictFPOpcode())
+    Node = DAG.mutateStrictFPToFP(Node);
+
   RTLIB::Libcall LC;
   switch (Node->getSimpleValueType(0).SimpleTy) {
   default: llvm_unreachable("Unexpected request for libcall!");
@@ -3907,16 +3962,19 @@
                                       RTLIB::FMAX_PPCF128));
     break;
   case ISD::FSQRT:
+  case ISD::STRICT_FSQRT:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::SQRT_F32, RTLIB::SQRT_F64,
                                       RTLIB::SQRT_F80, RTLIB::SQRT_F128,
                                       RTLIB::SQRT_PPCF128));
     break;
   case ISD::FSIN:
+  case ISD::STRICT_FSIN:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::SIN_F32, RTLIB::SIN_F64,
                                       RTLIB::SIN_F80, RTLIB::SIN_F128,
                                       RTLIB::SIN_PPCF128));
     break;
   case ISD::FCOS:
+  case ISD::STRICT_FCOS:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::COS_F32, RTLIB::COS_F64,
                                       RTLIB::COS_F80, RTLIB::COS_F128,
                                       RTLIB::COS_PPCF128));
@@ -3926,26 +3984,31 @@
     ExpandSinCosLibCall(Node, Results);
     break;
   case ISD::FLOG:
+  case ISD::STRICT_FLOG:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::LOG_F32, RTLIB::LOG_F64,
                                       RTLIB::LOG_F80, RTLIB::LOG_F128,
                                       RTLIB::LOG_PPCF128));
     break;
   case ISD::FLOG2:
+  case ISD::STRICT_FLOG2:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::LOG2_F32, RTLIB::LOG2_F64,
                                       RTLIB::LOG2_F80, RTLIB::LOG2_F128,
                                       RTLIB::LOG2_PPCF128));
     break;
   case ISD::FLOG10:
+  case ISD::STRICT_FLOG10:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::LOG10_F32, RTLIB::LOG10_F64,
                                       RTLIB::LOG10_F80, RTLIB::LOG10_F128,
                                       RTLIB::LOG10_PPCF128));
     break;
   case ISD::FEXP:
+  case ISD::STRICT_FEXP:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::EXP_F32, RTLIB::EXP_F64,
                                       RTLIB::EXP_F80, RTLIB::EXP_F128,
                                       RTLIB::EXP_PPCF128));
     break;
   case ISD::FEXP2:
+  case ISD::STRICT_FEXP2:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::EXP2_F32, RTLIB::EXP2_F64,
                                       RTLIB::EXP2_F80, RTLIB::EXP2_F128,
                                       RTLIB::EXP2_PPCF128));
@@ -3966,11 +4029,13 @@
                                       RTLIB::CEIL_PPCF128));
     break;
   case ISD::FRINT:
+  case ISD::STRICT_FRINT:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::RINT_F32, RTLIB::RINT_F64,
                                       RTLIB::RINT_F80, RTLIB::RINT_F128,
                                       RTLIB::RINT_PPCF128));
     break;
   case ISD::FNEARBYINT:
+  case ISD::STRICT_FNEARBYINT:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::NEARBYINT_F32,
                                       RTLIB::NEARBYINT_F64,
                                       RTLIB::NEARBYINT_F80,
@@ -3985,11 +4050,13 @@
                                       RTLIB::ROUND_PPCF128));
     break;
   case ISD::FPOWI:
+  case ISD::STRICT_FPOWI:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::POWI_F32, RTLIB::POWI_F64,
                                       RTLIB::POWI_F80, RTLIB::POWI_F128,
                                       RTLIB::POWI_PPCF128));
     break;
   case ISD::FPOW:
+  case ISD::STRICT_FPOW:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::POW_F32, RTLIB::POW_F64,
                                       RTLIB::POW_F80, RTLIB::POW_F128,
                                       RTLIB::POW_PPCF128));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 16c1f78..b26f1ce 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6542,6 +6542,63 @@
   return N;
 }
 
+SDNode* SelectionDAG::mutateStrictFPToFP(SDNode *Node) {
+  unsigned OrigOpc = Node->getOpcode();
+  unsigned NewOpc;
+  bool IsUnary = false;
+  switch (OrigOpc) {
+  default: 
+    llvm_unreachable("mutateStrictFPToFP called with unexpected opcode!");
+  case ISD::STRICT_FADD: NewOpc = ISD::FADD; break;
+  case ISD::STRICT_FSUB: NewOpc = ISD::FSUB; break;
+  case ISD::STRICT_FMUL: NewOpc = ISD::FMUL; break;
+  case ISD::STRICT_FDIV: NewOpc = ISD::FDIV; break;
+  case ISD::STRICT_FREM: NewOpc = ISD::FREM; break;
+  case ISD::STRICT_FSQRT: NewOpc = ISD::FSQRT; IsUnary = true; break;
+  case ISD::STRICT_FPOW: NewOpc = ISD::FPOW; break;
+  case ISD::STRICT_FPOWI: NewOpc = ISD::FPOWI; break;
+  case ISD::STRICT_FSIN: NewOpc = ISD::FSIN; IsUnary = true; break;
+  case ISD::STRICT_FCOS: NewOpc = ISD::FCOS; IsUnary = true; break;
+  case ISD::STRICT_FEXP: NewOpc = ISD::FEXP; IsUnary = true; break;
+  case ISD::STRICT_FEXP2: NewOpc = ISD::FEXP2; IsUnary = true; break;
+  case ISD::STRICT_FLOG: NewOpc = ISD::FLOG; IsUnary = true; break;
+  case ISD::STRICT_FLOG10: NewOpc = ISD::FLOG10; IsUnary = true; break;
+  case ISD::STRICT_FLOG2: NewOpc = ISD::FLOG2; IsUnary = true; break;
+  case ISD::STRICT_FRINT: NewOpc = ISD::FRINT; IsUnary = true; break;
+  case ISD::STRICT_FNEARBYINT:
+    NewOpc = ISD::FNEARBYINT;
+    IsUnary = true;
+    break;
+  }
+
+  // We're taking this node out of the chain, so we need to re-link things.
+  SDValue InputChain = Node->getOperand(0);
+  SDValue OutputChain = SDValue(Node, 1);
+  ReplaceAllUsesOfValueWith(OutputChain, InputChain);
+
+  SDVTList VTs = getVTList(Node->getOperand(1).getValueType());
+  SDNode *Res = nullptr;
+  if (IsUnary)
+    Res = MorphNodeTo(Node, NewOpc, VTs, { Node->getOperand(1) });
+  else
+    Res = MorphNodeTo(Node, NewOpc, VTs, { Node->getOperand(1),
+                                           Node->getOperand(2) });
+  
+  // MorphNodeTo can operate in two ways: if an existing node with the
+  // specified operands exists, it can just return it.  Otherwise, it
+  // updates the node in place to have the requested operands.
+  if (Res == Node) {
+    // If we updated the node in place, reset the node ID.  To the isel,
+    // this should be just like a newly allocated machine node.
+    Res->setNodeId(-1);
+  } else {
+    ReplaceAllUsesWith(Node, Res);
+    RemoveDeadNode(Node);
+  }
+
+  return Res; 
+}
+
 
 /// getMachineNode - These are used for target selectors to create a new node
 /// with specified return type(s), MachineInstr opcode, and operands.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1ed14de..b895da2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5245,7 +5245,19 @@
   case Intrinsic::experimental_constrained_fmul:
   case Intrinsic::experimental_constrained_fdiv:
   case Intrinsic::experimental_constrained_frem:
-    visitConstrainedFPIntrinsic(I, Intrinsic);
+  case Intrinsic::experimental_constrained_sqrt:
+  case Intrinsic::experimental_constrained_pow:
+  case Intrinsic::experimental_constrained_powi:
+  case Intrinsic::experimental_constrained_sin:
+  case Intrinsic::experimental_constrained_cos:
+  case Intrinsic::experimental_constrained_exp:
+  case Intrinsic::experimental_constrained_exp2:
+  case Intrinsic::experimental_constrained_log:
+  case Intrinsic::experimental_constrained_log10:
+  case Intrinsic::experimental_constrained_log2:
+  case Intrinsic::experimental_constrained_rint:
+  case Intrinsic::experimental_constrained_nearbyint:
+    visitConstrainedFPIntrinsic(cast<ConstrainedFPIntrinsic>(I));
     return nullptr;
   case Intrinsic::fmuladd: {
     EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType());
@@ -5743,11 +5755,11 @@
   }
 }
 
-void SelectionDAGBuilder::visitConstrainedFPIntrinsic(const CallInst &I,
-                                                      unsigned Intrinsic) {
+void SelectionDAGBuilder::visitConstrainedFPIntrinsic(
+    const ConstrainedFPIntrinsic &FPI) {
   SDLoc sdl = getCurSDLoc();
   unsigned Opcode;
-  switch (Intrinsic) {
+  switch (FPI.getIntrinsicID()) {
   default: llvm_unreachable("Impossible intrinsic");  // Can't reach here.
   case Intrinsic::experimental_constrained_fadd:
     Opcode = ISD::STRICT_FADD;
@@ -5764,23 +5776,64 @@
   case Intrinsic::experimental_constrained_frem:
     Opcode = ISD::STRICT_FREM;
     break;
+  case Intrinsic::experimental_constrained_sqrt:
+    Opcode = ISD::STRICT_FSQRT;
+    break;
+  case Intrinsic::experimental_constrained_pow:
+    Opcode = ISD::STRICT_FPOW;
+    break;
+  case Intrinsic::experimental_constrained_powi:
+    Opcode = ISD::STRICT_FPOWI;
+    break;
+  case Intrinsic::experimental_constrained_sin:
+    Opcode = ISD::STRICT_FSIN;
+    break;
+  case Intrinsic::experimental_constrained_cos:
+    Opcode = ISD::STRICT_FCOS;
+    break;
+  case Intrinsic::experimental_constrained_exp:
+    Opcode = ISD::STRICT_FEXP;
+    break;
+  case Intrinsic::experimental_constrained_exp2:
+    Opcode = ISD::STRICT_FEXP2;
+    break;
+  case Intrinsic::experimental_constrained_log:
+    Opcode = ISD::STRICT_FLOG;
+    break;
+  case Intrinsic::experimental_constrained_log10:
+    Opcode = ISD::STRICT_FLOG10;
+    break;
+  case Intrinsic::experimental_constrained_log2:
+    Opcode = ISD::STRICT_FLOG2;
+    break;
+  case Intrinsic::experimental_constrained_rint:
+    Opcode = ISD::STRICT_FRINT;
+    break;
+  case Intrinsic::experimental_constrained_nearbyint:
+    Opcode = ISD::STRICT_FNEARBYINT;
+    break;
   }
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   SDValue Chain = getRoot();
-  SDValue Ops[3] = { Chain, getValue(I.getArgOperand(0)),
-                     getValue(I.getArgOperand(1)) };
   SmallVector<EVT, 4> ValueVTs;
-  ComputeValueVTs(TLI, DAG.getDataLayout(), I.getType(), ValueVTs);
+  ComputeValueVTs(TLI, DAG.getDataLayout(), FPI.getType(), ValueVTs);
   ValueVTs.push_back(MVT::Other); // Out chain
 
   SDVTList VTs = DAG.getVTList(ValueVTs);
-  SDValue Result = DAG.getNode(Opcode, sdl, VTs, Ops);
+  SDValue Result;
+  if (FPI.isUnaryOp())
+    Result = DAG.getNode(Opcode, sdl, VTs, 
+                         { Chain, getValue(FPI.getArgOperand(0)) });
+  else
+    Result = DAG.getNode(Opcode, sdl, VTs, 
+                         { Chain, getValue(FPI.getArgOperand(0)),
+                           getValue(FPI.getArgOperand(1))  });
 
   assert(Result.getNode()->getNumValues() == 2);
   SDValue OutChain = Result.getValue(1);
   DAG.setRoot(OutChain);
   SDValue FPResult = Result.getValue(0);
-  setValue(&I, FPResult);
+  setValue(&FPI, FPResult);
 }
 
 std::pair<SDValue, SDValue>
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index bdaee85..77e131f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -895,7 +895,7 @@
   void visitInlineAsm(ImmutableCallSite CS);
   const char *visitIntrinsicCall(const CallInst &I, unsigned Intrinsic);
   void visitTargetIntrinsic(const CallInst &I, unsigned Intrinsic);
-  void visitConstrainedFPIntrinsic(const CallInst &I, unsigned Intrinsic);
+  void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
 
   void visitVAStart(const CallInst &I);
   void visitVAArg(const VAArgInst &I);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 5e0fecc..687b882 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -905,50 +905,6 @@
 
 } // end anonymous namespace
 
-static bool isStrictFPOp(SDNode *Node, unsigned &NewOpc) {
-  unsigned OrigOpc = Node->getOpcode();
-  switch (OrigOpc) {
-    case ISD::STRICT_FADD: NewOpc = ISD::FADD; return true;
-    case ISD::STRICT_FSUB: NewOpc = ISD::FSUB; return true;
-    case ISD::STRICT_FMUL: NewOpc = ISD::FMUL; return true;
-    case ISD::STRICT_FDIV: NewOpc = ISD::FDIV; return true;
-    case ISD::STRICT_FREM: NewOpc = ISD::FREM; return true;
-    default: return false;
-  }
-}
-
-SDNode* SelectionDAGISel::MutateStrictFPToFP(SDNode *Node, unsigned NewOpc) {
-  assert(((Node->getOpcode() == ISD::STRICT_FADD && NewOpc == ISD::FADD) ||
-          (Node->getOpcode() == ISD::STRICT_FSUB && NewOpc == ISD::FSUB) ||
-          (Node->getOpcode() == ISD::STRICT_FMUL && NewOpc == ISD::FMUL) ||
-          (Node->getOpcode() == ISD::STRICT_FDIV && NewOpc == ISD::FDIV) ||
-          (Node->getOpcode() == ISD::STRICT_FREM && NewOpc == ISD::FREM)) &&
-          "Unexpected StrictFP opcode!");
-
-  // We're taking this node out of the chain, so we need to re-link things.
-  SDValue InputChain = Node->getOperand(0);
-  SDValue OutputChain = SDValue(Node, 1);
-  CurDAG->ReplaceAllUsesOfValueWith(OutputChain, InputChain);
-
-  SDVTList VTs = CurDAG->getVTList(Node->getOperand(1).getValueType());
-  SDValue Ops[2] = { Node->getOperand(1), Node->getOperand(2) };
-  SDNode *Res = CurDAG->MorphNodeTo(Node, NewOpc, VTs, Ops);
-  
-  // MorphNodeTo can operate in two ways: if an existing node with the
-  // specified operands exists, it can just return it.  Otherwise, it
-  // updates the node in place to have the requested operands.
-  if (Res == Node) {
-    // If we updated the node in place, reset the node ID.  To the isel,
-    // this should be just like a newly allocated machine node.
-    Res->setNodeId(-1);
-  } else {
-    CurDAG->ReplaceAllUsesWith(Node, Res);
-    CurDAG->RemoveDeadNode(Node);
-  }
-
-  return Res; 
-}
-
 void SelectionDAGISel::DoInstructionSelection() {
   DEBUG(dbgs() << "===== Instruction selection begins: BB#"
         << FuncInfo->MBB->getNumber()
@@ -992,15 +948,12 @@
       // If the current node is a strict FP pseudo-op, the isStrictFPOp()
       // function will provide the corresponding normal FP opcode to which the
       // node should be mutated.
-      unsigned NormalFPOpc = ISD::UNDEF;
-      bool IsStrictFPOp = isStrictFPOp(Node, NormalFPOpc);
-      if (IsStrictFPOp)
-        Node = MutateStrictFPToFP(Node, NormalFPOpc);
+      //
+      // FIXME: The backends need a way to handle FP constraints.
+      if (Node->isStrictFPOpcode())
+        Node = CurDAG->mutateStrictFPToFP(Node);
 
       Select(Node);
-
-      // FIXME: Add code here to attach an implicit def and use of
-      // target-specific FP environment registers.
     }
 
     CurDAG->setRoot(Dummy.getValue());