When we can't prove that the target of an aggregate copy is
a complete object, the memcpy needs to use the data size of
the structure instead of its sizeof() value.  Fixes PR12204.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@153613 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/CGExprAgg.cpp b/lib/CodeGen/CGExprAgg.cpp
index b6efc1c..8f0d1db 100644
--- a/lib/CodeGen/CGExprAgg.cpp
+++ b/lib/CodeGen/CGExprAgg.cpp
@@ -179,7 +179,8 @@
 
   void VisitVAArgExpr(VAArgExpr *E);
 
-  void EmitInitializationToLValue(Expr *E, LValue Address);
+  void EmitInitializationToLValue(Expr *E, LValue Address,
+                          AggValueSlot::IsCompleteObject_t isCompleteObject);
   void EmitNullInitializationToLValue(LValue Address);
   //  case Expr::ChooseExprClass:
   void VisitCXXThrowExpr(const CXXThrowExpr *E) { CGF.EmitCXXThrowExpr(E); }
@@ -279,7 +280,7 @@
   // is volatile, unless copy has volatile for both source and destination..
   CGF.EmitAggregateCopy(Dest.getAddr(), Src.getAggregateAddr(), E->getType(),
                         Dest.isVolatile()|Src.isVolatileQualified(),
-                        Alignment);
+                        Alignment, Dest.isCompleteObject());
 }
 
 /// EmitFinalDestCopy - Perform the final copy to DestPtr, if desired.
@@ -441,7 +442,8 @@
       EmitStdInitializerList(element, initList);
     } else {
       LValue elementLV = CGF.MakeAddrLValue(element, elementType);
-      EmitInitializationToLValue(E->getInit(i), elementLV);
+      EmitInitializationToLValue(E->getInit(i), elementLV,
+                                 AggValueSlot::IsCompleteObject);
     }
   }
 
@@ -488,7 +490,8 @@
     // Emit the actual filler expression.
     LValue elementLV = CGF.MakeAddrLValue(currentElement, elementType);
     if (filler)
-      EmitInitializationToLValue(filler, elementLV);
+      EmitInitializationToLValue(filler, elementLV,
+                                 AggValueSlot::IsCompleteObject);
     else
       EmitNullInitializationToLValue(elementLV);
 
@@ -567,7 +570,8 @@
     llvm::Value *CastPtr = Builder.CreateBitCast(Dest.getAddr(),
                                                  CGF.ConvertType(PtrTy));
     EmitInitializationToLValue(E->getSubExpr(),
-                               CGF.MakeAddrLValue(CastPtr, Ty));
+                               CGF.MakeAddrLValue(CastPtr, Ty),
+                               Dest.isCompleteObject());
     break;
   }
 
@@ -675,6 +679,29 @@
   EmitFinalDestCopy(E, LV);
 }
 
+/// Quickly check whether the object looks like it might be a complete
+/// object.
+static AggValueSlot::IsCompleteObject_t isCompleteObject(const Expr *E) {
+  E = E->IgnoreParens();
+
+  QualType objectType;
+  if (const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E)) {
+    objectType = DRE->getDecl()->getType();
+  } else if (const MemberExpr *ME = dyn_cast<MemberExpr>(E)) {
+    objectType = ME->getMemberDecl()->getType();
+  } else {
+    // Be conservative.
+    return AggValueSlot::MayNotBeCompleteObject;
+  }
+
+  // The expression refers directly to some sort of object.
+  // If that object has reference type, be conservative.
+  if (objectType->isReferenceType())
+    return AggValueSlot::MayNotBeCompleteObject;
+
+  return AggValueSlot::IsCompleteObject;
+}
+
 void AggExprEmitter::VisitBinAssign(const BinaryOperator *E) {
   // For an assignment to work, the value on the right has
   // to be compatible with the value on the left.
@@ -682,7 +709,8 @@
                                                  E->getRHS()->getType())
          && "Invalid assignment");
 
