Cleanups and fixups for calculating the virtual base offsets.  WIP.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@79156 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/CGCXX.cpp b/lib/CodeGen/CGCXX.cpp
index f2e58af..1c9c563 100644
--- a/lib/CodeGen/CGCXX.cpp
+++ b/lib/CodeGen/CGCXX.cpp
@@ -680,6 +680,7 @@
 }
 
 void CodeGenFunction::GenerateVtableForVBases(const CXXRecordDecl *RD,
+                                              const CXXRecordDecl *Class,
                                               llvm::Constant *rtti,
                                          std::vector<llvm::Constant *> &methods,
                    llvm::SmallSet<const CXXRecordDecl *, 32> &IndirectPrimary) {
@@ -690,19 +691,40 @@
     if (i->isVirtual() && !IndirectPrimary.count(Base)) {
       // Mark it so we don't output it twice.
       IndirectPrimary.insert(Base);
-      GenerateVtableForBase(Base, RD, rtti, methods, false, true,
+      GenerateVtableForBase(Base, true, 0, Class, rtti, methods, true,
                             IndirectPrimary);
     }
     if (Base->getNumVBases())
-      GenerateVtableForVBases(Base, rtti, methods, IndirectPrimary);
+      GenerateVtableForVBases(Base, Class, rtti, methods, IndirectPrimary);
+  }
+}
+
+void CodeGenFunction::GenerateVBaseOffsets(
+  std::vector<llvm::Constant *> &methods, const CXXRecordDecl *RD,
+  llvm::SmallSet<const CXXRecordDecl *, 32> &SeenVBase,
+  uint64_t Offset, const ASTRecordLayout &BLayout, llvm::Type *Ptr8Ty) {
+  for (CXXRecordDecl::base_class_const_iterator i =RD->bases_begin(),
+         e = RD->bases_end(); i != e; ++i) {
+    const CXXRecordDecl *Base = 
+      cast<CXXRecordDecl>(i->getType()->getAs<RecordType>()->getDecl());
+    if (i->isVirtual() && !SeenVBase.count(Base)) {
+      SeenVBase.insert(Base);
+      int64_t BaseOffset = Offset/8 + BLayout.getVBaseClassOffset(Base) / 8;
+      llvm::Constant *m;
+      m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(VMContext), BaseOffset);
+      m = llvm::ConstantExpr::getIntToPtr(m, Ptr8Ty);
+      methods.push_back(m);
+    }
+    GenerateVBaseOffsets(methods, Base, SeenVBase, Offset, BLayout, Ptr8Ty);
   }
 }
 
 void CodeGenFunction::GenerateVtableForBase(const CXXRecordDecl *RD,
+                                            bool forPrimary,
+                                            int64_t Offset,
                                             const CXXRecordDecl *Class,
                                             llvm::Constant *rtti,
                                          std::vector<llvm::Constant *> &methods,
-                                            bool isPrimary,
                                             bool ForVirtualBase,
                    llvm::SmallSet<const CXXRecordDecl *, 32> &IndirectPrimary) {
   llvm::Type *Ptr8Ty;
@@ -712,69 +734,70 @@
   if (RD && !RD->isDynamicClass())
     return;
 
-  const ASTRecordLayout &Layout = getContext().getASTRecordLayout(Class);
+  const ASTRecordLayout &Layout = getContext().getASTRecordLayout(RD);
+  const CXXRecordDecl *PrimaryBase = Layout.getPrimaryBase(); 
+  const bool PrimaryBaseWasVirtual = Layout.getPrimaryBaseWasVirtual();
 
-  if (isPrimary) {
-    // The virtual base offsets come first...
-    // FIXME: audit
-    for (CXXRecordDecl::reverse_base_class_const_iterator i
-           = Class->bases_rbegin(),
-           e = Class->bases_rend(); i != e; ++i) {
-      if (!i->isVirtual())
-        continue;
-      const CXXRecordDecl *Base = 
-        cast<CXXRecordDecl>(i->getType()->getAs<RecordType>()->getDecl());
-      int64_t BaseOffset = Layout.getVBaseClassOffset(Base) / 8;
-      llvm::Constant *m;
-      m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(VMContext), BaseOffset);
-      m = llvm::ConstantExpr::getIntToPtr(m, Ptr8Ty);
-      methods.push_back(m);
-    }
+  // The virtual base offsets come first...
+  // FIXME: Audit, is this right?
+  if (forPrimary || !PrimaryBaseWasVirtual) {
+    llvm::SmallSet<const CXXRecordDecl *, 32> SeenVBase;
+    std::vector<llvm::Constant *> offsets;
+    GenerateVBaseOffsets(offsets, RD, SeenVBase, Offset, Layout, Ptr8Ty);
+    for (std::vector<llvm::Constant *>::reverse_iterator i = offsets.rbegin(),
+           e = offsets.rend(); i != e; ++i)
+      methods.push_back(*i);
   }
   
-  // then comes the the vcall offsets for all our functions...
-  if (isPrimary && ForVirtualBase)
-    GenerateVcalls(methods, Class, Ptr8Ty);
-
-  bool TopPrimary = true;
-  // Primary tables are composed from the chain of primaries.
-  if (isPrimary) {
-    const CXXRecordDecl *PrimaryBase = Layout.getPrimaryBase(); 
-    const bool PrimaryBaseWasVirtual = Layout.getPrimaryBaseWasVirtual();
-    if (PrimaryBase) {
-      if (PrimaryBaseWasVirtual)
-        IndirectPrimary.insert(PrimaryBase);
-      TopPrimary = false;
-      GenerateVtableForBase(0, PrimaryBase, rtti, methods, true,
-                            PrimaryBaseWasVirtual, IndirectPrimary);
-    }
+  if (forPrimary || ForVirtualBase) {
+    // then comes the the vcall offsets for all our functions...
+    GenerateVcalls(methods, RD, Ptr8Ty);
   }
+
+  bool Top = true;
+
+  // vtables are composed from the chain of primaries.
+  if (PrimaryBase) {
+    if (PrimaryBaseWasVirtual)
+      IndirectPrimary.insert(PrimaryBase);
+    Top = false;
+    GenerateVtableForBase(PrimaryBase, true, Offset, Class, rtti, methods,
+                          PrimaryBaseWasVirtual, IndirectPrimary);
+  }
+
   // then come the vcall offsets for all our virtual bases.
-  if (!isPrimary && RD && ForVirtualBase)
+  if (!1 && ForVirtualBase)
     GenerateVcalls(methods, RD, Ptr8Ty);
 
-  if (TopPrimary) {
-    if (RD) {
-      int64_t BaseOffset;
-      if (ForVirtualBase)
-        BaseOffset = -(Layout.getVBaseClassOffset(RD) / 8);
-      else
-        BaseOffset = -(Layout.getBaseClassOffset(RD) / 8);
-      m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(VMContext), BaseOffset);
-      m = llvm::ConstantExpr::getIntToPtr(m, Ptr8Ty);
-    }
+  if (Top) {
+    int64_t BaseOffset;
+    if (ForVirtualBase) {
+      const ASTRecordLayout &BLayout = getContext().getASTRecordLayout(Class);
+      BaseOffset = -(BLayout.getVBaseClassOffset(RD) / 8);
+    } else
+      BaseOffset = -Offset/8;
+    m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(VMContext), BaseOffset);
+    m = llvm::ConstantExpr::getIntToPtr(m, Ptr8Ty);
     methods.push_back(m);
     methods.push_back(rtti);
   }
 
-  if (!isPrimary) {
-    if (RD)
-      GenerateMethods(methods, RD, Ptr8Ty);
-    return;
-  }
-
   // And add the virtuals for the class to the primary vtable.
-  GenerateMethods(methods, Class, Ptr8Ty);
+  GenerateMethods(methods, RD, Ptr8Ty);
+
+  // and then the non-virtual bases.
+  for (CXXRecordDecl::base_class_const_iterator i = RD->bases_begin(),
+         e = RD->bases_end(); i != e; ++i) {
+    if (i->isVirtual())
+      continue;
+    const CXXRecordDecl *Base = 
+      cast<CXXRecordDecl>(i->getType()->getAs<RecordType>()->getDecl());
+    if (Base != PrimaryBase || PrimaryBaseWasVirtual) {
+      uint64_t o = Offset + Layout.getBaseClassOffset(Base);
+      GenerateVtableForBase(Base, true, o, Class, rtti, methods, false,
+                            IndirectPrimary);
+    }
+  }
 }
 
 llvm::Value *CodeGenFunction::GenerateVtable(const CXXRecordDecl *RD) {
@@ -787,36 +810,20 @@
   llvm::GlobalVariable::LinkageTypes linktype;
   linktype = llvm::GlobalValue::WeakAnyLinkage;
   std::vector<llvm::Constant *> methods;
-  llvm::Type *Ptr8Ty = llvm::PointerType::get(llvm::Type::getInt8Ty(VMContext), 0);
+  llvm::Type *Ptr8Ty=llvm::PointerType::get(llvm::Type::getInt8Ty(VMContext),0);
   int64_t Offset = 0;
   llvm::Constant *rtti = GenerateRtti(RD);
 
   Offset += LLVMPointerWidth;
   Offset += LLVMPointerWidth;
 
-  const ASTRecordLayout &Layout = getContext().getASTRecordLayout(RD);
-  const CXXRecordDecl *PrimaryBase = Layout.getPrimaryBase();
-  const bool PrimaryBaseWasVirtual = Layout.getPrimaryBaseWasVirtual();
   llvm::SmallSet<const CXXRecordDecl *, 32> IndirectPrimary;
 
-  // The primary base comes first.
-  GenerateVtableForBase(PrimaryBase, RD, rtti, methods, true,
-                        PrimaryBaseWasVirtual, IndirectPrimary);
+  // First comes the vtables for all the non-virtual bases...
+  GenerateVtableForBase(RD, true, 0, RD, rtti, methods, false, IndirectPrimary);
 
-  // Then come the non-virtual bases.
-  for (CXXRecordDecl::base_class_const_iterator i = RD->bases_begin(),
-         e = RD->bases_end(); i != e; ++i) {
-    if (i->isVirtual())
-      continue;
-    const CXXRecordDecl *Base = 
-      cast<CXXRecordDecl>(i->getType()->getAs<RecordType>()->getDecl());
-    if (Base != PrimaryBase || PrimaryBaseWasVirtual)
-      GenerateVtableForBase(Base, RD, rtti, methods, false, false,
-                            IndirectPrimary);
-  }
-
-  // Then come the vtables for all the virtual bases.
-  GenerateVtableForVBases(RD, rtti, methods, IndirectPrimary);
+  // then the vtables for all the virtual bases.
+  GenerateVtableForVBases(RD, RD, rtti, methods, IndirectPrimary);
 
   llvm::Constant *C;
   llvm::ArrayType *type = llvm::ArrayType::get(Ptr8Ty, methods.size());
@@ -825,7 +832,7 @@
                                                  linktype, C, Name);
   vtable = Builder.CreateBitCast(vtable, Ptr8Ty);
   vtable = Builder.CreateGEP(vtable,
-                             llvm::ConstantInt::get(llvm::Type::getInt64Ty(VMContext),
+                       llvm::ConstantInt::get(llvm::Type::getInt64Ty(VMContext),
                                                     Offset/8));
   return vtable;
 }
