[OPENMP] Replace calls of getAssociatedStmt().

getAssociatedStmt() returns the outermost captured statement for the
OpenMP directive. It may return incorrect region in case of combined
constructs. Reworked the code to reduce the number of calls of
getAssociatedStmt() and used getInnermostCapturedStmt() and
getCapturedStmt() functions instead.
In case of firstprivate variables it may lead to an extra allocas
generation for private copies even if the variable is passed by value
into outlined function and could be used directly as private copy.

llvm-svn: 322393
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 9fa173b..9cff2a3 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -1371,7 +1371,10 @@
   CodeGen.setAction(Action);
   assert(!ThreadIDVar->getType()->isPointerType() &&
          "thread id variable must be of type kmp_int32 for tasks");
-  auto *CS = cast<CapturedStmt>(D.getAssociatedStmt());
+  const OpenMPDirectiveKind Region =
+      isOpenMPTaskLoopDirective(D.getDirectiveKind()) ? OMPD_taskloop
+                                                      : OMPD_task;
+  auto *CS = D.getCapturedStmt(Region);
   auto *TD = dyn_cast<OMPTaskDirective>(&D);
   CodeGenFunction CGF(CGM, true);
   CGOpenMPTaskOutlinedRegionInfo CGInfo(*CS, ThreadIDVar, CodeGen,
@@ -5885,7 +5888,7 @@
        << llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
   }
 
-  const CapturedStmt &CS = *cast<CapturedStmt>(D.getAssociatedStmt());
+  const CapturedStmt &CS = *D.getCapturedStmt(OMPD_target);
 
   CodeGenFunction CGF(CGM, true);
   CGOpenMPTargetRegionInfo CGInfo(CS, CodeGen, EntryFnName);
@@ -5979,7 +5982,7 @@
   // the expression is captured in the enclosing target environment when the
   // teams directive is not combined with target.
 
