[ObjC] Circular containers: add support of subclasses

llvm-svn: 244193
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index e1656f2..a8a7009 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -8727,23 +8727,10 @@
 
 static Optional<int> GetNSMutableArrayArgumentIndex(Sema &S,
                                                     ObjCMessageExpr *Message) {
-  if (S.NSMutableArrayPointer.isNull()) {
-    IdentifierInfo *NSMutableArrayId =
-      S.NSAPIObj->getNSClassId(NSAPI::ClassId_NSMutableArray);
-    NamedDecl *IF = S.LookupSingleName(S.TUScope, NSMutableArrayId,
-                                       Message->getLocStart(),
-                                       Sema::LookupOrdinaryName);
-    ObjCInterfaceDecl *InterfaceDecl = dyn_cast_or_null<ObjCInterfaceDecl>(IF);
-    if (!InterfaceDecl) {
-      return None;
-    }
-    QualType NSMutableArrayObject =
-      S.Context.getObjCInterfaceType(InterfaceDecl);
-    S.NSMutableArrayPointer =
-      S.Context.getObjCObjectPointerType(NSMutableArrayObject);
-  }
-
-  if (S.NSMutableArrayPointer != Message->getReceiverType()) {
+  bool IsMutableArray = S.NSAPIObj->isSubclassOfNSClass(
+                                                Message->getReceiverInterface(),
+                                                NSAPI::ClassId_NSMutableArray);
+  if (!IsMutableArray) {
     return None;
   }
 
@@ -8775,24 +8762,10 @@
 static
 Optional<int> GetNSMutableDictionaryArgumentIndex(Sema &S,
                                                   ObjCMessageExpr *Message) {
-
-  if (S.NSMutableDictionaryPointer.isNull()) {
-    IdentifierInfo *NSMutableDictionaryId =
-      S.NSAPIObj->getNSClassId(NSAPI::ClassId_NSMutableDictionary);
-    NamedDecl *IF = S.LookupSingleName(S.TUScope, NSMutableDictionaryId,
-                                       Message->getLocStart(),
-                                       Sema::LookupOrdinaryName);
-    ObjCInterfaceDecl *InterfaceDecl = dyn_cast_or_null<ObjCInterfaceDecl>(IF);
-    if (!InterfaceDecl) {
-      return None;
-    }
-    QualType NSMutableDictionaryObject =
-      S.Context.getObjCInterfaceType(InterfaceDecl);
-    S.NSMutableDictionaryPointer =
-      S.Context.getObjCObjectPointerType(NSMutableDictionaryObject);
-  }
-
-  if (S.NSMutableDictionaryPointer != Message->getReceiverType()) {
+  bool IsMutableDictionary = S.NSAPIObj->isSubclassOfNSClass(
+                                            Message->getReceiverInterface(),
+                                            NSAPI::ClassId_NSMutableDictionary);
+  if (!IsMutableDictionary) {
     return None;
   }
 
@@ -8820,63 +8793,14 @@
 }
 
 static Optional<int> GetNSSetArgumentIndex(Sema &S, ObjCMessageExpr *Message) {
+  bool IsMutableSet = S.NSAPIObj->isSubclassOfNSClass(
+                                                Message->getReceiverInterface(),
+                                                NSAPI::ClassId_NSMutableSet);
 
-  ObjCInterfaceDecl *InterfaceDecl;
-  if (S.NSMutableSetPointer.isNull()) {
-    IdentifierInfo *NSMutableSetId =
-      S.NSAPIObj->getNSClassId(NSAPI::ClassId_NSMutableSet);
-    NamedDecl *IF = S.LookupSingleName(S.TUScope, NSMutableSetId,
-                                       Message->getLocStart(),
-                                       Sema::LookupOrdinaryName);
-    InterfaceDecl = dyn_cast_or_null<ObjCInterfaceDecl>(IF);
-    if (InterfaceDecl) {
-      QualType NSMutableSetObject =
-        S.Context.getObjCInterfaceType(InterfaceDecl);
-      S.NSMutableSetPointer =
-        S.Context.getObjCObjectPointerType(NSMutableSetObject);
-    }
-  }
-
-  if (S.NSCountedSetPointer.isNull()) {
-    IdentifierInfo *NSCountedSetId =
-      S.NSAPIObj->getNSClassId(NSAPI::ClassId_NSCountedSet);
-    NamedDecl *IF = S.LookupSingleName(S.TUScope, NSCountedSetId,
-                                       Message->getLocStart(),
-                                       Sema::LookupOrdinaryName);
-    InterfaceDecl = dyn_cast_or_null<ObjCInterfaceDecl>(IF);
-    if (InterfaceDecl) {
-      QualType NSCountedSetObject =
-        S.Context.getObjCInterfaceType(InterfaceDecl);
-      S.NSCountedSetPointer =
-        S.Context.getObjCObjectPointerType(NSCountedSetObject);
-    }
-  }
-
-  if (S.NSMutableOrderedSetPointer.isNull()) {
-    IdentifierInfo *NSOrderedSetId =
-      S.NSAPIObj->getNSClassId(NSAPI::ClassId_NSMutableOrderedSet);
-    NamedDecl *IF = S.LookupSingleName(S.TUScope, NSOrderedSetId,
-                                       Message->getLocStart(),
-                                       Sema::LookupOrdinaryName);
-    InterfaceDecl = dyn_cast_or_null<ObjCInterfaceDecl>(IF);
-    if (InterfaceDecl) {
-      QualType NSOrderedSetObject =
-        S.Context.getObjCInterfaceType(InterfaceDecl);
-      S.NSMutableOrderedSetPointer =
-        S.Context.getObjCObjectPointerType(NSOrderedSetObject);
-    }
-  }
-
-  QualType ReceiverType = Message->getReceiverType();
-
-  bool IsMutableSet = !S.NSMutableSetPointer.isNull() &&
-    ReceiverType == S.NSMutableSetPointer;
-  bool IsMutableOrderedSet = !S.NSMutableOrderedSetPointer.isNull() &&
-    ReceiverType == S.NSMutableOrderedSetPointer;
-  bool IsCountedSet = !S.NSCountedSetPointer.isNull() &&
-    ReceiverType == S.NSCountedSetPointer;
-
-  if (!IsMutableSet && !IsMutableOrderedSet && !IsCountedSet) {
+  bool IsMutableOrderedSet = S.NSAPIObj->isSubclassOfNSClass(
+                                            Message->getReceiverInterface(),
+                                            NSAPI::ClassId_NSMutableOrderedSet);
+  if (!IsMutableSet && !IsMutableOrderedSet) {
     return None;
   }
 
@@ -8917,38 +8841,51 @@
 
   int ArgIndex = *ArgOpt;
 
-  Expr *Receiver = Message->getInstanceReceiver()->IgnoreImpCasts();
-  if (OpaqueValueExpr *OE = dyn_cast<OpaqueValueExpr>(Receiver)) {
-    Receiver = OE->getSourceExpr()->IgnoreImpCasts();
-  }
-
   Expr *Arg = Message->getArg(ArgIndex)->IgnoreImpCasts();
   if (OpaqueValueExpr *OE = dyn_cast<OpaqueValueExpr>(Arg)) {
     Arg = OE->getSourceExpr()->IgnoreImpCasts();
   }
 
-  if (DeclRefExpr *ReceiverRE = dyn_cast<DeclRefExpr>(Receiver)) {
+  if (Message->getReceiverKind() == ObjCMessageExpr::SuperInstance) {
     if (DeclRefExpr *ArgRE = dyn_cast<DeclRefExpr>(Arg)) {
-      if (ReceiverRE->getDecl() == ArgRE->getDecl()) {
-        ValueDecl *Decl = ReceiverRE->getDecl();
+      if (ArgRE->isObjCSelfExpr()) {
         Diag(Message->getSourceRange().getBegin(),
              diag::warn_objc_circular_container)
-          << Decl->getName();
-        Diag(Decl->getLocation(),
-             diag::note_objc_circular_container_declared_here)
-          << Decl->getName();
+          << ArgRE->getDecl()->getName() << StringRef("super");
       }
     }
-  } else if (ObjCIvarRefExpr *IvarRE = dyn_cast<ObjCIvarRefExpr>(Receiver)) {
-    if (ObjCIvarRefExpr *IvarArgRE = dyn_cast<ObjCIvarRefExpr>(Arg)) {
-      if (IvarRE->getDecl() == IvarArgRE->getDecl()) {
-        ObjCIvarDecl *Decl = IvarRE->getDecl();
-        Diag(Message->getSourceRange().getBegin(),
-             diag::warn_objc_circular_container)
-          << Decl->getName();
-        Diag(Decl->getLocation(),
-             diag::note_objc_circular_container_declared_here)
-          << Decl->getName();
+  } else {
+    Expr *Receiver = Message->getInstanceReceiver()->IgnoreImpCasts();
+
+    if (OpaqueValueExpr *OE = dyn_cast<OpaqueValueExpr>(Receiver)) {
+      Receiver = OE->getSourceExpr()->IgnoreImpCasts();
+    }
+
+    if (DeclRefExpr *ReceiverRE = dyn_cast<DeclRefExpr>(Receiver)) {
+      if (DeclRefExpr *ArgRE = dyn_cast<DeclRefExpr>(Arg)) {
+        if (ReceiverRE->getDecl() == ArgRE->getDecl()) {
+          ValueDecl *Decl = ReceiverRE->getDecl();
+          Diag(Message->getSourceRange().getBegin(),
+               diag::warn_objc_circular_container)
+            << Decl->getName() << Decl->getName();
+          if (!ArgRE->isObjCSelfExpr()) {
+            Diag(Decl->getLocation(),
+                 diag::note_objc_circular_container_declared_here)
+              << Decl->getName();
+          }
+        }
+      }
+    } else if (ObjCIvarRefExpr *IvarRE = dyn_cast<ObjCIvarRefExpr>(Receiver)) {
+      if (ObjCIvarRefExpr *IvarArgRE = dyn_cast<ObjCIvarRefExpr>(Arg)) {
+        if (IvarRE->getDecl() == IvarArgRE->getDecl()) {
+          ObjCIvarDecl *Decl = IvarRE->getDecl();
+          Diag(Message->getSourceRange().getBegin(),
+               diag::warn_objc_circular_container)
+            << Decl->getName() << Decl->getName();
+          Diag(Decl->getLocation(),
+               diag::note_objc_circular_container_declared_here)
+            << Decl->getName();
+        }
       }
     }
   }