diff --git a/lib/CodeGen/CodeGenFunction.h b/lib/CodeGen/CodeGenFunction.h
index b4469c5..3110ac0 100644
--- a/lib/CodeGen/CodeGenFunction.h
+++ b/lib/CodeGen/CodeGenFunction.h
@@ -362,19 +362,26 @@
   void FinishFunction(SourceLocation EndLoc=SourceLocation());
 
   llvm::Constant *GenerateRtti(const CXXRecordDecl *RD);
+  void GenerateVBaseOffsets(std::vector<llvm::Constant *> &methods,
+                            const CXXRecordDecl *RD, 
+                           llvm::SmallSet<const CXXRecordDecl *, 32> &SeenVBase,
+                            uint64_t Offset,
+                            const ASTRecordLayout &Layout, llvm::Type *Ptr8Ty);
   void GenerateVcalls(std::vector<llvm::Constant *> &methods,
                       const CXXRecordDecl *RD, llvm::Type *Ptr8Ty);
   void GenerateMethods(std::vector<llvm::Constant *> &methods,
                        const CXXRecordDecl *RD, llvm::Type *Ptr8Ty);
-void GenerateVtableForVBases(const CXXRecordDecl *RD,
-                             llvm::Constant *rtti,
-                             std::vector<llvm::Constant *> &methods,
+  void GenerateVtableForVBases(const CXXRecordDecl *RD,
+                               const CXXRecordDecl *Class,
+                               llvm::Constant *rtti,
+                               std::vector<llvm::Constant *> &methods,
                     llvm::SmallSet<const CXXRecordDecl *, 32> &IndirectPrimary);
   void GenerateVtableForBase(const CXXRecordDecl *RD,
+                             bool ForPrimary,
+                             int64_t Offset,
                              const CXXRecordDecl *Class,
                              llvm::Constant *rtti,
                              std::vector<llvm::Constant *> &methods,
-                             bool isPrimary,
                              bool ForVirtualBase,
                     llvm::SmallSet<const CXXRecordDecl *, 32> &IndirectPrimary);
   llvm::Value *GenerateVtable(const CXXRecordDecl *RD);