Better diagnostics for covariance when checking overriding return types.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@71786 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaDeclCXX.cpp b/lib/Sema/SemaDeclCXX.cpp
index bdd3cc2..5a5c3c7 100644
--- a/lib/Sema/SemaDeclCXX.cpp
+++ b/lib/Sema/SemaDeclCXX.cpp
@@ -2701,12 +2701,71 @@
       CNewTy.getCVRQualifiers() == COldTy.getCVRQualifiers())
     return false;
   
-  // FIXME: Check covariance.
+  // Check if the return types are covariant
+  QualType NewClassTy, OldClassTy;
+  
+  /// Both types must be pointers or references to classes.
+  if (PointerType *NewPT = dyn_cast<PointerType>(NewTy)) {
+    if (PointerType *OldPT = dyn_cast<PointerType>(OldTy)) {
+      NewClassTy = NewPT->getPointeeType();
+      OldClassTy = OldPT->getPointeeType();
+    }
+  } else if (ReferenceType *NewRT = dyn_cast<ReferenceType>(NewTy)) {
+    if (ReferenceType *OldRT = dyn_cast<ReferenceType>(OldTy)) {
+      NewClassTy = NewRT->getPointeeType();
+      OldClassTy = OldRT->getPointeeType();
+    }
+  }
+  
+  // The return types aren't either both pointers or references to a class type.
+  if (NewClassTy.isNull()) {
+    Diag(New->getLocation(), 
+         diag::err_different_return_type_for_overriding_virtual_function)
+      << New->getDeclName() << NewTy << OldTy;
+    Diag(Old->getLocation(), diag::note_overridden_virtual_function);
+    
+    return true;
+  }
 
-  Diag(New->getLocation(), 
-       diag::err_different_return_type_for_overriding_virtual_function)
+  if (NewClassTy.getUnqualifiedType() != OldClassTy.getUnqualifiedType()) {
+    // Check if the new class derives from the old class.
+    if (!IsDerivedFrom(NewClassTy, OldClassTy)) {
+      Diag(New->getLocation(),
+           diag::err_covariant_return_not_derived)
+      << New->getDeclName() << NewTy << OldTy;
+      Diag(Old->getLocation(), diag::note_overridden_virtual_function);
+      return true;
+    }
+    
+    // Check if we the conversion from derived to base is valid.
+    if (CheckDerivedToBaseConversion(NewClassTy, OldClassTy, 
+                      diag::err_covariant_return_inaccessible_base,
+                      diag::err_covariant_return_ambiguous_derived_to_base_conv,
+                      // FIXME: Should this point to the return type?
+                      New->getLocation(), SourceRange(), New->getDeclName())) {
+      Diag(Old->getLocation(), diag::note_overridden_virtual_function);
+      return true;
+    }
+  }
+  
+  // The qualifiers of the return types must be the same.
+  if (CNewTy.getCVRQualifiers() != COldTy.getCVRQualifiers()) {
+    Diag(New->getLocation(),
+         diag::err_covariant_return_type_different_qualifications)
     << New->getDeclName() << NewTy << OldTy;
-  Diag(Old->getLocation(), diag::note_overridden_virtual_function);
-       
-  return true;
+    Diag(Old->getLocation(), diag::note_overridden_virtual_function);
+    return true;
+  };
+  
+
+  // The new class type must have the same or less qualifiers as the old type.
+  if (NewClassTy.isMoreQualifiedThan(OldClassTy)) {
+    Diag(New->getLocation(),
+         diag::err_covariant_return_type_class_type_more_qualified)
+    << New->getDeclName() << NewTy << OldTy;
+    Diag(Old->getLocation(), diag::note_overridden_virtual_function);
+    return true;
+  };
+  
+  return false;
 }