Implement Objective-C Related Result Type semantics.

Related result types apply Cocoa conventions to the type of message
sends and property accesses to Objective-C methods that are known to
always return objects whose type is the same as the type of the
receiving class (or a subclass thereof), such as +alloc and
-init. This tightens up static type safety for Objective-C, so that we
now diagnose mistakes like this:

t.m:4:10: warning: incompatible pointer types initializing 'NSSet *'
with an
      expression of type 'NSArray *' [-Wincompatible-pointer-types]
  NSSet *array = [[NSArray alloc] init];
         ^       ~~~~~~~~~~~~~~~~~~~~~~
/System/Library/Frameworks/Foundation.framework/Headers/NSObject.h:72:1:
note: 
      instance method 'init' is assumed to return an instance of its
      receiver
      type ('NSArray *')
- (id)init;
^

It also means that we get decent type inference when writing code in
Objective-C++0x:

  auto array = [[NSMutableArray alloc] initWithObjects:@"one",  @"two",nil];
  //    ^ now infers NSMutableArray* rather than id




git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@132868 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaDeclObjC.cpp b/lib/Sema/SemaDeclObjC.cpp
index 4e41aa9..de9097e 100644
--- a/lib/Sema/SemaDeclObjC.cpp
+++ b/lib/Sema/SemaDeclObjC.cpp
@@ -24,6 +24,141 @@
 
 using namespace clang;
 
+bool Sema::CheckObjCMethodOverride(ObjCMethodDecl *NewMethod, 
+                                   const ObjCMethodDecl *Overridden,
+                                   bool IsImplementation) {
+  if (Overridden->hasRelatedResultType() && 
+      !NewMethod->hasRelatedResultType()) {
+    // This can only happen when the method follows a naming convention that
+    // implies a related result type, and the original (overridden) method has
+    // a suitable return type, but the new (overriding) method does not have
+    // a suitable return type.
+    QualType ResultType = NewMethod->getResultType();
+    SourceRange ResultTypeRange;
+    if (const TypeSourceInfo *ResultTypeInfo 
+        = NewMethod->getResultTypeSourceInfo())
+      ResultTypeRange = ResultTypeInfo->getTypeLoc().getSourceRange();
+    
+    // Figure out which class this method is part of, if any.
+    ObjCInterfaceDecl *CurrentClass 
+      = dyn_cast<ObjCInterfaceDecl>(NewMethod->getDeclContext());
+    if (!CurrentClass) {
+      DeclContext *DC = NewMethod->getDeclContext();
+      if (ObjCCategoryDecl *Cat = dyn_cast<ObjCCategoryDecl>(DC))
+        CurrentClass = Cat->getClassInterface();
+      else if (ObjCImplDecl *Impl = dyn_cast<ObjCImplDecl>(DC))
+        CurrentClass = Impl->getClassInterface();
+      else if (ObjCCategoryImplDecl *CatImpl
+               = dyn_cast<ObjCCategoryImplDecl>(DC))
+        CurrentClass = CatImpl->getClassInterface();
+    }
+    
+    if (CurrentClass) {
+      Diag(NewMethod->getLocation(), 
+           diag::warn_related_result_type_compatibility_class)
+        << Context.getObjCInterfaceType(CurrentClass)
+        << ResultType
+        << ResultTypeRange;
+    } else {
+      Diag(NewMethod->getLocation(), 
+           diag::warn_related_result_type_compatibility_protocol)
+        << ResultType
+        << ResultTypeRange;
+    }
+    
+    Diag(Overridden->getLocation(), diag::note_related_result_type_overridden)
+      << Overridden->getMethodFamily();
+  }
+  
+  return false;
+}
+
+
+static bool CheckObjCMethodOverrides(Sema &S, ObjCMethodDecl *NewMethod,
+                                     DeclContext *DC, 
+                                     bool SkipCurrent = true) {
+  if (!DC)
+    return false;
+  
+  if (!SkipCurrent) {
+    // Look for this method. If we find it, we're done.
+    Selector Sel = NewMethod->getSelector();
+    bool IsInstance = NewMethod->isInstanceMethod();
+    DeclContext::lookup_const_iterator Meth, MethEnd;
+    for (llvm::tie(Meth, MethEnd) = DC->lookup(Sel); Meth != MethEnd; ++Meth) {
+      ObjCMethodDecl *MD = dyn_cast<ObjCMethodDecl>(*Meth);
+      if (MD && MD->isInstanceMethod() == IsInstance)
+        return S.CheckObjCMethodOverride(NewMethod, MD, false);
+    }
+  }
+  
+  if (ObjCInterfaceDecl *Class = llvm::dyn_cast<ObjCInterfaceDecl>(DC)) {
+    // Look through categories.
+    for (ObjCCategoryDecl *Category = Class->getCategoryList();
+         Category; Category = Category->getNextClassCategory()) {
+      if (CheckObjCMethodOverrides(S, NewMethod, Category, false))
+        return true;
+    }
+    
+    // Look through protocols.
+    for (ObjCList<ObjCProtocolDecl>::iterator I = Class->protocol_begin(),
+         IEnd = Class->protocol_end();
+         I != IEnd; ++I)
+      if (CheckObjCMethodOverrides(S, NewMethod, *I, false))
+        return true;
+    
+    // Look in our superclass.
+    return CheckObjCMethodOverrides(S, NewMethod, Class->getSuperClass(), 
+                                    false);
+  }
+  
+  if (ObjCCategoryDecl *Category = dyn_cast<ObjCCategoryDecl>(DC)) {
+    // Look through protocols.
+    for (ObjCList<ObjCProtocolDecl>::iterator I = Category->protocol_begin(),
+         IEnd = Category->protocol_end();
+         I != IEnd; ++I)
+      if (CheckObjCMethodOverrides(S, NewMethod, *I, false))
+        return true;
+    
+    return false;
+  }
+  
+  if (ObjCProtocolDecl *Protocol = dyn_cast<ObjCProtocolDecl>(DC)) {
+    // Look through protocols.
+    for (ObjCList<ObjCProtocolDecl>::iterator I = Protocol->protocol_begin(),
+         IEnd = Protocol->protocol_end();
+         I != IEnd; ++I)
+      if (CheckObjCMethodOverrides(S, NewMethod, *I, false))
+        return true;
+    
+    return false;
+  }
+  
+  return false;
+}
+
+bool Sema::CheckObjCMethodOverrides(ObjCMethodDecl *NewMethod, 
+                                    DeclContext *DC) {
+  if (ObjCInterfaceDecl *Class = dyn_cast<ObjCInterfaceDecl>(DC))
+    return ::CheckObjCMethodOverrides(*this, NewMethod, Class);
+  
+  if (ObjCCategoryDecl *Category = dyn_cast<ObjCCategoryDecl>(DC))
+    return ::CheckObjCMethodOverrides(*this, NewMethod, Category);
+  
+  if (ObjCProtocolDecl *Protocol = dyn_cast<ObjCProtocolDecl>(DC))
+    return ::CheckObjCMethodOverrides(*this, NewMethod, Protocol);
+  
+  if (ObjCImplementationDecl *Impl = dyn_cast<ObjCImplementationDecl>(DC))
+    return ::CheckObjCMethodOverrides(*this, NewMethod, 
+                                      Impl->getClassInterface());
+  
+  if (ObjCCategoryImplDecl *CatImpl = dyn_cast<ObjCCategoryImplDecl>(DC))
+    return ::CheckObjCMethodOverrides(*this, NewMethod, 
+                                      CatImpl->getClassInterface());
+  
+  return ::CheckObjCMethodOverrides(*this, NewMethod, CurContext);
+}
+
 static void DiagnoseObjCImplementedDeprecations(Sema &S,
                                                 NamedDecl *ND,
                                                 SourceLocation ImplLoc,
@@ -1717,11 +1852,71 @@
   return false;
 }
 
