[arcmt] Allow removing an -autorelease of a variable initialized in the previous statement.

rdar://11074996

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@171485 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/ARCMigrate/TransRetainReleaseDealloc.cpp b/lib/ARCMigrate/TransRetainReleaseDealloc.cpp
index 97d9bff..0c8d155 100644
--- a/lib/ARCMigrate/TransRetainReleaseDealloc.cpp
+++ b/lib/ARCMigrate/TransRetainReleaseDealloc.cpp
@@ -174,7 +174,7 @@
   ///   return var;
   ///
   bool isCommonUnusedAutorelease(ObjCMessageExpr *E) {
-    if (isPlusOneAssignAfterAutorelease(E))
+    if (isPlusOneAssignBeforeOrAfterAutorelease(E))
       return true;
     if (isReturnedAfterAutorelease(E))
       return true;
@@ -202,7 +202,7 @@
     return false;
   }
 
-  bool isPlusOneAssignAfterAutorelease(ObjCMessageExpr *E) {
+  bool isPlusOneAssignBeforeOrAfterAutorelease(ObjCMessageExpr *E) {
     Expr *Rec = E->getInstanceReceiver();
     if (!Rec)
       return false;
@@ -211,24 +211,46 @@
     if (!RefD)
       return false;
 
-    Stmt *nextStmt = getNextStmt(E);
-    if (!nextStmt)
+    Stmt *prevStmt, *nextStmt;
+    llvm::tie(prevStmt, nextStmt) = getPreviousAndNextStmt(E);
+
+    return isPlusOneAssignToVar(prevStmt, RefD) ||
+           isPlusOneAssignToVar(nextStmt, RefD);
+  }
+
+  bool isPlusOneAssignToVar(Stmt *S, Decl *RefD) {
+    if (!S)
       return false;
 
     // Check for "RefD = [+1 retained object];".
 
-    if (BinaryOperator *Bop = dyn_cast<BinaryOperator>(nextStmt)) {
+    if (BinaryOperator *Bop = dyn_cast<BinaryOperator>(S)) {
       if (RefD != getReferencedDecl(Bop->getLHS()))
         return false;
       if (isPlusOneAssign(Bop))
         return true;
+      return false;
     }
+
+    if (DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
+      if (DS->isSingleDecl() && DS->getSingleDecl() == RefD) {
+        if (VarDecl *VD = dyn_cast<VarDecl>(RefD))
+          return isPlusOne(VD->getInit());
+      }
+      return false;
+    }
+
     return false;
   }
 
   Stmt *getNextStmt(Expr *E) {
+    return getPreviousAndNextStmt(E).second;
+  }
+
+  std::pair<Stmt *, Stmt *> getPreviousAndNextStmt(Expr *E) {
+    Stmt *prevStmt = 0, *nextStmt = 0;
     if (!E)
-      return 0;
+      return std::make_pair(prevStmt, nextStmt);
 
     Stmt *OuterS = E, *InnerS;
     do {
@@ -240,24 +262,34 @@
                       isa<ExprWithCleanups>(OuterS)));
     
     if (!OuterS)
-      return 0;
+      return std::make_pair(prevStmt, nextStmt);
 
     Stmt::child_iterator currChildS = OuterS->child_begin();
     Stmt::child_iterator childE = OuterS->child_end();
+    Stmt::child_iterator prevChildS = childE;
     for (; currChildS != childE; ++currChildS) {
       if (*currChildS == InnerS)
         break;
+      prevChildS = currChildS;
     }
+
+    if (prevChildS != childE) {
+      prevStmt = *prevChildS;
+      if (prevStmt)
+        prevStmt = prevStmt->IgnoreImplicit();
+    }
+
     if (currChildS == childE)
-      return 0;
+      return std::make_pair(prevStmt, nextStmt);
     ++currChildS;
     if (currChildS == childE)
-      return 0;
+      return std::make_pair(prevStmt, nextStmt);
 
-    Stmt *nextStmt = *currChildS;
-    if (!nextStmt)
-      return 0;
-    return nextStmt->IgnoreImplicit();
+    nextStmt = *currChildS;
+    if (nextStmt)
+      nextStmt = nextStmt->IgnoreImplicit();
+
+    return std::make_pair(prevStmt, nextStmt);
   }
 
   Decl *getReferencedDecl(Expr *E) {
diff --git a/lib/ARCMigrate/Transforms.cpp b/lib/ARCMigrate/Transforms.cpp
index 938f015..136f618 100644
--- a/lib/ARCMigrate/Transforms.cpp
+++ b/lib/ARCMigrate/Transforms.cpp
@@ -71,13 +71,22 @@
   if (E->getOpcode() != BO_Assign)
     return false;
 
+  return isPlusOne(E->getRHS());
+}
+
+bool trans::isPlusOne(const Expr *E) {
+  if (!E)
+    return false;
+  if (const ExprWithCleanups *EWC = dyn_cast<ExprWithCleanups>(E))
+    E = EWC->getSubExpr();
+
   if (const ObjCMessageExpr *
-        ME = dyn_cast<ObjCMessageExpr>(E->getRHS()->IgnoreParenCasts()))
+        ME = dyn_cast<ObjCMessageExpr>(E->IgnoreParenCasts()))
     if (ME->getMethodFamily() == OMF_retain)
       return true;
 
   if (const CallExpr *
-        callE = dyn_cast<CallExpr>(E->getRHS()->IgnoreParenCasts())) {
+        callE = dyn_cast<CallExpr>(E->IgnoreParenCasts())) {
     if (const FunctionDecl *FD = callE->getDirectCallee()) {
       if (FD->getAttr<CFReturnsRetainedAttr>())
         return true;
@@ -98,7 +107,7 @@
     }
   }
 
-  const ImplicitCastExpr *implCE = dyn_cast<ImplicitCastExpr>(E->getRHS());
+  const ImplicitCastExpr *implCE = dyn_cast<ImplicitCastExpr>(E);
   while (implCE && implCE->getCastKind() ==  CK_BitCast)
     implCE = dyn_cast<ImplicitCastExpr>(implCE->getSubExpr());
 
diff --git a/lib/ARCMigrate/Transforms.h b/lib/ARCMigrate/Transforms.h
index 69e91e7..cb7d153 100644
--- a/lib/ARCMigrate/Transforms.h
+++ b/lib/ARCMigrate/Transforms.h
@@ -161,6 +161,7 @@
                   bool AllowOnUnknownClass = false);
 
 bool isPlusOneAssign(const BinaryOperator *E);
+bool isPlusOne(const Expr *E);
 
 /// \brief 'Loc' is the end of a statement range. This returns the location
 /// immediately after the semicolon following the statement.
diff --git a/test/ARCMT/autoreleases.m b/test/ARCMT/autoreleases.m
index d491394..543bcf6 100644
--- a/test/ARCMT/autoreleases.m
+++ b/test/ARCMT/autoreleases.m
@@ -69,3 +69,8 @@
   [[val retain] autorelease];
   return val;
 }
+
+id test3() {
+  id a = [[A alloc] init];
+  [a autorelease];
+}
diff --git a/test/ARCMT/autoreleases.m.result b/test/ARCMT/autoreleases.m.result
index 76ce8cf..9b71ff8 100644
--- a/test/ARCMT/autoreleases.m.result
+++ b/test/ARCMT/autoreleases.m.result
@@ -64,3 +64,7 @@
 id test2(A* val) {
   return val;
 }
+
+id test3() {
+  id a = [[A alloc] init];
+}