[SimplifyLibCalls] powf(x, sitofp(n)) -> powi(x, n)

Summary:
Partially solves https://bugs.llvm.org/show_bug.cgi?id=42190



Reviewers: spatel, nikic, efriedma

Reviewed By: efriedma

Subscribers: efriedma, nikic, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D63038

llvm-svn: 364940
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index b5f8b39..79b1b86 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -1322,12 +1322,12 @@
     APFloat BaseR = APFloat(1.0);
     BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored);
     BaseR = BaseR / *BaseF;
-    bool IsInteger    = BaseF->isInteger(),
-         IsReciprocal = BaseR.isInteger();
+    bool IsInteger = BaseF->isInteger(), IsReciprocal = BaseR.isInteger();
     const APFloat *NF = IsReciprocal ? &BaseR : BaseF;
     APSInt NI(64, false);
     if ((IsInteger || IsReciprocal) &&
-        !NF->convertToInteger(NI, APFloat::rmTowardZero, &Ignored) &&
+        NF->convertToInteger(NI, APFloat::rmTowardZero, &Ignored) ==
+            APFloat::opOK &&
         NI > 1 && NI.isPowerOf2()) {
       double N = NI.logBase2() * (IsReciprocal ? -1.0 : 1.0);
       Value *FMul = B.CreateFMul(Expo, ConstantFP::get(Ty, N), "mul");
@@ -1410,12 +1410,22 @@
   return Sqrt;
 }
 
+static Value *createPowWithIntegerExponent(Value *Base, Value *Expo, Module *M,
+                                           IRBuilder<> &B) {
+  Value *Args[] = {Base, Expo};
+  Function *F = Intrinsic::getDeclaration(M, Intrinsic::powi, Base->getType());
+  return B.CreateCall(F, Args);
+}
+
 Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) {
-  Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1);
+  Value *Base = Pow->getArgOperand(0);
+  Value *Expo = Pow->getArgOperand(1);
   Function *Callee = Pow->getCalledFunction();
   StringRef Name = Callee->getName();
   Type *Ty = Pow->getType();
+  Module *M = Pow->getModule();
   Value *Shrunk = nullptr;
+  bool AllowApprox = Pow->hasApproxFunc();
   bool Ignored;
 
   // Bail out if simplifying libcalls to pow() is disabled.
@@ -1428,8 +1438,8 @@
 
   // Shrink pow() to powf() if the arguments are single precision,
   // unless the result is expected to be double precision.
-  if (UnsafeFPShrink &&
-      Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name))
+  if (UnsafeFPShrink && Name == TLI->getName(LibFunc_pow) &&
+      hasFloatVersion(Name))
     Shrunk = optimizeBinaryDoubleFP(Pow, B, true);
 
   // Evaluate special cases related to the base.
@@ -1438,6 +1448,21 @@
   if (match(Base, m_FPOne()))
     return Base;
 
+  // powf(x, sitofp(e)) -> powi(x, e)
+  // powf(x, uitofp(e)) -> powi(x, e)
+  if (AllowApprox && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo))) {
+    Value *IntExpo = cast<Instruction>(Expo)->getOperand(0);
+    Value *NewExpo = nullptr;
+    unsigned BitWidth = IntExpo->getType()->getPrimitiveSizeInBits();
+    if (isa<SIToFPInst>(Expo) && BitWidth == 32)
+      NewExpo = IntExpo;
+    else if (BitWidth < 32)
+      NewExpo = isa<SIToFPInst>(Expo) ? B.CreateSExt(IntExpo, B.getInt32Ty())
+                                      : B.CreateZExt(IntExpo, B.getInt32Ty());
+    if (NewExpo)
+      return createPowWithIntegerExponent(Base, NewExpo, M, B);
+  }
+
   if (Value *Exp = replacePowWithExp(Pow, B))
     return Exp;
 
@@ -1449,7 +1474,7 @@
 
   // pow(x, 0.0) -> 1.0
   if (match(Expo, m_SpecificFP(0.0)))
-      return ConstantFP::get(Ty, 1.0);
+    return ConstantFP::get(Ty, 1.0);
 
   // pow(x, 1.0) -> x
   if (match(Expo, m_FPOne()))
@@ -1462,9 +1487,12 @@
   if (Value *Sqrt = replacePowWithSqrt(Pow, B))
     return Sqrt;
 
+  if (!AllowApprox)
+    return Shrunk;
+
   // pow(x, n) -> x * x * x * ...
   const APFloat *ExpoF;
-  if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) {
+  if (match(Expo, m_APFloat(ExpoF))) {
     // We limit to a max of 7 multiplications, thus the maximum exponent is 32.
     // If the exponent is an integer+0.5 we generate a call to sqrt and an
     // additional fmul.
@@ -1488,9 +1516,8 @@
         if (!Expo2.isInteger())
           return nullptr;
 
-        Sqrt =
-            getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(),
-                        Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI);
+        Sqrt = getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(),
+                           Pow->doesNotAccessMemory(), M, B, TLI);
       }
 
       // We will memoize intermediate products of the Addition Chain.
@@ -1513,6 +1540,14 @@
 
       return FMul;
     }
+
+    APSInt IntExpo(32, /*isUnsigned=*/false);
+    // powf(x, C) -> powi(x, C) iff C is a constant signed integer value
+    if (ExpoF->convertToInteger(IntExpo, APFloat::rmTowardZero, &Ignored) ==
+        APFloat::opOK) {
+      return createPowWithIntegerExponent(
+          Base, ConstantInt::get(B.getInt32Ty(), IntExpo), M, B);
+    }
   }
 
   return Shrunk;
@@ -3101,4 +3136,4 @@
 
 FortifiedLibCallSimplifier::FortifiedLibCallSimplifier(
     const TargetLibraryInfo *TLI, bool OnlyLowerUnknownSize)
-    : TLI(TLI), OnlyLowerUnknownSize(OnlyLowerUnknownSize) {}
+    : TLI(TLI), OnlyLowerUnknownSize(OnlyLowerUnknownSize) {}
\ No newline at end of file