Make -Wabsolute-value C++-aware.

Warn on std::abs() with unsigned argument.
Suggest std::abs as replacement for the C absolute value functions.
Suggest C++ headers if the specific std::abs overload is not found.

llvm-svn: 206340
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index b824ec4..e2a69d9 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -3862,54 +3862,107 @@
 // If the replacement is valid, emit a note with replacement function.
 // Additionally, suggest including the proper header if not already included.
 static void emitReplacement(Sema &S, SourceLocation Loc, SourceRange Range,
-                            unsigned AbsKind) {
-  std::string AbsName = S.Context.BuiltinInfo.GetName(AbsKind);
-
-  // Look up absolute value function in TU scope.
-  DeclarationName DN(&S.Context.Idents.get(AbsName));
-  LookupResult R(S, DN, Loc, Sema::LookupAnyName);
-  R.suppressDiagnostics();
-  S.LookupName(R, S.TUScope);
-
-  // Skip notes if multiple results found in lookup.
-  if (!R.empty() && !R.isSingleResult())
-    return;
-
-  FunctionDecl *FD = 0;
-  bool FoundFunction = R.isSingleResult();
-  // When one result is found, see if it is the correct function.
-  if (R.isSingleResult()) {
-    FD = dyn_cast<FunctionDecl>(R.getFoundDecl());
-    if (!FD || FD->getBuiltinID() != AbsKind)
-      return;
-  }
-
-  // Look for local name conflict, prepend "::" as necessary.
-  R.clear();
-  S.LookupName(R, S.getCurScope());
-
-  if (!FoundFunction) {
-    if (!R.empty()) {
-      AbsName = "::" + AbsName;
+                            unsigned AbsKind, QualType ArgType) {
+  bool EmitHeaderHint = true;
+  const char *HeaderName = 0;
+  const char *FunctionName = 0;
+  if (S.getLangOpts().CPlusPlus && !ArgType->isAnyComplexType()) {
+    FunctionName = "std::abs";
+    if (ArgType->isIntegralOrEnumerationType()) {
+      HeaderName = "cstdlib";
+    } else if (ArgType->isRealFloatingType()) {
+      HeaderName = "cmath";
+    } else {
+      llvm_unreachable("Invalid Type");
     }
-  } else { // FoundFunction
-    if (R.isSingleResult()) {
-      if (R.getFoundDecl() != FD) {
-        AbsName = "::" + AbsName;
+
+    // Lookup all std::abs
+    if (NamespaceDecl *Std = S.getStdNamespace()) {
+      LookupResult R(S, &S.PP.getIdentifierTable().get("abs"), Loc,
+                     Sema::LookupAnyName);
+      R.suppressDiagnostics();
+      S.LookupQualifiedName(R, Std);
+
+      for (const auto *I : R) {
+        const FunctionDecl *FDecl = 0;
+        if (const UsingShadowDecl *UsingD = dyn_cast<UsingShadowDecl>(I)) {
+          FDecl = dyn_cast<FunctionDecl>(UsingD->getTargetDecl());
+        } else {
+          FDecl = dyn_cast<FunctionDecl>(I);
+        }
+        if (!FDecl)
+          continue;
+
+        // Found std::abs(), check that they are the right ones.
+        if (FDecl->getNumParams() != 1)
+          continue;
+
+        // Check that the parameter type can handle the argument.
+        QualType ParamType = FDecl->getParamDecl(0)->getType();
+        if (getAbsoluteValueKind(ArgType) == getAbsoluteValueKind(ParamType) &&
+            S.Context.getTypeSize(ArgType) <=
+                S.Context.getTypeSize(ParamType)) {
+          // Found a function, don't need the header hint.
+          EmitHeaderHint = false;
+          break;
+        }
       }
-    } else if (!R.empty()) {
-      AbsName = "::" + AbsName;
+    }
+  } else {
+    FunctionName = S.Context.BuiltinInfo.GetName(AbsKind);
+    HeaderName = S.Context.BuiltinInfo.getHeaderName(AbsKind);
+
+    if (HeaderName) {
+      DeclarationName DN(&S.Context.Idents.get(FunctionName));
+      LookupResult R(S, DN, Loc, Sema::LookupAnyName);
+      R.suppressDiagnostics();
+      S.LookupName(R, S.getCurScope());
+
+      if (R.isSingleResult()) {
+        FunctionDecl *FD = dyn_cast<FunctionDecl>(R.getFoundDecl());
+        if (FD && FD->getBuiltinID() == AbsKind) {
+          EmitHeaderHint = false;
+        } else {
+          return;
+        }
+      } else if (!R.empty()) {
+        return;
+      }
     }
   }
 
   S.Diag(Loc, diag::note_replace_abs_function)
-      << AbsName << FixItHint::CreateReplacement(Range, AbsName);
+      << FunctionName << FixItHint::CreateReplacement(Range, FunctionName);
 
-  if (!FoundFunction) {
-    S.Diag(Loc, diag::note_please_include_header)
-        << S.Context.BuiltinInfo.getHeaderName(AbsKind)
-        << S.Context.BuiltinInfo.GetName(AbsKind);
+  if (!HeaderName)
+    return;
+
+  if (!EmitHeaderHint)
+    return;
+
+  S.Diag(Loc, diag::note_please_include_header) << HeaderName << FunctionName;
+}
+
+static bool IsFunctionStdAbs(const FunctionDecl *FDecl) {
+  if (!FDecl)
+    return false;
+
+  if (!FDecl->getIdentifier() || !FDecl->getIdentifier()->isStr("abs"))
+    return false;
+
+  const NamespaceDecl *ND = dyn_cast<NamespaceDecl>(FDecl->getDeclContext());
+
+  while (ND && ND->isInlineNamespace()) {
+    ND = dyn_cast<NamespaceDecl>(ND->getDeclContext());
   }
+
+  if (!ND || !ND->getIdentifier() || !ND->getIdentifier()->isStr("std"))
+    return false;
+
+  if (!isa<TranslationUnitDecl>(ND->getDeclContext()))
+    return false;
+
+  return true;
 }
 
 // Warn when using the wrong abs() function.
@@ -3920,7 +3973,8 @@
     return;
 
   unsigned AbsKind = getAbsoluteValueFunctionKind(FDecl);
-  if (AbsKind == 0)
+  bool IsStdAbs = IsFunctionStdAbs(FDecl);
+  if (AbsKind == 0 && !IsStdAbs)
     return;
 
   QualType ArgType = Call->getArg(0)->IgnoreParenImpCasts()->getType();
@@ -3929,13 +3983,20 @@
   // Unsigned types can not be negative.  Suggest to drop the absolute value
   // function.
   if (ArgType->isUnsignedIntegerType()) {
+    const char *FunctionName =
+        IsStdAbs ? "std::abs" : Context.BuiltinInfo.GetName(AbsKind);
     Diag(Call->getExprLoc(), diag::warn_unsigned_abs) << ArgType << ParamType;
     Diag(Call->getExprLoc(), diag::note_remove_abs)
-        << FDecl
+        << FunctionName
         << FixItHint::CreateRemoval(Call->getCallee()->getSourceRange());
     return;
   }
 
+  // std::abs has overloads which prevent most of the absolute value problems
+  // from occurring.
+  if (IsStdAbs)
+    return;
+
   AbsoluteValueKind ArgValueKind = getAbsoluteValueKind(ArgType);
   AbsoluteValueKind ParamValueKind = getAbsoluteValueKind(ParamType);
 
@@ -3953,7 +4014,7 @@
       return;
 
     emitReplacement(*this, Call->getExprLoc(),
-                    Call->getCallee()->getSourceRange(), NewAbsKind);
+                    Call->getCallee()->getSourceRange(), NewAbsKind, ArgType);
     return;
   }
 
@@ -3969,7 +4030,7 @@
       << FDecl << ParamValueKind << ArgValueKind;
 
   emitReplacement(*this, Call->getExprLoc(),
-                  Call->getCallee()->getSourceRange(), NewAbsKind);
+                  Call->getCallee()->getSourceRange(), NewAbsKind, ArgType);
   return;
 }