Fix vbtable indices when a class shares the vbptr with a non-virtual base

llvm-svn: 194082
diff --git a/clang/lib/AST/VTableBuilder.cpp b/clang/lib/AST/VTableBuilder.cpp
index c2b33e6..69a2fb2 100644
--- a/clang/lib/AST/VTableBuilder.cpp
+++ b/clang/lib/AST/VTableBuilder.cpp
@@ -2397,18 +2397,6 @@
   return CreateVTableLayout(Builder);
 }
 
-unsigned clang::GetVBTableIndex(const CXXRecordDecl *Derived,
-                                const CXXRecordDecl *VBase) {
-  unsigned VBTableIndex = 1; // Start with one to skip the self entry.
-  for (CXXRecordDecl::base_class_const_iterator I = Derived->vbases_begin(),
-       E = Derived->vbases_end(); I != E; ++I) {
-    if (I->getType()->getAsCXXRecordDecl() == VBase)
-      return VBTableIndex;
-    ++VBTableIndex;
-  }
-  llvm_unreachable("VBase must be a vbase of Derived");
-}
-
 namespace {
 
 // Vtables in the Microsoft ABI are different from the Itanium ABI.
@@ -2451,12 +2439,15 @@
 
 class VFTableBuilder {
 public:
-  typedef MicrosoftVFTableContext::MethodVFTableLocation MethodVFTableLocation;
+  typedef MicrosoftVTableContext::MethodVFTableLocation MethodVFTableLocation;
 
   typedef llvm::DenseMap<GlobalDecl, MethodVFTableLocation>
     MethodVFTableLocationsTy;
 
 private:
+  /// VTables - Global vtable information.
+  MicrosoftVTableContext &VTables;
+
   /// Context - The ASTContext which we will use for layout information.
   ASTContext &Context;
 
@@ -2591,8 +2582,10 @@
   }
 
 public:
-  VFTableBuilder(const CXXRecordDecl *MostDerivedClass, VFPtrInfo Which)
-      : Context(MostDerivedClass->getASTContext()),
+  VFTableBuilder(MicrosoftVTableContext &VTables,
+                 const CXXRecordDecl *MostDerivedClass, VFPtrInfo Which)
+      : VTables(VTables),
+        Context(MostDerivedClass->getASTContext()),
         MostDerivedClass(MostDerivedClass),
         MostDerivedClassLayout(Context.getASTRecordLayout(MostDerivedClass)),
         WhichVFPtr(Which),
@@ -2889,7 +2882,7 @@
     // If we got here, MD is a method not seen in any of the sub-bases or
     // it requires return adjustment. Insert the method info for this method.
     unsigned VBIndex =
-        LastVBase ? GetVBTableIndex(MostDerivedClass, LastVBase) : 0;
+        LastVBase ? VTables.getVBTableIndex(MostDerivedClass, LastVBase) : 0;
     MethodInfo MI(VBIndex, Components.size());
 
     assert(!MethodInfoMap.count(MD) &&
@@ -2916,8 +2909,8 @@
         ReturnAdjustment.Virtual.Microsoft.VBPtrOffset =
             DerivedLayout.getVBPtrOffset().getQuantity();
         ReturnAdjustment.Virtual.Microsoft.VBIndex =
-            GetVBTableIndex(ReturnAdjustmentOffset.DerivedClass,
-                            ReturnAdjustmentOffset.VirtualBase);
+            VTables.getVBTableIndex(ReturnAdjustmentOffset.DerivedClass,
+                                    ReturnAdjustmentOffset.VirtualBase);
       }
     }
 
@@ -3087,13 +3080,13 @@
 }
 }
 