-  const CapturedStmt &CS = *cast<CapturedStmt>(D.getAssociatedStmt());
+  const CapturedStmt &CS = *D.getCapturedStmt(OMPD_target);
 
   if (auto *TeamsDir = dyn_cast_or_null<OMPExecutableDirective>(
           ignoreCompoundStmts(CS.getCapturedStmt()))) {
@@ -6082,7 +6085,7 @@
   // the expression is captured in the enclosing target environment when the
   // teams directive is not combined with target.
 
-  const CapturedStmt &CS = *cast<CapturedStmt>(D.getAssociatedStmt());
+  const CapturedStmt &CS = *D.getCapturedStmt(OMPD_target);
 
   if (auto *TeamsDir = dyn_cast_or_null<OMPExecutableDirective>(
           ignoreCompoundStmts(CS.getCapturedStmt()))) {
@@ -7059,7 +7062,7 @@
   // Get mappable expression information.
   MappableExprsHandler MEHandler(D, CGF);
 
-  const CapturedStmt &CS = *cast<CapturedStmt>(D.getAssociatedStmt());
+  const CapturedStmt &CS = *D.getCapturedStmt(OMPD_target);
   auto RI = CS.getCapturedRecordDecl()->field_begin();
   auto CV = CapturedVars.begin();
   for (CapturedStmt::const_capture_iterator CI = CS.capture_begin(),
@@ -7314,12 +7317,11 @@
   }
 
   if (const OMPExecutableDirective *E = dyn_cast<OMPExecutableDirective>(S)) {
-    if (!E->hasAssociatedStmt())
+    if (!E->hasAssociatedStmt() || !E->getAssociatedStmt())
       return;
 
     scanForTargetRegionsFunctions(
-        cast<CapturedStmt>(E->getAssociatedStmt())->getCapturedStmt(),
-        ParentName);
+        E->getInnermostCapturedStmt()->getCapturedStmt(), ParentName);
     return;
   }
 
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
index b90e87a..6c0f00d10 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -2428,7 +2428,7 @@
 llvm::Function *CGOpenMPRuntimeNVPTX::createDataSharingWrapper(
     llvm::Function *OutlinedParallelFn, const OMPExecutableDirective &D) {
   ASTContext &Ctx = CGM.getContext();
-  const auto &CS = *cast<CapturedStmt>(D.getAssociatedStmt());
+  const CapturedStmt &CS = *D.getCapturedStmt(OMPD_parallel);
 
   // Create a function that takes as argument the source thread.
   FunctionArgList WrapperArgs;
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 7221ad9..116647d 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -53,34 +53,35 @@
   }
 
 public:
-  OMPLexicalScope(CodeGenFunction &CGF, const OMPExecutableDirective &S,
-                  bool AsInlined = false, bool EmitPreInitStmt = true)
+  OMPLexicalScope(
+      CodeGenFunction &CGF, const OMPExecutableDirective &S,
+      const llvm::Optional<OpenMPDirectiveKind> CapturedRegion = llvm::None,
+      const bool EmitPreInitStmt = true)
       : CodeGenFunction::LexicalScope(CGF, S.getSourceRange()),
         InlinedShareds(CGF) {
     if (EmitPreInitStmt)
       emitPreInitStmt(CGF, S);
-    if (AsInlined) {
-      if (S.hasAssociatedStmt()) {
-        auto *CS = cast<CapturedStmt>(S.getAssociatedStmt());
-        for (auto &C : CS->captures()) {
-          if (C.capturesVariable() || C.capturesVariableByCopy()) {
-            auto *VD = C.getCapturedVar();
-            assert(VD == VD->getCanonicalDecl() &&
-                        "Canonical decl must be captured.");
-            DeclRefExpr DRE(const_cast<VarDecl *>(VD),
-                            isCapturedVar(CGF, VD) ||
-                                (CGF.CapturedStmtInfo &&
-                                 InlinedShareds.isGlobalVarCaptured(VD)),
-                            VD->getType().getNonReferenceType(), VK_LValue,
-                            SourceLocation());
-            InlinedShareds.addPrivate(VD, [&CGF, &DRE]() -> Address {
-              return CGF.EmitLValue(&DRE).getAddress();
-            });
-          }
-        }
-        (void)InlinedShareds.Privatize();
+    if (!CapturedRegion.hasValue())
+      return;
+    assert(S.hasAssociatedStmt() &&
+           "Expected associated statement for inlined directive.");
+    const CapturedStmt *CS = S.getCapturedStmt(*CapturedRegion);
+    for (auto &C : CS->captures()) {
+      if (C.capturesVariable() || C.capturesVariableByCopy()) {
+        auto *VD = C.getCapturedVar();
+        assert(VD == VD->getCanonicalDecl() &&
+               "Canonical decl must be captured.");
+        DeclRefExpr DRE(
+            const_cast<VarDecl *>(VD),
+            isCapturedVar(CGF, VD) || (CGF.CapturedStmtInfo &&
+                                       InlinedShareds.isGlobalVarCaptured(VD)),
+            VD->getType().getNonReferenceType(), VK_LValue, SourceLocation());
+        InlinedShareds.addPrivate(VD, [&CGF, &DRE]() -> Address {
+          return CGF.EmitLValue(&DRE).getAddress();
+        });
       }
     }
+    (void)InlinedShareds.Privatize();
   }
 };
 
@@ -96,9 +97,8 @@
 
 public:
   OMPParallelScope(CodeGenFunction &CGF, const OMPExecutableDirective &S)
-      : OMPLexicalScope(CGF, S,
-                        /*AsInlined=*/false,
-                        /*EmitPreInitStmt=*/EmitPreInitStmt(S)) {}
+      : OMPLexicalScope(CGF, S, /*CapturedRegion=*/llvm::None,
+                        EmitPreInitStmt(S)) {}
 };
 
 /// Lexical scope for OpenMP teams construct, that handles correct codegen
@@ -112,9 +112,8 @@
 
 public:
   OMPTeamsScope(CodeGenFunction &CGF, const OMPExecutableDirective &S)
-      : OMPLexicalScope(CGF, S,
-                        /*AsInlined=*/false,
-                        /*EmitPreInitStmt=*/EmitPreInitStmt(S)) {}
+      : OMPLexicalScope(CGF, S, /*CapturedRegion=*/llvm::None,
+                        EmitPreInitStmt(S)) {}
 };
 
 /// Private scope for OpenMP loop-based directives, that supports capturing
@@ -738,7 +737,12 @@
           cast<VarDecl>(cast<DeclRefExpr>(D)->getDecl())->getCanonicalDecl());
   }
   llvm::DenseSet<const VarDecl *> EmittedAsFirstprivate;