-  if (const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->getLHS()))
+  if (const DeclRefExpr *DRE
+        = dyn_cast<DeclRefExpr>(E->getLHS()->IgnoreParens()))
     if (const VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl()))
       if (VD->hasAttr<BlocksAttr>() &&
           E->getRHS()->HasSideEffects(CGF.getContext())) {
@@ -692,18 +720,20 @@
         LValue LHS = CGF.EmitLValue(E->getLHS());
         Dest = AggValueSlot::forLValue(LHS, AggValueSlot::IsDestructed,
                                        needsGC(E->getLHS()->getType()),
-                                       AggValueSlot::IsAliased);
+                                       AggValueSlot::IsAliased,
+                                       AggValueSlot::IsCompleteObject);
         EmitFinalDestCopy(E, RHS, true);
         return;
       }
-  
+
   LValue LHS = CGF.EmitLValue(E->getLHS());
 
   // Codegen the RHS so that it stores directly into the LHS.
   AggValueSlot LHSSlot =
     AggValueSlot::forLValue(LHS, AggValueSlot::IsDestructed, 
                             needsGC(E->getLHS()->getType()),
-                            AggValueSlot::IsAliased);
+                            AggValueSlot::IsAliased,
+                            isCompleteObject(E->getLHS()));
   CGF.EmitAggExpr(E->getRHS(), LHSSlot, false);
   EmitFinalDestCopy(E, LHS, true);
 }
@@ -836,7 +866,8 @@
 
 
 void 
-AggExprEmitter::EmitInitializationToLValue(Expr* E, LValue LV) {
+AggExprEmitter::EmitInitializationToLValue(Expr* E, LValue LV,
+                                 AggValueSlot::IsCompleteObject_t isCompleteObject) {
   QualType type = LV.getType();
   // FIXME: Ignore result?
   // FIXME: Are initializers affected by volatile?
@@ -854,6 +885,7 @@
                                                AggValueSlot::IsDestructed,
                                       AggValueSlot::DoesNotNeedGCBarriers,
                                                AggValueSlot::IsNotAliased,
+                                               isCompleteObject,
                                                Dest.isZeroed()));
   } else if (LV.isSimple()) {
     CGF.EmitScalarInit(E, /*D=*/0, LV, /*Captured=*/false);
@@ -969,7 +1001,8 @@
     LValue FieldLoc = CGF.EmitLValueForFieldInitialization(DestPtr, Field, 0);
     if (NumInitElements) {
       // Store the initializer into the field
-      EmitInitializationToLValue(E->getInit(0), FieldLoc);
+      EmitInitializationToLValue(E->getInit(0), FieldLoc,
+                                 AggValueSlot::IsCompleteObject);
     } else {
       // Default-initialize to null.
       EmitNullInitializationToLValue(FieldLoc);
@@ -1011,7 +1044,8 @@
     
     if (curInitIndex < NumInitElements) {
       // Store the initializer into the field.
-      EmitInitializationToLValue(E->getInit(curInitIndex++), LV);
+      EmitInitializationToLValue(E->getInit(curInitIndex++), LV,
+                                 AggValueSlot::IsCompleteObject);
     } else {
       // We're out of initalizers; default-initialize to null
       EmitNullInitializationToLValue(LV);
@@ -1186,30 +1220,94 @@
   LValue LV = MakeAddrLValue(Temp, E->getType());
   EmitAggExpr(E, AggValueSlot::forLValue(LV, AggValueSlot::IsNotDestructed,
                                          AggValueSlot::DoesNotNeedGCBarriers,
-                                         AggValueSlot::IsNotAliased));
+                                         AggValueSlot::IsNotAliased,
+                                         AggValueSlot::IsCompleteObject));
   return LV;
 }
 