+/// \brief Check whether the declared result type of the given Objective-C
+/// method declaration is compatible with the method's class.
+///
+static bool 
+CheckRelatedResultTypeCompatibility(Sema &S, ObjCMethodDecl *Method,
+                                    ObjCInterfaceDecl *CurrentClass) {
+  QualType ResultType = Method->getResultType();
+  SourceRange ResultTypeRange;
+  if (const TypeSourceInfo *ResultTypeInfo = Method->getResultTypeSourceInfo())
+    ResultTypeRange = ResultTypeInfo->getTypeLoc().getSourceRange();
+  
+  // If an Objective-C method inherits its related result type, then its 
+  // declared result type must be compatible with its own class type. The
+  // declared result type is compatible if:
+  if (const ObjCObjectPointerType *ResultObjectType
+                                = ResultType->getAs<ObjCObjectPointerType>()) {
+    //   - it is id or qualified id, or
+    if (ResultObjectType->isObjCIdType() ||
+        ResultObjectType->isObjCQualifiedIdType())
+      return false;
+  
+    if (CurrentClass) {
+      if (ObjCInterfaceDecl *ResultClass 
+                                      = ResultObjectType->getInterfaceDecl()) {
+        //   - it is the same as the method's class type, or
+        if (CurrentClass == ResultClass)
+          return false;
+        
+        //   - it is a superclass of the method's class type
+        if (ResultClass->isSuperClassOf(CurrentClass))
+          return false;
+      }      
+    }
+  }
+  
+  return true;
+}
+
+/// \brief Determine if any method in the global method pool has an inferred 
+/// result type.
+static bool 
+anyMethodInfersRelatedResultType(Sema &S, Selector Sel, bool IsInstance) {
+  Sema::GlobalMethodPool::iterator Pos = S.MethodPool.find(Sel);
+  if (Pos == S.MethodPool.end()) {
+    if (S.ExternalSource)
+      Pos = S.ReadMethodPool(Sel);
+    else
+      return 0;
+  }
+  
+  ObjCMethodList &List = IsInstance ? Pos->second.first : Pos->second.second;
+  for (ObjCMethodList *M = &List; M; M = M->Next) {
+    if (M->Method && M->Method->hasRelatedResultType())
+      return true;
+  }  
+  
+  return false;
+}
+
 Decl *Sema::ActOnMethodDeclaration(
     Scope *S,
     SourceLocation MethodLoc, SourceLocation EndLoc,
     tok::TokenKind MethodType, Decl *ClassDecl,
     ObjCDeclSpec &ReturnQT, ParsedType ReturnType,
+    SourceLocation SelectorStartLoc,
     Selector Sel,
     // optional arguments. The number of types/arguments is obtained
     // from the Sel.getNumArgs().
@@ -1746,7 +1941,7 @@
       Diag(MethodLoc, diag::err_object_cannot_be_passed_returned_by_value)
         << 0 << resultDeclType;
       return 0;
-    }
+    }    
   } else // get the type for "id".
     resultDeclType = Context.getObjCIdType();
 
