Change GetAddressCXXOfBaseClass to use CXXBasePaths for calculating base class offsets. Fix the code to handle virtual bases as well.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@83426 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/CGCXXClass.cpp b/lib/CodeGen/CGCXXClass.cpp
index 9c8174b..ff879f5 100644
--- a/lib/CodeGen/CGCXXClass.cpp
+++ b/lib/CodeGen/CGCXXClass.cpp
@@ -12,61 +12,35 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodeGenFunction.h"
+#include "clang/AST/CXXInheritance.h"
 #include "clang/AST/RecordLayout.h"
+
 using namespace clang;
 using namespace CodeGen;
 
-static bool
-GetNestedPaths(llvm::SmallVectorImpl<const CXXRecordDecl *> &NestedBasePaths,
-               const CXXRecordDecl *ClassDecl,
-               const CXXRecordDecl *BaseClassDecl) {
-  for (CXXRecordDecl::base_class_const_iterator i = ClassDecl->bases_begin(),
-      e = ClassDecl->bases_end(); i != e; ++i) {
-    if (i->isVirtual())
-      continue;
-    const CXXRecordDecl *Base =
-      cast<CXXRecordDecl>(i->getType()->getAs<RecordType>()->getDecl());
-    if (Base == BaseClassDecl) {
-      NestedBasePaths.push_back(BaseClassDecl);
-      return true;
-    }
-  }
-  // BaseClassDecl not an immediate base of ClassDecl.
-  for (CXXRecordDecl::base_class_const_iterator i = ClassDecl->bases_begin(),
-       e = ClassDecl->bases_end(); i != e; ++i) {
-    if (i->isVirtual())
-      continue;
-    const CXXRecordDecl *Base =
-      cast<CXXRecordDecl>(i->getType()->getAs<RecordType>()->getDecl());
-    if (GetNestedPaths(NestedBasePaths, Base, BaseClassDecl)) {
-      NestedBasePaths.push_back(Base);
-      return true;
-    }
-  }
-  return false;
-}
+static uint64_t 
+ComputeNonVirtualBaseClassOffset(ASTContext &Context, CXXBasePaths &Paths,
+                                 unsigned Start) {
+  uint64_t Offset = 0;
 
-static uint64_t ComputeBaseClassOffset(ASTContext &Context,
-                                       const CXXRecordDecl *ClassDecl,
-                                       const CXXRecordDecl *BaseClassDecl) {
-    uint64_t Offset = 0;
+  const CXXBasePath &Path = Paths.front();
+  for (unsigned i = Start, e = Path.size(); i != e; ++i) {
+    const CXXBasePathElement& Element = Path[i];
 
-    llvm::SmallVector<const CXXRecordDecl *, 16> NestedBasePaths;
-    GetNestedPaths(NestedBasePaths, ClassDecl, BaseClassDecl);
-    assert(NestedBasePaths.size() > 0 &&
-           "AddressCXXOfBaseClass - inheritence path failed");
-    NestedBasePaths.push_back(ClassDecl);
+    // Get the layout.
+    const ASTRecordLayout &Layout = Context.getASTRecordLayout(Element.Class);
     
-    for (unsigned i = NestedBasePaths.size() - 1; i > 0; i--) {
-        const CXXRecordDecl *DerivedClass = NestedBasePaths[i];
-        const CXXRecordDecl *BaseClass = NestedBasePaths[i-1];
-        const ASTRecordLayout &Layout = 
-            Context.getASTRecordLayout(DerivedClass);
-        
-        Offset += Layout.getBaseClassOffset(BaseClass) / 8;
-    }
+    const CXXBaseSpecifier *BS = Element.Base;
+    assert(!BS->isVirtual() && "Should not see virtual bases here!");
     
-    return Offset;
+    const CXXRecordDecl *Base = 
+      cast<CXXRecordDecl>(BS->getType()->getAs<RecordType>()->getDecl());
+    
+    // Add the offset.
+    Offset += Layout.getBaseClassOffset(Base) / 8;
+  }
+
+  return Offset;
 }
 
 llvm::Constant *
@@ -75,12 +49,15 @@
   if (ClassDecl == BaseClassDecl)
     return 0;
 
-  QualType BTy =
-    getContext().getCanonicalType(
-      getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(BaseClassDecl)));
+  CXXBasePaths Paths(/*FindAmbiguities=*/false,
+                     /*RecordPaths=*/true, /*DetectVirtual=*/false);
+  if (!const_cast<CXXRecordDecl *>(ClassDecl)->
+        isDerivedFrom(const_cast<CXXRecordDecl *>(BaseClassDecl), Paths)) {
+    assert(false && "Class must be derived from the passed in base class!");
+    return 0;
+  }
 
