Handle base-to-derived casts. Will land test case shortly.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@89678 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/CGCXXClass.cpp b/lib/CodeGen/CGCXXClass.cpp
index 533aabc..e122b95 100644
--- a/lib/CodeGen/CGCXXClass.cpp
+++ b/lib/CodeGen/CGCXXClass.cpp
@@ -117,10 +117,10 @@
 }
 
 llvm::Value *
-CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue,
-                                          const CXXRecordDecl *ClassDecl,
-                                          const CXXRecordDecl *BaseClassDecl,
-                                          bool NullCheckValue) {
+CodeGenFunction::GetAddressOfBaseClass(llvm::Value *Value,
+                                       const CXXRecordDecl *ClassDecl,
+                                       const CXXRecordDecl *BaseClassDecl,
+                                       bool NullCheckValue) {
   QualType BTy =
     getContext().getCanonicalType(
       getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(BaseClassDecl)));
@@ -128,7 +128,7 @@
 
   if (ClassDecl == BaseClassDecl) {
     // Just cast back.
-    return Builder.CreateBitCast(BaseValue, BasePtrTy);
+    return Builder.CreateBitCast(Value, BasePtrTy);
   }
   
   llvm::BasicBlock *CastNull = 0;
@@ -141,8 +141,8 @@
     CastEnd = createBasicBlock("cast.end");
     
     llvm::Value *IsNull = 
-      Builder.CreateICmpEQ(BaseValue,
-                           llvm::Constant::getNullValue(BaseValue->getType()));
+      Builder.CreateICmpEQ(Value,
+                           llvm::Constant::getNullValue(Value->getType()));
     Builder.CreateCondBr(IsNull, CastNull, CastNotNull);
     EmitBlock(CastNotNull);
   }
@@ -150,16 +150,16 @@
   const llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(VMContext);
 
   llvm::Value *Offset = 
-    GetCXXBaseClassOffset(*this, BaseValue, ClassDecl, BaseClassDecl);
+    GetCXXBaseClassOffset(*this, Value, ClassDecl, BaseClassDecl);
   
   if (Offset) {
     // Apply the offset.
-    BaseValue = Builder.CreateBitCast(BaseValue, Int8PtrTy);
-    BaseValue = Builder.CreateGEP(BaseValue, Offset, "add.ptr");
+    Value = Builder.CreateBitCast(Value, Int8PtrTy);
+    Value = Builder.CreateGEP(Value, Offset, "add.ptr");
   }
   
   // Cast back.
-  BaseValue = Builder.CreateBitCast(BaseValue, BasePtrTy);
+  Value = Builder.CreateBitCast(Value, BasePtrTy);
  
   if (NullCheckValue) {
     Builder.CreateBr(CastEnd);
@@ -167,13 +167,73 @@
     Builder.CreateBr(CastEnd);
     EmitBlock(CastEnd);
     
-    llvm::PHINode *PHI = Builder.CreatePHI(BaseValue->getType());
+    llvm::PHINode *PHI = Builder.CreatePHI(Value->getType());
     PHI->reserveOperandSpace(2);
-    PHI->addIncoming(BaseValue, CastNotNull);
-    PHI->addIncoming(llvm::Constant::getNullValue(BaseValue->getType()), 
+    PHI->addIncoming(Value, CastNotNull);
+    PHI->addIncoming(llvm::Constant::getNullValue(Value->getType()), 
                      CastNull);
-    BaseValue = PHI;
+    Value = PHI;
   }
   
-  return BaseValue;
+  return Value;
+}
+
+llvm::Value *
+CodeGenFunction::GetAddressOfDerivedClass(llvm::Value *Value,
+                                          const CXXRecordDecl *ClassDecl,
+                                          const CXXRecordDecl *DerivedClassDecl,
+                                          bool NullCheckValue) {
+  QualType DerivedTy =
+    getContext().getCanonicalType(
+    getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(DerivedClassDecl)));
+  const llvm::Type *DerivedPtrTy = ConvertType(DerivedTy)->getPointerTo();
+  
+  if (ClassDecl == DerivedClassDecl) {
+    // Just cast back.
+    return Builder.CreateBitCast(Value, DerivedPtrTy);
+  }
+
+  llvm::BasicBlock *CastNull = 0;
+  llvm::BasicBlock *CastNotNull = 0;
+  llvm::BasicBlock *CastEnd = 0;
+  
+  if (NullCheckValue) {
+    CastNull = createBasicBlock("cast.null");
+    CastNotNull = createBasicBlock("cast.notnull");
+    CastEnd = createBasicBlock("cast.end");
+    
+    llvm::Value *IsNull = 
+    Builder.CreateICmpEQ(Value,
+                         llvm::Constant::getNullValue(Value->getType()));
+    Builder.CreateCondBr(IsNull, CastNull, CastNotNull);
+    EmitBlock(CastNotNull);
+  }
+  
+  llvm::Value *Offset = GetCXXBaseClassOffset(*this, Value, DerivedClassDecl,
+                                              ClassDecl);
+  if (Offset) {
+    // Apply the offset.
+    Value = Builder.CreatePtrToInt(Value, Offset->getType());
+    Value = Builder.CreateSub(Value, Offset);
+    Value = Builder.CreateIntToPtr(Value, DerivedPtrTy);
+  } else {
+    // Just cast.
+    Value = Builder.CreateBitCast(Value, DerivedPtrTy);
+  }
+
+  if (NullCheckValue) {
+    Builder.CreateBr(CastEnd);
+    EmitBlock(CastNull);
+    Builder.CreateBr(CastEnd);
+    EmitBlock(CastEnd);
+    
+    llvm::PHINode *PHI = Builder.CreatePHI(Value->getType());
+    PHI->reserveOperandSpace(2);
+    PHI->addIncoming(Value, CastNotNull);
+    PHI->addIncoming(llvm::Constant::getNullValue(Value->getType()), 
+                     CastNull);
+    Value = PHI;
+  }
+  
+  return Value;
 }