-static void EnumerateVFPtrs(
-    ASTContext &Context, const CXXRecordDecl *MostDerivedClass,
-    const ASTRecordLayout &MostDerivedClassLayout,
-    BaseSubobject Base, const CXXRecordDecl *LastVBase,
+void MicrosoftVTableContext::enumerateVFPtrs(
+    const CXXRecordDecl *MostDerivedClass,
+    const ASTRecordLayout &MostDerivedClassLayout, BaseSubobject Base,
+    const CXXRecordDecl *LastVBase,
     const VFPtrInfo::BasePath &PathFromCompleteClass,
     BasesSetVectorTy &VisitedVBases,
-    MicrosoftVFTableContext::VFPtrListTy &Result) {
+    VFPtrListTy &Result) {
   const CXXRecordDecl *CurrentClass = Base.getBase();
   CharUnits OffsetInCompleteClass = Base.getBaseOffset();
   const ASTRecordLayout &CurrentClassLayout =
@@ -3101,7 +3094,7 @@
 
   if (CurrentClassLayout.hasOwnVFPtr()) {
     if (LastVBase) {
-      uint64_t VBIndex = GetVBTableIndex(MostDerivedClass, LastVBase);
+      uint64_t VBIndex = getVBTableIndex(MostDerivedClass, LastVBase);
       assert(VBIndex > 0 && "vbases must have vbindex!");
       CharUnits VFPtrOffset =
           OffsetInCompleteClass -
@@ -3134,7 +3127,7 @@
     NewPath.push_back(BaseDecl);
     BaseSubobject NextBase(BaseDecl, NextBaseOffset);
 
-    EnumerateVFPtrs(Context, MostDerivedClass, MostDerivedClassLayout, NextBase,
+    enumerateVFPtrs(MostDerivedClass, MostDerivedClassLayout, NextBase,
                     NextLastVBase, NewPath, VisitedVBases, Result);
   }
 }
@@ -3188,12 +3181,13 @@
   }
 }
 
-static void EnumerateVFPtrs(ASTContext &Context, const CXXRecordDecl *ForClass,
-                            MicrosoftVFTableContext::VFPtrListTy &Result) {
+void MicrosoftVTableContext::enumerateVFPtrs(
+    const CXXRecordDecl *ForClass,
+    MicrosoftVTableContext::VFPtrListTy &Result) {
   Result.clear();
   const ASTRecordLayout &ClassLayout = Context.getASTRecordLayout(ForClass);
   BasesSetVectorTy VisitedVBases;
-  EnumerateVFPtrs(Context, ForClass, ClassLayout,
+  enumerateVFPtrs(ForClass, ClassLayout,
                   BaseSubobject(ForClass, CharUnits::Zero()), 0,
                   VFPtrInfo::BasePath(), VisitedVBases, Result);
   if (Result.size() > 1) {
@@ -3202,7 +3196,7 @@
   }
 }
 
-void MicrosoftVFTableContext::computeVTableRelatedInformation(
+void MicrosoftVTableContext::computeVTableRelatedInformation(
     const CXXRecordDecl *RD) {
   assert(RD->isDynamicClass());
 
@@ -3213,12 +3207,12 @@
   const VTableLayout::AddressPointsMapTy EmptyAddressPointsMap;
 
   VFPtrListTy &VFPtrs = VFPtrLocations[RD];
-  EnumerateVFPtrs(Context, RD, VFPtrs);
+  enumerateVFPtrs(RD, VFPtrs);
 
   MethodVFTableLocationsTy NewMethodLocations;
   for (VFPtrListTy::iterator I = VFPtrs.begin(), E = VFPtrs.end();
        I != E; ++I) {
-    VFTableBuilder Builder(RD, *I);
+    VFTableBuilder Builder(*this, RD, *I);
 
     VFTableIdTy id(RD, I->VFPtrFullOffset);
     assert(VFTableLayouts.count(id) == 0);
@@ -3238,7 +3232,7 @@
     dumpMethodLocations(RD, NewMethodLocations, llvm::errs());
 }
 
-void MicrosoftVFTableContext::dumpMethodLocations(
+void MicrosoftVTableContext::dumpMethodLocations(
     const CXXRecordDecl *RD, const MethodVFTableLocationsTy &NewMethods,
     raw_ostream &Out) {
   // Compute the vtable indices for all the member functions.
@@ -3297,8 +3291,56 @@
   }
 }
 
-const MicrosoftVFTableContext::VFPtrListTy &
-MicrosoftVFTableContext::getVFPtrOffsets(const CXXRecordDecl *RD) {
+void MicrosoftVTableContext::computeVBTableRelatedInformation(
+    const CXXRecordDecl *RD) {
+  if (ComputedVBTableIndices.count(RD))
+    return;
+  ComputedVBTableIndices.insert(RD);
+
+  const ASTRecordLayout &Layout = Context.getASTRecordLayout(RD);
+  BasesSetVectorTy VisitedBases;
+
+  // First, see if the Derived class shared the vbptr
+  // with the first non-virtual base.
+  for (CXXRecordDecl::base_class_const_iterator I = RD->bases_begin(),
+       E = RD->bases_end(); I != E; ++I) {
+    if (I->isVirtual())
+      continue;
+
+    const CXXRecordDecl *CurBase = I->getType()->getAsCXXRecordDecl();
+    CharUnits DerivedVBPtrOffset = Layout.getVBPtrOffset(),
+              BaseOffset = Layout.getBaseClassOffset(CurBase);
+    const ASTRecordLayout &BaseLayout = Context.getASTRecordLayout(CurBase);
+    if (!BaseLayout.hasVBPtr() ||
+        DerivedVBPtrOffset != BaseOffset + BaseLayout.getVBPtrOffset())
+      continue;
+
+    // If the Derived class shares the vbptr with a non-virtual base,
+    // it inherits its vbase indices.
+    computeVBTableRelatedInformation(CurBase);
+    for (CXXRecordDecl::base_class_const_iterator J = CurBase->vbases_begin(),
+         F = CurBase->vbases_end(); J != F; ++J) {
+      const CXXRecordDecl *SubVBase = J->getType()->getAsCXXRecordDecl();
+      assert(VBTableIndices.count(ClassPairTy(CurBase, SubVBase)));
+      VBTableIndices[ClassPairTy(RD, SubVBase)] =
+          VBTableIndices[ClassPairTy(CurBase, SubVBase)];
+      VisitedBases.insert(SubVBase);
+    }
+  }
+
+  // New vbases are added to the end of the vbtable.
+  // Skip the self entry and vbases visited in the non-virtual base, if any.
+  unsigned VBTableIndex = 1 + VisitedBases.size();
+  for (CXXRecordDecl::base_class_const_iterator I = RD->vbases_begin(),
+       E = RD->vbases_end(); I != E; ++I) {
+    const CXXRecordDecl *CurVBase = I->getType()->getAsCXXRecordDecl();
+    if (VisitedBases.insert(CurVBase))
+      VBTableIndices[ClassPairTy(RD, CurVBase)] = VBTableIndex++;
+  }
+}
+
+const MicrosoftVTableContext::VFPtrListTy &
+MicrosoftVTableContext::getVFPtrOffsets(const CXXRecordDecl *RD) {
   computeVTableRelatedInformation(RD);
 
   assert(VFPtrLocations.count(RD) && "Couldn't find vfptr locations");
@@ -3306,8 +3348,8 @@
 }
 
 const VTableLayout &
-MicrosoftVFTableContext::getVFTableLayout(const CXXRecordDecl *RD,
-                                          CharUnits VFPtrOffset) {
+MicrosoftVTableContext::getVFTableLayout(const CXXRecordDecl *RD,
+                                         CharUnits VFPtrOffset) {
   computeVTableRelatedInformation(RD);
 
   VFTableIdTy id(RD, VFPtrOffset);
@@ -3315,8 +3357,8 @@
   return *VFTableLayouts[id];
 }
 
-const MicrosoftVFTableContext::MethodVFTableLocation &
-MicrosoftVFTableContext::getMethodVFTableLocation(GlobalDecl GD) {
+const MicrosoftVTableContext::MethodVFTableLocation &
+MicrosoftVTableContext::getMethodVFTableLocation(GlobalDecl GD) {
   assert(cast<CXXMethodDecl>(GD.getDecl())->isVirtual() &&
          "Only use this method for virtual methods or dtors");
   if (isa<CXXDestructorDecl>(GD.getDecl()))