[OpenMP] Codegen support for 'target parallel' on the host.

This patch adds support for codegen of 'target parallel' on the host.
It is also the first combined directive that requires two or more
captured statements.  Support for this functionality is included in
the patch.

A combined directive such as 'target parallel' has two captured
statements, one for the 'target' and the other for the 'parallel'
region.  Two captured statements are required because each has
different implicit parameters (see SemaOpenMP.cpp).  For example,
the 'parallel' has 'global_tid' and 'bound_tid' while the 'target'
does not.  The patch adds support for handling multiple captured
statements based on the combined directive.

When codegen'ing the 'target parallel' directive, the 'target'
outlined function is created using the outer captured statement
and the 'parallel' outlined function is created using the inner
captured statement.

Reviewers: ABataev
Differential Revision: https://reviews.llvm.org/D28753

llvm-svn: 292419
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 905c3693..12686ff 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -863,3 +863,73 @@
          Kind == OMPD_target_teams_distribute_parallel_for_simd ||
          Kind == OMPD_target_teams_distribute_simd;
 }
+
+void clang::getOpenMPCaptureRegions(
+    SmallVectorImpl<OpenMPDirectiveKind> &CaptureRegions,
+    OpenMPDirectiveKind DKind) {
+  assert(DKind <= OMPD_unknown);
+  switch (DKind) {
+  case OMPD_parallel:
+  case OMPD_parallel_for:
+  case OMPD_parallel_for_simd:
+  case OMPD_parallel_sections:
+    CaptureRegions.push_back(OMPD_parallel);
+    break;
+  case OMPD_teams:
+  case OMPD_target_teams:
+  case OMPD_simd:
+  case OMPD_for:
+  case OMPD_for_simd:
+  case OMPD_sections:
+  case OMPD_section:
+  case OMPD_single:
+  case OMPD_master:
+  case OMPD_critical:
+  case OMPD_taskgroup:
+  case OMPD_distribute:
+  case OMPD_ordered:
+  case OMPD_atomic:
+  case OMPD_target_data:
+  case OMPD_target:
+  case OMPD_target_parallel_for:
+  case OMPD_target_parallel_for_simd:
+  case OMPD_target_simd:
+  case OMPD_task:
+  case OMPD_taskloop:
+  case OMPD_taskloop_simd:
+  case OMPD_distribute_parallel_for_simd:
+  case OMPD_distribute_simd:
+  case OMPD_distribute_parallel_for:
+  case OMPD_teams_distribute:
+  case OMPD_teams_distribute_simd:
+  case OMPD_teams_distribute_parallel_for_simd:
+  case OMPD_teams_distribute_parallel_for:
+  case OMPD_target_teams_distribute:
+  case OMPD_target_teams_distribute_parallel_for:
+  case OMPD_target_teams_distribute_parallel_for_simd:
+  case OMPD_target_teams_distribute_simd:
+    CaptureRegions.push_back(DKind);
+    break;
+  case OMPD_target_parallel:
+    CaptureRegions.push_back(OMPD_target);
+    CaptureRegions.push_back(OMPD_parallel);
+    break;
+  case OMPD_threadprivate:
+  case OMPD_taskyield:
+  case OMPD_barrier:
+  case OMPD_taskwait:
+  case OMPD_cancellation_point:
+  case OMPD_cancel:
+  case OMPD_flush:
+  case OMPD_target_enter_data:
+  case OMPD_target_exit_data:
+  case OMPD_declare_reduction:
+  case OMPD_declare_simd:
+  case OMPD_declare_target:
+  case OMPD_end_declare_target:
+  case OMPD_target_update:
+    llvm_unreachable("OpenMP Directive is not allowed");
+  case OMPD_unknown:
+    llvm_unreachable("Unknown OpenMP directive");
+  }
+}
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 6419ac9..2f0648e0 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -842,12 +842,12 @@
   return CGF.Builder.CreateStructGEP(Addr, Field, Offset, Name);
 }
 
