Add support for RSASTReplace for ref-counting.

BUG=3092382

Change-Id: Ia17d40753952e4a021b39549a082cc4b3f20916c
diff --git a/Android.mk b/Android.mk
index a1ee9c0..89baa5e 100644
--- a/Android.mk
+++ b/Android.mk
@@ -196,6 +196,7 @@
 LOCAL_SRC_FILES :=	\
 	llvm-rs-cc.cpp	\
 	slang_rs.cpp	\
+	slang_rs_ast_replace.cpp	\
 	slang_rs_context.cpp	\
 	slang_rs_pragma_handler.cpp	\
 	slang_rs_backend.cpp	\
diff --git a/slang_rs_ast_replace.cpp b/slang_rs_ast_replace.cpp
new file mode 100644
index 0000000..7bcdd17
--- /dev/null
+++ b/slang_rs_ast_replace.cpp
@@ -0,0 +1,167 @@
+/*
+ * Copyright 2011, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "slang_rs_ast_replace.h"
+
+#include "slang_assert.h"
+
+namespace slang {
+
+void RSASTReplace::ReplaceStmt(
+    clang::Stmt *OuterStmt,
+    clang::Stmt *OldStmt,
+    clang::Stmt *NewStmt) {
+  mOldStmt = OldStmt;
+  mNewStmt = NewStmt;
+  mOuterStmt = OuterStmt;
+
+  // This simplifies use in various Stmt visitor passes where the only
+  // valid type is an Expr.
+  mOldExpr = dyn_cast<clang::Expr>(OldStmt);
+  if (mOldExpr) {
+    mNewExpr = dyn_cast<clang::Expr>(NewStmt);
+  }
+  Visit(mOuterStmt);
+}
+
+void RSASTReplace::ReplaceInCompoundStmt(clang::CompoundStmt *CS) {
+  clang::Stmt **UpdatedStmtList = new clang::Stmt*[CS->size()];
+
+  unsigned UpdatedStmtCount = 0;
+  clang::CompoundStmt::body_iterator bI = CS->body_begin();
+  clang::CompoundStmt::body_iterator bE = CS->body_end();
+
+  for ( ; bI != bE; bI++) {
+    if (matchesStmt(*bI)) {
+      UpdatedStmtList[UpdatedStmtCount++] = mNewStmt;
+    } else {
+      UpdatedStmtList[UpdatedStmtCount++] = *bI;
+    }
+  }
+
+  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
+
+  delete [] UpdatedStmtList;
+
+  return;
+}
+
+void RSASTReplace::VisitStmt(clang::Stmt *S) {
+  // This function does the actual iteration through all sub-Stmt's within
+  // a given Stmt. Note that this function is skipped by all of the other
+  // Visit* functions if we have already found a higher-level match.
+  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
+       I != E;
+       I++) {
+    if (clang::Stmt *Child = *I) {
+      if (!matchesStmt(Child)) {
+        Visit(Child);
+      }
+    }
+  }
+  return;
+}
+
+void RSASTReplace::VisitCompoundStmt(clang::CompoundStmt *CS) {
+  VisitStmt(CS);
+  ReplaceInCompoundStmt(CS);
+  return;
+}
+
+void RSASTReplace::VisitCaseStmt(clang::CaseStmt *CS) {
+  if (matchesStmt(CS->getSubStmt())) {
+    CS->setSubStmt(mNewStmt);
+  } else {
+    VisitStmt(CS);
+  }
+  return;
+}
+
+void RSASTReplace::VisitDefaultStmt(clang::DefaultStmt *DS) {
+  if (matchesStmt(DS->getSubStmt())) {
+    DS->setSubStmt(mNewStmt);
+  } else {
+    VisitStmt(DS);
+  }
+  return;
+}
+
+void RSASTReplace::VisitDoStmt(clang::DoStmt *DS) {
+  if (matchesExpr(DS->getCond())) {
+    DS->setCond(mNewExpr);
+  } else if (matchesStmt(DS->getBody())) {
+    DS->setBody(mNewStmt);
+  } else {
+    VisitStmt(DS);
+  }
+  return;
+}
+
+void RSASTReplace::VisitForStmt(clang::ForStmt *FS) {
+  if (matchesStmt(FS->getInit())) {
+    FS->setInit(mNewStmt);
+  } else if (matchesExpr(FS->getCond())) {
+    FS->setCond(mNewExpr);
+  } else if (matchesExpr(FS->getInc())) {
+    FS->setInc(mNewExpr);
+  } else if (matchesStmt(FS->getBody())) {
+    FS->setBody(mNewStmt);
+  } else {
+    VisitStmt(FS);
+  }
+  return;
+}
+
+void RSASTReplace::VisitIfStmt(clang::IfStmt *IS) {
+  if (matchesExpr(IS->getCond())) {
+    IS->setCond(mNewExpr);
+  } else if (matchesStmt(IS->getThen())) {
+    IS->setThen(mNewStmt);
+  } else if (matchesStmt(IS->getElse())) {
+    IS->setElse(mNewStmt);
+  } else {
+    VisitStmt(IS);
+  }
+  return;
+}
+
+void RSASTReplace::VisitSwitchCase(clang::SwitchCase *SC) {
+  slangAssert(false && "Both case and default have specialized handlers");
+  VisitStmt(SC);
+  return;
+}
+
+void RSASTReplace::VisitSwitchStmt(clang::SwitchStmt *SS) {
+  if (matchesExpr(SS->getCond())) {
+    SS->setCond(mNewExpr);
+  } else {
+    VisitStmt(SS);
+  }
+  return;
+}
+
+void RSASTReplace::VisitWhileStmt(clang::WhileStmt *WS) {
+  if (matchesExpr(WS->getCond())) {
+    WS->setCond(mNewExpr);
+  } else if (matchesStmt(WS->getBody())) {
+    WS->setBody(mNewStmt);
+  } else {
+    VisitStmt(WS);
+  }
+  return;
+}
+
+}  // namespace slang
diff --git a/slang_rs_ast_replace.h b/slang_rs_ast_replace.h
new file mode 100644
index 0000000..78e094a
--- /dev/null
+++ b/slang_rs_ast_replace.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2011, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_AST_REPLACE_H_  // NOLINT
+#define _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_AST_REPLACE_H_
+
+#include "clang/AST/StmtVisitor.h"
+
+#include "slang_assert.h"
+#include "clang/AST/ASTContext.h"
+
+namespace clang {
+  class Diagnostic;
+  class Expr;
+  class Stmt;
+}
+
+namespace slang {
+
+class RSASTReplace : public clang::StmtVisitor<RSASTReplace> {
+ private:
+
+  clang::ASTContext &C;
+  clang::Stmt *mOuterStmt;
+  clang::Stmt *mOldStmt;
+  clang::Stmt *mNewStmt;
+  clang::Expr *mOldExpr;
+  clang::Expr *mNewExpr;
+
+  inline bool matchesExpr(const clang::Expr *E) const {
+    bool retVal = mOldExpr && (mOldExpr == E);
+    if (retVal) {
+      slangAssert(mNewExpr &&
+          "Cannot replace an expression if we don't have a new expression");
+    }
+    return retVal;
+  }
+
+  inline bool matchesStmt(const clang::Stmt *S) const {
+    slangAssert(mOldStmt);
+    return mOldStmt == S;
+  }
+
+  void ReplaceInCompoundStmt(clang::CompoundStmt *CS);
+
+ public:
+  explicit RSASTReplace(clang::ASTContext &Con)
+      : C(Con),
+        mOuterStmt(NULL),
+        mOldStmt(NULL),
+        mNewStmt(NULL),
+        mOldExpr(NULL),
+        mNewExpr(NULL) {
+    return;
+  }
+
+  void VisitStmt(clang::Stmt *S);
+  void VisitCompoundStmt(clang::CompoundStmt *CS);
+  void VisitCaseStmt(clang::CaseStmt *CS);
+  void VisitDefaultStmt(clang::DefaultStmt *DS);
+  void VisitDoStmt(clang::DoStmt *DS);
+  void VisitForStmt(clang::ForStmt *FS);
+  void VisitIfStmt(clang::IfStmt *IS);
+  void VisitSwitchCase(clang::SwitchCase *SC);
+  void VisitSwitchStmt(clang::SwitchStmt *SS);
+  void VisitWhileStmt(clang::WhileStmt *WS);
+
+  // Replace all instances of OldStmt in OuterStmt with NewStmt.
+  void ReplaceStmt(
+      clang::Stmt *OuterStmt,
+      clang::Stmt *OldStmt,
+      clang::Stmt *NewStmt);
+};
+
+}  // namespace slang
+
+#endif  // _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_AST_REPLACE_H_  NOLINT
diff --git a/slang_rs_object_ref_count.cpp b/slang_rs_object_ref_count.cpp
index 8788a54..1cda516 100644
--- a/slang_rs_object_ref_count.cpp
+++ b/slang_rs_object_ref_count.cpp
@@ -27,6 +27,7 @@
 
 #include "slang_assert.h"
 #include "slang_rs.h"
+#include "slang_rs_ast_replace.h"
 #include "slang_rs_export_type.h"
 
 namespace slang {
@@ -89,53 +90,78 @@
 
 namespace {
 
-static void AppendToCompoundStatement(clang::ASTContext& C,
-                                      clang::CompoundStmt *CS,
-                                      std::list<clang::Stmt*> &StmtList,
-                                      bool InsertAtEndOfBlock) {
-  // Destructor code will be inserted before any return statement.
-  // Any subsequent statements in the compound statement are then placed
-  // after our new code.
-  // TODO(srhines): This should also handle the case of goto/break/continue.
-
-  clang::CompoundStmt::body_iterator bI = CS->body_begin();
-
-  unsigned OldStmtCount = 0;
-  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
-    OldStmtCount++;
-  }
-
+// This function constructs a new CompoundStmt from the input StmtList.
+static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
+      std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
   unsigned NewStmtCount = StmtList.size();
+  unsigned CompoundStmtCount = 0;
 
-  clang::Stmt **UpdatedStmtList;
-  UpdatedStmtList = new clang::Stmt*[OldStmtCount+NewStmtCount];
+  clang::Stmt **CompoundStmtList;
+  CompoundStmtList = new clang::Stmt*[NewStmtCount];
+
+  std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
+  std::list<clang::Stmt*>::const_iterator E = StmtList.end();
+  for ( ; I != E; I++) {
+    CompoundStmtList[CompoundStmtCount++] = *I;
+  }
+  slangAssert(CompoundStmtCount == NewStmtCount);
+
+  clang::CompoundStmt *CS = new(C) clang::CompoundStmt(C,
+                                                       CompoundStmtList,
+                                                       CompoundStmtCount,
+                                                       Loc,
+                                                       Loc);
+
+  delete [] CompoundStmtList;
+
+  return CS;
+}
+
+static void AppendAfterStmt(clang::ASTContext &C,
+                            clang::CompoundStmt *CS,
+                            clang::Stmt *S,
+                            std::list<clang::Stmt*> &StmtList) {
+  slangAssert(CS);
+  clang::CompoundStmt::body_iterator bI = CS->body_begin();
+  clang::CompoundStmt::body_iterator bE = CS->body_end();
+  clang::Stmt **UpdatedStmtList =
+      new clang::Stmt*[CS->size() + StmtList.size()];
 
   unsigned UpdatedStmtCount = 0;
-  bool FoundReturn = false;
-  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
-    if ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass) {
-      FoundReturn = true;
-      break;
-    }
-    UpdatedStmtList[UpdatedStmtCount++] = *bI;
-  }
+  unsigned Once = 0;
+  for ( ; bI != bE; bI++) {
+    if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
+      // If we come across a return here, we don't have anything we can
+      // reasonably replace. We should have already inserted our destructor
+      // code in the proper spot, so we just clean up and return.
+      delete [] UpdatedStmtList;
 
-  // Always insert before a return that we found, or if we are told
-  // to insert at the end of the block
-  if (FoundReturn || InsertAtEndOfBlock) {
+      return;
+    }
+
+    UpdatedStmtList[UpdatedStmtCount++] = *bI;
+
+    if ((*bI == S) && !Once) {
+      Once++;
+      std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
+      std::list<clang::Stmt*>::const_iterator E = StmtList.end();
+      for ( ; I != E; I++) {
+        UpdatedStmtList[UpdatedStmtCount++] = *I;
+      }
+    }
+  }
+  slangAssert(Once <= 1);
+
+  // When S is NULL, we are appending to the end of the CompoundStmt.
+  if (!S) {
+    slangAssert(Once == 0);
     std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
-    for (std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
-         I != StmtList.end();
-         I++) {
+    std::list<clang::Stmt*>::const_iterator E = StmtList.end();
+    for ( ; I != E; I++) {
       UpdatedStmtList[UpdatedStmtCount++] = *I;
     }
   }
 
-  // Pick up anything left over after a return statement
-  for ( ; bI != CS->body_end(); bI++) {
-    UpdatedStmtList[UpdatedStmtCount++] = *bI;
-  }
-
   CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
 
   delete [] UpdatedStmtList;
@@ -143,99 +169,95 @@
   return;
 }
 
-static void AppendAfterStmt(clang::ASTContext& C,
-                            clang::CompoundStmt *CS,
-                            clang::Stmt *OldStmt,
-                            clang::Stmt *NewStmt) {
-  slangAssert(CS && OldStmt && NewStmt);
-  clang::CompoundStmt::body_iterator bI = CS->body_begin();
-  unsigned StmtCount = 1;  // Take into account new statement
-  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
-    StmtCount++;
-  }
-
-  clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
-
-  unsigned UpdatedStmtCount = 0;
-  unsigned Once = 0;
-  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
-    UpdatedStmtList[UpdatedStmtCount++] = *bI;
-    if (*bI == OldStmt) {
-      Once++;
-      slangAssert(Once == 1);
-      UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
-    }
-  }
-
-  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
-
-  delete [] UpdatedStmtList;
-
-  return;
-}
-
-static void ReplaceInCompoundStmt(clang::ASTContext& C,
-                                  clang::CompoundStmt *CS,
-                                  clang::Stmt* OldStmt,
-                                  clang::Stmt* NewStmt) {
-  clang::CompoundStmt::body_iterator bI = CS->body_begin();
-
-  unsigned StmtCount = 0;
-  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
-    StmtCount++;
-  }
-
-  clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
-
-  unsigned UpdatedStmtCount = 0;
-  for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
-    if (*bI == OldStmt) {
-      UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
-    } else {
-      UpdatedStmtList[UpdatedStmtCount++] = *bI;
-    }
-  }
-
-  CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
-
-  delete [] UpdatedStmtList;
-
-  return;
-}
-
-
 // This class visits a compound statement and inserts the StmtList containing
 // destructors in proper locations. This includes inserting them before any
 // return statement in any sub-block, at the end of the logical enclosing
 // scope (compound statement), and/or before any break/continue statement that
 // would resume outside the declared scope. We will not handle the case for
 // goto statements that leave a local scope.
-// TODO(srhines): Make this work properly for break/continue.
+//
+// To accomplish these goals, it collects a list of sub-Stmt's that
+// correspond to scope exit points. It then uses an RSASTReplace visitor to
+// transform the AST, inserting appropriate destructors before each of those
+// sub-Stmt's (and also before the exit of the outermost containing Stmt for
+// the scope).
 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
  private:
   clang::ASTContext &mC;
+
+  // The loop depth of the currently visited node.
+  int mLoopDepth;
+
+  // The switch statement depth of the currently visited node.
+  // Note that this is tracked separately from the loop depth because
+  // SwitchStmt-contained ContinueStmt's should have destructors for the
+  // corresponding loop scope.
+  int mSwitchDepth;
+
+  // The outermost statement block that we are currently visiting.
+  // This should always be a CompoundStmt.
+  clang::Stmt *mOuterStmt;
+
+  // The list of destructors to execute for this scope.
   std::list<clang::Stmt*> &mStmtList;
-  bool mTopLevel;
+
+  // The stack of statements which should be replaced by a compound statement
+  // containing the new destructor calls followed by the original Stmt.
+  std::stack<clang::Stmt*> mReplaceStmtStack;
+
  public:
-  DestructorVisitor(clang::ASTContext &C, std::list<clang::Stmt*> &StmtList);
+  DestructorVisitor(clang::ASTContext &C,
+                    clang::Stmt* OuterStmt,
+                    std::list<clang::Stmt*> &StmtList);
+
+  // This code walks the collected list of Stmts to replace and actually does
+  // the replacement. It also finishes up by appending appropriate destructors
+  // to the current outermost CompoundStmt.
+  void InsertDestructors() {
+    clang::Stmt *S = NULL;
+    while (!mReplaceStmtStack.empty()) {
+      S = mReplaceStmtStack.top();
+      mReplaceStmtStack.pop();
+
+      mStmtList.push_back(S);
+      clang::CompoundStmt *CS =
+          BuildCompoundStmt(mC, mStmtList, S->getLocEnd());
+      mStmtList.pop_back();
+
+      RSASTReplace R(mC);
+      R.ReplaceStmt(mOuterStmt, S, CS);
+    }
+    clang::CompoundStmt *CS = dyn_cast<clang::CompoundStmt>(mOuterStmt);
+    slangAssert(CS);
+    if (CS) {
+      AppendAfterStmt(mC, CS, NULL, mStmtList);
+    }
+  }
+
   void VisitStmt(clang::Stmt *S);
   void VisitCompoundStmt(clang::CompoundStmt *CS);
+
+  void VisitBreakStmt(clang::BreakStmt *BS);
+  void VisitCaseStmt(clang::CaseStmt *CS);
+  void VisitContinueStmt(clang::ContinueStmt *CS);
+  void VisitDefaultStmt(clang::DefaultStmt *DS);
+  void VisitDoStmt(clang::DoStmt *DS);
+  void VisitForStmt(clang::ForStmt *FS);
+  void VisitIfStmt(clang::IfStmt *IS);
+  void VisitReturnStmt(clang::ReturnStmt *RS);
+  void VisitSwitchCase(clang::SwitchCase *SC);
+  void VisitSwitchStmt(clang::SwitchStmt *SS);
+  void VisitWhileStmt(clang::WhileStmt *WS);
 };
 
 DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
-                                     std::list<clang::Stmt*> &StmtList)
+                         clang::Stmt *OuterStmt,
+                         std::list<clang::Stmt*> &StmtList)
   : mC(C),
-    mStmtList(StmtList),
-    mTopLevel(true) {
-  return;
-}
-
-void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
-  if (!CS->body_empty()) {
-    AppendToCompoundStatement(mC, CS, mStmtList, mTopLevel);
-    mTopLevel = false;
-    VisitStmt(CS);
-  }
+    mLoopDepth(0),
+    mSwitchDepth(0),
+    mOuterStmt(OuterStmt),
+    mStmtList(StmtList) {
   return;
 }
 
@@ -250,6 +272,82 @@
   return;
 }
 
+void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
+  VisitStmt(CS);
+  return;
+}
+
+void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
+  VisitStmt(BS);
+  if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
+    mReplaceStmtStack.push(BS);
+  }
+  return;
+}
+
+void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
+  VisitStmt(CS);
+  return;
+}
+
+void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
+  VisitStmt(CS);
+  if (mLoopDepth == 0) {
+    // Switch statements can have nested continues.
+    mReplaceStmtStack.push(CS);
+  }
+  return;
+}
+
+void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
+  VisitStmt(DS);
+  return;
+}
+
+void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
+  mLoopDepth++;
+  VisitStmt(DS);
+  mLoopDepth--;
+  return;
+}
+
+void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
+  mLoopDepth++;
+  VisitStmt(FS);
+  mLoopDepth--;
+  return;
+}
+
+void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
+  VisitStmt(IS);
+  return;
+}
+
+void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
+  mReplaceStmtStack.push(RS);
+  return;
+}
+
+void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
+  slangAssert(false && "Both case and default have specialized handlers");
+  VisitStmt(SC);
+  return;
+}
+
+void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
+  mSwitchDepth++;
+  VisitStmt(SS);
+  mSwitchDepth--;
+  return;
+}
+
+void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
+  mLoopDepth++;
+  VisitStmt(WS);
+  mLoopDepth--;
+  return;
+}
+
 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
                                  clang::Expr *RefRSVar,
                                  clang::SourceLocation Loc) {
@@ -974,7 +1072,8 @@
         CreateSingleRSSetObject(C, Diags, AS->getLHS(), AS->getRHS(), Loc);
   }
 
-  ReplaceInCompoundStmt(C, mCS, AS, UpdatedStmt);
+  RSASTReplace R(C);
+  R.ReplaceStmt(mCS, AS, UpdatedStmt);
   return;
 }
 
@@ -996,7 +1095,6 @@
       RSExportPrimitiveType::DataTypeRSFont)->getLocation();
 
   if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
-    // TODO(srhines): Skip struct initialization right now
     const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
     clang::DeclRefExpr *RefRSVar =
         clang::DeclRefExpr::Create(C,
@@ -1010,7 +1108,9 @@
     clang::Stmt *RSSetObjectOps =
         CreateStructRSSetObject(C, Diags, RefRSVar, InitExpr, Loc);
 
-    AppendAfterStmt(C, mCS, DS, RSSetObjectOps);
+    std::list<clang::Stmt*> StmtList;
+    StmtList.push_back(RSSetObjectOps);
+    AppendAfterStmt(C, mCS, DS, StmtList);
     return;
   }
 
@@ -1068,7 +1168,9 @@
                              clang::VK_RValue,
                              Loc);
 
-  AppendAfterStmt(C, mCS, DS, RSSetObjectCall);
+  std::list<clang::Stmt*> StmtList;
+  StmtList.push_back(RSSetObjectCall);
+  AppendAfterStmt(C, mCS, DS, StmtList);
 
   return;
 }
@@ -1085,8 +1187,11 @@
     }
   }
   if (RSClearObjectCalls.size() > 0) {
-    DestructorVisitor DV((*mRSO.begin())->getASTContext(), RSClearObjectCalls);
+    DestructorVisitor DV((*mRSO.begin())->getASTContext(),
+                         mCS,
+                         RSClearObjectCalls);
     DV.Visit(mCS);
+    DV.InsertDestructors();
   }
   return;
 }
diff --git a/slang_rs_object_ref_count.h b/slang_rs_object_ref_count.h
index 33a5a89..78984a7 100644
--- a/slang_rs_object_ref_count.h
+++ b/slang_rs_object_ref_count.h
@@ -32,6 +32,16 @@
 
 namespace slang {
 
+// This class provides the overall reference counting mechanism for handling
+// local variables of RS object types (rs_font, rs_allocation, ...). This
+// class ensures that appropriate functions (rsSetObject, rsClearObject) are
+// called at proper points in the object's lifetime.
+// 1) Each local object of appropriate type must be zero-initialized (to
+// prevent corruption) during subsequent rsSetObject()/rsClearObject() calls.
+// 2) Assignments using these types must also be converted into the
+// appropriate (possibly a series of) rsSetObject() calls.
+// 3) Finally, each local object must call rsClearObject() when it goes out
+// of scope.
 class RSObjectRefCount : public clang::StmtVisitor<RSObjectRefCount> {
  private:
   class Scope {
@@ -79,9 +89,6 @@
   // Initialize RSSetObjectFD and RSClearObjectFD.
   static void GetRSRefCountingFunctions(clang::ASTContext &C);
 
-  // TODO(srhines): Composite types and arrays based on RS object types need
-  // to be handled for both zero-initialization + clearing.
-
   // Return false if the type of variable declared in VD does not contain
   // an RS object type.
   static bool InitializeRSObject(clang::VarDecl *VD,