Add a new GetAddressOfBaseClass overload that takes a base path and. Use it for derived-to-base casts.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@102270 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/CGClass.cpp b/lib/CodeGen/CGClass.cpp
index 2b06e5e..991b6a7 100644
--- a/lib/CodeGen/CGClass.cpp
+++ b/lib/CodeGen/CGClass.cpp
@@ -20,6 +20,35 @@
 using namespace CodeGen;
 
 static uint64_t 
+ComputeNonVirtualBaseClassOffset(ASTContext &Context, 
+                                 const CXXRecordDecl *DerivedClass,
+                                 CXXBaseSpecifierArray::iterator Start,
+                                 CXXBaseSpecifierArray::iterator End) {
+  uint64_t Offset = 0;
+  
+  const CXXRecordDecl *RD = DerivedClass;
+  
+  for (CXXBaseSpecifierArray::iterator I = Start; I != End; ++I) {
+    const CXXBaseSpecifier *Base = *I;
+    assert(!Base->isVirtual() && "Should not see virtual bases here!");
+
+    // Get the layout.
+    const ASTRecordLayout &Layout = Context.getASTRecordLayout(RD);
+    
+    const CXXRecordDecl *BaseDecl = 
+      cast<CXXRecordDecl>(Base->getType()->getAs<RecordType>()->getDecl());
+    
+    // Add the offset.
+    Offset += Layout.getBaseClassOffset(BaseDecl);
+    
+    RD = BaseDecl;
+  }
+  
+  // FIXME: We should not use / 8 here.
+  return Offset / 8;
+}
+                                 
+static uint64_t 
 ComputeNonVirtualBaseClassOffset(ASTContext &Context,
                                  const CXXBasePath &Path,
                                  unsigned Start) {
@@ -133,6 +162,81 @@
 }
 
 llvm::Value *
+CodeGenFunction::GetAddressOfBaseClass(llvm::Value *Value, 
+                                       const CXXRecordDecl *ClassDecl,
+                                       const CXXBaseSpecifierArray &BasePath, 
+                                       bool NullCheckValue) {
+  assert(!BasePath.empty() && "Base path should not be empty!");
+
+  CXXBaseSpecifierArray::iterator Start = BasePath.begin();
+  const CXXRecordDecl *VBase = 0;
+  
+  // Get the virtual base.
+  if ((*Start)->isVirtual()) {
+    VBase = 
+      cast<CXXRecordDecl>((*Start)->getType()->getAs<RecordType>()->getDecl());
+    ++Start;
+  }
+  
+  uint64_t NonVirtualOffset = 
+    ComputeNonVirtualBaseClassOffset(getContext(), VBase ? VBase : ClassDecl,
+                                     Start, BasePath.end());
+
+  // Get the base pointer type.
+  const llvm::Type *BasePtrTy = 
+    llvm::PointerType::getUnqual(ConvertType((BasePath.end()[-1])->getType()));
+  
+  if (!NonVirtualOffset && !VBase) {
+    // Just cast back.
+    return Builder.CreateBitCast(Value, BasePtrTy);
+  }    
+  
+  llvm::BasicBlock *CastNull = 0;
+  llvm::BasicBlock *CastNotNull = 0;
+  llvm::BasicBlock *CastEnd = 0;
+  
+  if (NullCheckValue) {
+    CastNull = createBasicBlock("cast.null");
+    CastNotNull = createBasicBlock("cast.notnull");
+    CastEnd = createBasicBlock("cast.end");
+    
+    llvm::Value *IsNull = 
+      Builder.CreateICmpEQ(Value,
+                           llvm::Constant::getNullValue(Value->getType()));
+    Builder.CreateCondBr(IsNull, CastNull, CastNotNull);
+    EmitBlock(CastNotNull);
+  }
+
+  llvm::Value *VirtualOffset = 0;
+
+  if (VBase)
+    VirtualOffset = GetVirtualBaseClassOffset(Value, ClassDecl, VBase);
+
+  // Apply the offsets.
+  Value = ApplyNonVirtualAndVirtualOffset(*this, Value, NonVirtualOffset, 
+                                          VirtualOffset);
+  
+  // Cast back.
+  Value = Builder.CreateBitCast(Value, BasePtrTy);
+ 
+  if (NullCheckValue) {
+    Builder.CreateBr(CastEnd);
+    EmitBlock(CastNull);
+    Builder.CreateBr(CastEnd);
+    EmitBlock(CastEnd);
+    
+    llvm::PHINode *PHI = Builder.CreatePHI(Value->getType());
+    PHI->reserveOperandSpace(2);
+    PHI->addIncoming(Value, CastNotNull);
+    PHI->addIncoming(llvm::Constant::getNullValue(Value->getType()), 
+                     CastNull);
+    Value = PHI;
+  }
+  
+  return Value;
+}
+
+llvm::Value *
 CodeGenFunction::GetAddressOfBaseClass(llvm::Value *Value,
                                        const CXXRecordDecl *Class,
                                        const CXXRecordDecl *BaseClass,