Diagnose unsafe uses of nil and __nonnull pointers.

This generalizes the checking of null arguments to also work with
values of pointer-to-function, reference-to-function, and block
pointer type, using the nullability information within the underling
function prototype to extend non-null checking, and diagnoses returns
of 'nil' within a function with a __nonnull return type.

Note that we don't warn about nil returns from Objective-C methods,
because it's common for Objective-C methods to mimic the nil-swallowing
behavior of the receiver by checking ostensibly non-null parameters
and returning nil from otherwise non-null methods in that
case.

It also diagnoses (via a separate flag) conversions from nullable to
nonnull pointers. It's a separate flag because this warning can be noisy.

llvm-svn: 240153
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 2014052..f76727c 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -1115,6 +1115,13 @@
 /// \brief Returns true if the value evaluates to null.
 static bool CheckNonNullExpr(Sema &S,
                              const Expr *Expr) {
+  // If the expression has non-null type, it doesn't evaluate to null.
+  if (auto nullability
+        = Expr->IgnoreImplicit()->getType()->getNullability(S.Context)) {
+    if (*nullability == NullabilityKind::NonNull)
+      return false;
+  }
+
   // As a special case, transparent unions initialized with zero are
   // considered null for the purposes of the nonnull attribute.
   if (const RecordType *UT = Expr->getType()->getAsUnionType()) {
@@ -1190,56 +1197,111 @@
   }
 }
 