-llvm::Value *CGOpenMPRuntime::emitParallelOrTeamsOutlinedFunction(
-    const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
-    OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
+static llvm::Value *emitParallelOrTeamsOutlinedFunction(
+    CodeGenModule &CGM, const OMPExecutableDirective &D, const CapturedStmt *CS,
+    const VarDecl *ThreadIDVar, OpenMPDirectiveKind InnermostKind,
+    const StringRef OutlinedHelperName, const RegionCodeGenTy &CodeGen) {
   assert(ThreadIDVar->getType()->isPointerType() &&
          "thread id variable must be of type kmp_int32 *");
-  const CapturedStmt *CS = cast<CapturedStmt>(D.getAssociatedStmt());
   CodeGenFunction CGF(CGM, true);
   bool HasCancel = false;
   if (auto *OPD = dyn_cast<OMPParallelDirective>(&D))
@@ -857,11 +857,27 @@
   else if (auto *OPFD = dyn_cast<OMPParallelForDirective>(&D))
     HasCancel = OPFD->hasCancel();
   CGOpenMPOutlinedRegionInfo CGInfo(*CS, ThreadIDVar, CodeGen, InnermostKind,
-                                    HasCancel, getOutlinedHelperName());
+                                    HasCancel, OutlinedHelperName);
   CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
   return CGF.GenerateOpenMPCapturedStmtFunction(*CS);
 }
 
+llvm::Value *CGOpenMPRuntime::emitParallelOutlinedFunction(
+    const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
+    OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
+  const CapturedStmt *CS = D.getCapturedStmt(OMPD_parallel);
+  return emitParallelOrTeamsOutlinedFunction(
+      CGM, D, CS, ThreadIDVar, InnermostKind, getOutlinedHelperName(), CodeGen);
+}
+
+llvm::Value *CGOpenMPRuntime::emitTeamsOutlinedFunction(
+    const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
+    OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
+  const CapturedStmt *CS = D.getCapturedStmt(OMPD_teams);
+  return emitParallelOrTeamsOutlinedFunction(
+      CGM, D, CS, ThreadIDVar, InnermostKind, getOutlinedHelperName(), CodeGen);
+}
+
 llvm::Value *CGOpenMPRuntime::emitTaskOutlinedFunction(
     const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
     const VarDecl *PartIDVar, const VarDecl *TaskTVar,
@@ -6124,6 +6140,10 @@
       CodeGenFunction::EmitOMPTargetDeviceFunction(
           CGM, ParentName, cast<OMPTargetDirective>(*S));
       break;
+    case Stmt::OMPTargetParallelDirectiveClass:
+      CodeGenFunction::EmitOMPTargetParallelDeviceFunction(
+          CGM, ParentName, cast<OMPTargetParallelDirective>(*S));
+      break;
     default:
       llvm_unreachable("Unknown target directive for OpenMP device codegen.");
     }
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index 61ddc70..ee8c4da 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -527,6 +527,7 @@
   /// Get combiner/initializer for the specified user-defined reduction, if any.
   virtual std::pair<llvm::Function *, llvm::Function *>
   getUserDefinedReduction(const OMPDeclareReductionDecl *D);
+
   /// \brief Emits outlined function for the specified OpenMP parallel directive
   /// \a D. This outlined function has type void(*)(kmp_int32 *ThreadID,
   /// kmp_int32 BoundID, struct context_vars*).
@@ -535,7 +536,19 @@
   /// \param InnermostKind Kind of innermost directive (for simple directives it
   /// is a directive itself, for combined - its innermost directive).
   /// \param CodeGen Code generation sequence for the \a D directive.
-  virtual llvm::Value *emitParallelOrTeamsOutlinedFunction(
+  virtual llvm::Value *emitParallelOutlinedFunction(
+      const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
+      OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen);
+
+  /// \brief Emits outlined function for the specified OpenMP teams directive
+  /// \a D. This outlined function has type void(*)(kmp_int32 *ThreadID,
+  /// kmp_int32 BoundID, struct context_vars*).
+  /// \param D OpenMP directive.
+  /// \param ThreadIDVar Variable for thread id in the current OpenMP region.
+  /// \param InnermostKind Kind of innermost directive (for simple directives it
+  /// is a directive itself, for combined - its innermost directive).
+  /// \param CodeGen Code generation sequence for the \a D directive.
+  virtual llvm::Value *emitTeamsOutlinedFunction(
       const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
       OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen);
 
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
index e749552..f03c0d9 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -478,24 +478,22 @@
                                               const Expr *ThreadLimit,
                                               SourceLocation Loc) {}
 
-llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOrTeamsOutlinedFunction(
+llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOutlinedFunction(
+    const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
+    OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
+  return CGOpenMPRuntime::emitParallelOutlinedFunction(D, ThreadIDVar,
+                                                       InnermostKind, CodeGen);
+}
+
+llvm::Value *CGOpenMPRuntimeNVPTX::emitTeamsOutlinedFunction(
     const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
     OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
 
-  llvm::Function *OutlinedFun = nullptr;
-  if (isa<OMPTeamsDirective>(D)) {
-    llvm::Value *OutlinedFunVal =
-        CGOpenMPRuntime::emitParallelOrTeamsOutlinedFunction(
-            D, ThreadIDVar, InnermostKind, CodeGen);
-    OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
-    OutlinedFun->removeFnAttr(llvm::Attribute::NoInline);
-    OutlinedFun->addFnAttr(llvm::Attribute::AlwaysInline);
-  } else {
-    llvm::Value *OutlinedFunVal =
-        CGOpenMPRuntime::emitParallelOrTeamsOutlinedFunction(
-            D, ThreadIDVar, InnermostKind, CodeGen);
-    OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
-  }
+  llvm::Value *OutlinedFunVal = CGOpenMPRuntime::emitTeamsOutlinedFunction(
+      D, ThreadIDVar, InnermostKind, CodeGen);
+  llvm::Function *OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
+  OutlinedFun->removeFnAttr(llvm::Attribute::NoInline);
+  OutlinedFun->addFnAttr(llvm::Attribute::AlwaysInline);
 
   return OutlinedFun;
 }
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h
index 4010b46..a69f051 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.h
@@ -138,7 +138,7 @@
                           const Expr *ThreadLimit, SourceLocation Loc) override;
 
   /// \brief Emits inlined function for the specified OpenMP parallel
-  //  directive but an inlined function for teams.
+  //  directive.
   /// \a D. This outlined function has type void(*)(kmp_int32 *ThreadID,
   /// kmp_int32 BoundID, struct context_vars*).
   /// \param D OpenMP directive.
@@ -147,10 +147,25 @@
   /// is a directive itself, for combined - its innermost directive).
   /// \param CodeGen Code generation sequence for the \a D directive.
   llvm::Value *
