Change CodeGenFunction::GetAddressOfDerivedClass to take a BasePath.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@102273 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/CGClass.cpp b/lib/CodeGen/CGClass.cpp
index b8f7df9..6dea9f7 100644
--- a/lib/CodeGen/CGClass.cpp
+++ b/lib/CodeGen/CGClass.cpp
@@ -74,6 +74,23 @@
 }
 
 llvm::Constant *
+CodeGenModule::GetNonVirtualBaseClassOffset(const CXXRecordDecl *ClassDecl,
+                                        const CXXBaseSpecifierArray &BasePath) {
+  assert(!BasePath.empty() && "Base path should not be empty!");
+
+  uint64_t Offset = 
+    ComputeNonVirtualBaseClassOffset(getContext(), ClassDecl, 
+                                     BasePath.begin(), BasePath.end());
+  if (!Offset)
+    return 0;
+  
+  const llvm::Type *PtrDiffTy = 
+  Types.ConvertType(getContext().getPointerDiffType());
+  
+  return llvm::ConstantInt::get(PtrDiffTy, Offset);
+}  
+
+llvm::Constant *
 CodeGenModule::GetNonVirtualBaseClassOffset(const CXXRecordDecl *Class,
                                             const CXXRecordDecl *BaseClass) {
   if (Class == BaseClass)
@@ -336,21 +353,18 @@
 
 llvm::Value *
 CodeGenFunction::GetAddressOfDerivedClass(llvm::Value *Value,
-                                          const CXXRecordDecl *Class,
                                           const CXXRecordDecl *DerivedClass,
+                                          const CXXBaseSpecifierArray &BasePath,
                                           bool NullCheckValue) {
+  assert(!BasePath.empty() && "Base path should not be empty!");
+
   QualType DerivedTy =
     getContext().getCanonicalType(
     getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(DerivedClass)));
   const llvm::Type *DerivedPtrTy = ConvertType(DerivedTy)->getPointerTo();
   
-  if (Class == DerivedClass) {
-    // Just cast back.
-    return Builder.CreateBitCast(Value, DerivedPtrTy);
-  }
-
   llvm::Value *NonVirtualOffset =
-    CGM.GetNonVirtualBaseClassOffset(DerivedClass, Class);
+    CGM.GetNonVirtualBaseClassOffset(DerivedClass, BasePath);
   
   if (!NonVirtualOffset) {
     // No offset, we can just cast back.
diff --git a/lib/CodeGen/CGExpr.cpp b/lib/CodeGen/CGExpr.cpp
index 62b5a3d..5c9374d 100644
--- a/lib/CodeGen/CGExpr.cpp
+++ b/lib/CodeGen/CGExpr.cpp
@@ -1693,11 +1693,6 @@
   case CastExpr::CK_ToUnion:
     return EmitAggExprToLValue(E);
   case CastExpr::CK_BaseToDerived: {
-    const RecordType *BaseClassTy = 
-      E->getSubExpr()->getType()->getAs<RecordType>();
-    CXXRecordDecl *BaseClassDecl = 
-      cast<CXXRecordDecl>(BaseClassTy->getDecl());
-    
     const RecordType *DerivedClassTy = E->getType()->getAs<RecordType>();
     CXXRecordDecl *DerivedClassDecl = 
       cast<CXXRecordDecl>(DerivedClassTy->getDecl());
@@ -1706,8 +1701,8 @@
     
     // Perform the base-to-derived conversion
     llvm::Value *Derived = 
-      GetAddressOfDerivedClass(LV.getAddress(), BaseClassDecl, 
-                               DerivedClassDecl, /*NullCheckValue=*/false);
+      GetAddressOfDerivedClass(LV.getAddress(), DerivedClassDecl, 
+                               E->getBasePath(),/*NullCheckValue=*/false);
     
     return LValue::MakeAddr(Derived, MakeQualifiers(E->getType()));
   }
diff --git a/lib/CodeGen/CGExprScalar.cpp b/lib/CodeGen/CGExprScalar.cpp
index f38126b..ad072c6 100644
--- a/lib/CodeGen/CGExprScalar.cpp
+++ b/lib/CodeGen/CGExprScalar.cpp
@@ -822,16 +822,12 @@
     return Visit(const_cast<Expr*>(E));
 
   case CastExpr::CK_BaseToDerived: {
-    const CXXRecordDecl *BaseClassDecl = 
-      E->getType()->getCXXRecordDeclForPointerType();
     const CXXRecordDecl *DerivedClassDecl = 
       DestTy->getCXXRecordDeclForPointerType();
     
-    Value *Src = Visit(const_cast<Expr*>(E));
-    
-    bool NullCheckValue = ShouldNullCheckClassCastValue(CE);
-    return CGF.GetAddressOfDerivedClass(Src, BaseClassDecl, DerivedClassDecl, 
-                                        NullCheckValue);
+    return CGF.GetAddressOfDerivedClass(Visit(E), DerivedClassDecl, 
+                                        CE->getBasePath(), 
+                                        ShouldNullCheckClassCastValue(CE));
   }
   case CastExpr::CK_UncheckedDerivedToBase:
   case CastExpr::CK_DerivedToBase: {
diff --git a/lib/CodeGen/CodeGenFunction.h b/lib/CodeGen/CodeGenFunction.h
index 0707ada..c0786b8 100644
--- a/lib/CodeGen/CodeGenFunction.h
+++ b/lib/CodeGen/CodeGenFunction.h
@@ -793,8 +793,8 @@
                                      bool NullCheckValue);
 
   llvm::Value *GetAddressOfDerivedClass(llvm::Value *Value,
-                                        const CXXRecordDecl *ClassDecl,
                                         const CXXRecordDecl *DerivedClassDecl,
+                                        const CXXBaseSpecifierArray &BasePath,
                                         bool NullCheckValue);
 
   llvm::Value *GetVirtualBaseClassOffset(llvm::Value *This,
diff --git a/lib/CodeGen/CodeGenModule.h b/lib/CodeGen/CodeGenModule.h
index 238b36d..c324d61 100644
--- a/lib/CodeGen/CodeGenModule.h
+++ b/lib/CodeGen/CodeGenModule.h
@@ -245,7 +245,10 @@
   llvm::Constant *
   GetNonVirtualBaseClassOffset(const CXXRecordDecl *ClassDecl,
                                const CXXRecordDecl *BaseClassDecl);
-
+  llvm::Constant *
+  GetNonVirtualBaseClassOffset(const CXXRecordDecl *ClassDecl,
+                               const CXXBaseSpecifierArray &BasePath);
+  
   /// GetStringForStringLiteral - Return the appropriate bytes for a string
   /// literal, properly padded to match the literal type. If only the address of
   /// a constant is needed consider using GetAddrOfConstantStringLiteral.