PR40642: Fix determination of whether the final statement of a statement
expression is a discarded-value expression.

Summary:
We used to get this wrong in three ways:

1) During parsing, an expression-statement followed by the }) ending a
   statement expression was always treated as producing the value of the
   statement expression. That's wrong for ({ if (1) expr; })
2) During template instantiation, various kinds of statement (most
   statements not appearing directly in a compound-statement) were not
   treated as discarded-value expressions, resulting in missing volatile
   loads (etc).
3) In all contexts, an expression-statement with attributes was not
   treated as producing the value of the statement expression, eg
   ({ [[attr]] expr; }).

Also fix incorrect enforcement of OpenMP rule that directives can "only
be placed in the program at a position where ignoring or deleting the
directive would result in a program with correct syntax". In particular,
a label (be it goto, case, or default) should not affect whether
directives are permitted.

Reviewers: aaron.ballman, rjmccall

Subscribers: cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D57984

llvm-svn: 354090
diff --git a/clang/lib/AST/Stmt.cpp b/clang/lib/AST/Stmt.cpp
index d1a4dc0..f2bb289 100644
--- a/clang/lib/AST/Stmt.cpp
+++ b/clang/lib/AST/Stmt.cpp
@@ -320,6 +320,23 @@
   return New;
 }
 
+const Expr *ValueStmt::getExprStmt() const {
+  const Stmt *S = this;
+  do {
+    if (const auto *E = dyn_cast<Expr>(S))
+      return E;
+
+    if (const auto *LS = dyn_cast<LabelStmt>(S))
+      S = LS->getSubStmt();
+    else if (const auto *AS = dyn_cast<AttributedStmt>(S))
+      S = AS->getSubStmt();
+    else
+      llvm_unreachable("unknown kind of ValueStmt");
+  } while (isa<ValueStmt>(S));
+
+  return nullptr;
+}
+
 const char *LabelStmt::getName() const {
   return getDecl()->getIdentifier()->getNameStart();
 }
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index db3f4f1..5b5113b 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -391,26 +391,35 @@
     // at the end of a statement expression, they yield the value of their
     // subexpression.  Handle this by walking through all labels we encounter,
     // emitting them before we evaluate the subexpr.
+    // Similar issues arise for attributed statements.
     const Stmt *LastStmt = S.body_back();
-    while (const LabelStmt *LS = dyn_cast<LabelStmt>(LastStmt)) {
-      EmitLabel(LS->getDecl());
-      LastStmt = LS->getSubStmt();
+    while (!isa<Expr>(LastStmt)) {
+      if (const auto *LS = dyn_cast<LabelStmt>(LastStmt)) {
+        EmitLabel(LS->getDecl());
+        LastStmt = LS->getSubStmt();
+      } else if (const auto *AS = dyn_cast<AttributedStmt>(LastStmt)) {
+        // FIXME: Update this if we ever have attributes that affect the
+        // semantics of an expression.
+        LastStmt = AS->getSubStmt();
+      } else {
+        llvm_unreachable("unknown value statement");
+      }
     }
 
     EnsureInsertPoint();
 
-    QualType ExprTy = cast<Expr>(LastStmt)->getType();
+    const Expr *E = cast<Expr>(LastStmt);
+    QualType ExprTy = E->getType();
     if (hasAggregateEvaluationKind(ExprTy)) {
-      EmitAggExpr(cast<Expr>(LastStmt), AggSlot);
+      EmitAggExpr(E, AggSlot);
     } else {
       // We can't return an RValue here because there might be cleanups at
       // the end of the StmtExpr.  Because of that, we have to emit the result
       // here into a temporary alloca.
       RetAlloca = CreateMemTemp(ExprTy);
-      EmitAnyExprToMem(cast<Expr>(LastStmt), RetAlloca, Qualifiers(),
-                       /*IsInit*/false);
+      EmitAnyExprToMem(E, RetAlloca, Qualifiers(),
+                       /*IsInit*/ false);
     }
-
   }
 
   return RetAlloca;
diff --git a/clang/lib/Parse/ParseObjc.cpp b/clang/lib/Parse/ParseObjc.cpp
index 0a1dbf7..5ac01f6 100644
--- a/clang/lib/Parse/ParseObjc.cpp
+++ b/clang/lib/Parse/ParseObjc.cpp
@@ -2703,7 +2703,8 @@
   return MDecl;
 }
 