-void CodeGenFunction::EmitAggregateCopy(llvm::Value *DestPtr,
-                                        llvm::Value *SrcPtr, QualType Ty,
-                                        bool isVolatile, unsigned Alignment) {
-  assert(!Ty->isAnyComplexType() && "Shouldn't happen for complex");
+void CodeGenFunction::EmitAggregateCopy(llvm::Value *dest, llvm::Value *src,
+                                        QualType type,
+                                        bool isVolatile, unsigned alignment,
+                                        bool destIsCompleteObject) {
+  assert(!type->isAnyComplexType() && "Shouldn't happen for complex");
 
+  // Get size and alignment info for this type.  Note that the type
+  // might include an alignment attribute, so we can't just rely on
+  // the layout.
+  // FIXME: Do we need to handle VLAs here?
+  std::pair<CharUnits, CharUnits> typeInfo =
+    getContext().getTypeInfoInChars(type);
+
+  // If we weren't given an alignment, use the natural alignment.
+  if (!alignment) alignment = typeInfo.second.getQuantity();
+
+  CharUnits sizeToCopy = typeInfo.first;
+
+  // There's some special logic that applies to C++ classes:
   if (getContext().getLangOpts().CPlusPlus) {
-    if (const RecordType *RT = Ty->getAs<RecordType>()) {
-      CXXRecordDecl *Record = cast<CXXRecordDecl>(RT->getDecl());
-      assert((Record->hasTrivialCopyConstructor() || 
-              Record->hasTrivialCopyAssignment() ||
-              Record->hasTrivialMoveConstructor() ||
-              Record->hasTrivialMoveAssignment()) &&
+    if (const RecordType *RT = type->getAs<RecordType>()) {
+      // First, we want to assert that we're not doing this to
+      // something with a non-trivial operator/constructor.
+      CXXRecordDecl *record = cast<CXXRecordDecl>(RT->getDecl());
+      assert((record->hasTrivialCopyConstructor() || 
+              record->hasTrivialCopyAssignment() ||
+              record->hasTrivialMoveConstructor() ||
+              record->hasTrivialMoveAssignment()) &&
              "Trying to aggregate-copy a type without a trivial copy "
              "constructor or assignment operator");
-      // Ignore empty classes in C++.
-      if (Record->isEmpty())
+
+      // Second, we want to ignore empty classes.
+      if (record->isEmpty())
         return;
+
+      // Third, if it's possible that the destination might not be a
+      // complete object, then we need to make sure we only copy the
+      // data size, not the full sizeof, so that we don't overwrite
+      // subclass fields in the tailing padding.  It's generally going
+      // to be more efficient to copy the sizeof, since we can use
+      // larger stores.
+      //
+      // Unions and final classes can never be base classes.
+      if (!destIsCompleteObject && !record->isUnion() &&
+          !record->hasAttr<FinalAttr>()) {
+        const ASTRecordLayout &layout
+          = getContext().getASTRecordLayout(record);
+        sizeToCopy = layout.getDataSize();
+      }
     }
   }
   
+  llvm::PointerType *DPT = cast<llvm::PointerType>(dest->getType());
+  llvm::Type *DBP =
+    llvm::Type::getInt8PtrTy(getLLVMContext(), DPT->getAddressSpace());
+  dest = Builder.CreateBitCast(dest, DBP);
+
+  llvm::PointerType *SPT = cast<llvm::PointerType>(src->getType());
+  llvm::Type *SBP =
+    llvm::Type::getInt8PtrTy(getLLVMContext(), SPT->getAddressSpace());
+  src = Builder.CreateBitCast(src, SBP);
+
+  llvm::Value *sizeVal =
+    llvm::ConstantInt::get(CGM.SizeTy, sizeToCopy.getQuantity());
+
+  // Don't do any of the memmove_collectable tests if GC isn't set.
+  if (CGM.getLangOpts().getGC() == LangOptions::NonGC) {
+    // fall through
+  } else if (const RecordType *RT = type->getAs<RecordType>()) {
+    if (RT->getDecl()->hasObjectMember()) {
+      CGM.getObjCRuntime().EmitGCMemmoveCollectable(*this, dest, src, sizeVal);
+      return;
+    }
+  } else if (type->isArrayType()) {
+    QualType baseType = getContext().getBaseElementType(type);
+    if (const RecordType *RT = baseType->getAs<RecordType>()) {
+      if (RT->getDecl()->hasObjectMember()) {
+        CGM.getObjCRuntime().EmitGCMemmoveCollectable(*this, dest, src,sizeVal);
+        return;
+      }
+    }
+  }
+
   // Aggregate assignment turns into llvm.memcpy.  This is almost valid per
   // C99 6.5.16.1p3, which states "If the value being stored in an object is
   // read from another object that overlaps in anyway the storage of the first
@@ -1220,71 +1318,8 @@
   // equal, but other compilers do this optimization, and almost every memcpy
   // implementation handles this case safely.  If there is a libc that does not
   // safely handle this, we can add a target hook.
-
-  // Get size and alignment info for this aggregate.
-  std::pair<CharUnits, CharUnits> TypeInfo = 
-    getContext().getTypeInfoInChars(Ty);
-
-  if (!Alignment)
-    Alignment = TypeInfo.second.getQuantity();
-
-  // FIXME: Handle variable sized types.
-
-  // FIXME: If we have a volatile struct, the optimizer can remove what might
-  // appear to be `extra' memory ops:
-  //
-  // volatile struct { int i; } a, b;
-  //
-  // int main() {
-  //   a = b;
-  //   a = b;
-  // }
-  //
-  // we need to use a different call here.  We use isVolatile to indicate when
-  // either the source or the destination is volatile.
-
-  llvm::PointerType *DPT = cast<llvm::PointerType>(DestPtr->getType());
-  llvm::Type *DBP =
-    llvm::Type::getInt8PtrTy(getLLVMContext(), DPT->getAddressSpace());
-  DestPtr = Builder.CreateBitCast(DestPtr, DBP);
-
-  llvm::PointerType *SPT = cast<llvm::PointerType>(SrcPtr->getType());
-  llvm::Type *SBP =
-    llvm::Type::getInt8PtrTy(getLLVMContext(), SPT->getAddressSpace());
-  SrcPtr = Builder.CreateBitCast(SrcPtr, SBP);
-
-  // Don't do any of the memmove_collectable tests if GC isn't set.
-  if (CGM.getLangOpts().getGC() == LangOptions::NonGC) {
-    // fall through
-  } else if (const RecordType *RecordTy = Ty->getAs<RecordType>()) {
-    RecordDecl *Record = RecordTy->getDecl();
-    if (Record->hasObjectMember()) {
-      CharUnits size = TypeInfo.first;
-      llvm::Type *SizeTy = ConvertType(getContext().getSizeType());
-      llvm::Value *SizeVal = llvm::ConstantInt::get(SizeTy, size.getQuantity());
-      CGM.getObjCRuntime().EmitGCMemmoveCollectable(*this, DestPtr, SrcPtr, 
-                                                    SizeVal);
-      return;
-    }
-  } else if (Ty->isArrayType()) {
-    QualType BaseType = getContext().getBaseElementType(Ty);
-    if (const RecordType *RecordTy = BaseType->getAs<RecordType>()) {
-      if (RecordTy->getDecl()->hasObjectMember()) {
-        CharUnits size = TypeInfo.first;
-        llvm::Type *SizeTy = ConvertType(getContext().getSizeType());
-        llvm::Value *SizeVal = 
-          llvm::ConstantInt::get(SizeTy, size.getQuantity());
-        CGM.getObjCRuntime().EmitGCMemmoveCollectable(*this, DestPtr, SrcPtr, 
-                                                      SizeVal);
-        return;
-      }
-    }
-  }
   
-  Builder.CreateMemCpy(DestPtr, SrcPtr,
-                       llvm::ConstantInt::get(IntPtrTy, 
-                                              TypeInfo.first.getQuantity()),
-                       Alignment, isVolatile);
+  Builder.CreateMemCpy(dest, src, sizeVal, alignment, isVolatile);
 }
 
 void CodeGenFunction::MaybeEmitStdInitializerListCleanup(llvm::Value *loc,