Support local destructor for many return stmts.

This allows us to insert destructor calls to rsClearObject() for all local
variables at the end of their scope (closing curly brace), as well as before
any enclosed return statement. Note that proper support for a local destructor
with break/continue is still unimplemented.

Change-Id: I7c0633d2901b9bc7e1bac8e211b9eae2ac9f6e92
diff --git a/slang_rs_object_ref_count.cpp b/slang_rs_object_ref_count.cpp
index 9e813ba..bba37f6 100644
--- a/slang_rs_object_ref_count.cpp
+++ b/slang_rs_object_ref_count.cpp
@@ -86,17 +86,22 @@
   }
 }
 
-void RSObjectRefCount::Scope::AppendToCompoundStatement(
-    clang::ASTContext& C, std::list<clang::Expr*> &ExprList) {
+namespace {
+
+static void AppendToCompoundStatement(clang::ASTContext& C,
+                                      clang::CompoundStmt *CS,
+                                      std::list<clang::Expr*> &ExprList,
+                                      bool InsertAtEndOfBlock) {
   // Destructor code will be inserted before any return statement.
   // Any subsequent statements in the compound statement are then placed
   // after our new code.
   // TODO(srhines): This should also handle the case of goto/break/continue.
-  clang::CompoundStmt::body_iterator bI = mCS->body_begin();
-  clang::CompoundStmt::body_iterator bE = mCS->body_end();
+
+  clang::CompoundStmt::body_iterator bI = CS->body_begin();
+  clang::CompoundStmt::body_iterator bE = CS->body_end();
 
   unsigned OldStmtCount = 0;
-  for ( ; bI != bE; bI++) {
+  for (bI = CS->body_begin(); bI != bE; bI++) {
     OldStmtCount++;
   }
 
@@ -106,19 +111,25 @@
   StmtList = new clang::Stmt*[OldStmtCount+NewExprCount];
 
   unsigned UpdatedStmtCount = 0;
-  for (bI = mCS->body_begin(); bI != bE; bI++) {
+  bool FoundReturn = false;
+  for (bI = CS->body_begin(); bI != bE; bI++) {
     if ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass) {
+      FoundReturn = true;
       break;
     }
     StmtList[UpdatedStmtCount++] = *bI;
   }
 
-  std::list<clang::Expr*>::const_iterator E = ExprList.end();
-  for (std::list<clang::Expr*>::const_iterator I = ExprList.begin(),
-          E = ExprList.end();
-       I != E;
-       I++) {
-    StmtList[UpdatedStmtCount++] = *I;
+  // Always insert before a return that we found, or if we are told
+  // to insert at the end of the block
+  if (FoundReturn || InsertAtEndOfBlock) {
+    std::list<clang::Expr*>::const_iterator E = ExprList.end();
+    for (std::list<clang::Expr*>::const_iterator I = ExprList.begin(),
+            E = ExprList.end();
+         I != E;
+         I++) {
+      StmtList[UpdatedStmtCount++] = *I;
+    }
   }
 
   // Pick up anything left over after a return statement
@@ -126,14 +137,61 @@
     StmtList[UpdatedStmtCount++] = *bI;
   }
 
-  mCS->setStmts(C, StmtList, UpdatedStmtCount);
-  assert(UpdatedStmtCount == (OldStmtCount + NewExprCount));
+  CS->setStmts(C, StmtList, UpdatedStmtCount);
 
   delete [] StmtList;
 
   return;
 }
 
+// This class visits a compound statement and inserts the ExprList containing
+// destructors in proper locations. This includes inserting them before any
+// return statement in any sub-block, at the end of the logical enclosing
+// scope (compound statement), and/or before any break/continue statement that
+// would resume outside the declared scope. We will not handle the case for
+// goto statements that leave a local scope.
+// TODO(srhines): Make this work properly for break/continue.
+class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
+ private:
+  clang::ASTContext &mC;
+  std::list<clang::Expr*> &mExprList;
+  bool mTopLevel;
+ public:
+  DestructorVisitor(clang::ASTContext &C, std::list<clang::Expr*> &ExprList);
+  void VisitStmt(clang::Stmt *S);
+  void VisitCompoundStmt(clang::CompoundStmt *CS);
+};
+
+DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
+                                     std::list<clang::Expr*> &ExprList)
+  : mC(C),
+    mExprList(ExprList),
+    mTopLevel(true) {
+  return;
+}
+
+void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
+  if (!CS->body_empty()) {
+    AppendToCompoundStatement(mC, CS, mExprList, mTopLevel);
+    mTopLevel = false;
+    VisitStmt(CS);
+  }
+  return;
+}
+
+void DestructorVisitor::VisitStmt(clang::Stmt *S) {
+  for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
+       I != E;
+       I++) {
+    if (clang::Stmt *Child = *I) {
+      Visit(Child);
+    }
+  }
+  return;
+}
+
+}  // namespace
+
 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
   std::list<clang::Expr*> RSClearObjectCalls;
   for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
@@ -146,11 +204,8 @@
     }
   }
   if (RSClearObjectCalls.size() > 0) {
-    clang::ASTContext &C = (*mRSO.begin())->getASTContext();
-    AppendToCompoundStatement(C, RSClearObjectCalls);
-    // TODO(srhines): This should also be extended to append destructors to any
-    // further nested scope (we need another visitor here from within the
-    // current compound statement in case they call return/goto).
+    DestructorVisitor DV((*mRSO.begin())->getASTContext(), RSClearObjectCalls);
+    DV.Visit(mCS);
   }
   return;
 }