-StmtResult Parser::ParseObjCAtStatement(SourceLocation AtLoc) {
+StmtResult Parser::ParseObjCAtStatement(SourceLocation AtLoc,
+                                        ParsedStmtContext StmtCtx) {
   if (Tok.is(tok::code_completion)) {
     Actions.CodeCompleteObjCAtStatement(getCurScope());
     cutOffParsing();
@@ -2740,7 +2741,7 @@
 
   // Otherwise, eat the semicolon.
   ExpectAndConsumeSemi(diag::err_expected_semi_after_expr);
-  return Actions.ActOnExprStmt(Res, isExprValueDiscarded());
+  return handleExprStmt(Res, StmtCtx);
 }
 
 ExprResult Parser::ParseObjCAtExpression(SourceLocation AtLoc) {
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 5b8670e..d046341 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -1132,8 +1132,8 @@
 ///         'target teams distribute simd' {clause}
 ///         annot_pragma_openmp_end
 ///
-StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
-    AllowedConstructsKind Allowed) {
+StmtResult
+Parser::ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx) {
   assert(Tok.is(tok::annot_pragma_openmp) && "Not an OpenMP directive!");
   ParenBraceBracketBalancer BalancerRAIIObj(*this);
   SmallVector<OMPClause *, 5> Clauses;
@@ -1152,7 +1152,9 @@
 
   switch (DKind) {
   case OMPD_threadprivate: {
-    if (Allowed != ACK_Any) {
+    // FIXME: Should this be permitted in C++?
+    if ((StmtCtx & ParsedStmtContext::AllowDeclarationsInC) ==
+        ParsedStmtContext()) {
       Diag(Tok, diag::err_omp_immediate_directive)
           << getOpenMPDirectiveName(DKind) << 0;
     }
@@ -1219,7 +1221,8 @@
   case OMPD_target_enter_data:
   case OMPD_target_exit_data:
   case OMPD_target_update:
-    if (Allowed == ACK_StatementsOpenMPNonStandalone) {
+    if ((StmtCtx & ParsedStmtContext::AllowStandaloneOpenMPDirectives) ==
+        ParsedStmtContext()) {
       Diag(Tok, diag::err_omp_immediate_directive)
           << getOpenMPDirectiveName(DKind) << 0;
     }
@@ -1322,7 +1325,8 @@
     // If the depend clause is specified, the ordered construct is a stand-alone
     // directive.
     if (DKind == OMPD_ordered && FirstClauses[OMPC_depend].getInt()) {
-      if (Allowed == ACK_StatementsOpenMPNonStandalone) {
+      if ((StmtCtx & ParsedStmtContext::AllowStandaloneOpenMPDirectives) ==
+          ParsedStmtContext()) {
         Diag(Loc, diag::err_omp_immediate_directive)
             << getOpenMPDirectiveName(DKind) << 1
             << getOpenMPClauseName(OMPC_depend);
diff --git a/clang/lib/Parse/ParseStmt.cpp b/clang/lib/Parse/ParseStmt.cpp
index 3fc2925..267003b 100644
--- a/clang/lib/Parse/ParseStmt.cpp
+++ b/clang/lib/Parse/ParseStmt.cpp
@@ -29,17 +29,14 @@
 /// Parse a standalone statement (for instance, as the body of an 'if',
 /// 'while', or 'for').
 StmtResult Parser::ParseStatement(SourceLocation *TrailingElseLoc,
-                                  bool AllowOpenMPStandalone) {
+                                  ParsedStmtContext StmtCtx) {
   StmtResult Res;
 
   // We may get back a null statement if we found a #pragma. Keep going until
   // we get an actual statement.
   do {
     StmtVector Stmts;
-    Res = ParseStatementOrDeclaration(
-        Stmts, AllowOpenMPStandalone ? ACK_StatementsOpenMPAnyExecutable
-                                     : ACK_StatementsOpenMPNonStandalone,
-        TrailingElseLoc);
+    Res = ParseStatementOrDeclaration(Stmts, StmtCtx, TrailingElseLoc);
   } while (!Res.isInvalid() && !Res.get());
 
   return Res;
@@ -96,7 +93,7 @@
 ///
 StmtResult
 Parser::ParseStatementOrDeclaration(StmtVector &Stmts,
-                                    AllowedConstructsKind Allowed,
+                                    ParsedStmtContext StmtCtx,
                                     SourceLocation *TrailingElseLoc) {
 
   ParenBraceBracketBalancer BalancerRAIIObj(*this);
@@ -107,7 +104,7 @@
     return StmtError();
 
   StmtResult Res = ParseStatementOrDeclarationAfterAttributes(
-      Stmts, Allowed, TrailingElseLoc, Attrs);
+      Stmts, StmtCtx, TrailingElseLoc, Attrs);
 
   assert((Attrs.empty() || Res.isInvalid() || Res.isUsable()) &&
          "attributes on empty statement");
@@ -147,10 +144,9 @@
 };
 }
 
-StmtResult
-Parser::ParseStatementOrDeclarationAfterAttributes(StmtVector &Stmts,
-          AllowedConstructsKind Allowed, SourceLocation *TrailingElseLoc,
-          ParsedAttributesWithRange &Attrs) {
+StmtResult Parser::ParseStatementOrDeclarationAfterAttributes(
+    StmtVector &Stmts, ParsedStmtContext StmtCtx,
+    SourceLocation *TrailingElseLoc, ParsedAttributesWithRange &Attrs) {
   const char *SemiError = nullptr;
   StmtResult Res;
 
@@ -165,7 +161,7 @@
     {
       ProhibitAttributes(Attrs); // TODO: is it correct?
       AtLoc = ConsumeToken();  // consume @
-      return ParseObjCAtStatement(AtLoc);
+      return ParseObjCAtStatement(AtLoc, StmtCtx);
     }
 
   case tok::code_completion:
@@ -177,7 +173,7 @@
     Token Next = NextToken();
     if (Next.is(tok::colon)) { // C99 6.8.1: labeled-statement
       // identifier ':' statement
-      return ParseLabeledStatement(Attrs);
+      return ParseLabeledStatement(Attrs, StmtCtx);
     }
 
     // Look up the identifier, and typo-correct it to a keyword if it's not
@@ -207,7 +203,8 @@
 
   default: {
     if ((getLangOpts().CPlusPlus || getLangOpts().MicrosoftExt ||
-         Allowed == ACK_Any) &&
+         (StmtCtx & ParsedStmtContext::AllowDeclarationsInC) !=
+             ParsedStmtContext()) &&
         isDeclarationStatement()) {
       SourceLocation DeclStart = Tok.getLocation(), DeclEnd;
       DeclGroupPtrTy Decl = ParseDeclaration(DeclaratorContext::BlockContext,
@@ -220,13 +217,13 @@
       return StmtError();
     }
 
-    return ParseExprStatement();
+    return ParseExprStatement(StmtCtx);
   }
 
   case tok::kw_case:                // C99 6.8.1: labeled-statement
-    return ParseCaseStatement();
+    return ParseCaseStatement(StmtCtx);
   case tok::kw_default:             // C99 6.8.1: labeled-statement
-    return ParseDefaultStatement();
+    return ParseDefaultStatement(StmtCtx);
 
   case tok::l_brace:                // C99 6.8.2: compound-statement
     return ParseCompoundStatement();
@@ -363,7 +360,7 @@
 
   case tok::annot_pragma_openmp:
     ProhibitAttributes(Attrs);
-    return ParseOpenMPDeclarativeOrExecutableDirective(Allowed);
+    return ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx);
 
   case tok::annot_pragma_ms_pointers_to_members:
     ProhibitAttributes(Attrs);
@@ -382,7 +379,7 @@
 
   case tok::annot_pragma_loop_hint:
     ProhibitAttributes(Attrs);
-    return ParsePragmaLoopHint(Stmts, Allowed, TrailingElseLoc, Attrs);
+    return ParsePragmaLoopHint(Stmts, StmtCtx, TrailingElseLoc, Attrs);
 
   case tok::annot_pragma_dump:
     HandlePragmaDump();
@@ -407,7 +404,7 @@
 }
 
 /// Parse an expression statement.
-StmtResult Parser::ParseExprStatement() {
+StmtResult Parser::ParseExprStatement(ParsedStmtContext StmtCtx) {
   // If a case keyword is missing, this is where it should be inserted.
   Token OldToken = Tok;
 
@@ -433,12 +430,12 @@
       << FixItHint::CreateInsertion(OldToken.getLocation(), "case ");
 
     // Recover parsing as a case statement.
-    return ParseCaseStatement(/*MissingCase=*/true, Expr);
+    return ParseCaseStatement(StmtCtx, /*MissingCase=*/true, Expr);
   }
 
   // Otherwise, eat the semicolon.
   ExpectAndConsumeSemi(diag::err_expected_semi_after_expr);
-  return Actions.ActOnExprStmt(Expr, isExprValueDiscarded());
+  return handleExprStmt(Expr, StmtCtx);
 }
 
 /// ParseSEHTryBlockCommon
@@ -577,10 +574,15 @@
 ///         identifier ':' statement
 /// [GNU]   identifier ':' attributes[opt] statement
 ///
-StmtResult Parser::ParseLabeledStatement(ParsedAttributesWithRange &attrs) {
+StmtResult Parser::ParseLabeledStatement(ParsedAttributesWithRange &attrs,
+                                         ParsedStmtContext StmtCtx) {
   assert(Tok.is(tok::identifier) && Tok.getIdentifierInfo() &&
          "Not an identifier!");
 
+  // The substatement is always a 'statement', not a 'declaration', but is
+  // otherwise in the same context as the labeled-statement.
+  StmtCtx &= ~ParsedStmtContext::AllowDeclarationsInC;
+
   Token IdentTok = Tok;  // Save the whole token.
   ConsumeToken();  // eat the identifier.
 
@@ -610,9 +612,8 @@
       // statement, but that doesn't work correctly (because ProhibitAttributes
       // can't handle GNU attributes), so only call it in the one case where
       // GNU attributes are allowed.
-      SubStmt = ParseStatementOrDeclarationAfterAttributes(
-          Stmts, /*Allowed=*/ACK_StatementsOpenMPNonStandalone, nullptr,
-          TempAttrs);
+      SubStmt = ParseStatementOrDeclarationAfterAttributes(Stmts, StmtCtx,
+                                                           nullptr, TempAttrs);
       if (!TempAttrs.empty() && !SubStmt.isInvalid())
         SubStmt = Actions.ProcessStmtAttributes(SubStmt.get(), TempAttrs,
                                                 TempAttrs.Range);
@@ -623,7 +624,7 @@
 
   // If we've not parsed a statement yet, parse one now.
   if (!SubStmt.isInvalid() && !SubStmt.isUsable())
-    SubStmt = ParseStatement();
+    SubStmt = ParseStatement(nullptr, StmtCtx);
 
   // Broken substmt shouldn't prevent the label from being added to the AST.
   if (SubStmt.isInvalid())
@@ -643,9 +644,14 @@
 ///         'case' constant-expression ':' statement
 /// [GNU]   'case' constant-expression '...' constant-expression ':' statement
 ///
-StmtResult Parser::ParseCaseStatement(bool MissingCase, ExprResult Expr) {
+StmtResult Parser::ParseCaseStatement(ParsedStmtContext StmtCtx,
+                                      bool MissingCase, ExprResult Expr) {
   assert((MissingCase || Tok.is(tok::kw_case)) && "Not a case stmt!");
 
+  // The substatement is always a 'statement', not a 'declaration', but is
+  // otherwise in the same context as the labeled-statement.
+  StmtCtx &= ~ParsedStmtContext::AllowDeclarationsInC;
+
   // It is very very common for code to contain many case statements recursively
   // nested, as in (but usually without indentation):
   //  case 1:
@@ -737,8 +743,7 @@
     // continue parsing the sub-stmt.
     if (Case.isInvalid()) {
       if (TopLevelCase.isInvalid())  // No parsed case stmts.
-        return ParseStatement(/*TrailingElseLoc=*/nullptr,
-                              /*AllowOpenMPStandalone=*/true);
+        return ParseStatement(/*TrailingElseLoc=*/nullptr, StmtCtx);
       // Otherwise, just don't add it as a nested case.
     } else {
       // If this is the first case statement we parsed, it becomes TopLevelCase.
@@ -758,8 +763,7 @@
   StmtResult SubStmt;
 
   if (Tok.isNot(tok::r_brace)) {
-    SubStmt = ParseStatement(/*TrailingElseLoc=*/nullptr,
-                             /*AllowOpenMPStandalone=*/true);
+    SubStmt = ParseStatement(/*TrailingElseLoc=*/nullptr, StmtCtx);
   } else {
     // Nicely diagnose the common error "switch (X) { case 4: }", which is
     // not valid.  If ColonLoc doesn't point to a valid text location, there was
@@ -789,8 +793,13 @@
 ///         'default' ':' statement
 /// Note that this does not parse the 'statement' at the end.
 ///
-StmtResult Parser::ParseDefaultStatement() {
+StmtResult Parser::ParseDefaultStatement(ParsedStmtContext StmtCtx) {
   assert(Tok.is(tok::kw_default) && "Not a default stmt!");
+
+  // The substatement is always a 'statement', not a 'declaration', but is
+  // otherwise in the same context as the labeled-statement.
+  StmtCtx &= ~ParsedStmtContext::AllowDeclarationsInC;
+
   SourceLocation DefaultLoc = ConsumeToken();  // eat the 'default'.
 
   SourceLocation ColonLoc;
@@ -811,8 +820,7 @@
   StmtResult SubStmt;
 
   if (Tok.isNot(tok::r_brace)) {
-    SubStmt = ParseStatement(/*TrailingElseLoc=*/nullptr,
-                             /*AllowOpenMPStandalone=*/true);
+    SubStmt = ParseStatement(/*TrailingElseLoc=*/nullptr, StmtCtx);
   } else {
     // Diagnose the common error "switch (X) {... default: }", which is
     // not valid.
@@ -943,7 +951,8 @@
     EndLoc = Tok.getLocation();
 
     // Don't just ConsumeToken() this tok::semi, do store it in AST.
-    StmtResult R = ParseStatementOrDeclaration(Stmts, ACK_Any);
+    StmtResult R =
+        ParseStatementOrDeclaration(Stmts, ParsedStmtContext::SubStmt);
     if (R.isUsable())
       Stmts.push_back(R.get());
   }
@@ -957,14 +966,18 @@
   return true;
 }
 
-bool Parser::isExprValueDiscarded() {
-  if (Actions.isCurCompoundStmtAStmtExpr()) {
-    // Look to see if the next two tokens close the statement expression;
+StmtResult Parser::handleExprStmt(ExprResult E, ParsedStmtContext StmtCtx) {
+  bool IsStmtExprResult = false;
+  if ((StmtCtx & ParsedStmtContext::InStmtExpr) != ParsedStmtContext()) {
+    // Look ahead to see if the next two tokens close the statement expression;
     // if so, this expression statement is the last statement in a
     // statment expression.
-    return Tok.isNot(tok::r_brace) || NextToken().isNot(tok::r_paren);
+    IsStmtExprResult = Tok.is(tok::r_brace) && NextToken().is(tok::r_paren);
   }
-  return true;
+
+  if (IsStmtExprResult)
+    E = Actions.ActOnStmtExprResult(E);
+  return Actions.ActOnExprStmt(E, /*DiscardedValue=*/!IsStmtExprResult);
 }
 
 /// ParseCompoundStatementBody - Parse a sequence of statements and invoke the
@@ -1022,6 +1035,10 @@
       Stmts.push_back(R.get());
   }
 
+  ParsedStmtContext SubStmtCtx =
+      ParsedStmtContext::Compound |
+      (isStmtExpr ? ParsedStmtContext::InStmtExpr : ParsedStmtContext());
+
   while (!tryParseMisplacedModuleImport() && Tok.isNot(tok::r_brace) &&
          Tok.isNot(tok::eof)) {
     if (Tok.is(tok::annot_pragma_unused)) {
@@ -1034,7 +1051,7 @@
 
     StmtResult R;
     if (Tok.isNot(tok::kw___extension__)) {
-      R = ParseStatementOrDeclaration(Stmts, ACK_Any);
+      R = ParseStatementOrDeclaration(Stmts, SubStmtCtx);
     } else {
       // __extension__ can start declarations and it can also be a unary
       // operator for expressions.  Consume multiple __extension__ markers here
@@ -1067,11 +1084,12 @@
           continue;
         }
 
-        // FIXME: Use attributes?
         // Eat the semicolon at the end of stmt and convert the expr into a
         // statement.
         ExpectAndConsumeSemi(diag::err_expected_semi_after_expr);
-        R = Actions.ActOnExprStmt(Res, isExprValueDiscarded());
+        R = handleExprStmt(Res, SubStmtCtx);
+        if (R.isUsable())
+          R = Actions.ProcessStmtAttributes(R.get(), attrs, attrs.Range);
       }
     }
 
@@ -2001,7 +2019,7 @@
 }
 
 StmtResult Parser::ParsePragmaLoopHint(StmtVector &Stmts,
-                                       AllowedConstructsKind Allowed,
+                                       ParsedStmtContext StmtCtx,
                                        SourceLocation *TrailingElseLoc,
                                        ParsedAttributesWithRange &Attrs) {
   // Create temporary attribute list.
@@ -2024,7 +2042,7 @@
   MaybeParseCXX11Attributes(Attrs);
 
   StmtResult S = ParseStatementOrDeclarationAfterAttributes(
-      Stmts, Allowed, TrailingElseLoc, Attrs);
+      Stmts, StmtCtx, TrailingElseLoc, Attrs);
 
   Attrs.takeAllFrom(TempAttrs);
   return S;
@@ -2329,7 +2347,8 @@
 
   // Condition is true, parse the statements.
   while (Tok.isNot(tok::r_brace)) {
-    StmtResult R = ParseStatementOrDeclaration(Stmts, ACK_Any);
+    StmtResult R =
+        ParseStatementOrDeclaration(Stmts, ParsedStmtContext::Compound);
     if (R.isUsable())
       Stmts.push_back(R.get());
   }
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index b992f33..00b4cdd 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -13301,29 +13301,6 @@
                                      Context.getPointerType(Context.VoidTy));
 }
 
-/// Given the last statement in a statement-expression, check whether
-/// the result is a producing expression (like a call to an
-/// ns_returns_retained function) and, if so, rebuild it to hoist the
-/// release out of the full-expression.  Otherwise, return null.
-/// Cannot fail.
-static Expr *maybeRebuildARCConsumingStmt(Stmt *Statement) {
-  // Should always be wrapped with one of these.
-  ExprWithCleanups *cleanups = dyn_cast<ExprWithCleanups>(Statement);
-  if (!cleanups) return nullptr;
-
-  ImplicitCastExpr *cast = dyn_cast<ImplicitCastExpr>(cleanups->getSubExpr());
-  if (!cast || cast->getCastKind() != CK_ARCConsumeObject)
-    return nullptr;
-
-  // Splice out the cast.  This shouldn't modify any interesting
-  // features of the statement.
-  Expr *producer = cast->getSubExpr();
-  assert(producer->getType() == cast->getType());
-  assert(producer->getValueKind() == cast->getValueKind());
-  cleanups->setSubExpr(producer);
-  return cleanups;
-}
-
 void Sema::ActOnStartStmtExpr() {
   PushExpressionEvaluationContext(ExprEvalContexts.back().Context);
 }
@@ -13357,47 +13334,10 @@
   QualType Ty = Context.VoidTy;
   bool StmtExprMayBindToTemp = false;
   if (!Compound->body_empty()) {
-    Stmt *LastStmt = Compound->body_back();
-    LabelStmt *LastLabelStmt = nullptr;
-    // If LastStmt is a label, skip down through into the body.
-    while (LabelStmt *Label = dyn_cast<LabelStmt>(LastStmt)) {
-      LastLabelStmt = Label;
-      LastStmt = Label->getSubStmt();
-    }
-
-    if (Expr *LastE = dyn_cast<Expr>(LastStmt)) {
-      // Do function/array conversion on the last expression, but not
-      // lvalue-to-rvalue.  However, initialize an unqualified type.
-      ExprResult LastExpr = DefaultFunctionArrayConversion(LastE);
-      if (LastExpr.isInvalid())
-        return ExprError();
-      Ty = LastExpr.get()->getType().getUnqualifiedType();
-
-      if (!Ty->isDependentType() && !LastExpr.get()->isTypeDependent()) {
-        // In ARC, if the final expression ends in a consume, splice
-        // the consume out and bind it later.  In the alternate case
-        // (when dealing with a retainable type), the result
-        // initialization will create a produce.  In both cases the
-        // result will be +1, and we'll need to balance that out with
-        // a bind.
-        if (Expr *rebuiltLastStmt
-              = maybeRebuildARCConsumingStmt(LastExpr.get())) {
-          LastExpr = rebuiltLastStmt;
-        } else {
-          LastExpr = PerformCopyInitialization(
-              InitializedEntity::InitializeStmtExprResult(LPLoc, Ty),
-              SourceLocation(), LastExpr);
-        }
-
-        if (LastExpr.isInvalid())
-          return ExprError();
-        if (LastExpr.get() != nullptr) {
-          if (!LastLabelStmt)
-            Compound->setLastStmt(LastExpr.get());
-          else
-            LastLabelStmt->setSubStmt(LastExpr.get());
-          StmtExprMayBindToTemp = true;
-        }
+    if (const auto *LastStmt = dyn_cast<ValueStmt>(Compound->body_back())) {
+      if (const Expr *Value = LastStmt->getExprStmt()) {
+        StmtExprMayBindToTemp = true;
+        Ty = Value->getType();
       }
     }
   }
@@ -13410,6 +13350,37 @@
   return ResStmtExpr;
 }
 
+ExprResult Sema::ActOnStmtExprResult(ExprResult ER) {
+  if (ER.isInvalid())
+    return ExprError();
+
+  // Do function/array conversion on the last expression, but not
+  // lvalue-to-rvalue.  However, initialize an unqualified type.
+  ER = DefaultFunctionArrayConversion(ER.get());
+  if (ER.isInvalid())
+    return ExprError();
+  Expr *E = ER.get();
+
+  if (E->isTypeDependent())
+    return E;
+
+  // In ARC, if the final expression ends in a consume, splice
+  // the consume out and bind it later.  In the alternate case
+  // (when dealing with a retainable type), the result
+  // initialization will create a produce.  In both cases the
+  // result will be +1, and we'll need to balance that out with
+  // a bind.
+  auto *Cast = dyn_cast<ImplicitCastExpr>(E);
+  if (Cast && Cast->getCastKind() == CK_ARCConsumeObject)
+    return Cast->getSubExpr();
+
+  // FIXME: Provide a better location for the initialization.
+  return PerformCopyInitialization(
+      InitializedEntity::InitializeStmtExprResult(
+          E->getBeginLoc(), E->getType().getUnqualifiedType()),
+      SourceLocation(), E);
+}
+
 ExprResult Sema::BuildBuiltinOffsetOf(SourceLocation BuiltinLoc,
                                       TypeSourceInfo *TInfo,
                                       ArrayRef<OffsetOfComponent> Components,
@@ -14490,14 +14461,6 @@
     // Make sure we redo semantic analysis
     bool AlwaysRebuild() { return true; }
 
-    // Make sure we handle LabelStmts correctly.
-    // FIXME: This does the right thing, but maybe we need a more general
-    // fix to TreeTransform?
-    StmtResult TransformLabelStmt(LabelStmt *S) {
-      S->getDecl()->setStmt(nullptr);
-      return BaseTransform::TransformLabelStmt(S);
-    }
-
     // We need to special-case DeclRefExprs referring to FieldDecls which
     // are not part of a member pointer formation; normal TreeTransforming
     // doesn't catch this case because of the way we represent them in the AST.
diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index e4b4357..5f04139 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -346,10 +346,6 @@
   return getCurFunction()->CompoundScopes.back();
 }
 
-bool Sema::isCurCompoundStmtAStmtExpr() const {
-  return getCurCompoundScope().IsStmtExpr;
-}
-
 StmtResult Sema::ActOnCompoundStmt(SourceLocation L, SourceLocation R,
                                    ArrayRef<Stmt *> Elts, bool isStmtExpr) {
   const unsigned NumElts = Elts.size();
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index e411a2a..ff33028 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -318,6 +318,13 @@
   TypeSourceInfo *TransformTypeWithDeducedTST(TypeSourceInfo *DI);
   /// @}
 
+  /// The reason why the value of a statement is not discarded, if any.
+  enum StmtDiscardKind {
+    SDK_Discarded,
+    SDK_NotDiscarded,
+    SDK_StmtExprResult,
+  };
+
   /// Transform the given statement.
   ///
   /// By default, this routine transforms a statement by delegating to the
@@ -327,7 +334,7 @@
   /// other mechanism.
   ///
   /// \returns the transformed statement.
-  StmtResult TransformStmt(Stmt *S, bool DiscardedValue = false);
+  StmtResult TransformStmt(Stmt *S, StmtDiscardKind SDK = SDK_Discarded);
 
   /// Transform the given statement.
   ///
@@ -672,6 +679,9 @@
 #define STMT(Node, Parent)                        \
   LLVM_ATTRIBUTE_NOINLINE \
   StmtResult Transform##Node(Node *S);
+#define VALUESTMT(Node, Parent)                   \
+  LLVM_ATTRIBUTE_NOINLINE \
+  StmtResult Transform##Node(Node *S, StmtDiscardKind SDK);
 #define EXPR(Node, Parent)                        \
   LLVM_ATTRIBUTE_NOINLINE \
   ExprResult Transform##Node(Node *E);
@@ -3270,7 +3280,7 @@
 };
 
 template <typename Derived>
-StmtResult TreeTransform<Derived>::TransformStmt(Stmt *S, bool DiscardedValue) {
+StmtResult TreeTransform<Derived>::TransformStmt(Stmt *S, StmtDiscardKind SDK) {
   if (!S)
     return S;
 
@@ -3278,8 +3288,12 @@
   case Stmt::NoStmtClass: break;
 
   // Transform individual statement nodes
+  // Pass SDK into statements that can produce a value
 #define STMT(Node, Parent)                                              \
   case Stmt::Node##Class: return getDerived().Transform##Node(cast<Node>(S));
+#define VALUESTMT(Node, Parent)                                         \
+  case Stmt::Node##Class:                                               \
+    return getDerived().Transform##Node(cast<Node>(S), SDK);
 #define ABSTRACT_STMT(Node)
 #define EXPR(Node, Parent)
 #include "clang/AST/StmtNodes.inc"
@@ -3291,10 +3305,10 @@
 #include "clang/AST/StmtNodes.inc"
     {
       ExprResult E = getDerived().TransformExpr(cast<Expr>(S));
-      if (E.isInvalid())
-        return StmtError();
 
-      return getSema().ActOnExprStmt(E, DiscardedValue);
+      if (SDK == SDK_StmtExprResult)
+        E = getSema().ActOnStmtExprResult(E);
+      return getSema().ActOnExprStmt(E, SDK == SDK_Discarded);
     }
   }
 
@@ -6522,8 +6536,9 @@
   bool SubStmtChanged = false;
   SmallVector<Stmt*, 8> Statements;
   for (auto *B : S->body()) {
-    StmtResult Result =
-        getDerived().TransformStmt(B, !IsStmtExpr || B != S->body_back());
+    StmtResult Result = getDerived().TransformStmt(
+        B,
+        IsStmtExpr && B == S->body_back() ? SDK_StmtExprResult : SDK_Discarded);
 
     if (Result.isInvalid()) {
       // Immediately fail if this was a DeclStmt, since it's very
@@ -6586,7 +6601,8 @@
     return StmtError();
 
   // Transform the statement following the case
-  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt());
+  StmtResult SubStmt =
+      getDerived().TransformStmt(S->getSubStmt());
   if (SubStmt.isInvalid())
     return StmtError();
 
@@ -6594,11 +6610,11 @@
   return getDerived().RebuildCaseStmtBody(Case.get(), SubStmt.get());
 }
 
-template<typename Derived>
-StmtResult
-TreeTransform<Derived>::TransformDefaultStmt(DefaultStmt *S) {
+template <typename Derived>
+StmtResult TreeTransform<Derived>::TransformDefaultStmt(DefaultStmt *S) {
   // Transform the statement following the default case
-  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt());
+  StmtResult SubStmt =
+      getDerived().TransformStmt(S->getSubStmt());
   if (SubStmt.isInvalid())
     return StmtError();
 
@@ -6609,8 +6625,8 @@
 
 template<typename Derived>
 StmtResult
-TreeTransform<Derived>::TransformLabelStmt(LabelStmt *S) {
-  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt());
+TreeTransform<Derived>::TransformLabelStmt(LabelStmt *S, StmtDiscardKind SDK) {
+  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt(), SDK);
   if (SubStmt.isInvalid())
     return StmtError();
 
@@ -6619,6 +6635,11 @@
   if (!LD)
     return StmtError();
 
+  // If we're transforming "in-place" (we're not creating new local
+  // declarations), assume we're replacing the old label statement
+  // and clear out the reference to it.
+  if (LD == S->getDecl())
+    S->getDecl()->setStmt(nullptr);
 
   // FIXME: Pass the real colon location in.
   return getDerived().RebuildLabelStmt(S->getIdentLoc(),
@@ -6644,7 +6665,9 @@
 }
 
 template <typename Derived>
-StmtResult TreeTransform<Derived>::TransformAttributedStmt(AttributedStmt *S) {
+StmtResult
+TreeTransform<Derived>::TransformAttributedStmt(AttributedStmt *S,
+                                                StmtDiscardKind SDK) {
   bool AttrsChanged = false;
   SmallVector<const Attr *, 1> Attrs;
 
@@ -6655,7 +6678,7 @@
     Attrs.push_back(R);
   }
 
-  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt());
+  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt(), SDK);
   if (SubStmt.isInvalid())
     return StmtError();
 
@@ -7360,7 +7383,8 @@
 TreeTransform<Derived>::TransformObjCForCollectionStmt(
                                                   ObjCForCollectionStmt *S) {
   // Transform the element statement.
-  StmtResult Element = getDerived().TransformStmt(S->getElement());
+  StmtResult Element =
+      getDerived().TransformStmt(S->getElement(), SDK_NotDiscarded);
   if (Element.isInvalid())
     return StmtError();