Replace struct assignments with rsSetObject calls.

Bug: 3092382
Change-Id: I63f16a7dac02eb348b87a6225944d48faa615899
diff --git a/slang_rs_object_ref_count.cpp b/slang_rs_object_ref_count.cpp
index 5fd8718..dcbcf2d 100644
--- a/slang_rs_object_ref_count.cpp
+++ b/slang_rs_object_ref_count.cpp
@@ -475,12 +475,12 @@
   return CS;
 }
 
-static unsigned CountRSObjectTypesInStruct(const clang::Type *T) {
+static unsigned CountRSObjectTypes(const clang::Type *T) {
   slangAssert(T);
   unsigned RSObjectCount = 0;
 
   if (T->isArrayType()) {
-    return CountRSObjectTypesInStruct(T->getArrayElementTypeNoTypeQual());
+    return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
   }
 
   RSExportPrimitiveType::DataType DT =
@@ -501,7 +501,7 @@
        FI++) {
     const clang::FieldDecl *FD = *FI;
     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
-    if (CountRSObjectTypesInStruct(FT)) {
+    if (CountRSObjectTypes(FT)) {
       // Sub-structs should only count once (as should arrays, etc.)
       RSObjectCount++;
     }
@@ -526,10 +526,13 @@
   // Structs should show up as unknown primitive types
   slangAssert(DT == RSExportPrimitiveType::DataTypeUnknown);
 
-  unsigned FieldsToDestroy = CountRSObjectTypesInStruct(BaseType);
+  unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
 
   unsigned StmtCount = 0;
   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
+  for (unsigned i = 0; i < FieldsToDestroy; i++) {
+    StmtArray[i] = NULL;
+  }
 
   // Populate StmtArray by creating a destructor for each RS object field
   clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
@@ -577,7 +580,7 @@
                                                      RSObjectMember,
                                                      Loc);
       }
-    } else if (FT->isStructureType() && CountRSObjectTypesInStruct(FT)) {
+    } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
       // In this case, we have a nested struct. We may not end up filling all
       // of the spaces in StmtArray (sub-structs should handle themselves
       // with separate compound statements).
@@ -620,25 +623,21 @@
   return CS;
 }
 