+/// Determine whether the given type has a non-null nullability annotation.
+static bool isNonNullType(ASTContext &ctx, QualType type) {
+  if (auto nullability = type->getNullability(ctx))
+    return *nullability == NullabilityKind::NonNull;
+     
+  return false;
+}
+
 static void CheckNonNullArguments(Sema &S,
                                   const NamedDecl *FDecl,
+                                  const FunctionProtoType *Proto,
                                   ArrayRef<const Expr *> Args,
                                   SourceLocation CallSiteLoc) {
+  assert((FDecl || Proto) && "Need a function declaration or prototype");
+
   // Check the attributes attached to the method/function itself.
   llvm::SmallBitVector NonNullArgs;
-  for (const auto *NonNull : FDecl->specific_attrs<NonNullAttr>()) {
-    if (!NonNull->args_size()) {
-      // Easy case: all pointer arguments are nonnull.
-      for (const auto *Arg : Args)
-        if (S.isValidPointerAttrType(Arg->getType()))
-          CheckNonNullArgument(S, Arg, CallSiteLoc);
-      return;
-    }
+  if (FDecl) {
+    // Handle the nonnull attribute on the function/method declaration itself.
+    for (const auto *NonNull : FDecl->specific_attrs<NonNullAttr>()) {
+      if (!NonNull->args_size()) {
+        // Easy case: all pointer arguments are nonnull.
+        for (const auto *Arg : Args)
+          if (S.isValidPointerAttrType(Arg->getType()))
+            CheckNonNullArgument(S, Arg, CallSiteLoc);
+        return;
+      }
 
-    for (unsigned Val : NonNull->args()) {
-      if (Val >= Args.size())
-        continue;
-      if (NonNullArgs.empty())
-        NonNullArgs.resize(Args.size());
-      NonNullArgs.set(Val);
+      for (unsigned Val : NonNull->args()) {
+        if (Val >= Args.size())
+          continue;
+        if (NonNullArgs.empty())
+          NonNullArgs.resize(Args.size());
+        NonNullArgs.set(Val);
+      }
     }
   }
 
-  // Check the attributes on the parameters.
-  ArrayRef<ParmVarDecl*> parms;
-  if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(FDecl))
-    parms = FD->parameters();
-  else if (const ObjCMethodDecl *MD = dyn_cast<ObjCMethodDecl>(FDecl))
-    parms = MD->parameters();
+  if (FDecl && (isa<FunctionDecl>(FDecl) || isa<ObjCMethodDecl>(FDecl))) {
+    // Handle the nonnull attribute on the parameters of the
+    // function/method.
+    ArrayRef<ParmVarDecl*> parms;
+    if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(FDecl))
+      parms = FD->parameters();
+    else
+      parms = cast<ObjCMethodDecl>(FDecl)->parameters();
+    
+    unsigned ParamIndex = 0;
+    for (ArrayRef<ParmVarDecl*>::iterator I = parms.begin(), E = parms.end();
+         I != E; ++I, ++ParamIndex) {
+      const ParmVarDecl *PVD = *I;
+      if (PVD->hasAttr<NonNullAttr>() || 
+          isNonNullType(S.Context, PVD->getType())) {
+        if (NonNullArgs.empty())
+          NonNullArgs.resize(Args.size());
 
-  unsigned ArgIndex = 0;
-  for (ArrayRef<ParmVarDecl*>::iterator I = parms.begin(), E = parms.end();
-       I != E; ++I, ++ArgIndex) {
-    const ParmVarDecl *PVD = *I;
-    if (PVD->hasAttr<NonNullAttr>() ||
-        (ArgIndex < NonNullArgs.size() && NonNullArgs[ArgIndex]))
-      CheckNonNullArgument(S, Args[ArgIndex], CallSiteLoc);
+        NonNullArgs.set(ParamIndex);
+      }
+    }
+  } else {
+    // If we have a non-function, non-method declaration but no
+    // function prototype, try to dig out the function prototype.
+    if (!Proto) {
+      if (const ValueDecl *VD = dyn_cast<ValueDecl>(FDecl)) {
+        QualType type = VD->getType().getNonReferenceType();
+        if (auto pointerType = type->getAs<PointerType>())
+          type = pointerType->getPointeeType();
+        else if (auto blockType = type->getAs<BlockPointerType>())
+          type = blockType->getPointeeType();
+        // FIXME: data member pointers?
+
+        // Dig out the function prototype, if there is one.
+        Proto = type->getAs<FunctionProtoType>();
+      } 
+    }
+
+    // Fill in non-null argument information from the nullability
+    // information on the parameter types (if we have them).
+    if (Proto) {
+      unsigned Index = 0;
+      for (auto paramType : Proto->getParamTypes()) {
+        if (isNonNullType(S.Context, paramType)) {
+          if (NonNullArgs.empty())
+            NonNullArgs.resize(Args.size());
+          
+          NonNullArgs.set(Index);
+        }
+        
+        ++Index;
+      }
+    }
   }
 
-  // In case this is a variadic call, check any remaining arguments.
-  for (/**/; ArgIndex < NonNullArgs.size(); ++ArgIndex)
+  // Check for non-null arguments.
+  for (unsigned ArgIndex = 0, ArgIndexEnd = NonNullArgs.size(); 
+       ArgIndex != ArgIndexEnd; ++ArgIndex) {
     if (NonNullArgs[ArgIndex])
       CheckNonNullArgument(S, Args[ArgIndex], CallSiteLoc);
+  }
 }
 
 /// Handles the checks for format strings, non-POD arguments to vararg
 /// functions, and NULL arguments passed to non-NULL parameters.
-void Sema::checkCall(NamedDecl *FDecl, ArrayRef<const Expr *> Args,
-                     unsigned NumParams, bool IsMemberFunction,
+void Sema::checkCall(NamedDecl *FDecl, const FunctionProtoType *Proto,
+                     ArrayRef<const Expr *> Args, bool IsMemberFunction,
                      SourceLocation Loc, SourceRange Range,
                      VariadicCallType CallType) {
   // FIXME: We should check as much as we can in the template definition.
@@ -1261,6 +1323,13 @@
   // Refuse POD arguments that weren't caught by the format string
   // checks above.
   if (CallType != VariadicDoesNotApply) {
+    unsigned NumParams = Proto ? Proto->getNumParams()
+                       : FDecl && isa<FunctionDecl>(FDecl)
+                           ? cast<FunctionDecl>(FDecl)->getNumParams()
+                       : FDecl && isa<ObjCMethodDecl>(FDecl)
+                           ? cast<ObjCMethodDecl>(FDecl)->param_size()
+                       : 0;
+
     for (unsigned ArgIdx = NumParams; ArgIdx < Args.size(); ++ArgIdx) {
       // Args[ArgIdx] can be null in malformed code.
       if (const Expr *Arg = Args[ArgIdx]) {
@@ -1270,12 +1339,14 @@
     }
   }
 
-  if (FDecl) {
-    CheckNonNullArguments(*this, FDecl, Args, Loc);
+  if (FDecl || Proto) {
+    CheckNonNullArguments(*this, FDecl, Proto, Args, Loc);
 
     // Type safety checking.
-    for (const auto *I : FDecl->specific_attrs<ArgumentWithTypeTagAttr>())
-      CheckArgumentWithTypeTag(I, Args.data());
+    if (FDecl) {
+      for (const auto *I : FDecl->specific_attrs<ArgumentWithTypeTagAttr>())
+        CheckArgumentWithTypeTag(I, Args.data());
+    }
   }
 }
 
@@ -1287,8 +1358,8 @@
                                 SourceLocation Loc) {
   VariadicCallType CallType =
     Proto->isVariadic() ? VariadicConstructor : VariadicDoesNotApply;
-  checkCall(FDecl, Args, Proto->getNumParams(),
-            /*IsMemberFunction=*/true, Loc, SourceRange(), CallType);
+  checkCall(FDecl, Proto, Args, /*IsMemberFunction=*/true, Loc, SourceRange(), 
+            CallType);
 }
 
 /// CheckFunctionCall - Check a direct function call for various correctness
@@ -1301,7 +1372,6 @@
                           IsMemberOperatorCall;
   VariadicCallType CallType = getVariadicCallType(FDecl, Proto,
                                                   TheCall->getCallee());
-  unsigned NumParams = Proto ? Proto->getNumParams() : 0;
   Expr** Args = TheCall->getArgs();
   unsigned NumArgs = TheCall->getNumArgs();
   if (IsMemberOperatorCall) {
@@ -1311,7 +1381,7 @@
     ++Args;
     --NumArgs;
   }
-  checkCall(FDecl, llvm::makeArrayRef(Args, NumArgs), NumParams,
+  checkCall(FDecl, Proto, llvm::makeArrayRef(Args, NumArgs), 
             IsMemberFunction, TheCall->getRParenLoc(),
             TheCall->getCallee()->getSourceRange(), CallType);
 
@@ -1345,9 +1415,9 @@
   VariadicCallType CallType =
       Method->isVariadic() ? VariadicMethod : VariadicDoesNotApply;
 
-  checkCall(Method, Args, Method->param_size(),
-            /*IsMemberFunction=*/false,
-            lbrac, Method->getSourceRange(), CallType);
+  checkCall(Method, nullptr, Args,
+            /*IsMemberFunction=*/false, lbrac, Method->getSourceRange(), 
+            CallType);
 
   return false;
 }
@@ -1356,13 +1426,14 @@
                             const FunctionProtoType *Proto) {
   QualType Ty;
   if (const auto *V = dyn_cast<VarDecl>(NDecl))
-    Ty = V->getType();
+    Ty = V->getType().getNonReferenceType();
   else if (const auto *F = dyn_cast<FieldDecl>(NDecl))
-    Ty = F->getType();
+    Ty = F->getType().getNonReferenceType();
   else
     return false;
 
-  if (!Ty->isBlockPointerType() && !Ty->isFunctionPointerType())
+  if (!Ty->isBlockPointerType() && !Ty->isFunctionPointerType() &&
+      !Ty->isFunctionProtoType())
     return false;
 
   VariadicCallType CallType;
@@ -1373,11 +1444,10 @@
   } else { // Ty->isFunctionPointerType()
     CallType = VariadicFunction;
   }
-  unsigned NumParams = Proto ? Proto->getNumParams() : 0;
 
-  checkCall(NDecl, llvm::makeArrayRef(TheCall->getArgs(),
-                                      TheCall->getNumArgs()),
-            NumParams, /*IsMemberFunction=*/false, TheCall->getRParenLoc(),
+  checkCall(NDecl, Proto,
+            llvm::makeArrayRef(TheCall->getArgs(), TheCall->getNumArgs()),
+            /*IsMemberFunction=*/false, TheCall->getRParenLoc(),
             TheCall->getCallee()->getSourceRange(), CallType);
 
   return false;
@@ -1388,11 +1458,9 @@
 bool Sema::CheckOtherCall(CallExpr *TheCall, const FunctionProtoType *Proto) {
   VariadicCallType CallType = getVariadicCallType(/*FDecl=*/nullptr, Proto,
                                                   TheCall->getCallee());
-  unsigned NumParams = Proto ? Proto->getNumParams() : 0;
-
-  checkCall(/*FDecl=*/nullptr,
+  checkCall(/*FDecl=*/nullptr, Proto,
             llvm::makeArrayRef(TheCall->getArgs(), TheCall->getNumArgs()),
-            NumParams, /*IsMemberFunction=*/false, TheCall->getRParenLoc(),
+            /*IsMemberFunction=*/false, TheCall->getRParenLoc(),
             TheCall->getCallee()->getSourceRange(), CallType);
 
   return false;
@@ -5680,7 +5748,8 @@
   CheckReturnStackAddr(*this, RetValExp, lhsType, ReturnLoc);
 
   // Check if the return value is null but should not be.
-  if (Attrs && hasSpecificAttr<ReturnsNonNullAttr>(*Attrs) &&
+  if (((Attrs && hasSpecificAttr<ReturnsNonNullAttr>(*Attrs)) ||
+       (!isObjCMethod && isNonNullType(Context, lhsType))) &&
       CheckNonNullExpr(*this, RetValExp))
     Diag(ReturnLoc, diag::warn_null_ret)
       << (isObjCMethod ? 1 : 0) << RetValExp->getSourceRange();