-  uint64_t Offset = ComputeBaseClassOffset(getContext(), 
-                                           ClassDecl, BaseClassDecl);
+  uint64_t Offset = ComputeNonVirtualBaseClassOffset(getContext(), Paths, 0);
   if (!Offset)
     return 0;
 
@@ -90,19 +67,63 @@
   return llvm::ConstantInt::get(PtrDiffTy, Offset);
 }
 
+static llvm::Value *GetCXXBaseClassOffset(CodeGenFunction &CGF,
+                                          llvm::Value *BaseValue,
+                                          const CXXRecordDecl *ClassDecl,
+                                          const CXXRecordDecl *BaseClassDecl) {
+  CXXBasePaths Paths(/*FindAmbiguities=*/false,
+                     /*RecordPaths=*/true, /*DetectVirtual=*/true);
+  if (!const_cast<CXXRecordDecl *>(ClassDecl)->
+        isDerivedFrom(const_cast<CXXRecordDecl *>(BaseClassDecl), Paths)) {
+    assert(false && "Class must be derived from the passed in base class!");
+    return 0;
+  }
+
+  unsigned Start = 0;
+  llvm::Value *VirtualOffset = 0;
+  if (const RecordType *RT = Paths.getDetectedVirtual()) {
+    const CXXRecordDecl *VBase = cast<CXXRecordDecl>(RT->getDecl());
+    
+    VirtualOffset = 
+      CGF.GetVirtualCXXBaseClassOffset(BaseValue, ClassDecl, VBase);
+    
+    const CXXBasePath &Path = Paths.front();
+    unsigned e = Path.size();
+    for (Start = 0; Start != e; ++Start) {
+      const CXXBasePathElement& Element = Path[Start];
+      
+      if (Element.Class == VBase)
+        break;
+    }
+  }
+  
+  uint64_t Offset = 
+    ComputeNonVirtualBaseClassOffset(CGF.getContext(), Paths, Start);
+  
+  if (!Offset)
+    return VirtualOffset;
+  
+  const llvm::Type *PtrDiffTy = 
+    CGF.ConvertType(CGF.getContext().getPointerDiffType());
+  llvm::Value *NonVirtualOffset = llvm::ConstantInt::get(PtrDiffTy, Offset);
+  
+  if (VirtualOffset)
+    return CGF.Builder.CreateAdd(VirtualOffset, NonVirtualOffset);
+                    
+  return NonVirtualOffset;
+}
+
 llvm::Value *
 CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue,
                                           const CXXRecordDecl *ClassDecl,
                                           const CXXRecordDecl *BaseClassDecl,
                                           bool NullCheckValue) {
-  llvm::Constant *Offset = CGM.GetCXXBaseClassOffset(ClassDecl, BaseClassDecl);
-  
   QualType BTy =
     getContext().getCanonicalType(
       getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(BaseClassDecl)));
   const llvm::Type *BasePtrTy = llvm::PointerType::getUnqual(ConvertType(BTy));
 
-  if (!Offset) {
+  if (ClassDecl == BaseClassDecl) {
     // Just cast back.
     return Builder.CreateBitCast(BaseValue, BasePtrTy);
   }
@@ -125,10 +146,15 @@
   
   const llvm::Type *Int8PtrTy = 
     llvm::PointerType::getUnqual(llvm::Type::getInt8Ty(VMContext));
+
+  llvm::Value *Offset = 
+    GetCXXBaseClassOffset(*this, BaseValue, ClassDecl, BaseClassDecl);
   
-  // Apply the offset.
-  BaseValue = Builder.CreateBitCast(BaseValue, Int8PtrTy);
-  BaseValue = Builder.CreateGEP(BaseValue, Offset, "add.ptr");
+  if (Offset) {
+    // Apply the offset.
+    BaseValue = Builder.CreateBitCast(BaseValue, Int8PtrTy);
+    BaseValue = Builder.CreateGEP(BaseValue, Offset, "add.ptr");
+  }
   
   // Cast back.
   BaseValue = Builder.CreateBitCast(BaseValue, BasePtrTy);