-  CGCapturedStmtInfo CapturesInfo(cast<CapturedStmt>(*D.getAssociatedStmt()));
+  llvm::SmallVector<OpenMPDirectiveKind, 4> CaptureRegions;
+  getOpenMPCaptureRegions(CaptureRegions, D.getDirectiveKind());
+  // Force emission of the firstprivate copy if the directive does not emit
+  // outlined function, like omp for, omp simd, omp distribute etc.
+  bool MustEmitFirstprivateCopy =
+      CaptureRegions.size() == 1 && CaptureRegions.back() == OMPD_unknown;
   for (const auto *C : D.getClausesOfKind<OMPFirstprivateClause>()) {
     auto IRef = C->varlist_begin();
     auto InitsRef = C->inits().begin();
@@ -746,9 +750,8 @@
       auto *OrigVD = cast<VarDecl>(cast<DeclRefExpr>(*IRef)->getDecl());
       bool ThisFirstprivateIsLastprivate =
           Lastprivates.count(OrigVD->getCanonicalDecl()) > 0;
-      auto *CapFD = CapturesInfo.lookup(OrigVD);
       auto *FD = CapturedStmtInfo->lookup(OrigVD);
-      if (!ThisFirstprivateIsLastprivate && FD && (FD == CapFD) &&
+      if (!MustEmitFirstprivateCopy && !ThisFirstprivateIsLastprivate && FD &&
           !FD->getType()->isReferenceType()) {
         EmittedAsFirstprivate.insert(OrigVD->getCanonicalDecl());
         ++IRef;
@@ -1272,7 +1275,7 @@
     CGF.EmitOMPPrivateClause(S, PrivateScope);
     CGF.EmitOMPReductionClauseInit(S, PrivateScope);
     (void)PrivateScope.Privatize();
-    CGF.EmitStmt(cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+    CGF.EmitStmt(S.getCapturedStmt(OMPD_parallel)->getCapturedStmt());
     CGF.EmitOMPReductionClauseFinal(S, /*ReductionKind=*/OMPD_parallel);
   };
   emitCommonOMPParallelDirective(*this, S, OMPD_parallel, CodeGen,
@@ -1734,7 +1737,7 @@
   auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
     emitOMPSimdRegion(CGF, S, Action);
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_simd, CodeGen);
 }
 
@@ -2122,7 +2125,7 @@
     CGF.EmitOMPDistributeLoop(S, emitInnerParallelForWhenCombined,
                               S.getDistInc());
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_parallel);
   CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_distribute, CodeGen);
 }
 
@@ -2132,7 +2135,7 @@
     CGF.EmitOMPDistributeLoop(S, emitInnerParallelForWhenCombined,
                               S.getDistInc());
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_parallel);
   CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_distribute, CodeGen);
 }
 
@@ -2141,7 +2144,7 @@
   auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &) {
     CGF.EmitOMPDistributeLoop(S, emitOMPLoopBodyWithStopPoint, S.getInc());
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_simd, CodeGen);
 }
 
@@ -2169,12 +2172,11 @@
 
 void CodeGenFunction::EmitOMPTargetTeamsDistributeParallelForSimdDirective(
     const OMPTargetTeamsDistributeParallelForSimdDirective &S) {
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitInlinedDirective(
       *this, OMPD_target_teams_distribute_parallel_for_simd,
       [&S](CodeGenFunction &CGF, PrePostActionTy &) {
-        CGF.EmitStmt(
-            cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+        CGF.EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt());
       });
 }
 
@@ -2414,7 +2416,7 @@
                                                  emitDispatchForLoopBounds);
   };
   {
-    OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+    OMPLexicalScope Scope(*this, S, OMPD_unknown);
     CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_for, CodeGen,
                                                 S.hasCancel());
   }
@@ -2434,7 +2436,7 @@
                                                  emitDispatchForLoopBounds);
   };
   {
-    OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+    OMPLexicalScope Scope(*this, S, OMPD_unknown);
     CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_simd, CodeGen);
   }
 