-}  // namespace
-
-void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
-    clang::BinaryOperator *AS) {
-
-  clang::QualType QT = AS->getType();
-
-  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(
-      QT.getTypePtr());
+static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
+                                            clang::Diagnostic *Diags,
+                                            clang::Expr *DstExpr,
+                                            clang::Expr *SrcExpr,
+                                            clang::SourceLocation Loc) {
+  const clang::Type *T = DstExpr->getType().getTypePtr();
+  clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
   slangAssert((SetObjectFD != NULL) &&
               "rsSetObject doesn't cover all RS object types");
-  clang::ASTContext &C = SetObjectFD->getASTContext();
 
   clang::QualType SetObjectFDType = SetObjectFD->getType();
   clang::QualType SetObjectFDArgType[2];
   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
 
-  clang::SourceLocation Loc = SetObjectFD->getLocation();
   clang::Expr *RefRSSetObjectFD =
       clang::DeclRefExpr::Create(C,
                                  NULL,
@@ -656,11 +655,11 @@
                                       clang::VK_RValue);
 
   clang::Expr *ArgList[2];
-  ArgList[0] = new(C) clang::UnaryOperator(AS->getLHS(),
+  ArgList[0] = new(C) clang::UnaryOperator(DstExpr,
                                            clang::UO_AddrOf,
                                            SetObjectFDArgType[0],
                                            Loc);
-  ArgList[1] = AS->getRHS();
+  ArgList[1] = SrcExpr;
 
   clang::CallExpr *RSSetObjectCall =
       new(C) clang::CallExpr(C,
@@ -670,8 +669,287 @@
                              SetObjectFD->getCallResultType(),
                              Loc);
 
-  ReplaceInCompoundStmt(C, mCS, AS, RSSetObjectCall);
+  return RSSetObjectCall;
+}
 
+static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
+                                            clang::Diagnostic *Diags,
+                                            clang::Expr *LHS,
+                                            clang::Expr *RHS,
+                                            clang::SourceLocation Loc);
+
+static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
+                                           clang::Diagnostic *Diags,
+                                           clang::Expr *DstArr,
+                                           clang::Expr *SrcArr,
+                                           clang::SourceLocation Loc) {
+  clang::DeclContext *DC = NULL;
+  clang::SourceRange Range;
+  const clang::Type *BaseType = DstArr->getType().getTypePtr();
+  slangAssert(BaseType->isArrayType());
+
+  int NumArrayElements = ArrayDim(BaseType);
+  // Actually extract out the base RS object type for use later
+  BaseType = BaseType->getArrayElementTypeNoTypeQual();
+
+  clang::Stmt *StmtArray[2] = {NULL};
+  int StmtCtr = 0;
+
+  if (NumArrayElements <= 0) {
+    return NULL;
+  }
+
+  // Create helper variable for iterating through elements
+  clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
+  clang::VarDecl *IIVD =
+      clang::VarDecl::Create(C,
+                             DC,
+                             Loc,
+                             &II,
+                             C.IntTy,
+                             C.getTrivialTypeSourceInfo(C.IntTy),
+                             clang::SC_None,
+                             clang::SC_None);
+  clang::Decl *IID = (clang::Decl *)IIVD;
+
+  clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
+  StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
+
+  // Form the actual loop
+  // for (Init; Cond; Inc)
+  //   RSSetObjectCall;
+
+  // Init -> "rsIntIter = 0"
+  clang::DeclRefExpr *RefrsIntIter =
+      clang::DeclRefExpr::Create(C,
+                                 NULL,
+                                 Range,
+                                 IIVD,
+                                 Loc,
+                                 C.IntTy);
+
+  clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
+      llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
+
+  clang::BinaryOperator *Init =
+      new(C) clang::BinaryOperator(RefrsIntIter,
+                                   Int0,
+                                   clang::BO_Assign,
+                                   C.IntTy,
+                                   Loc);
+
+  // Cond -> "rsIntIter < NumArrayElements"
+  clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
+      llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
+
+  clang::BinaryOperator *Cond =
+      new(C) clang::BinaryOperator(RefrsIntIter,
+                                   NumArrayElementsExpr,
+                                   clang::BO_LT,
+                                   C.IntTy,
+                                   Loc);
+
+  // Inc -> "rsIntIter++"
+  clang::UnaryOperator *Inc =
+      new(C) clang::UnaryOperator(RefrsIntIter,
+                                  clang::UO_PostInc,
+                                  C.IntTy,
+                                  Loc);
+
+  // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
+  // Loop operates on individual array elements
+
+  clang::Expr *DstArrPtr =
+      clang::ImplicitCastExpr::Create(C,
+          C.getPointerType(BaseType->getCanonicalTypeInternal()),
+          clang::CK_ArrayToPointerDecay,
+          DstArr,
+          NULL,
+          clang::VK_RValue);
+
+  clang::Expr *DstArrPtrSubscript =
+      new(C) clang::ArraySubscriptExpr(DstArrPtr,
+                                       RefrsIntIter,
+                                       BaseType->getCanonicalTypeInternal(),
+                                       Loc);
+
+  clang::Expr *SrcArrPtr =
+      clang::ImplicitCastExpr::Create(C,
+          C.getPointerType(BaseType->getCanonicalTypeInternal()),
+          clang::CK_ArrayToPointerDecay,
+          SrcArr,
+          NULL,
+          clang::VK_RValue);
+
+  clang::Expr *SrcArrPtrSubscript =
+      new(C) clang::ArraySubscriptExpr(SrcArrPtr,
+                                       RefrsIntIter,
+                                       BaseType->getCanonicalTypeInternal(),
+                                       Loc);
+
+  RSExportPrimitiveType::DataType DT =
+      RSExportPrimitiveType::GetRSSpecificType(BaseType);
+
+  clang::Stmt *RSSetObjectCall = NULL;
+  if (BaseType->isArrayType()) {
+    RSSetObjectCall = CreateArrayRSSetObject(C, Diags, DstArrPtrSubscript,
+                                             SrcArrPtrSubscript, Loc);
+  } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
+    RSSetObjectCall = CreateStructRSSetObject(C, Diags, DstArrPtrSubscript,
+                                              SrcArrPtrSubscript, Loc);
+  } else {
+    RSSetObjectCall = CreateSingleRSSetObject(C, Diags, DstArrPtrSubscript,
+                                              SrcArrPtrSubscript, Loc);
+  }
+
+  clang::ForStmt *DestructorLoop =
+      new(C) clang::ForStmt(C,
+                            Init,
+                            Cond,
+                            NULL,  // no condVar
+                            Inc,
+                            RSSetObjectCall,
+                            Loc,
+                            Loc,
+                            Loc);
+
+  StmtArray[StmtCtr++] = DestructorLoop;
+  slangAssert(StmtCtr == 2);
+
+  clang::CompoundStmt *CS =
+      new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
+
+  return CS;
+}
+
+static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
+                                            clang::Diagnostic *Diags,
+                                            clang::Expr *LHS,
+                                            clang::Expr *RHS,
+                                            clang::SourceLocation Loc) {
+  clang::SourceRange Range;
+  clang::QualType QT = LHS->getType();
+  const clang::Type *T = QT.getTypePtr();
+  slangAssert(T->isStructureType());
+  slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
+
+  // Keep an extra slot for the original copy (memcpy)
+  unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
+
+  unsigned StmtCount = 0;
+  clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
+  for (unsigned i = 0; i < FieldsToSet; i++) {
+    StmtArray[i] = NULL;
+  }
+
+  clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
+  RD = RD->getDefinition();
+  for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
+         FE = RD->field_end();
+       FI != FE;
+       FI++) {
+    bool IsArrayType = false;
+    clang::FieldDecl *FD = *FI;
+    const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
+    const clang::Type *OrigType = FT;
+
+    if (!CountRSObjectTypes(FT)) {
+      // Skip to next if we don't have any viable RS object types
+      continue;
+    }
+
+    clang::DeclAccessPair FoundDecl =
+        clang::DeclAccessPair::make(FD, clang::AS_none);
+    clang::MemberExpr *DstMember =
+        clang::MemberExpr::Create(C,
+                                  LHS,
+                                  false,
+                                  NULL,
+                                  Range,
+                                  FD,
+                                  FoundDecl,
+                                  clang::DeclarationNameInfo(),
+                                  NULL,
+                                  OrigType->getCanonicalTypeInternal());
+
+    clang::MemberExpr *SrcMember =
+        clang::MemberExpr::Create(C,
+                                  RHS,
+                                  false,
+                                  NULL,
+                                  Range,
+                                  FD,
+                                  FoundDecl,
+                                  clang::DeclarationNameInfo(),
+                                  NULL,
+                                  OrigType->getCanonicalTypeInternal());
+
+    if (FT->isArrayType()) {
+      FT = FT->getArrayElementTypeNoTypeQual();
+      IsArrayType = true;
+    }
+
+    RSExportPrimitiveType::DataType DT =
+        RSExportPrimitiveType::GetRSSpecificType(FT);
+
+    if (IsArrayType) {
+      Diags->Report(Diags->getCustomDiagID(clang::Diagnostic::Error,
+           "Arrays of RS object types within structures cannot be copied"));
+      // TODO(srhines): Support setting arrays of RS objects
+      // StmtArray[StmtCount++] =
+      //    CreateArrayRSSetObject(C, Diags, DstMember, SrcMember, Loc);
+    } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
+      StmtArray[StmtCount++] =
+          CreateStructRSSetObject(C, Diags, DstMember, SrcMember, Loc);
+    } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
+      StmtArray[StmtCount++] =
+          CreateSingleRSSetObject(C, Diags, DstMember, SrcMember, Loc);
+    } else {
+      slangAssert(false);
+    }
+  }
+
+  slangAssert(StmtCount > 0 && StmtCount < FieldsToSet);
+
+  // We still need to actually do the overall struct copy. For simplicity,
+  // we just do a straight-up assignment (which will still preserve all
+  // the proper RS object reference counts).
+  clang::BinaryOperator *CopyStruct =
+      new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT, Loc);
+  StmtArray[StmtCount++] = CopyStruct;
+
+  clang::CompoundStmt *CS =
+      new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
+
+  delete [] StmtArray;
+
+  return CS;
+}
+
+}  // namespace
+
+void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
+    clang::BinaryOperator *AS,
+    clang::Diagnostic *Diags) {
+
+  clang::QualType QT = AS->getType();
+
+  clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
+      RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
+
+  clang::SourceLocation Loc = AS->getLocEnd();
+  clang::Stmt *UpdatedStmt = NULL;
+
+  if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
+    // By definition, this is a struct assignment if we get here
+    UpdatedStmt =
+        CreateStructRSSetObject(C, Diags, AS->getLHS(), AS->getRHS(), Loc);
+  } else {
+    UpdatedStmt =
+        CreateSingleRSSetObject(C, Diags, AS->getLHS(), AS->getRHS(), Loc);
+  }
+
+  ReplaceInCompoundStmt(C, mCS, AS, UpdatedStmt);
   return;
 }
 
@@ -988,8 +1266,8 @@
 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
   clang::QualType QT = AS->getType();
 
-  if (RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
-    getCurrentScope()->ReplaceRSObjectAssignment(AS);
+  if (CountRSObjectTypes(QT.getTypePtr())) {
+    getCurrentScope()->ReplaceRSObjectAssignment(AS, mDiags);
   }
 
   return;