[InstCombine] snprintf optimizations

Reviewers: spatel, efriedma, majnemer, rja, bkramer

Reviewed By: rja, bkramer

Subscribers: mstorsjo, rja, llvm-commits

Differential Revision: https://reviews.llvm.org/D46285

llvm-svn: 332110
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index ce804ae..6e351d4 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -1912,6 +1912,94 @@
   return nullptr;
 }
 
+Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, IRBuilder<> &B) {
+  // Check for a fixed format string.
+  StringRef FormatStr;
+  if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr))
+    return nullptr;
+
+  // Check for size
+  ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+  if (!Size)
+    return nullptr;
+
+  uint64_t N = Size->getZExtValue();
+
+  // If we just have a format string (nothing else crazy) transform it.
+  if (CI->getNumArgOperands() == 3) {
+    // Make sure there's no % in the constant array.  We could try to handle
+    // %% -> % in the future if we cared.
+    for (unsigned i = 0, e = FormatStr.size(); i != e; ++i)
+      if (FormatStr[i] == '%')
+        return nullptr; // we found a format specifier, bail out.
+
+    if (N == 0)
+      return ConstantInt::get(CI->getType(), FormatStr.size());
+    else if (N < FormatStr.size() + 1)
+      return nullptr;
+
+    // sprintf(str, size, fmt) -> llvm.memcpy(align 1 str, align 1 fmt,
+    // strlen(fmt)+1)
+    B.CreateMemCpy(
+        CI->getArgOperand(0), 1, CI->getArgOperand(2), 1,
+        ConstantInt::get(DL.getIntPtrType(CI->getContext()),
+                         FormatStr.size() + 1)); // Copy the null byte.
+    return ConstantInt::get(CI->getType(), FormatStr.size());
+  }
+
+  // The remaining optimizations require the format string to be "%s" or "%c"
+  // and have an extra operand.
+  if (FormatStr.size() == 2 && FormatStr[0] == '%' &&
+      CI->getNumArgOperands() == 4) {
+
+    // Decode the second character of the format string.
+    if (FormatStr[1] == 'c') {
+      if (N == 0)
+        return ConstantInt::get(CI->getType(), 1);
+      else if (N == 1)
+        return nullptr;
+
+      // snprintf(dst, size, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0
+      if (!CI->getArgOperand(3)->getType()->isIntegerTy())
+        return nullptr;
+      Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char");
+      Value *Ptr = castToCStr(CI->getArgOperand(0), B);
+      B.CreateStore(V, Ptr);
+      Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul");
+      B.CreateStore(B.getInt8(0), Ptr);
+
+      return ConstantInt::get(CI->getType(), 1);
+    }
+
+    if (FormatStr[1] == 's') {
+      // snprintf(dest, size, "%s", str) to llvm.memcpy(dest, str, len+1, 1)
+      StringRef Str;
+      if (!getConstantStringInfo(CI->getArgOperand(3), Str))
+        return nullptr;
+
+      if (N == 0)
+        return ConstantInt::get(CI->getType(), Str.size());
+      else if (N < Str.size() + 1)
+        return nullptr;
+
+      B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(3), 1,
+                     ConstantInt::get(CI->getType(), Str.size() + 1));
+
+      // The snprintf result is the unincremented number of bytes in the string.
+      return ConstantInt::get(CI->getType(), Str.size());
+    }
+  }
+  return nullptr;
+}
+
+Value *LibCallSimplifier::optimizeSnPrintF(CallInst *CI, IRBuilder<> &B) {
+  if (Value *V = optimizeSnPrintFString(CI, B)) {
+    return V;
+  }
+
+  return nullptr;
+}
+
 Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) {
   optimizeErrorReporting(CI, B, 0);
 
@@ -2318,6 +2406,8 @@
       return optimizePrintF(CI, Builder);
     case LibFunc_sprintf:
       return optimizeSPrintF(CI, Builder);
+    case LibFunc_snprintf:
+      return optimizeSnPrintF(CI, Builder);
     case LibFunc_fprintf:
       return optimizeFPrintF(CI, Builder);
     case LibFunc_fwrite: