[PartiallyInlineLibCalls][x86] add TTI hook to allow sqrt inlining to depend on arg rather than result

This should fix PR31455:
https://bugs.llvm.org/show_bug.cgi?id=31455

Differential Revision: https://reviews.llvm.org/D28314

llvm-svn: 319094
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 53bedfe..7feb40d 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -281,6 +281,10 @@
   return TTIImpl->haveFastSqrt(Ty);
 }
 
+bool TargetTransformInfo::isFCmpOrdCheaperThanFCmpZero(Type *Ty) const {
+  return TTIImpl->isFCmpOrdCheaperThanFCmpZero(Ty);
+}
+
 int TargetTransformInfo::getFPOpCost(Type *Ty) const {
   int Cost = TTIImpl->getFPOpCost(Ty);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index d06d6a5..9b07491 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -2537,6 +2537,10 @@
   return TLI->isOperationLegal(IsSigned ? ISD::SDIVREM : ISD::UDIVREM, VT);
 }
 
+bool X86TTIImpl::isFCmpOrdCheaperThanFCmpZero(Type *Ty) {
+  return false;
+}
+
 bool X86TTIImpl::areInlineCompatible(const Function *Caller,
                                      const Function *Callee) const {
   const TargetMachine &TM = getTLI()->getTargetMachine();
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 81b804e..6f01a6f 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -125,6 +125,7 @@
   bool isLegalMaskedGather(Type *DataType);
   bool isLegalMaskedScatter(Type *DataType);
   bool hasDivRemOp(Type *DataType, bool IsSigned);
+  bool isFCmpOrdCheaperThanFCmpZero(Type *Ty);
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const;
   const TTI::MemCmpExpansionOptions *enableMemCmpExpansion(
diff --git a/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
index a044fe3..1748815 100644
--- a/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
+++ b/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp
@@ -26,7 +26,8 @@
 
 
 static bool optimizeSQRT(CallInst *Call, Function *CalledFunc,
-                         BasicBlock &CurrBB, Function::iterator &BB) {
+                         BasicBlock &CurrBB, Function::iterator &BB,
+                         const TargetTransformInfo *TTI) {
   // There is no need to change the IR, since backend will emit sqrt
   // instruction if the call has already been marked read-only.
   if (Call->onlyReadsMemory())
@@ -39,7 +40,7 @@
   //
   // (after)
   // v0 = sqrt_noreadmem(src) # native sqrt instruction.
-  // if (v0 is a NaN)
+  // [if (v0 is a NaN) || if (src < 0)]
   //   v1 = sqrt(src)         # library call.
   // dst = phi(v0, v1)
   //
@@ -48,7 +49,8 @@
   // Create phi and replace all uses.
   BasicBlock *JoinBB = llvm::SplitBlock(&CurrBB, Call->getNextNode());
   IRBuilder<> Builder(JoinBB, JoinBB->begin());
-  PHINode *Phi = Builder.CreatePHI(Call->getType(), 2);
+  Type *Ty = Call->getType();
+  PHINode *Phi = Builder.CreatePHI(Ty, 2);
   Call->replaceAllUsesWith(Phi);
 
   // Create basic block LibCallBB and insert a call to library function sqrt.
@@ -65,7 +67,10 @@
   Call->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
   CurrBB.getTerminator()->eraseFromParent();
   Builder.SetInsertPoint(&CurrBB);
-  Value *FCmp = Builder.CreateFCmpOEQ(Call, Call);
+  Value *FCmp = TTI->isFCmpOrdCheaperThanFCmpZero(Ty)
+                    ? Builder.CreateFCmpORD(Call, Call)
+                    : Builder.CreateFCmpOGE(Call->getOperand(0),
+                                            ConstantFP::get(Ty, 0.0));
   Builder.CreateCondBr(FCmp, JoinBB, LibCallBB);
 
   // Add phi operands.
@@ -106,7 +111,7 @@
       case LibFunc_sqrtf:
       case LibFunc_sqrt:
         if (TTI->haveFastSqrt(Call->getType()) &&
-            optimizeSQRT(Call, CalledFunc, *CurrBB, BB))
+            optimizeSQRT(Call, CalledFunc, *CurrBB, BB, TTI))
           break;
         continue;
       default: