[OpenMP] `omp begin/end declare variant` - part 1, parsing

This is the first part extracted from D71179 and cleaned up.

This patch provides parsing support for `omp begin/end declare variant`,
as defined in OpenMP technical report 8 (TR8) [0].

A major purpose of this patch is to provide proper math.h/cmath support
for OpenMP target offloading. See PR42061, PR42798, PR42799. The current
code was developed with this feature in mind, see [1].

[0] https://www.openmp.org/wp-content/uploads/openmp-TR8.pdf
[1] https://reviews.llvm.org/D61399#change-496lQkg0mhRN

Reviewed By: aaron.ballman

Differential Revision: https://reviews.llvm.org/D74941
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index d47051c..1eca154 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -48,6 +48,8 @@
   OMPD_target_teams_distribute_parallel,
   OMPD_mapper,
   OMPD_variant,
+  OMPD_begin,
+  OMPD_begin_declare,
 };
 
 // Helper to unify the enum class OpenMPDirectiveKind with its extension
@@ -101,6 +103,7 @@
       .Case("update", OMPD_update)
       .Case("mapper", OMPD_mapper)
       .Case("variant", OMPD_variant)
+      .Case("begin", OMPD_begin)
       .Default(OMPD_unknown);
 }
 
@@ -109,18 +112,21 @@
   // E.g.: OMPD_for OMPD_simd ===> OMPD_for_simd
   // TODO: add other combined directives in topological order.
   static const OpenMPDirectiveKindExWrapper F[][3] = {
+      {OMPD_begin, OMPD_declare, OMPD_begin_declare},
+      {OMPD_end, OMPD_declare, OMPD_end_declare},
       {OMPD_cancellation, OMPD_point, OMPD_cancellation_point},
       {OMPD_declare, OMPD_reduction, OMPD_declare_reduction},
       {OMPD_declare, OMPD_mapper, OMPD_declare_mapper},
       {OMPD_declare, OMPD_simd, OMPD_declare_simd},
       {OMPD_declare, OMPD_target, OMPD_declare_target},
       {OMPD_declare, OMPD_variant, OMPD_declare_variant},
+      {OMPD_begin_declare, OMPD_variant, OMPD_begin_declare_variant},
+      {OMPD_end_declare, OMPD_variant, OMPD_end_declare_variant},
       {OMPD_distribute, OMPD_parallel, OMPD_distribute_parallel},
       {OMPD_distribute_parallel, OMPD_for, OMPD_distribute_parallel_for},
       {OMPD_distribute_parallel_for, OMPD_simd,
        OMPD_distribute_parallel_for_simd},
       {OMPD_distribute, OMPD_simd, OMPD_distribute_simd},
-      {OMPD_end, OMPD_declare, OMPD_end_declare},
       {OMPD_end_declare, OMPD_target, OMPD_end_declare_target},
       {OMPD_target, OMPD_data, OMPD_target_data},
       {OMPD_target, OMPD_enter, OMPD_target_enter},
@@ -1124,13 +1130,19 @@
   // Parse '('.
   (void)BDT.consumeOpen();
 
+  SourceLocation ScoreLoc = Tok.getLocation();
   ExprResult Score = parseContextScore(*this);
 
-  if (!AllowsTraitScore && Score.isUsable()) {
-    Diag(Score.get()->getBeginLoc(),
-         diag::warn_omp_ctx_incompatible_score_for_property)
-        << getOpenMPContextTraitSelectorName(TISelector.Kind)
-        << getOpenMPContextTraitSetName(Set) << Score.get();
+  if (!AllowsTraitScore && !Score.isUnset()) {
+    if (Score.isUsable()) {
+      Diag(ScoreLoc, diag::warn_omp_ctx_incompatible_score_for_property)
+          << getOpenMPContextTraitSelectorName(TISelector.Kind)
+          << getOpenMPContextTraitSetName(Set) << Score.get();
+    } else {
+      Diag(ScoreLoc, diag::warn_omp_ctx_incompatible_score_for_property)
+          << getOpenMPContextTraitSelectorName(TISelector.Kind)
+          << getOpenMPContextTraitSetName(Set) << "<invalid>";
+    }
     Score = ExprResult();
   }
 
@@ -1334,6 +1346,29 @@
     return;
   }
 
+  OMPTraitInfo TI;
+  if (parseOMPDeclareVariantMatchClause(Loc, TI))
+    return;
+
+  Optional<std::pair<FunctionDecl *, Expr *>> DeclVarData =
+      Actions.checkOpenMPDeclareVariantFunction(
+          Ptr, AssociatedFunction.get(), TI,
+          SourceRange(Loc, Tok.getLocation()));
+
+  // Skip last tokens.
+  while (Tok.isNot(tok::annot_pragma_openmp_end))
+    ConsumeAnyToken();
+  if (DeclVarData && !TI.Sets.empty())
+    Actions.ActOnOpenMPDeclareVariantDirective(
+        DeclVarData->first, DeclVarData->second, TI,
+        SourceRange(Loc, Tok.getLocation()));
+
+  // Skip the last annot_pragma_openmp_end.
+  (void)ConsumeAnnotationToken();
+}
+
+bool Parser::parseOMPDeclareVariantMatchClause(SourceLocation Loc,
+                                               OMPTraitInfo &TI) {
   // Parse 'match'.
   OpenMPClauseKind CKind = Tok.isAnnotation()
                                ? OMPC_unknown
@@ -1345,7 +1380,7 @@
       ;
     // Skip the last annot_pragma_openmp_end.
     (void)ConsumeAnnotationToken();
-    return;
+    return true;
   }
   (void)ConsumeToken();
   // Parse '('.
@@ -1356,31 +1391,15 @@
       ;
     // Skip the last annot_pragma_openmp_end.
     (void)ConsumeAnnotationToken();
-    return;
+    return true;
   }
 
   // Parse inner context selectors.
-  OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo();
   parseOMPContextSelectors(Loc, TI);
 
   // Parse ')'
   (void)T.consumeClose();
-
-  Optional<std::pair<FunctionDecl *, Expr *>> DeclVarData =
-      Actions.checkOpenMPDeclareVariantFunction(
-          Ptr, AssociatedFunction.get(), TI,
-          SourceRange(Loc, Tok.getLocation()));
-
-  // Skip last tokens.
-  while (Tok.isNot(tok::annot_pragma_openmp_end))
-    ConsumeAnyToken();
-  if (DeclVarData.hasValue() && !TI.Sets.empty())
-    Actions.ActOnOpenMPDeclareVariantDirective(
-        DeclVarData.getValue().first, DeclVarData.getValue().second, TI,
-        SourceRange(Loc, Tok.getLocation()));
-
-  // Skip the last annot_pragma_openmp_end.
-  (void)ConsumeAnnotationToken();
+  return false;
 }
 
 /// Parsing of simple OpenMP clauses like 'default' or 'proc_bind'.
@@ -1530,17 +1549,36 @@
     ConsumeAnyToken();
 }
 
-void Parser::ParseOMPEndDeclareTargetDirective(OpenMPDirectiveKind DKind,
-                                               SourceLocation DTLoc) {
-  if (DKind != OMPD_end_declare_target) {
-    Diag(Tok, diag::err_expected_end_declare_target);
-    Diag(DTLoc, diag::note_matching) << "'#pragma omp declare target'";
+void Parser::parseOMPEndDirective(OpenMPDirectiveKind BeginKind,
+                                  OpenMPDirectiveKind ExpectedKind,
+                                  OpenMPDirectiveKind FoundKind,
+                                  SourceLocation BeginLoc,
+                                  SourceLocation FoundLoc,
+                                  bool SkipUntilOpenMPEnd) {
+  int DiagSelection = ExpectedKind == OMPD_end_declare_target ? 0 : 1;
+
+  if (FoundKind == ExpectedKind) {
+    ConsumeAnyToken();
+    skipUntilPragmaOpenMPEnd(ExpectedKind);
     return;
   }
-  ConsumeAnyToken();
-  skipUntilPragmaOpenMPEnd(OMPD_end_declare_target);
+
+  Diag(FoundLoc, diag::err_expected_end_declare_target_or_variant)
+      << DiagSelection;
+  Diag(BeginLoc, diag::note_matching)
+      << ("'#pragma omp " + getOpenMPDirectiveName(BeginKind) + "'").str();
+  if (SkipUntilOpenMPEnd)
+    SkipUntil(tok::annot_pragma_openmp_end, StopBeforeMatch);
+}
+
+void Parser::ParseOMPEndDeclareTargetDirective(OpenMPDirectiveKind DKind,
+                                               SourceLocation DKLoc) {
+  parseOMPEndDirective(OMPD_declare_target, OMPD_end_declare_target, DKind,
+                       DKLoc, Tok.getLocation(),
+                       /* SkipUntilOpenMPEnd */ false);
   // Skip the last annot_pragma_openmp_end.
-  ConsumeAnyToken();
+  if (Tok.is(tok::annot_pragma_openmp_end))
+    ConsumeAnnotationToken();
 }
 
 /// Parsing of declarative OpenMP directives.
@@ -1725,6 +1763,56 @@
     }
     break;
   }
+  case OMPD_begin_declare_variant: {
+    // The syntax is:
+    // { #pragma omp begin declare variant clause }
+    // <function-declaration-or-definition-sequence>
+    // { #pragma omp end declare variant }
+    //
+    ConsumeToken();
+    OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo();
+    if (parseOMPDeclareVariantMatchClause(Loc, TI))
+      break;
+
+    // Skip last tokens.
+    skipUntilPragmaOpenMPEnd(OMPD_begin_declare_variant);
+
+    VariantMatchInfo VMI;
+    ASTContext &ASTCtx = Actions.getASTContext();
+    TI.getAsVariantMatchInfo(ASTCtx, VMI, /* DeviceSetOnly */ true);
+    OMPContext OMPCtx(ASTCtx.getLangOpts().OpenMPIsDevice,
+                      ASTCtx.getTargetInfo().getTriple());
+
+    if (isVariantApplicableInContext(VMI, OMPCtx))
+      break;
+
+    // Elide all the code till the matching end declare variant was found.
+    unsigned Nesting = 1;
+    SourceLocation DKLoc;
+    OpenMPDirectiveKind DK = OMPD_unknown;
+    do {
+      DKLoc = Tok.getLocation();
+      DK = parseOpenMPDirectiveKind(*this);
+      if (DK == OMPD_end_declare_variant)
+        --Nesting;
+      else if (DK == OMPD_begin_declare_variant)
+        ++Nesting;
+      if (!Nesting || isEofOrEom())
+        break;
+      ConsumeAnyToken();
+    } while (true);
+
+    parseOMPEndDirective(OMPD_begin_declare_variant, OMPD_end_declare_variant,
+                         DK, Loc, DKLoc, /* SkipUntilOpenMPEnd */ true);
+    if (isEofOrEom())
+      return nullptr;
+    break;
+  }
+  case OMPD_end_declare_variant:
+    // FIXME: With the sema changes we will keep track of nesting and be able to
+    // diagnose unmatchend OMPD_end_declare_variant.
+    ConsumeToken();
+    break;
   case OMPD_declare_variant:
   case OMPD_declare_simd: {
     // The syntax is:
@@ -2215,6 +2303,8 @@
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_requires:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
   case OMPD_declare_variant:
     Diag(Tok, diag::err_omp_unexpected_directive)
         << 1 << getOpenMPDirectiveName(DKind);