@@ -2454,8 +2456,8 @@
 }
 
 void CodeGenFunction::EmitSections(const OMPExecutableDirective &S) {
-  auto *Stmt = cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt();
-  auto *CS = dyn_cast<CompoundStmt>(Stmt);
+  const Stmt *Stmt = S.getInnermostCapturedStmt()->getCapturedStmt();
+  const auto *CS = dyn_cast<CompoundStmt>(Stmt);
   bool HasLastprivates = false;
   auto &&CodeGen = [&S, Stmt, CS, &HasLastprivates](CodeGenFunction &CGF,
                                                     PrePostActionTy &) {
@@ -2595,7 +2597,7 @@
 
 void CodeGenFunction::EmitOMPSectionsDirective(const OMPSectionsDirective &S) {
   {
-    OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+    OMPLexicalScope Scope(*this, S, OMPD_unknown);
     EmitSections(S);
   }
   // Emit an implicit barrier at the end.
@@ -2607,9 +2609,9 @@
 
 void CodeGenFunction::EmitOMPSectionDirective(const OMPSectionDirective &S) {
   auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &) {
-    CGF.EmitStmt(cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+    CGF.EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt());
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_section, CodeGen,
                                               S.hasCancel());
 }
@@ -2638,10 +2640,10 @@
     (void)CGF.EmitOMPFirstprivateClause(S, SingleScope);
     CGF.EmitOMPPrivateClause(S, SingleScope);
     (void)SingleScope.Privatize();
-    CGF.EmitStmt(cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+    CGF.EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt());
   };
   {
-    OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+    OMPLexicalScope Scope(*this, S, OMPD_unknown);
     CGM.getOpenMPRuntime().emitSingleRegion(*this, CodeGen, S.getLocStart(),
                                             CopyprivateVars, DestExprs,
                                             SrcExprs, AssignmentOps);
@@ -2658,21 +2660,21 @@
 void CodeGenFunction::EmitOMPMasterDirective(const OMPMasterDirective &S) {
   auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
     Action.Enter(CGF);
-    CGF.EmitStmt(cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+    CGF.EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt());
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitMasterRegion(*this, CodeGen, S.getLocStart());
 }
 
 void CodeGenFunction::EmitOMPCriticalDirective(const OMPCriticalDirective &S) {
   auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
     Action.Enter(CGF);
-    CGF.EmitStmt(cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+    CGF.EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt());
   };
   Expr *Hint = nullptr;
   if (auto *HintClause = S.getSingleClause<OMPHintClause>())
     Hint = HintClause->getHint();
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitCriticalRegion(*this,
                                             S.getDirectiveName().getAsString(),
                                             CodeGen, S.getLocStart(), Hint);
@@ -2714,12 +2716,12 @@
                                  emitEmptyBoundParameters);
 }
 
-void CodeGenFunction::EmitOMPTaskBasedDirective(const OMPExecutableDirective &S,
-                                                const RegionCodeGenTy &BodyGen,
-                                                const TaskGenTy &TaskGen,
-                                                OMPTaskDataTy &Data) {
+void CodeGenFunction::EmitOMPTaskBasedDirective(
+    const OMPExecutableDirective &S, const OpenMPDirectiveKind CapturedRegion,
+    const RegionCodeGenTy &BodyGen, const TaskGenTy &TaskGen,
+    OMPTaskDataTy &Data) {
   // Emit outlined function for task construct.
-  auto CS = cast<CapturedStmt>(S.getAssociatedStmt());
+  const CapturedStmt *CS = S.getCapturedStmt(CapturedRegion);
   auto *I = CS->getCapturedDecl()->param_begin();
   auto *PartId = std::next(I);
   auto *TaskT = std::next(I, 4);
@@ -2820,8 +2822,9 @@
   for (const auto *C : S.getClausesOfKind<OMPDependClause>())
     for (auto *IRef : C->varlists())
       Data.Dependences.push_back(std::make_pair(C->getDependencyKind(), IRef));
-  auto &&CodeGen = [&Data, &S, CS, &BodyGen, &LastprivateDstsOrigs](
-      CodeGenFunction &CGF, PrePostActionTy &Action) {
+  auto &&CodeGen = [&Data, &S, CS, &BodyGen, &LastprivateDstsOrigs,
+                    CapturedRegion](CodeGenFunction &CGF,
+                                    PrePostActionTy &Action) {
     // Set proper addresses for generated private copies.
     OMPPrivateScope Scope(CGF);
     if (!Data.PrivateVars.empty() || !Data.FirstprivateVars.empty() ||
@@ -2878,7 +2881,7 @@
       }
     }
     if (Data.Reductions) {
-      OMPLexicalScope LexScope(CGF, S, /*AsInlined=*/true);
+      OMPLexicalScope LexScope(CGF, S, CapturedRegion);
       ReductionCodeGen RedCG(Data.ReductionVars, Data.ReductionCopies,
                              Data.ReductionOps);
       llvm::Value *ReductionsPtr = CGF.Builder.CreateLoad(
@@ -3096,8 +3099,7 @@
         CGF.GetAddrOfLocalVar(SVD), /*Index=*/0, CGF.getSizeSize());
 
     Action.Enter(CGF);
-    OMPLexicalScope LexScope(CGF, S, /*AsInlined=*/true,
-                             /*EmitPreInitStmt=*/false);
+    OMPLexicalScope LexScope(CGF, S, OMPD_task, /*EmitPreInitStmt=*/false);
     BodyGen(CGF);
   };
   auto *OutlinedFn = CGM.getOpenMPRuntime().emitTaskOutlinedFunction(
@@ -3114,7 +3116,7 @@
 
 void CodeGenFunction::EmitOMPTaskDirective(const OMPTaskDirective &S) {
   // Emit outlined function for task construct.
-  auto CS = cast<CapturedStmt>(S.getAssociatedStmt());
+  const CapturedStmt *CS = S.getCapturedStmt(OMPD_task);
   auto CapturedStruct = GenerateCapturedStmtArgument(*CS);
   auto SharedsTy = getContext().getRecordType(CS->getCapturedRecordDecl());
   const Expr *IfCond = nullptr;
@@ -3139,7 +3141,7 @@
                                             SharedsTy, CapturedStruct, IfCond,
                                             Data);
   };
-  EmitOMPTaskBasedDirective(S, BodyGen, TaskGen, Data);
+  EmitOMPTaskBasedDirective(S, OMPD_task, BodyGen, TaskGen, Data);
 }
 
 void CodeGenFunction::EmitOMPTaskyieldDirective(
@@ -3188,9 +3190,9 @@
       CGF.EmitStoreOfScalar(ReductionDesc, CGF.GetAddrOfLocalVar(VD),
                             /*Volatile=*/false, E->getType());
     }
-    CGF.EmitStmt(cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+    CGF.EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt());
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitTaskgroupRegion(*this, CodeGen, S.getLocStart());
 }
 
@@ -3398,7 +3400,7 @@
 
     CGF.EmitOMPDistributeLoop(S, emitOMPLoopBodyWithStopPoint, S.getInc());
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_distribute, CodeGen);
 }
 