-  emitParallelOrTeamsOutlinedFunction(const OMPExecutableDirective &D,
-                                      const VarDecl *ThreadIDVar,
-                                      OpenMPDirectiveKind InnermostKind,
-                                      const RegionCodeGenTy &CodeGen) override;
+  emitParallelOutlinedFunction(const OMPExecutableDirective &D,
+                               const VarDecl *ThreadIDVar,
+                               OpenMPDirectiveKind InnermostKind,
+                               const RegionCodeGenTy &CodeGen) override;
+
+  /// \brief Emits inlined function for the specified OpenMP teams
+  //  directive.
+  /// \a D. This outlined function has type void(*)(kmp_int32 *ThreadID,
+  /// kmp_int32 BoundID, struct context_vars*).
+  /// \param D OpenMP directive.
+  /// \param ThreadIDVar Variable for thread id in the current OpenMP region.
+  /// \param InnermostKind Kind of innermost directive (for simple directives it
+  /// is a directive itself, for combined - its innermost directive).
+  /// \param CodeGen Code generation sequence for the \a D directive.
+  llvm::Value *
+  emitTeamsOutlinedFunction(const OMPExecutableDirective &D,
+                            const VarDecl *ThreadIDVar,
+                            OpenMPDirectiveKind InnermostKind,
+                            const RegionCodeGenTy &CodeGen) override;
 
   /// \brief Emits code for teams call of the \a OutlinedFn with
   /// variables captured in a record which address is stored in \a
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 2d14ec4..ba531b9 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -1213,10 +1213,9 @@
                                            const OMPExecutableDirective &S,
                                            OpenMPDirectiveKind InnermostKind,
                                            const RegionCodeGenTy &CodeGen) {
-  auto CS = cast<CapturedStmt>(S.getAssociatedStmt());
-  auto OutlinedFn = CGF.CGM.getOpenMPRuntime().
-      emitParallelOrTeamsOutlinedFunction(S,
-          *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen);
+  const CapturedStmt *CS = S.getCapturedStmt(OMPD_parallel);
+  auto OutlinedFn = CGF.CGM.getOpenMPRuntime().emitParallelOutlinedFunction(
+      S, *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen);
   if (const auto *NumThreadsClause = S.getSingleClause<OMPNumThreadsClause>()) {
     CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF);
     auto NumThreads = CGF.EmitScalarExpr(NumThreadsClause->getNumThreads(),
@@ -3497,10 +3496,9 @@
                                         const OMPExecutableDirective &S,
                                         OpenMPDirectiveKind InnermostKind,
                                         const RegionCodeGenTy &CodeGen) {
-  auto CS = cast<CapturedStmt>(S.getAssociatedStmt());
-  auto OutlinedFn = CGF.CGM.getOpenMPRuntime().
-      emitParallelOrTeamsOutlinedFunction(S,
-          *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen);
+  const CapturedStmt *CS = S.getCapturedStmt(OMPD_teams);
+  auto OutlinedFn = CGF.CGM.getOpenMPRuntime().emitTeamsOutlinedFunction(
+      S, *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen);
 
   const OMPTeamsDirective &TD = *dyn_cast<OMPTeamsDirective>(&S);
   const OMPNumTeamsClause *NT = TD.getSingleClause<OMPNumTeamsClause>();
@@ -3755,9 +3753,39 @@
   CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(*this, S, IfCond, Device);
 }
 
+static void emitTargetParallelRegion(CodeGenFunction &CGF,
+                                     const OMPTargetParallelDirective &S,
+                                     PrePostActionTy &Action) {
+  // Get the captured statement associated with the 'parallel' region.
+  auto *CS = S.getCapturedStmt(OMPD_parallel);
+  Action.Enter(CGF);
+  auto &&CodeGen = [CS](CodeGenFunction &CGF, PrePostActionTy &) {
+    // TODO: Add support for clauses.
+    CGF.EmitStmt(CS->getCapturedStmt());
+  };
+  emitCommonOMPParallelDirective(CGF, S, OMPD_parallel, CodeGen);
+}
+
+void CodeGenFunction::EmitOMPTargetParallelDeviceFunction(
+    CodeGenModule &CGM, StringRef ParentName,
+    const OMPTargetParallelDirective &S) {
+  auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
+    emitTargetParallelRegion(CGF, S, Action);
+  };
+  llvm::Function *Fn;
+  llvm::Constant *Addr;
+  // Emit target region as a standalone region.
+  CGM.getOpenMPRuntime().emitTargetOutlinedFunction(
+      S, ParentName, Fn, Addr, /*IsOffloadEntry=*/true, CodeGen);
+  assert(Fn && Addr && "Target device function emission failed.");
+}
+
 void CodeGenFunction::EmitOMPTargetParallelDirective(
     const OMPTargetParallelDirective &S) {
-  // TODO: codegen for target parallel.
+  auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
+    emitTargetParallelRegion(CGF, S, Action);
+  };
+  emitCommonOMPTargetDirective(*this, S, CodeGen);
 }
 
 void CodeGenFunction::EmitOMPTargetParallelForDirective(
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index fe62618..7db72dd 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -2708,6 +2708,9 @@
   static void EmitOMPTargetDeviceFunction(CodeGenModule &CGM,
                                           StringRef ParentName,
                                           const OMPTargetDirective &S);
+  static void
+  EmitOMPTargetParallelDeviceFunction(CodeGenModule &CGM, StringRef ParentName,
+                                      const OMPTargetParallelDirective &S);
   /// \brief Emit inner loop of the worksharing/simd construct.
   ///
   /// \param S Directive, for which the inner loop must be emitted.
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index dcd19c8..5b21f580 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -1608,6 +1608,26 @@
                              Params);
     break;
   }
+  case OMPD_target_parallel: {
+    Sema::CapturedParamNameType ParamsTarget[] = {
+        std::make_pair(StringRef(), QualType()) // __context with shared vars
+    };
+    // Start a captured region for 'target' with no implicit parameters.
+    ActOnCapturedRegionStart(DSAStack->getConstructLoc(), CurScope, CR_OpenMP,
+                             ParamsTarget);
+    QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1);
+    QualType KmpInt32PtrTy =
+        Context.getPointerType(KmpInt32Ty).withConst().withRestrict();
+    Sema::CapturedParamNameType ParamsParallel[] = {
+        std::make_pair(".global_tid.", KmpInt32PtrTy),
+        std::make_pair(".bound_tid.", KmpInt32PtrTy),
+        std::make_pair(StringRef(), QualType()) // __context with shared vars
+    };
+    // Start a captured region for 'parallel'.
+    ActOnCapturedRegionStart(DSAStack->getConstructLoc(), CurScope, CR_OpenMP,
+                             ParamsParallel);
+    break;
+  }
   case OMPD_simd:
   case OMPD_for:
   case OMPD_for_simd:
@@ -1622,7 +1642,6 @@
   case OMPD_atomic:
   case OMPD_target_data:
   case OMPD_target:
-  case OMPD_target_parallel:
   case OMPD_target_parallel_for:
   case OMPD_target_parallel_for_simd:
   case OMPD_target_simd: {
@@ -1737,6 +1756,12 @@
   }
 }
 
+int Sema::getOpenMPCaptureLevels(OpenMPDirectiveKind DKind) {
+  SmallVector<OpenMPDirectiveKind, 4> CaptureRegions;
+  getOpenMPCaptureRegions(CaptureRegions, DKind);
+  return CaptureRegions.size();
+}
+
 static OMPCapturedExprDecl *buildCaptureDecl(Sema &S, IdentifierInfo *Id,
                                              Expr *CaptureExpr, bool WithInit,
                                              bool AsExpression) {
@@ -1796,10 +1821,42 @@
   return CaptureExpr->isGLValue() ? Res : S.DefaultLvalueConversion(Res.get());
 }
 
+namespace {
+// OpenMP directives parsed in this section are represented as a
+// CapturedStatement with an associated statement.  If a syntax error
+// is detected during the parsing of the associated statement, the
+// compiler must abort processing and close the CapturedStatement.
+//
+// Combined directives such as 'target parallel' have more than one
+// nested CapturedStatements.  This RAII ensures that we unwind out
+// of all the nested CapturedStatements when an error is found.
+class CaptureRegionUnwinderRAII {
+private:
+  Sema &S;
+  bool &ErrorFound;
+  OpenMPDirectiveKind DKind;
+
+public:
+  CaptureRegionUnwinderRAII(Sema &S, bool &ErrorFound,
+                            OpenMPDirectiveKind DKind)
+      : S(S), ErrorFound(ErrorFound), DKind(DKind) {}
+  ~CaptureRegionUnwinderRAII() {
+    if (ErrorFound) {
+      int ThisCaptureLevel = S.getOpenMPCaptureLevels(DKind);
+      while (--ThisCaptureLevel >= 0)
+        S.ActOnCapturedRegionError();
+    }
+  }
+};
+} // namespace
+
 StmtResult Sema::ActOnOpenMPRegionEnd(StmtResult S,
                                       ArrayRef<OMPClause *> Clauses) {
+  bool ErrorFound = false;
+  CaptureRegionUnwinderRAII CaptureRegionUnwinder(
+      *this, ErrorFound, DSAStack->getCurrentDirective());
   if (!S.isUsable()) {
-    ActOnCapturedRegionError();
+    ErrorFound = true;
     return StmtError();
   }
 
@@ -1843,7 +1900,6 @@
     else if (Clause->getClauseKind() == OMPC_linear)
       LCs.push_back(cast<OMPLinearClause>(Clause));
   }
-  bool ErrorFound = false;
   // OpenMP, 2.7.1 Loop Construct, Restrictions
   // The nonmonotonic modifier cannot be specified if an ordered clause is
   // specified.
@@ -1874,10 +1930,14 @@
     ErrorFound = true;
   }
   if (ErrorFound) {
-    ActOnCapturedRegionError();
     return StmtError();
   }
-  return ActOnCapturedRegionEnd(S.get());
+  StmtResult SR = S;
+  int ThisCaptureLevel =
+      getOpenMPCaptureLevels(DSAStack->getCurrentDirective());
+  while (--ThisCaptureLevel >= 0)
+    SR = ActOnCapturedRegionEnd(SR.get());
+  return SR;
 }
 
 static bool CheckNestingOfRegions(Sema &SemaRef, DSAStackTy *Stack,
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index c2aa3fe..1d620bb 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -7238,8 +7238,12 @@
     StmtResult Body;
     {
       Sema::CompoundScopeRAII CompoundScope(getSema());
-      Body = getDerived().TransformStmt(
-          cast<CapturedStmt>(D->getAssociatedStmt())->getCapturedStmt());
+      int ThisCaptureLevel =
+          Sema::getOpenMPCaptureLevels(D->getDirectiveKind());
+      Stmt *CS = D->getAssociatedStmt();
+      while (--ThisCaptureLevel >= 0)
+        CS = cast<CapturedStmt>(CS)->getCapturedStmt();
+      Body = getDerived().TransformStmt(CS);
     }
     AssociatedStmt =
         getDerived().getSema().ActOnOpenMPRegionEnd(Body, TClauses);