Add simple destructors to local RS objects.

Change-Id: Ie4aa964840b25c3aa8eed257d4ff0a1e4f6ef22a
diff --git a/slang_rs_object_ref_count.cpp b/slang_rs_object_ref_count.cpp
index e2e9117..7e61b63 100644
--- a/slang_rs_object_ref_count.cpp
+++ b/slang_rs_object_ref_count.cpp
@@ -27,6 +27,205 @@
 
 using namespace slang;
 
+clang::FunctionDecl *RSObjectRefCount::Scope::
+    RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
+                  RSExportPrimitiveType::FirstRSObjectType + 1];
+clang::FunctionDecl *RSObjectRefCount::Scope::
+    RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
+                    RSExportPrimitiveType::FirstRSObjectType + 1];
+
+void RSObjectRefCount::Scope::GetRSRefCountingFunctions(
+    clang::ASTContext &C) {
+  for (unsigned i = 0;
+       i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
+       i++) {
+    RSSetObjectFD[i] = NULL;
+    RSClearObjectFD[i] = NULL;
+  }
+
+  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
+
+  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
+          E = TUDecl->decls_end(); I != E; I++) {
+    if ((I->getKind() >= clang::Decl::firstFunction) &&
+        (I->getKind() <= clang::Decl::lastFunction)) {
+      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
+
+      // points to RSSetObjectFD or RSClearObjectFD
+      clang::FunctionDecl **RSObjectFD;
+
+      if (FD->getName() == "rsSetObject") {
+        assert((FD->getNumParams() == 2) &&
+               "Invalid rsSetObject function prototype (# params)");
+        RSObjectFD = RSSetObjectFD;
+      } else if (FD->getName() == "rsClearObject") {
+        assert((FD->getNumParams() == 1) &&
+               "Invalid rsClearObject function prototype (# params)");
+        RSObjectFD = RSClearObjectFD;
+      }
+      else {
+        continue;
+      }
+
+      const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
+      clang::QualType PVT = PVD->getOriginalType();
+      // The first parameter must be a pointer like rs_allocation*
+      assert(PVT->isPointerType() &&
+             "Invalid rs{Set,Clear}Object function prototype (pointer param)");
+
+      // The rs object type passed to the FD
+      clang::QualType RST = PVT->getPointeeType();
+      RSExportPrimitiveType::DataType DT =
+          RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
+      assert(RSExportPrimitiveType::IsRSObjectType(DT)
+             && "must be RS object type");
+
+      RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
+    }
+  }
+}
+
+void RSObjectRefCount::Scope::AppendToCompoundStatement(
+    clang::ASTContext& C, std::list<clang::Expr*> &ExprList) {
+  // Destructor code will be inserted before any return statement.
+  // Any subsequent statements in the compound statement are then placed
+  // after our new code.
+  // TODO: This should also handle the case of goto/break/continue.
+  clang::CompoundStmt::body_iterator bI = mCS->body_begin();
+  clang::CompoundStmt::body_iterator bE = mCS->body_end();
+
+  unsigned OldStmtCount = 0;
+  for ( ; bI != bE; bI++) {
+    OldStmtCount++;
+  }
+
+  unsigned NewExprCount = ExprList.size();
+
+  clang::Stmt **StmtList;
+  StmtList = new clang::Stmt*[OldStmtCount+NewExprCount];
+
+  unsigned UpdatedStmtCount = 0;
+  for (bI = mCS->body_begin(); bI != bE; bI++) {
+    if ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass) {
+      break;
+    }
+    StmtList[UpdatedStmtCount++] = *bI;
+  }
+
+  std::list<clang::Expr*>::const_iterator E = ExprList.end();
+  for (std::list<clang::Expr*>::const_iterator I = ExprList.begin(),
+          E = ExprList.end();
+       I != E;
+       I++) {
+    StmtList[UpdatedStmtCount++] = *I;
+  }
+
+  // Pick up anything left over after a return statement
+  for ( ; bI != bE; bI++) {
+    StmtList[UpdatedStmtCount++] = *bI;
+  }
+
+  mCS->setStmts(C, StmtList, UpdatedStmtCount);
+  assert(UpdatedStmtCount == (OldStmtCount + NewExprCount));
+
+  delete [] StmtList;
+
+  return;
+}
+
+void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
+  std::list<clang::Expr*> RSClearObjectCalls;
+  for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
+          E = mRSO.end();
+        I != E;
+        I++) {
+    clang::Expr *E = ClearRSObject(*I);
+    if (E) {
+      RSClearObjectCalls.push_back(E);
+    }
+  }
+  if (RSClearObjectCalls.size() > 0) {
+    clang::ASTContext &C = (*mRSO.begin())->getASTContext();
+    AppendToCompoundStatement(C, RSClearObjectCalls);
+    // TODO: This should also be extended to append destructors to any
+    // further nested scope (we need another visitor here from within the
+    // current compound statement in case they call return/goto).
+  }
+  return;
+}
+
+clang::Expr *RSObjectRefCount::Scope::ClearRSObject(clang::VarDecl *VD) {
+  clang::ASTContext &C = VD->getASTContext();
+  clang::SourceLocation Loc = VD->getLocation();
+  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
+  RSExportPrimitiveType::DataType DT =
+      RSExportPrimitiveType::GetRSSpecificType(T);
+
+  assert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
+      "Should be RS object");
+
+  // Find the rsClearObject() for VD of RS object type DT
+  clang::FunctionDecl *ClearObjectFD =
+      RSClearObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)];
+  assert((ClearObjectFD != NULL) &&
+      "rsClearObject doesn't cover all RS object types");
+
+  clang::QualType ClearObjectFDType = ClearObjectFD->getType();
+  clang::QualType ClearObjectFDArgType =
+      ClearObjectFD->getParamDecl(0)->getOriginalType();
+
+  // We generate a call to rsClearObject passing &VD as the parameter
+  // (CallExpr 'void'
+  //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
+  //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
+  //   (UnaryOperator 'rs_font *' prefix '&'
+  //     (DeclRefExpr 'rs_font':'rs_font' Var='[var name]')))
+
+  // Reference expr to target RS object variable
+  clang::DeclRefExpr *RefRSVar =
+      clang::DeclRefExpr::Create(C,
+                                 NULL,
+                                 VD->getQualifierRange(),
+                                 VD,
+                                 Loc,
+                                 T->getCanonicalTypeInternal(),
+                                 NULL);
+
+  // Get address of RSObject in VD
+  clang::Expr *AddrRefRSVar =
+      new (C) clang::UnaryOperator(RefRSVar,
+                                   clang::UO_AddrOf,
+                                   ClearObjectFDArgType,
+                                   Loc);
+
+  clang::Expr *RefRSClearObjectFD =
+      clang::DeclRefExpr::Create(C,
+                                 NULL,
+                                 ClearObjectFD->getQualifierRange(),
+                                 ClearObjectFD,
+                                 ClearObjectFD->getLocation(),
+                                 ClearObjectFDType,
+                                 NULL);
+
+  clang::Expr *RSClearObjectFP =
+      clang::ImplicitCastExpr::Create(C,
+                                      C.getPointerType(ClearObjectFDType),
+                                      clang::CK_FunctionToPointerDecay,
+                                      RefRSClearObjectFD,
+                                      NULL,
+                                      clang::VK_RValue);
+
+  clang::CallExpr *RSClearObjectCall =
+      new (C) clang::CallExpr(C,
+                              RSClearObjectFP,
+                              &AddrRefRSVar,
+                              1,
+                              ClearObjectFD->getCallResultType(),
+                              clang::SourceLocation());
+
+  return RSClearObjectCall;
+}
+
 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD) {
   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
   RSExportPrimitiveType::DataType DT =
@@ -178,6 +377,7 @@
     // TODO: Update reference count of the RS object refenced by
     //       getCurrentScope().
     assert((getCurrentScope() == S) && "Corrupted scope stack!");
+    S->InsertLocalVarDestructors();
     mScopeStack.pop();
     delete S;
   }
@@ -200,3 +400,4 @@
   return;
 }
 
+