@@ -3413,7 +3415,9 @@
 }
 
 void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
-  if (!S.getAssociatedStmt()) {
+  if (S.hasClausesOfKind<OMPDependClause>()) {
+    assert(!S.getAssociatedStmt() &&
+           "No associated statement must be in ordered depend construct.");
     for (const auto *DC : S.getClausesOfKind<OMPDependClause>())
       CGM.getOpenMPRuntime().emitDoacrossOrdered(*this, DC);
     return;
@@ -3421,8 +3425,8 @@
   auto *C = S.getSingleClause<OMPSIMDClause>();
   auto &&CodeGen = [&S, C, this](CodeGenFunction &CGF,
                                  PrePostActionTy &Action) {
+    const CapturedStmt *CS = S.getInnermostCapturedStmt();
     if (C) {
-      auto CS = cast<CapturedStmt>(S.getAssociatedStmt());
       llvm::SmallVector<llvm::Value *, 16> CapturedVars;
       CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
       auto *OutlinedFn = emitOutlinedOrderedFunction(CGM, CS);
@@ -3430,11 +3434,10 @@
                                                       OutlinedFn, CapturedVars);
     } else {
       Action.Enter(CGF);
-      CGF.EmitStmt(
-          cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+      CGF.EmitStmt(CS->getCapturedStmt());
     }
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitOrderedRegion(*this, CodeGen, S.getLocStart(), !C);
 }
 
@@ -3878,8 +3881,7 @@
     }
   }
 
-  const auto *CS =
-      S.getAssociatedStmt()->IgnoreContainers(/*IgnoreCaptured=*/true);
+  const auto *CS = S.getInnermostCapturedStmt()->IgnoreContainers();
   if (const auto *EWC = dyn_cast<ExprWithCleanups>(CS)) {
     enterFullExpression(EWC);
   }
@@ -3899,7 +3901,7 @@
                       S.getV(), S.getExpr(), S.getUpdateExpr(),
                       S.isXLHSInRHSPart(), S.getLocStart());
   };
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_unknown);
   CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_atomic, CodeGen);
 }
 
@@ -3971,7 +3973,7 @@
   (void)PrivateScope.Privatize();
 
   Action.Enter(CGF);
