<rdar://problem/13267210> Ensure that Sema::CompareReferenceRelationship returns consistent results with invalid types.

When Sema::RequireCompleteType() is given a class template
specialization type that then fails to instantiate, it returns
'true'. On subsequent invocations, it can return false. Make sure that
this difference doesn't change the result of
Sema::CompareReferenceRelationship, which is expected to remain stable
while we're checking an initialization sequence.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@178088 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaDeclCXX.cpp b/lib/Sema/SemaDeclCXX.cpp
index cc0aade..1da1884 100644
--- a/lib/Sema/SemaDeclCXX.cpp
+++ b/lib/Sema/SemaDeclCXX.cpp
@@ -1311,29 +1311,25 @@
                        (CXXBaseSpecifier**)(Bases), NumBases);
 }
 
-static CXXRecordDecl *GetClassForType(QualType T) {
-  if (const RecordType *RT = T->getAs<RecordType>())
-    return cast<CXXRecordDecl>(RT->getDecl());
-  else if (const InjectedClassNameType *ICT = T->getAs<InjectedClassNameType>())
-    return ICT->getDecl();
-  else
-    return 0;
-}
-
 /// \brief Determine whether the type \p Derived is a C++ class that is
 /// derived from the type \p Base.
 bool Sema::IsDerivedFrom(QualType Derived, QualType Base) {
   if (!getLangOpts().CPlusPlus)
     return false;
   
-  CXXRecordDecl *DerivedRD = GetClassForType(Derived);
+  CXXRecordDecl *DerivedRD = Derived->getAsCXXRecordDecl();
   if (!DerivedRD)
     return false;
   
-  CXXRecordDecl *BaseRD = GetClassForType(Base);
+  CXXRecordDecl *BaseRD = Base->getAsCXXRecordDecl();
   if (!BaseRD)
     return false;
-  
+
+  // If either the base or the derived type is invalid, don't try to
+  // check whether one is derived from the other.
+  if (BaseRD->isInvalidDecl() || DerivedRD->isInvalidDecl())
+    return false;
+
   // FIXME: instantiate DerivedRD if necessary.  We need a PoI for this.
   return DerivedRD->hasDefinition() && DerivedRD->isDerivedFrom(BaseRD);
 }
@@ -1344,11 +1340,11 @@
   if (!getLangOpts().CPlusPlus)
     return false;
   
-  CXXRecordDecl *DerivedRD = GetClassForType(Derived);
+  CXXRecordDecl *DerivedRD = Derived->getAsCXXRecordDecl();
   if (!DerivedRD)
     return false;
   
-  CXXRecordDecl *BaseRD = GetClassForType(Base);
+  CXXRecordDecl *BaseRD = Base->getAsCXXRecordDecl();
   if (!BaseRD)
     return false;