[Sema] Add warning when comparing nonnull and null

Currently, we emit warnings in some cases where nonnull function
parameters are compared against null. This patch extends this support
to warn when comparing the result of `returns_nonnull` functions
against null.

More specifically, we will now warn cases like:

int *foo() __attribute__((returns_nonnull));
int main() {
  if (foo() == NULL) {} // warning: will always evaluate to false
}

Differential Revision: http://reviews.llvm.org/D15324

llvm-svn: 255058
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 9d8a3f5..d5c80c9 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -1180,8 +1180,7 @@
 /// Checks if a the given expression evaluates to null.
 ///
 /// \brief Returns true if the value evaluates to null.
-static bool CheckNonNullExpr(Sema &S,
-                             const Expr *Expr) {
+static bool CheckNonNullExpr(Sema &S, const Expr *Expr) {
   // If the expression has non-null type, it doesn't evaluate to null.
   if (auto nullability
         = Expr->IgnoreImplicit()->getType()->getNullability(S.Context)) {
@@ -7666,6 +7665,26 @@
     }
   }
 
+  auto ComplainAboutNonnullParamOrCall = [&](bool IsParam) {
+    std::string Str;
+    llvm::raw_string_ostream S(Str);
+    E->printPretty(S, nullptr, getPrintingPolicy());
+    unsigned DiagID = IsCompare ? diag::warn_nonnull_expr_compare
+                                : diag::warn_cast_nonnull_to_bool;
+    Diag(E->getExprLoc(), DiagID) << IsParam << S.str()
+      << E->getSourceRange() << Range << IsEqual;
+  };
+
+  // If we have a CallExpr that is tagged with returns_nonnull, we can complain.
+  if (auto *Call = dyn_cast<CallExpr>(E->IgnoreParenImpCasts())) {
+    if (auto *Callee = Call->getDirectCallee()) {
+      if (Callee->hasAttr<ReturnsNonNullAttr>()) {
+        ComplainAboutNonnullParamOrCall(false);
+        return;
+      }
+    }
+  }
+
   // Expect to find a single Decl.  Skip anything more complicated.
   ValueDecl *D = nullptr;
   if (DeclRefExpr *R = dyn_cast<DeclRefExpr>(E)) {
@@ -7677,40 +7696,38 @@
   // Weak Decls can be null.
   if (!D || D->isWeak())
     return;
-  
+
   // Check for parameter decl with nonnull attribute
-  if (const ParmVarDecl* PV = dyn_cast<ParmVarDecl>(D)) {
-    if (getCurFunction() && !getCurFunction()->ModifiedNonNullParams.count(PV))
-      if (const FunctionDecl* FD = dyn_cast<FunctionDecl>(PV->getDeclContext())) {
-        unsigned NumArgs = FD->getNumParams();
-        llvm::SmallBitVector AttrNonNull(NumArgs);
+  if (const auto* PV = dyn_cast<ParmVarDecl>(D)) {
+    if (getCurFunction() &&
+        !getCurFunction()->ModifiedNonNullParams.count(PV)) {
+      if (PV->hasAttr<NonNullAttr>()) {
+        ComplainAboutNonnullParamOrCall(true);
+        return;
+      }
+
+      if (const auto *FD = dyn_cast<FunctionDecl>(PV->getDeclContext())) {
+        auto ParamIter = std::find(FD->param_begin(), FD->param_end(), PV);
+        assert(ParamIter != FD->param_end());
+        unsigned ParamNo = std::distance(FD->param_begin(), ParamIter);
+
         for (const auto *NonNull : FD->specific_attrs<NonNullAttr>()) {
           if (!NonNull->args_size()) {
-            AttrNonNull.set(0, NumArgs);
-            break;
+              ComplainAboutNonnullParamOrCall(true);
+              return;
           }
-          for (unsigned Val : NonNull->args()) {
-            if (Val >= NumArgs)
-              continue;
-            AttrNonNull.set(Val);
-          }
-        }
-        if (!AttrNonNull.empty())
-          for (unsigned i = 0; i < NumArgs; ++i)
-            if (FD->getParamDecl(i) == PV &&
-                (AttrNonNull[i] || PV->hasAttr<NonNullAttr>())) {
-              std::string Str;
-              llvm::raw_string_ostream S(Str);
-              E->printPretty(S, nullptr, getPrintingPolicy());
-              unsigned DiagID = IsCompare ? diag::warn_nonnull_parameter_compare
-                                          : diag::warn_cast_nonnull_to_bool;
-              Diag(E->getExprLoc(), DiagID) << S.str() << E->getSourceRange()
-                << Range << IsEqual;
+
+          for (unsigned ArgNo : NonNull->args()) {
+            if (ArgNo == ParamNo) {
+              ComplainAboutNonnullParamOrCall(true);
               return;
             }
+          }
+        }
       }
     }
-  
+  }
+
   QualType T = D->getType();
   const bool IsArray = T->isArrayType();
   const bool IsFunction = T->isFunctionType();