Add support for RSASTReplace for ref-counting.
BUG=3092382
Change-Id: Ia17d40753952e4a021b39549a082cc4b3f20916c
diff --git a/Android.mk b/Android.mk
index a1ee9c0..89baa5e 100644
--- a/Android.mk
+++ b/Android.mk
@@ -196,6 +196,7 @@
LOCAL_SRC_FILES := \
llvm-rs-cc.cpp \
slang_rs.cpp \
+ slang_rs_ast_replace.cpp \
slang_rs_context.cpp \
slang_rs_pragma_handler.cpp \
slang_rs_backend.cpp \
diff --git a/slang_rs_ast_replace.cpp b/slang_rs_ast_replace.cpp
new file mode 100644
index 0000000..7bcdd17
--- /dev/null
+++ b/slang_rs_ast_replace.cpp
@@ -0,0 +1,167 @@
+/*
+ * Copyright 2011, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "slang_rs_ast_replace.h"
+
+#include "slang_assert.h"
+
+namespace slang {
+
+void RSASTReplace::ReplaceStmt(
+ clang::Stmt *OuterStmt,
+ clang::Stmt *OldStmt,
+ clang::Stmt *NewStmt) {
+ mOldStmt = OldStmt;
+ mNewStmt = NewStmt;
+ mOuterStmt = OuterStmt;
+
+ // This simplifies use in various Stmt visitor passes where the only
+ // valid type is an Expr.
+ mOldExpr = dyn_cast<clang::Expr>(OldStmt);
+ if (mOldExpr) {
+ mNewExpr = dyn_cast<clang::Expr>(NewStmt);
+ }
+ Visit(mOuterStmt);
+}
+
+void RSASTReplace::ReplaceInCompoundStmt(clang::CompoundStmt *CS) {
+ clang::Stmt **UpdatedStmtList = new clang::Stmt*[CS->size()];
+
+ unsigned UpdatedStmtCount = 0;
+ clang::CompoundStmt::body_iterator bI = CS->body_begin();
+ clang::CompoundStmt::body_iterator bE = CS->body_end();
+
+ for ( ; bI != bE; bI++) {
+ if (matchesStmt(*bI)) {
+ UpdatedStmtList[UpdatedStmtCount++] = mNewStmt;
+ } else {
+ UpdatedStmtList[UpdatedStmtCount++] = *bI;
+ }
+ }
+
+ CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
+
+ delete [] UpdatedStmtList;
+
+ return;
+}
+
+void RSASTReplace::VisitStmt(clang::Stmt *S) {
+ // This function does the actual iteration through all sub-Stmt's within
+ // a given Stmt. Note that this function is skipped by all of the other
+ // Visit* functions if we have already found a higher-level match.
+ for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
+ I != E;
+ I++) {
+ if (clang::Stmt *Child = *I) {
+ if (!matchesStmt(Child)) {
+ Visit(Child);
+ }
+ }
+ }
+ return;
+}
+
+void RSASTReplace::VisitCompoundStmt(clang::CompoundStmt *CS) {
+ VisitStmt(CS);
+ ReplaceInCompoundStmt(CS);
+ return;
+}
+
+void RSASTReplace::VisitCaseStmt(clang::CaseStmt *CS) {
+ if (matchesStmt(CS->getSubStmt())) {
+ CS->setSubStmt(mNewStmt);
+ } else {
+ VisitStmt(CS);
+ }
+ return;
+}
+
+void RSASTReplace::VisitDefaultStmt(clang::DefaultStmt *DS) {
+ if (matchesStmt(DS->getSubStmt())) {
+ DS->setSubStmt(mNewStmt);
+ } else {
+ VisitStmt(DS);
+ }
+ return;
+}
+
+void RSASTReplace::VisitDoStmt(clang::DoStmt *DS) {
+ if (matchesExpr(DS->getCond())) {
+ DS->setCond(mNewExpr);
+ } else if (matchesStmt(DS->getBody())) {
+ DS->setBody(mNewStmt);
+ } else {
+ VisitStmt(DS);
+ }
+ return;
+}
+
+void RSASTReplace::VisitForStmt(clang::ForStmt *FS) {
+ if (matchesStmt(FS->getInit())) {
+ FS->setInit(mNewStmt);
+ } else if (matchesExpr(FS->getCond())) {
+ FS->setCond(mNewExpr);
+ } else if (matchesExpr(FS->getInc())) {
+ FS->setInc(mNewExpr);
+ } else if (matchesStmt(FS->getBody())) {
+ FS->setBody(mNewStmt);
+ } else {
+ VisitStmt(FS);
+ }
+ return;
+}
+
+void RSASTReplace::VisitIfStmt(clang::IfStmt *IS) {
+ if (matchesExpr(IS->getCond())) {
+ IS->setCond(mNewExpr);
+ } else if (matchesStmt(IS->getThen())) {
+ IS->setThen(mNewStmt);
+ } else if (matchesStmt(IS->getElse())) {
+ IS->setElse(mNewStmt);
+ } else {
+ VisitStmt(IS);
+ }
+ return;
+}
+
+void RSASTReplace::VisitSwitchCase(clang::SwitchCase *SC) {
+ slangAssert(false && "Both case and default have specialized handlers");
+ VisitStmt(SC);
+ return;
+}
+
+void RSASTReplace::VisitSwitchStmt(clang::SwitchStmt *SS) {
+ if (matchesExpr(SS->getCond())) {
+ SS->setCond(mNewExpr);
+ } else {
+ VisitStmt(SS);
+ }
+ return;
+}
+
+void RSASTReplace::VisitWhileStmt(clang::WhileStmt *WS) {
+ if (matchesExpr(WS->getCond())) {
+ WS->setCond(mNewExpr);
+ } else if (matchesStmt(WS->getBody())) {
+ WS->setBody(mNewStmt);
+ } else {
+ VisitStmt(WS);
+ }
+ return;
+}
+
+} // namespace slang
diff --git a/slang_rs_ast_replace.h b/slang_rs_ast_replace.h
new file mode 100644
index 0000000..78e094a
--- /dev/null
+++ b/slang_rs_ast_replace.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2011, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_AST_REPLACE_H_ // NOLINT
+#define _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_AST_REPLACE_H_
+
+#include "clang/AST/StmtVisitor.h"
+
+#include "slang_assert.h"
+#include "clang/AST/ASTContext.h"
+
+namespace clang {
+ class Diagnostic;
+ class Expr;
+ class Stmt;
+}
+
+namespace slang {
+
+class RSASTReplace : public clang::StmtVisitor<RSASTReplace> {
+ private:
+
+ clang::ASTContext &C;
+ clang::Stmt *mOuterStmt;
+ clang::Stmt *mOldStmt;
+ clang::Stmt *mNewStmt;
+ clang::Expr *mOldExpr;
+ clang::Expr *mNewExpr;
+
+ inline bool matchesExpr(const clang::Expr *E) const {
+ bool retVal = mOldExpr && (mOldExpr == E);
+ if (retVal) {
+ slangAssert(mNewExpr &&
+ "Cannot replace an expression if we don't have a new expression");
+ }
+ return retVal;
+ }
+
+ inline bool matchesStmt(const clang::Stmt *S) const {
+ slangAssert(mOldStmt);
+ return mOldStmt == S;
+ }
+
+ void ReplaceInCompoundStmt(clang::CompoundStmt *CS);
+
+ public:
+ explicit RSASTReplace(clang::ASTContext &Con)
+ : C(Con),
+ mOuterStmt(NULL),
+ mOldStmt(NULL),
+ mNewStmt(NULL),
+ mOldExpr(NULL),
+ mNewExpr(NULL) {
+ return;
+ }
+
+ void VisitStmt(clang::Stmt *S);
+ void VisitCompoundStmt(clang::CompoundStmt *CS);
+ void VisitCaseStmt(clang::CaseStmt *CS);
+ void VisitDefaultStmt(clang::DefaultStmt *DS);
+ void VisitDoStmt(clang::DoStmt *DS);
+ void VisitForStmt(clang::ForStmt *FS);
+ void VisitIfStmt(clang::IfStmt *IS);
+ void VisitSwitchCase(clang::SwitchCase *SC);
+ void VisitSwitchStmt(clang::SwitchStmt *SS);
+ void VisitWhileStmt(clang::WhileStmt *WS);
+
+ // Replace all instances of OldStmt in OuterStmt with NewStmt.
+ void ReplaceStmt(
+ clang::Stmt *OuterStmt,
+ clang::Stmt *OldStmt,
+ clang::Stmt *NewStmt);
+};
+
+} // namespace slang
+
+#endif // _FRAMEWORKS_COMPILE_SLANG_SLANG_RS_AST_REPLACE_H_ NOLINT
diff --git a/slang_rs_object_ref_count.cpp b/slang_rs_object_ref_count.cpp
index 8788a54..1cda516 100644
--- a/slang_rs_object_ref_count.cpp
+++ b/slang_rs_object_ref_count.cpp
@@ -27,6 +27,7 @@
#include "slang_assert.h"
#include "slang_rs.h"
+#include "slang_rs_ast_replace.h"
#include "slang_rs_export_type.h"
namespace slang {
@@ -89,53 +90,78 @@
namespace {
-static void AppendToCompoundStatement(clang::ASTContext& C,
- clang::CompoundStmt *CS,
- std::list<clang::Stmt*> &StmtList,
- bool InsertAtEndOfBlock) {
- // Destructor code will be inserted before any return statement.
- // Any subsequent statements in the compound statement are then placed
- // after our new code.
- // TODO(srhines): This should also handle the case of goto/break/continue.
-
- clang::CompoundStmt::body_iterator bI = CS->body_begin();
-
- unsigned OldStmtCount = 0;
- for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
- OldStmtCount++;
- }
-
+// This function constructs a new CompoundStmt from the input StmtList.
+static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
+ std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
unsigned NewStmtCount = StmtList.size();
+ unsigned CompoundStmtCount = 0;
- clang::Stmt **UpdatedStmtList;
- UpdatedStmtList = new clang::Stmt*[OldStmtCount+NewStmtCount];
+ clang::Stmt **CompoundStmtList;
+ CompoundStmtList = new clang::Stmt*[NewStmtCount];
+
+ std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
+ std::list<clang::Stmt*>::const_iterator E = StmtList.end();
+ for ( ; I != E; I++) {
+ CompoundStmtList[CompoundStmtCount++] = *I;
+ }
+ slangAssert(CompoundStmtCount == NewStmtCount);
+
+ clang::CompoundStmt *CS = new(C) clang::CompoundStmt(C,
+ CompoundStmtList,
+ CompoundStmtCount,
+ Loc,
+ Loc);
+
+ delete [] CompoundStmtList;
+
+ return CS;
+}
+
+static void AppendAfterStmt(clang::ASTContext &C,
+ clang::CompoundStmt *CS,
+ clang::Stmt *S,
+ std::list<clang::Stmt*> &StmtList) {
+ slangAssert(CS);
+ clang::CompoundStmt::body_iterator bI = CS->body_begin();
+ clang::CompoundStmt::body_iterator bE = CS->body_end();
+ clang::Stmt **UpdatedStmtList =
+ new clang::Stmt*[CS->size() + StmtList.size()];
unsigned UpdatedStmtCount = 0;
- bool FoundReturn = false;
- for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
- if ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass) {
- FoundReturn = true;
- break;
- }
- UpdatedStmtList[UpdatedStmtCount++] = *bI;
- }
+ unsigned Once = 0;
+ for ( ; bI != bE; bI++) {
+ if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
+ // If we come across a return here, we don't have anything we can
+ // reasonably replace. We should have already inserted our destructor
+ // code in the proper spot, so we just clean up and return.
+ delete [] UpdatedStmtList;
- // Always insert before a return that we found, or if we are told
- // to insert at the end of the block
- if (FoundReturn || InsertAtEndOfBlock) {
+ return;
+ }
+
+ UpdatedStmtList[UpdatedStmtCount++] = *bI;
+
+ if ((*bI == S) && !Once) {
+ Once++;
+ std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
+ std::list<clang::Stmt*>::const_iterator E = StmtList.end();
+ for ( ; I != E; I++) {
+ UpdatedStmtList[UpdatedStmtCount++] = *I;
+ }
+ }
+ }
+ slangAssert(Once <= 1);
+
+ // When S is NULL, we are appending to the end of the CompoundStmt.
+ if (!S) {
+ slangAssert(Once == 0);
std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
- for (std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
- I != StmtList.end();
- I++) {
+ std::list<clang::Stmt*>::const_iterator E = StmtList.end();
+ for ( ; I != E; I++) {
UpdatedStmtList[UpdatedStmtCount++] = *I;
}
}
- // Pick up anything left over after a return statement
- for ( ; bI != CS->body_end(); bI++) {
- UpdatedStmtList[UpdatedStmtCount++] = *bI;
- }
-
CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
delete [] UpdatedStmtList;
@@ -143,99 +169,95 @@
return;
}
-static void AppendAfterStmt(clang::ASTContext& C,
- clang::CompoundStmt *CS,
- clang::Stmt *OldStmt,
- clang::Stmt *NewStmt) {
- slangAssert(CS && OldStmt && NewStmt);
- clang::CompoundStmt::body_iterator bI = CS->body_begin();
- unsigned StmtCount = 1; // Take into account new statement
- for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
- StmtCount++;
- }
-
- clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
-
- unsigned UpdatedStmtCount = 0;
- unsigned Once = 0;
- for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
- UpdatedStmtList[UpdatedStmtCount++] = *bI;
- if (*bI == OldStmt) {
- Once++;
- slangAssert(Once == 1);
- UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
- }
- }
-
- CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
-
- delete [] UpdatedStmtList;
-
- return;
-}
-
-static void ReplaceInCompoundStmt(clang::ASTContext& C,
- clang::CompoundStmt *CS,
- clang::Stmt* OldStmt,
- clang::Stmt* NewStmt) {
- clang::CompoundStmt::body_iterator bI = CS->body_begin();
-
- unsigned StmtCount = 0;
- for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
- StmtCount++;
- }
-
- clang::Stmt **UpdatedStmtList = new clang::Stmt*[StmtCount];
-
- unsigned UpdatedStmtCount = 0;
- for (bI = CS->body_begin(); bI != CS->body_end(); bI++) {
- if (*bI == OldStmt) {
- UpdatedStmtList[UpdatedStmtCount++] = NewStmt;
- } else {
- UpdatedStmtList[UpdatedStmtCount++] = *bI;
- }
- }
-
- CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
-
- delete [] UpdatedStmtList;
-
- return;
-}
-
-
// This class visits a compound statement and inserts the StmtList containing
// destructors in proper locations. This includes inserting them before any
// return statement in any sub-block, at the end of the logical enclosing
// scope (compound statement), and/or before any break/continue statement that
// would resume outside the declared scope. We will not handle the case for
// goto statements that leave a local scope.
-// TODO(srhines): Make this work properly for break/continue.
+//
+// To accomplish these goals, it collects a list of sub-Stmt's that
+// correspond to scope exit points. It then uses an RSASTReplace visitor to
+// transform the AST, inserting appropriate destructors before each of those
+// sub-Stmt's (and also before the exit of the outermost containing Stmt for
+// the scope).
class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
private:
clang::ASTContext &mC;
+
+ // The loop depth of the currently visited node.
+ int mLoopDepth;
+
+ // The switch statement depth of the currently visited node.
+ // Note that this is tracked separately from the loop depth because
+ // SwitchStmt-contained ContinueStmt's should have destructors for the
+ // corresponding loop scope.
+ int mSwitchDepth;
+
+ // The outermost statement block that we are currently visiting.
+ // This should always be a CompoundStmt.
+ clang::Stmt *mOuterStmt;
+
+ // The list of destructors to execute for this scope.
std::list<clang::Stmt*> &mStmtList;
- bool mTopLevel;
+
+ // The stack of statements which should be replaced by a compound statement
+ // containing the new destructor calls followed by the original Stmt.
+ std::stack<clang::Stmt*> mReplaceStmtStack;
+
public:
- DestructorVisitor(clang::ASTContext &C, std::list<clang::Stmt*> &StmtList);
+ DestructorVisitor(clang::ASTContext &C,
+ clang::Stmt* OuterStmt,
+ std::list<clang::Stmt*> &StmtList);
+
+ // This code walks the collected list of Stmts to replace and actually does
+ // the replacement. It also finishes up by appending appropriate destructors
+ // to the current outermost CompoundStmt.
+ void InsertDestructors() {
+ clang::Stmt *S = NULL;
+ while (!mReplaceStmtStack.empty()) {
+ S = mReplaceStmtStack.top();
+ mReplaceStmtStack.pop();
+
+ mStmtList.push_back(S);
+ clang::CompoundStmt *CS =
+ BuildCompoundStmt(mC, mStmtList, S->getLocEnd());
+ mStmtList.pop_back();
+
+ RSASTReplace R(mC);
+ R.ReplaceStmt(mOuterStmt, S, CS);
+ }
+ clang::CompoundStmt *CS = dyn_cast<clang::CompoundStmt>(mOuterStmt);
+ slangAssert(CS);
+ if (CS) {
+ AppendAfterStmt(mC, CS, NULL, mStmtList);
+ }
+ }
+
void VisitStmt(clang::Stmt *S);
void VisitCompoundStmt(clang::CompoundStmt *CS);
+
+ void VisitBreakStmt(clang::BreakStmt *BS);
+ void VisitCaseStmt(clang::CaseStmt *CS);
+ void VisitContinueStmt(clang::ContinueStmt *CS);
+ void VisitDefaultStmt(clang::DefaultStmt *DS);
+ void VisitDoStmt(clang::DoStmt *DS);
+ void VisitForStmt(clang::ForStmt *FS);
+ void VisitIfStmt(clang::IfStmt *IS);
+ void VisitReturnStmt(clang::ReturnStmt *RS);
+ void VisitSwitchCase(clang::SwitchCase *SC);
+ void VisitSwitchStmt(clang::SwitchStmt *SS);
+ void VisitWhileStmt(clang::WhileStmt *WS);
};
DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
- std::list<clang::Stmt*> &StmtList)
+ clang::Stmt *OuterStmt,
+ std::list<clang::Stmt*> &StmtList)
: mC(C),
- mStmtList(StmtList),
- mTopLevel(true) {
- return;
-}
-
-void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
- if (!CS->body_empty()) {
- AppendToCompoundStatement(mC, CS, mStmtList, mTopLevel);
- mTopLevel = false;
- VisitStmt(CS);
- }
+ mLoopDepth(0),
+ mSwitchDepth(0),
+ mOuterStmt(OuterStmt),
+ mStmtList(StmtList) {
return;
}
@@ -250,6 +272,82 @@
return;
}
+void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
+ VisitStmt(CS);
+ return;
+}
+
+void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
+ VisitStmt(BS);
+ if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
+ mReplaceStmtStack.push(BS);
+ }
+ return;
+}
+
+void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
+ VisitStmt(CS);
+ return;
+}
+
+void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
+ VisitStmt(CS);
+ if (mLoopDepth == 0) {
+ // Switch statements can have nested continues.
+ mReplaceStmtStack.push(CS);
+ }
+ return;
+}
+
+void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
+ VisitStmt(DS);
+ return;
+}
+
+void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
+ mLoopDepth++;
+ VisitStmt(DS);
+ mLoopDepth--;
+ return;
+}
+
+void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
+ mLoopDepth++;
+ VisitStmt(FS);
+ mLoopDepth--;
+ return;
+}
+
+void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
+ VisitStmt(IS);
+ return;
+}
+
+void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
+ mReplaceStmtStack.push(RS);
+ return;
+}
+
+void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
+ slangAssert(false && "Both case and default have specialized handlers");
+ VisitStmt(SC);
+ return;
+}
+
+void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
+ mSwitchDepth++;
+ VisitStmt(SS);
+ mSwitchDepth--;
+ return;
+}
+
+void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
+ mLoopDepth++;
+ VisitStmt(WS);
+ mLoopDepth--;
+ return;
+}
+
clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
clang::Expr *RefRSVar,
clang::SourceLocation Loc) {
@@ -974,7 +1072,8 @@
CreateSingleRSSetObject(C, Diags, AS->getLHS(), AS->getRHS(), Loc);
}
- ReplaceInCompoundStmt(C, mCS, AS, UpdatedStmt);
+ RSASTReplace R(C);
+ R.ReplaceStmt(mCS, AS, UpdatedStmt);
return;
}
@@ -996,7 +1095,6 @@
RSExportPrimitiveType::DataTypeRSFont)->getLocation();
if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
- // TODO(srhines): Skip struct initialization right now
const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
clang::DeclRefExpr *RefRSVar =
clang::DeclRefExpr::Create(C,
@@ -1010,7 +1108,9 @@
clang::Stmt *RSSetObjectOps =
CreateStructRSSetObject(C, Diags, RefRSVar, InitExpr, Loc);
- AppendAfterStmt(C, mCS, DS, RSSetObjectOps);
+ std::list<clang::Stmt*> StmtList;
+ StmtList.push_back(RSSetObjectOps);
+ AppendAfterStmt(C, mCS, DS, StmtList);
return;
}
@@ -1068,7 +1168,9 @@
clang::VK_RValue,
Loc);
- AppendAfterStmt(C, mCS, DS, RSSetObjectCall);
+ std::list<clang::Stmt*> StmtList;
+ StmtList.push_back(RSSetObjectCall);
+ AppendAfterStmt(C, mCS, DS, StmtList);
return;
}
@@ -1085,8 +1187,11 @@
}
}
if (RSClearObjectCalls.size() > 0) {
- DestructorVisitor DV((*mRSO.begin())->getASTContext(), RSClearObjectCalls);
+ DestructorVisitor DV((*mRSO.begin())->getASTContext(),
+ mCS,
+ RSClearObjectCalls);
DV.Visit(mCS);
+ DV.InsertDestructors();
}
return;
}
diff --git a/slang_rs_object_ref_count.h b/slang_rs_object_ref_count.h
index 33a5a89..78984a7 100644
--- a/slang_rs_object_ref_count.h
+++ b/slang_rs_object_ref_count.h
@@ -32,6 +32,16 @@
namespace slang {
+// This class provides the overall reference counting mechanism for handling
+// local variables of RS object types (rs_font, rs_allocation, ...). This
+// class ensures that appropriate functions (rsSetObject, rsClearObject) are
+// called at proper points in the object's lifetime.
+// 1) Each local object of appropriate type must be zero-initialized (to
+// prevent corruption) during subsequent rsSetObject()/rsClearObject() calls.
+// 2) Assignments using these types must also be converted into the
+// appropriate (possibly a series of) rsSetObject() calls.
+// 3) Finally, each local object must call rsClearObject() when it goes out
+// of scope.
class RSObjectRefCount : public clang::StmtVisitor<RSObjectRefCount> {
private:
class Scope {
@@ -79,9 +89,6 @@
// Initialize RSSetObjectFD and RSClearObjectFD.
static void GetRSRefCountingFunctions(clang::ASTContext &C);
- // TODO(srhines): Composite types and arrays based on RS object types need
- // to be handled for both zero-initialization + clearing.
-
// Return false if the type of variable declared in VD does not contain
// an RS object type.
static bool InitializeRSObject(clang::VarDecl *VD,