move ObjCQualifiedIdTypesAreCompatible out of ASTContext into Sema.
While it is similar to the other compatibility predicates in ASTContext,
it is not used by them and is different.

In addition, greatly simplify ObjCQualifiedIdTypesAreCompatible and
fix some canonical type bugs.  Also, simplify my Type::getAsObjC* methods.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@49313 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaExprObjC.cpp b/lib/Sema/SemaExprObjC.cpp
index ec6e41b..63dd0d7 100644
--- a/lib/Sema/SemaExprObjC.cpp
+++ b/lib/Sema/SemaExprObjC.cpp
@@ -297,3 +297,170 @@
   return new ObjCMessageExpr(RExpr, Sel, returnType, Method, lbrac, rbrac, 
                              ArgExprs, NumArgs);
 }
+
+//===----------------------------------------------------------------------===//
+// ObjCQualifiedIdTypesAreCompatible - Compatibility testing for qualified id's.
+//===----------------------------------------------------------------------===//
+
+/// ProtocolCompatibleWithProtocol - return 'true' if 'lProto' is in the
+/// inheritance hierarchy of 'rProto'.
+static bool ProtocolCompatibleWithProtocol(ObjCProtocolDecl *lProto,
+                                           ObjCProtocolDecl *rProto) {
+  if (lProto == rProto)
+    return true;
+  ObjCProtocolDecl** RefPDecl = rProto->getReferencedProtocols();
+  for (unsigned i = 0; i < rProto->getNumReferencedProtocols(); i++)
+    if (ProtocolCompatibleWithProtocol(lProto, RefPDecl[i]))
+      return true;
+  return false;
+}
+
+/// ClassImplementsProtocol - Checks that 'lProto' protocol
+/// has been implemented in IDecl class, its super class or categories (if
+/// lookupCategory is true). 
+static bool ClassImplementsProtocol(ObjCProtocolDecl *lProto,
+                                    ObjCInterfaceDecl *IDecl, 
+                                    bool lookupCategory) {
+  
+  // 1st, look up the class.
+  ObjCProtocolDecl **protoList = IDecl->getReferencedProtocols();
+  for (unsigned i = 0; i < IDecl->getNumIntfRefProtocols(); i++) {
+    if (ProtocolCompatibleWithProtocol(lProto, protoList[i]))
+      return true;
+  }
+  
+  // 2nd, look up the category.
+  if (lookupCategory)
+    for (ObjCCategoryDecl *CDecl = IDecl->getCategoryList(); CDecl;
+         CDecl = CDecl->getNextClassCategory()) {
+      protoList = CDecl->getReferencedProtocols();
+      for (unsigned i = 0; i < CDecl->getNumReferencedProtocols(); i++) {
+        if (ProtocolCompatibleWithProtocol(lProto, protoList[i]))
+          return true;
+      }
+    }
+  
+  // 3rd, look up the super class(s)
+  if (IDecl->getSuperClass())
+    return 
+      ClassImplementsProtocol(lProto, IDecl->getSuperClass(), lookupCategory);
+  
+  return false;
+}
+
+bool Sema::ObjCQualifiedIdTypesAreCompatible(QualType lhs, QualType rhs,
+                                             bool compare) {
+  // Allow id<P..> and an 'id' or void* type in all cases.
+  if (const PointerType *PT = lhs->getAsPointerType()) {
+    QualType PointeeTy = PT->getPointeeType();
+    if (Context.isObjCIdType(PointeeTy) || PointeeTy->isVoidType())
+      return true;
+  } else if (const PointerType *PT = rhs->getAsPointerType()) {
+    QualType PointeeTy = PT->getPointeeType();
+    if (Context.isObjCIdType(PointeeTy) || PointeeTy->isVoidType())
+      return true;
+  }
+  
+  const ObjCQualifiedInterfaceType *lhsQI = 0;
+  const ObjCQualifiedInterfaceType *rhsQI = 0;
+  ObjCInterfaceDecl *lhsID = 0;
+  ObjCInterfaceDecl *rhsID = 0;
+  const ObjCQualifiedIdType *lhsQID = lhs->getAsObjCQualifiedIdType();
+  const ObjCQualifiedIdType *rhsQID = rhs->getAsObjCQualifiedIdType();
+  
+  if (lhsQID) {
+    if (!rhsQID && rhs->isPointerType()) {
+      QualType rtype = rhs->getAsPointerType()->getPointeeType();
+      rhsQI = rtype->getAsObjCQualifiedInterfaceType();
+      if (!rhsQI) {
+        if (const ObjCInterfaceType *IT = rtype->getAsObjCInterfaceType())
+          rhsID = IT->getDecl();
+      }
+    }
+    if (!rhsQI && !rhsQID && !rhsID)
+      return false;
+    
+    ObjCQualifiedIdType::qual_iterator RHSProtoI, RHSProtoE;
+    if (rhsQI) {
+      RHSProtoI = rhsQI->qual_begin();
+      RHSProtoE = rhsQI->qual_end();
+    } else if (rhsQID) {
+      RHSProtoI = rhsQID->qual_begin();
+      RHSProtoE = rhsQID->qual_end();
+    }
+    
+    for (unsigned i =0; i < lhsQID->getNumProtocols(); i++) {
+      ObjCProtocolDecl *lhsProto = lhsQID->getProtocols(i);
+      bool match = false;
+
+      // when comparing an id<P> on lhs with a static type on rhs,
+      // see if static class implements all of id's protocols, directly or
+      // through its super class and categories.
+      if (rhsID) {
+        if (ClassImplementsProtocol(lhsProto, rhsID, true))
+          match = true;
+      } else {
+        for (; RHSProtoI != RHSProtoE; ++RHSProtoI) {
+          ObjCProtocolDecl *rhsProto = *RHSProtoI;
+          if (ProtocolCompatibleWithProtocol(lhsProto, rhsProto) ||
+              compare && ProtocolCompatibleWithProtocol(rhsProto, lhsProto)) {
+            match = true;
+            break;
+          }
+        }
+      }
+      if (!match)
+        return false;
+    }    
+  } else if (rhsQID) {
+    if (!lhsQID && lhs->isPointerType()) {
+      QualType ltype = lhs->getAsPointerType()->getPointeeType();
+      lhsQI = ltype->getAsObjCQualifiedInterfaceType();
+      if (!lhsQI) {
+        if (const ObjCInterfaceType *IT = ltype->getAsObjCInterfaceType())
+          lhsID = IT->getDecl();
+      }
+    }
+    if (!lhsQI && !lhsQID && !lhsID)
+      return false;
+    
+    ObjCQualifiedIdType::qual_iterator LHSProtoI, LHSProtoE;
+    if (lhsQI) {
+      LHSProtoI = lhsQI->qual_begin();
+      LHSProtoE = lhsQI->qual_end();
+    } else if (lhsQID) {
+      LHSProtoI = lhsQID->qual_begin();
+      LHSProtoE = lhsQID->qual_end();
+    }
+    
+    bool match = false;
+    // for static type vs. qualified 'id' type, check that class implements
+    // one of 'id's protocols.
+    if (lhsID) {
+      for (unsigned j = 0; j < rhsQID->getNumProtocols(); j++) {
+        ObjCProtocolDecl *rhsProto = rhsQID->getProtocols(j);
+        if (ClassImplementsProtocol(rhsProto, lhsID, compare)) {
+          match = true;
+          break;
+        }
+      }
+    } else {
+      for (; LHSProtoI != LHSProtoE; ++LHSProtoI) {
+        match = false;
+        ObjCProtocolDecl *lhsProto = *LHSProtoI;
+        for (unsigned j = 0; j < rhsQID->getNumProtocols(); j++) {
+          ObjCProtocolDecl *rhsProto = rhsQID->getProtocols(j);
+          if (ProtocolCompatibleWithProtocol(lhsProto, rhsProto) ||
+            compare && ProtocolCompatibleWithProtocol(rhsProto, lhsProto)) {
+            match = true;
+            break;
+          }
+        }
+      }
+    }
+    if (!match)
+      return false;
+  }
+  return true;  
+}
+