-  CGF.EmitStmt(cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+  CGF.EmitStmt(S.getCapturedStmt(OMPD_target)->getCapturedStmt());
 }
 
 void CodeGenFunction::EmitOMPTargetDeviceFunction(CodeGenModule &CGM,
@@ -4028,7 +4030,7 @@
     CGF.EmitOMPPrivateClause(S, PrivateScope);
     CGF.EmitOMPReductionClauseInit(S, PrivateScope);
     (void)PrivateScope.Privatize();
-    CGF.EmitStmt(cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+    CGF.EmitStmt(S.getCapturedStmt(OMPD_teams)->getCapturedStmt());
     CGF.EmitOMPReductionClauseFinal(S, /*ReductionKind=*/OMPD_teams);
   };
   emitCommonOMPTeamsDirective(*this, S, OMPD_distribute, CodeGen);
@@ -4421,10 +4423,9 @@
   DevicePointerPrivActionTy PrivAction(PrivatizeDevicePointers);
 
   auto &&CodeGen = [&S, &Info, &PrivatizeDevicePointers](
-      CodeGenFunction &CGF, PrePostActionTy &Action) {
+                       CodeGenFunction &CGF, PrePostActionTy &Action) {
     auto &&InnermostCodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &) {
-      CGF.EmitStmt(
-          cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
+      CGF.EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt());
     };
 
     // Codegen that selects wheather to generate the privatization code or not.
@@ -4506,7 +4507,7 @@
   if (auto *C = S.getSingleClause<OMPDeviceClause>())
     Device = C->getDevice();
 
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_task);
   CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(*this, S, IfCond, Device);
 }
 
@@ -4527,7 +4528,7 @@
   if (auto *C = S.getSingleClause<OMPDeviceClause>())
     Device = C->getDevice();
 
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_task);
   CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(*this, S, IfCond, Device);
 }
 
@@ -4664,7 +4665,7 @@
 void CodeGenFunction::EmitOMPTaskLoopBasedDirective(const OMPLoopDirective &S) {
   assert(isOpenMPTaskLoopDirective(S.getDirectiveKind()));
   // Emit outlined function for task construct.
-  auto CS = cast<CapturedStmt>(S.getAssociatedStmt());
+  const CapturedStmt *CS = S.getCapturedStmt(OMPD_taskloop);
   auto CapturedStruct = GenerateCapturedStmtArgument(*CS);
   auto SharedsTy = getContext().getRecordType(CS->getCapturedRecordDecl());
   const Expr *IfCond = nullptr;
@@ -4786,15 +4787,16 @@
     CGF.CGM.getOpenMPRuntime().emitInlinedDirective(CGF, OMPD_taskloop,
                                                     CodeGen);
   };
-  if (Data.Nogroup)
-    EmitOMPTaskBasedDirective(S, BodyGen, TaskGen, Data);
-  else {
+  if (Data.Nogroup) {
+    EmitOMPTaskBasedDirective(S, OMPD_taskloop, BodyGen, TaskGen, Data);
+  } else {
     CGM.getOpenMPRuntime().emitTaskgroupRegion(
         *this,
         [&S, &BodyGen, &TaskGen, &Data](CodeGenFunction &CGF,
                                         PrePostActionTy &Action) {
           Action.Enter(CGF);
-          CGF.EmitOMPTaskBasedDirective(S, BodyGen, TaskGen, Data);
+          CGF.EmitOMPTaskBasedDirective(S, OMPD_taskloop, BodyGen, TaskGen,
+                                        Data);
         },
         S.getLocStart());
   }
@@ -4827,7 +4829,7 @@
   if (auto *C = S.getSingleClause<OMPDeviceClause>())
     Device = C->getDevice();
 
-  OMPLexicalScope Scope(*this, S, /*AsInlined=*/true);
+  OMPLexicalScope Scope(*this, S, OMPD_task);
   CGM.getOpenMPRuntime().emitTargetDataStandAloneCall(*this, S, IfCond, Device);
 }
 
@@ -4849,10 +4851,7 @@
           }
         }
       }
-      const auto *CS = cast<CapturedStmt>(D.getAssociatedStmt());
-      while (const auto *CCS = dyn_cast<CapturedStmt>(CS->getCapturedStmt()))
-        CS = CCS;
-      CGF.EmitStmt(CS->getCapturedStmt());
+      CGF.EmitStmt(D.getInnermostCapturedStmt()->getCapturedStmt());
     }
   };
   OMPSimdLexicalScope Scope(*this, D);
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index cedf327..228cbe6 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -2832,6 +2832,7 @@
                                         const OMPTaskDataTy & /*Data*/)>
       TaskGenTy;
   void EmitOMPTaskBasedDirective(const OMPExecutableDirective &S,
+                                 const OpenMPDirectiveKind CapturedRegion,
                                  const RegionCodeGenTy &BodyGen,
                                  const TaskGenTy &TaskGen, OMPTaskDataTy &Data);
   struct OMPTargetDataInfo {