[SimplifyLibCalls] Inline calls to cabs when it's safe to do so

When unsafe algerbra is allowed calls to cabs(r) can be replaced by:

  sqrt(creal(r)*creal(r) + cimag(r)*cimag(r))

Patch by Paul Walker, thanks!

Differential Revision: https://reviews.llvm.org/D40069

llvm-svn: 320901
diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp
index f8facf2..609b996 100644
--- a/llvm/lib/Analysis/TargetLibraryInfo.cpp
+++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp
@@ -182,6 +182,9 @@
     TLI.setUnavailable(LibFunc_atanh);
     TLI.setUnavailable(LibFunc_atanhf);
     TLI.setUnavailable(LibFunc_atanhl);
+    TLI.setUnavailable(LibFunc_cabs);
+    TLI.setUnavailable(LibFunc_cabsf);
+    TLI.setUnavailable(LibFunc_cabsl);
     TLI.setUnavailable(LibFunc_cbrt);
     TLI.setUnavailable(LibFunc_cbrtf);
     TLI.setUnavailable(LibFunc_cbrtl);
@@ -1267,6 +1270,25 @@
     return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() &&
             FTy.getReturnType()->isIntegerTy());
 
+  case LibFunc_cabs:
+  case LibFunc_cabsf:
+  case LibFunc_cabsl: {
+    Type* RetTy = FTy.getReturnType();
+    if (!RetTy->isFloatingPointTy())
+      return false;
+
+    // NOTE: These prototypes are target specific and currently support
+    // "complex" passed as an array or discrete real & imaginary parameters.
+    // Add other calling conventions to enable libcall optimizations.
+    if (NumParams == 1)
+      return (FTy.getParamType(0)->isArrayTy() &&
+              FTy.getParamType(0)->getArrayNumElements() == 2 &&
+              FTy.getParamType(0)->getArrayElementType() == RetTy);
+    else if (NumParams == 2)
+      return (FTy.getParamType(0) == RetTy && FTy.getParamType(1) == RetTy);
+    else
+      return false;
+  }
   case LibFunc::NumLibFuncs:
     break;
   }
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 60f6c60..03a1d55 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -1033,6 +1033,35 @@
   return B.CreateFPExt(V, B.getDoubleTy());
 }
 
+// cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z)))
+Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilder<> &B) {
+  if (!CI->isFast())
+    return nullptr;
+
+  // Propagate fast-math flags from the existing call to new instructions.
+  IRBuilder<>::FastMathFlagGuard Guard(B);
+  B.setFastMathFlags(CI->getFastMathFlags());
+
+  Value *Real, *Imag;
+  if (CI->getNumArgOperands() == 1) {
+    Value *Op = CI->getArgOperand(0);
+    assert(Op->getType()->isArrayTy() && "Unexpected signature for cabs!");
+    Real = B.CreateExtractValue(Op, 0, "real");
+    Imag = B.CreateExtractValue(Op, 1, "imag");
+  } else {
+    assert(CI->getNumArgOperands() == 2 && "Unexpected signature for cabs!");
+    Real = CI->getArgOperand(0);
+    Imag = CI->getArgOperand(1);
+  }
+
+  Value *RealReal = B.CreateFMul(Real, Real);
+  Value *ImagImag = B.CreateFMul(Imag, Imag);
+
+  Function *FSqrt = Intrinsic::getDeclaration(CI->getModule(), Intrinsic::sqrt,
+                                              CI->getType());
+  return B.CreateCall(FSqrt, B.CreateFAdd(RealReal, ImagImag), "cabs");
+}
+
 Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
@@ -2162,6 +2191,10 @@
   case LibFunc_fmax:
   case LibFunc_fmaxl:
     return optimizeFMinFMax(CI, Builder);
+  case LibFunc_cabs:
+  case LibFunc_cabsf:
+  case LibFunc_cabsl:
+    return optimizeCAbs(CI, Builder);
   default:
     return nullptr;
   }