@@ -1756,9 +1951,10 @@
                            cast<DeclContext>(ClassDecl),
                            MethodType == tok::minus, isVariadic,
                            false, false,
-                           MethodDeclKind == tok::objc_optional ?
-                           ObjCMethodDecl::Optional :
-                           ObjCMethodDecl::Required);
+                           MethodDeclKind == tok::objc_optional 
+                             ? ObjCMethodDecl::Optional
+                             : ObjCMethodDecl::Required,
+                           false);
 
   llvm::SmallVector<ParmVarDecl*, 16> Params;
 
@@ -1854,6 +2050,7 @@
     }
     InterfaceMD = ImpDecl->getClassInterface()->getMethod(Sel,
                                                    MethodType == tok::minus);
+    
     if (ObjCMethod->hasAttrs() &&
         containsInvalidMethodImplAttribute(ObjCMethod->getAttrs()))
       Diag(EndLoc, diag::warn_attribute_method_def);
@@ -1866,6 +2063,10 @@
       PrevMethod = CatImpDecl->getClassMethod(Sel);
       CatImpDecl->addClassMethod(ObjCMethod);
     }
+
+    if (ObjCCategoryDecl *Cat = CatImpDecl->getCategoryDecl())
+      InterfaceMD = Cat->getMethod(Sel, MethodType == tok::minus);
+
     if (ObjCMethod->hasAttrs() &&
         containsInvalidMethodImplAttribute(ObjCMethod->getAttrs()))
       Diag(EndLoc, diag::warn_attribute_method_def);
@@ -1879,10 +2080,65 @@
     Diag(PrevMethod->getLocation(), diag::note_previous_declaration);
   }
 
+  // If this Objective-C method does not have a related result type, but we
+  // are allowed to infer related result types, try to do so based on the
+  // method family.
+  ObjCInterfaceDecl *CurrentClass = dyn_cast<ObjCInterfaceDecl>(ClassDecl);
+  if (!CurrentClass) {
+    if (ObjCCategoryDecl *Cat = dyn_cast<ObjCCategoryDecl>(ClassDecl))
+      CurrentClass = Cat->getClassInterface();
+    else if (ObjCImplDecl *Impl = dyn_cast<ObjCImplDecl>(ClassDecl))
+      CurrentClass = Impl->getClassInterface();
+    else if (ObjCCategoryImplDecl *CatImpl
+                                   = dyn_cast<ObjCCategoryImplDecl>(ClassDecl))
+      CurrentClass = CatImpl->getClassInterface();
+  }
+  
   // Merge information down from the interface declaration if we have one.
-  if (InterfaceMD)
+  if (InterfaceMD) {
+    // Inherit the related result type, if we can.
+    if (InterfaceMD->hasRelatedResultType() &&
+        !CheckRelatedResultTypeCompatibility(*this, ObjCMethod, CurrentClass))
+      ObjCMethod->SetRelatedResultType();
+      
     mergeObjCMethodDecls(ObjCMethod, InterfaceMD);
-
+  }
+  
+  if (!ObjCMethod->hasRelatedResultType() && 
+      getLangOptions().ObjCInferRelatedResultType) {
+    bool InferRelatedResultType = false;
+    switch (ObjCMethod->getMethodFamily()) {
+    case OMF_None:
+    case OMF_copy:
+    case OMF_dealloc:
+    case OMF_mutableCopy:
+    case OMF_release:
+    case OMF_retainCount:
+      break;
+      
+    case OMF_alloc:
+    case OMF_new:
+      InferRelatedResultType = ObjCMethod->isClassMethod();
+      break;
+        
+    case OMF_init:
+    case OMF_autorelease:
+    case OMF_retain:
+    case OMF_self:
+      InferRelatedResultType = ObjCMethod->isInstanceMethod();
+      break;
+    }
+    
+    if (InferRelatedResultType &&
+        !CheckRelatedResultTypeCompatibility(*this, ObjCMethod, CurrentClass))
+      ObjCMethod->SetRelatedResultType();
+    
+    if (!InterfaceMD && 
+        anyMethodInfersRelatedResultType(*this, ObjCMethod->getSelector(),
+                                         ObjCMethod->isInstanceMethod()))
+      CheckObjCMethodOverrides(ObjCMethod, cast<DeclContext>(ClassDecl));
+  }
+    
   return ObjCMethod;
 }