[OPENMP5.0]Add basic support for declare variant directive.

Added basic support for declare variant directive and its match clause
with user context selector.

llvm-svn: 371892
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 761df3b..4859027 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -606,8 +606,6 @@
       break;
     }
     break;
-  case OMPD_declare_simd:
-    break;
   case OMPD_cancel:
     switch (CKind) {
 #define OPENMP_CANCEL_CLAUSE(Name)                                             \
@@ -849,6 +847,8 @@
   case OMPD_taskwait:
   case OMPD_cancellation_point:
   case OMPD_declare_reduction:
+  case OMPD_declare_simd:
+  case OMPD_declare_variant:
     break;
   }
   return false;
@@ -1078,6 +1078,7 @@
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_requires:
+  case OMPD_declare_variant:
     llvm_unreachable("OpenMP Directive is not allowed");
   case OMPD_unknown:
     llvm_unreachable("Unknown OpenMP directive");
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 99fa079..45833e1 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6770,6 +6770,7 @@
   case OMPD_teams_distribute_parallel_for_simd:
   case OMPD_target_update:
   case OMPD_declare_simd:
+  case OMPD_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_declare_reduction:
@@ -7075,6 +7076,7 @@
   case OMPD_teams_distribute_parallel_for_simd:
   case OMPD_target_update:
   case OMPD_declare_simd:
+  case OMPD_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_declare_reduction:
@@ -8826,6 +8828,7 @@
     case OMPD_teams_distribute_parallel_for_simd:
     case OMPD_target_update:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -9583,6 +9586,7 @@
     case OMPD_teams_distribute_parallel_for_simd:
     case OMPD_target_update:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -10205,6 +10209,7 @@
     case OMPD_teams_distribute_parallel_for:
     case OMPD_teams_distribute_parallel_for_simd:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
index 09042b9..c6c595d 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -795,6 +795,7 @@
     case OMPD_teams_distribute_parallel_for_simd:
     case OMPD_target_update:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -865,6 +866,7 @@
   case OMPD_teams_distribute_parallel_for_simd:
   case OMPD_target_update:
   case OMPD_declare_simd:
+  case OMPD_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_declare_reduction:
@@ -1028,6 +1030,7 @@
     case OMPD_teams_distribute_parallel_for_simd:
     case OMPD_target_update:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -1104,6 +1107,7 @@
   case OMPD_teams_distribute_parallel_for_simd:
   case OMPD_target_update:
   case OMPD_declare_simd:
+  case OMPD_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_declare_reduction:
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 6bd9605..293660e 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -42,6 +42,7 @@
   OMPD_teams_distribute_parallel,
   OMPD_target_teams_distribute_parallel,
   OMPD_mapper,
+  OMPD_variant,
 };
 
 class DeclDirectiveListParserHelper final {
@@ -80,6 +81,7 @@
       .Case("reduction", OMPD_reduction)
       .Case("update", OMPD_update)
       .Case("mapper", OMPD_mapper)
+      .Case("variant", OMPD_variant)
       .Default(OMPD_unknown);
 }
 
@@ -93,6 +95,7 @@
       {OMPD_declare, OMPD_mapper, OMPD_declare_mapper},
       {OMPD_declare, OMPD_simd, OMPD_declare_simd},
       {OMPD_declare, OMPD_target, OMPD_declare_target},
+      {OMPD_declare, OMPD_variant, OMPD_declare_variant},
       {OMPD_distribute, OMPD_parallel, OMPD_distribute_parallel},
       {OMPD_distribute_parallel, OMPD_for, OMPD_distribute_parallel_for},
       {OMPD_distribute_parallel_for, OMPD_simd,
@@ -752,6 +755,7 @@
                       /*IsReinject*/ true);
   // Consume the previously pushed token.
   ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
+  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
 
   FNContextRAII FnContext(*this, Ptr);
   OMPDeclareSimdDeclAttr::BranchStateTy BS =
@@ -782,6 +786,107 @@
       LinModifiers, Steps, SourceRange(Loc, EndLoc));
 }
 
+/// Parses clauses for 'declare variant' directive.
+/// clause:
+/// 'match' '('
+/// <selector_set_name> '=' '{' <context_selectors> '}'
+/// ')'
+static bool parseDeclareVariantClause(Parser &P) {
+  Token Tok = P.getCurToken();
+  // Parse 'match'.
+  if (!Tok.is(tok::identifier) ||
+      P.getPreprocessor().getSpelling(Tok).compare("match")) {
+    P.Diag(Tok.getLocation(), diag::err_omp_declare_variant_wrong_clause)
+        << "match";
+    while (!P.SkipUntil(tok::annot_pragma_openmp_end, Parser::StopBeforeMatch))
+      ;
+    return true;
+  }
+  (void)P.ConsumeToken();
+  // Parse '('.
+  BalancedDelimiterTracker T(P, tok::l_paren, tok::annot_pragma_openmp_end);
+  if (T.expectAndConsume(diag::err_expected_lparen_after, "match"))
+    return true;
+  // Parse inner context selector.
+  Tok = P.getCurToken();
+  if (!Tok.is(tok::identifier)) {
+    P.Diag(Tok.getLocation(), diag::err_omp_declare_variant_no_ctx_selector)
+        << "match";
+    return true;
+  }
+  SmallString<16> Buffer;
+  StringRef CtxSelectorName = P.getPreprocessor().getSpelling(Tok, Buffer);
+  // Parse '='.
+  (void)P.ConsumeToken();
+  Tok = P.getCurToken();
+  if (Tok.isNot(tok::equal)) {
+    P.Diag(Tok.getLocation(), diag::err_omp_declare_variant_equal_expected)
+        << CtxSelectorName;
+    return true;
+  }
+  (void)P.ConsumeToken();
+  // Unknown selector - just ignore it completely.
+  {
+    // Parse '{'.
+    BalancedDelimiterTracker TBr(P, tok::l_brace, tok::annot_pragma_openmp_end);
+    if (TBr.expectAndConsume(diag::err_expected_lbrace_after, "="))
+      return true;
+    while (!P.SkipUntil(tok::r_brace, tok::r_paren,
+                        tok::annot_pragma_openmp_end, Parser::StopBeforeMatch))
+      ;
+    // Parse '}'.
+    (void)TBr.consumeClose();
+  }
+  // Parse ')'.
+  (void)T.consumeClose();
+  // TBD: add parsing of known context selectors.
+  return false;
+}
+
+/// Parse clauses for '#pragma omp declare variant ( variant-func-id ) clause'.
+Parser::DeclGroupPtrTy
+Parser::ParseOMPDeclareVariantClauses(Parser::DeclGroupPtrTy Ptr,
+                                      CachedTokens &Toks, SourceLocation Loc) {
+  PP.EnterToken(Tok, /*IsReinject*/ true);
+  PP.EnterTokenStream(Toks, /*DisableMacroExpansion=*/true,
+                      /*IsReinject*/ true);
+  // Consume the previously pushed token.
+  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
+  ConsumeAnyToken(/*ConsumeCodeCompletionTok=*/true);
+
+  FNContextRAII FnContext(*this, Ptr);
+  // Parse function declaration id.
+  SourceLocation RLoc;
+  // Parse with IsAddressOfOperand set to true to parse methods as DeclRefExprs
+  // instead of MemberExprs.
+  ExprResult AssociatedFunction =
+      ParseOpenMPParensExpr(getOpenMPDirectiveName(OMPD_declare_variant), RLoc,
+                            /*IsAddressOfOperand=*/true);
+  if (!AssociatedFunction.isUsable()) {
+    if (!Tok.is(tok::annot_pragma_openmp_end))
+      while (!SkipUntil(tok::annot_pragma_openmp_end, StopBeforeMatch))
+        ;
+    // Skip the last annot_pragma_openmp_end.
+    (void)ConsumeAnnotationToken();
+    return Ptr;
+  }
+
+  bool IsError = parseDeclareVariantClause(*this);
+  // Need to check for extra tokens.
+  if (Tok.isNot(tok::annot_pragma_openmp_end)) {
+    Diag(Tok, diag::warn_omp_extra_tokens_at_eol)
+        << getOpenMPDirectiveName(OMPD_declare_variant);
+    while (Tok.isNot(tok::annot_pragma_openmp_end))
+      ConsumeAnyToken();
+  }
+  // Skip the last annot_pragma_openmp_end.
+  SourceLocation EndLoc = ConsumeAnnotationToken();
+  if (IsError)
+    return Ptr;
+  return Actions.ActOnOpenMPDeclareVariantDirective(
+      Ptr, AssociatedFunction.get(), SourceRange(Loc, EndLoc));
+}
+
 /// Parsing of simple OpenMP clauses like 'default' or 'proc_bind'.
 ///
 ///    default-clause:
@@ -1103,13 +1208,15 @@
     }
     break;
   }
+  case OMPD_declare_variant:
   case OMPD_declare_simd: {
     // The syntax is:
-    // { #pragma omp declare simd }
+    // { #pragma omp declare {simd|variant} }
     // <function-declaration-or-definition>
     //
-    ConsumeToken();
     CachedTokens Toks;
+    Toks.push_back(Tok);
+    ConsumeToken();
     while(Tok.isNot(tok::annot_pragma_openmp_end)) {
       Toks.push_back(Tok);
       ConsumeAnyToken();
@@ -1133,10 +1240,15 @@
       }
     }
     if (!Ptr) {
-      Diag(Loc, diag::err_omp_decl_in_declare_simd);
+      Diag(Loc, diag::err_omp_decl_in_declare_simd_variant)
+          << (DKind == OMPD_declare_simd ? 0 : 1);
       return DeclGroupPtrTy();
     }
-    return ParseOMPDeclareSimdClauses(Ptr, Toks, Loc);
+    if (DKind == OMPD_declare_simd)
+      return ParseOMPDeclareSimdClauses(Ptr, Toks, Loc);
+    assert(DKind == OMPD_declare_variant &&
+           "Expected declare variant directive only");
+    return ParseOMPDeclareVariantClauses(Ptr, Toks, Loc);
   }
   case OMPD_declare_target: {
     SourceLocation DTLoc = ConsumeAnyToken();
@@ -1572,6 +1684,7 @@
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_requires:
+  case OMPD_declare_variant:
     Diag(Tok, diag::err_omp_unexpected_directive)
         << 1 << getOpenMPDirectiveName(DKind);
     SkipUntil(tok::annot_pragma_openmp_end);
@@ -1831,14 +1944,15 @@
 /// constructs.
 /// \param RLoc Returned location of right paren.
 ExprResult Parser::ParseOpenMPParensExpr(StringRef ClauseName,
-                                         SourceLocation &RLoc) {
+                                         SourceLocation &RLoc,
+                                         bool IsAddressOfOperand) {
   BalancedDelimiterTracker T(*this, tok::l_paren, tok::annot_pragma_openmp_end);
   if (T.expectAndConsume(diag::err_expected_lparen_after, ClauseName.data()))
     return ExprError();
 
   SourceLocation ELoc = Tok.getLocation();
   ExprResult LHS(ParseCastExpression(
-      /*isUnaryExpression=*/false, /*isAddressOfOperand=*/false, NotTypeCast));
+      /*isUnaryExpression=*/false, IsAddressOfOperand, NotTypeCast));
   ExprResult Val(ParseRHSOfBinaryExpression(LHS, prec::Conditional));
   Val = Actions.ActOnFinishFullExpr(Val.get(), ELoc, /*DiscardedValue*/ false);
 
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 1290c61..dfa5647 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -9670,10 +9670,13 @@
   return false;
 }
 
-static bool CheckMultiVersionAdditionalRules(Sema &S, const FunctionDecl *OldFD,
-                                             const FunctionDecl *NewFD,
-                                             bool CausesMV,
-                                             MultiVersionKind MVType) {
+bool Sema::areMultiversionVariantFunctionsCompatible(
+    const FunctionDecl *OldFD, const FunctionDecl *NewFD,
+    const PartialDiagnostic &NoProtoDiagID,
+    const PartialDiagnosticAt &NoteCausedDiagIDAt,
+    const PartialDiagnosticAt &NoSupportDiagIDAt,
+    const PartialDiagnosticAt &DiffDiagIDAt, bool TemplatesSupported,
+    bool ConstexprSupported) {
   enum DoesntSupport {
     FuncTemplates = 0,
     VirtFuncs = 1,
@@ -9691,22 +9694,96 @@
     ConstexprSpec = 2,
     InlineSpec = 3,
     StorageClass = 4,
-    Linkage = 5
+    Linkage = 5,
   };
 
-  bool IsCPUSpecificCPUDispatchMVType =
-      MVType == MultiVersionKind::CPUDispatch ||
-      MVType == MultiVersionKind::CPUSpecific;
-
   if (OldFD && !OldFD->getType()->getAs<FunctionProtoType>()) {
-    S.Diag(OldFD->getLocation(), diag::err_multiversion_noproto);
-    S.Diag(NewFD->getLocation(), diag::note_multiversioning_caused_here);
+    Diag(OldFD->getLocation(), NoProtoDiagID);
+    Diag(NoteCausedDiagIDAt.first, NoteCausedDiagIDAt.second);
     return true;
   }
 
   if (!NewFD->getType()->getAs<FunctionProtoType>())
-    return S.Diag(NewFD->getLocation(), diag::err_multiversion_noproto);
+    return Diag(NewFD->getLocation(), NoProtoDiagID);
 
+  if (!TemplatesSupported &&
+      NewFD->getTemplatedKind() == FunctionDecl::TK_FunctionTemplate)
+    return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
+           << FuncTemplates;
+
+  if (const auto *NewCXXFD = dyn_cast<CXXMethodDecl>(NewFD)) {
+    if (NewCXXFD->isVirtual())
+      return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
+             << VirtFuncs;
+
+    if (isa<CXXConstructorDecl>(NewCXXFD))
+      return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
+             << Constructors;
+
+    if (isa<CXXDestructorDecl>(NewCXXFD))
+      return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
+             << Destructors;
+  }
+
+  if (NewFD->isDeleted())
+    return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
+           << DeletedFuncs;
+
+  if (NewFD->isDefaulted())
+    return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
+           << DefaultedFuncs;
+
+  if (!ConstexprSupported && NewFD->isConstexpr())
+    return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
+           << (NewFD->isConsteval() ? ConstevalFuncs : ConstexprFuncs);
+
+  QualType NewQType = Context.getCanonicalType(NewFD->getType());
+  const auto *NewType = cast<FunctionType>(NewQType);
+  QualType NewReturnType = NewType->getReturnType();
+
+  if (NewReturnType->isUndeducedType())
+    return Diag(NoSupportDiagIDAt.first, NoSupportDiagIDAt.second)
+           << DeducedReturn;
+
+  // Ensure the return type is identical.
+  if (OldFD) {
+    QualType OldQType = Context.getCanonicalType(OldFD->getType());
+    const auto *OldType = cast<FunctionType>(OldQType);
+    FunctionType::ExtInfo OldTypeInfo = OldType->getExtInfo();
+    FunctionType::ExtInfo NewTypeInfo = NewType->getExtInfo();
+
+    if (OldTypeInfo.getCC() != NewTypeInfo.getCC())
+      return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << CallingConv;
+
+    QualType OldReturnType = OldType->getReturnType();
+
+    if (OldReturnType != NewReturnType)
+      return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << ReturnType;
+
+    if (OldFD->getConstexprKind() != NewFD->getConstexprKind())
+      return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << ConstexprSpec;
+
+    if (OldFD->isInlineSpecified() != NewFD->isInlineSpecified())
+      return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << InlineSpec;
+
+    if (OldFD->getStorageClass() != NewFD->getStorageClass())
+      return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << StorageClass;
+
+    if (OldFD->isExternC() != NewFD->isExternC())
+      return Diag(DiffDiagIDAt.first, DiffDiagIDAt.second) << Linkage;
+
+    if (CheckEquivalentExceptionSpec(
+            OldFD->getType()->getAs<FunctionProtoType>(), OldFD->getLocation(),
+            NewFD->getType()->getAs<FunctionProtoType>(), NewFD->getLocation()))
+      return true;
+  }
+  return false;
+}
+
+static bool CheckMultiVersionAdditionalRules(Sema &S, const FunctionDecl *OldFD,
+                                             const FunctionDecl *NewFD,
+                                             bool CausesMV,
+                                             MultiVersionKind MVType) {
   if (!S.getASTContext().getTargetInfo().supportsMultiVersioning()) {
     S.Diag(NewFD->getLocation(), diag::err_multiversion_not_supported);
     if (OldFD)
@@ -9714,6 +9791,10 @@
     return true;
   }
 
+  bool IsCPUSpecificCPUDispatchMVType =
+      MVType == MultiVersionKind::CPUDispatch ||
+      MVType == MultiVersionKind::CPUSpecific;
+
   // For now, disallow all other attributes.  These should be opt-in, but
   // an analysis of all of them is a future FIXME.
   if (CausesMV && OldFD && HasNonMultiVersionAttributes(OldFD, MVType)) {
@@ -9727,92 +9808,21 @@
     return S.Diag(NewFD->getLocation(), diag::err_multiversion_no_other_attrs)
            << IsCPUSpecificCPUDispatchMVType;
 
-  if (NewFD->getTemplatedKind() == FunctionDecl::TK_FunctionTemplate)
-    return S.Diag(NewFD->getLocation(), diag::err_multiversion_doesnt_support)
-           << IsCPUSpecificCPUDispatchMVType << FuncTemplates;
-
-  if (const auto *NewCXXFD = dyn_cast<CXXMethodDecl>(NewFD)) {
-    if (NewCXXFD->isVirtual())
-      return S.Diag(NewCXXFD->getLocation(),
-                    diag::err_multiversion_doesnt_support)
-             << IsCPUSpecificCPUDispatchMVType << VirtFuncs;
-
-    if (const auto *NewCXXCtor = dyn_cast<CXXConstructorDecl>(NewFD))
-      return S.Diag(NewCXXCtor->getLocation(),
-                    diag::err_multiversion_doesnt_support)
-             << IsCPUSpecificCPUDispatchMVType << Constructors;
-
-    if (const auto *NewCXXDtor = dyn_cast<CXXDestructorDecl>(NewFD))
-      return S.Diag(NewCXXDtor->getLocation(),
-                    diag::err_multiversion_doesnt_support)
-             << IsCPUSpecificCPUDispatchMVType << Destructors;
-  }
-
-  if (NewFD->isDeleted())
-    return S.Diag(NewFD->getLocation(), diag::err_multiversion_doesnt_support)
-           << IsCPUSpecificCPUDispatchMVType << DeletedFuncs;
-
-  if (NewFD->isDefaulted())
-    return S.Diag(NewFD->getLocation(), diag::err_multiversion_doesnt_support)
-           << IsCPUSpecificCPUDispatchMVType << DefaultedFuncs;
-
-  if (NewFD->isConstexpr() && (MVType == MultiVersionKind::CPUDispatch ||
-                               MVType == MultiVersionKind::CPUSpecific))
-    return S.Diag(NewFD->getLocation(), diag::err_multiversion_doesnt_support)
-           << IsCPUSpecificCPUDispatchMVType
-           << (NewFD->isConsteval() ? ConstevalFuncs : ConstexprFuncs);
-
-  QualType NewQType = S.getASTContext().getCanonicalType(NewFD->getType());
-  const auto *NewType = cast<FunctionType>(NewQType);
-  QualType NewReturnType = NewType->getReturnType();
-
-  if (NewReturnType->isUndeducedType())
-    return S.Diag(NewFD->getLocation(), diag::err_multiversion_doesnt_support)
-           << IsCPUSpecificCPUDispatchMVType << DeducedReturn;
-
   // Only allow transition to MultiVersion if it hasn't been used.
   if (OldFD && CausesMV && OldFD->isUsed(false))
     return S.Diag(NewFD->getLocation(), diag::err_multiversion_after_used);
 
-  // Ensure the return type is identical.
-  if (OldFD) {
-    QualType OldQType = S.getASTContext().getCanonicalType(OldFD->getType());
-    const auto *OldType = cast<FunctionType>(OldQType);
-    FunctionType::ExtInfo OldTypeInfo = OldType->getExtInfo();
-    FunctionType::ExtInfo NewTypeInfo = NewType->getExtInfo();
-
-    if (OldTypeInfo.getCC() != NewTypeInfo.getCC())
-      return S.Diag(NewFD->getLocation(), diag::err_multiversion_diff)
-             << CallingConv;
-
-    QualType OldReturnType = OldType->getReturnType();
-
-    if (OldReturnType != NewReturnType)
-      return S.Diag(NewFD->getLocation(), diag::err_multiversion_diff)
-             << ReturnType;
-
-    if (OldFD->getConstexprKind() != NewFD->getConstexprKind())
-      return S.Diag(NewFD->getLocation(), diag::err_multiversion_diff)
-             << ConstexprSpec;
-
-    if (OldFD->isInlineSpecified() != NewFD->isInlineSpecified())
-      return S.Diag(NewFD->getLocation(), diag::err_multiversion_diff)
-             << InlineSpec;
-
-    if (OldFD->getStorageClass() != NewFD->getStorageClass())
-      return S.Diag(NewFD->getLocation(), diag::err_multiversion_diff)
-             << StorageClass;
-
-    if (OldFD->isExternC() != NewFD->isExternC())
-      return S.Diag(NewFD->getLocation(), diag::err_multiversion_diff)
-             << Linkage;
-
-    if (S.CheckEquivalentExceptionSpec(
-            OldFD->getType()->getAs<FunctionProtoType>(), OldFD->getLocation(),
-            NewFD->getType()->getAs<FunctionProtoType>(), NewFD->getLocation()))
-      return true;
-  }
-  return false;
+  return S.areMultiversionVariantFunctionsCompatible(
+      OldFD, NewFD, S.PDiag(diag::err_multiversion_noproto),
+      PartialDiagnosticAt(NewFD->getLocation(),
+                          S.PDiag(diag::note_multiversioning_caused_here)),
+      PartialDiagnosticAt(NewFD->getLocation(),
+                          S.PDiag(diag::err_multiversion_doesnt_support)
+                              << IsCPUSpecificCPUDispatchMVType),
+      PartialDiagnosticAt(NewFD->getLocation(),
+                          S.PDiag(diag::err_multiversion_diff)),
+      /*TemplatesSupported=*/false,
+      /*ConstexprSupported=*/!IsCPUSpecificCPUDispatchMVType);
 }
 
 /// Check the validity of a multiversion function declaration that is the
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index dbf8155..97844cd 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -3447,6 +3447,7 @@
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_requires:
+  case OMPD_declare_variant:
     llvm_unreachable("OpenMP Directive is not allowed");
   case OMPD_unknown:
     llvm_unreachable("Unknown OpenMP directive");
@@ -4516,6 +4517,7 @@
   case OMPD_declare_mapper:
   case OMPD_declare_simd:
   case OMPD_requires:
+  case OMPD_declare_variant:
     llvm_unreachable("OpenMP Directive is not allowed");
   case OMPD_unknown:
     llvm_unreachable("Unknown OpenMP directive");
@@ -4653,8 +4655,10 @@
   if (!DG || DG.get().isNull())
     return DeclGroupPtrTy();
 
+  const int SimdId = 0;
   if (!DG.get().isSingleDecl()) {
-    Diag(SR.getBegin(), diag::err_omp_single_decl_in_declare_simd);
+    Diag(SR.getBegin(), diag::err_omp_single_decl_in_declare_simd_variant)
+        << SimdId;
     return DG;
   }
   Decl *ADecl = DG.get().getSingleDecl();
@@ -4663,7 +4667,7 @@
 
   auto *FD = dyn_cast<FunctionDecl>(ADecl);
   if (!FD) {
-    Diag(ADecl->getLocation(), diag::err_omp_function_expected);
+    Diag(ADecl->getLocation(), diag::err_omp_function_expected) << SimdId;
     return DeclGroupPtrTy();
   }
 
@@ -4888,6 +4892,204 @@
   return ConvertDeclToDeclGroup(ADecl);
 }
 
+Sema::DeclGroupPtrTy
+Sema::ActOnOpenMPDeclareVariantDirective(Sema::DeclGroupPtrTy DG,
+                                         Expr *VariantRef, SourceRange SR) {
+  if (!DG || DG.get().isNull())
+    return DeclGroupPtrTy();
+
+  const int VariantId = 1;
+  // Must be applied only to single decl.
+  if (!DG.get().isSingleDecl()) {
+    Diag(SR.getBegin(), diag::err_omp_single_decl_in_declare_simd_variant)
+        << VariantId << SR;
+    return DG;
+  }
+  Decl *ADecl = DG.get().getSingleDecl();
+  if (auto *FTD = dyn_cast<FunctionTemplateDecl>(ADecl))
+    ADecl = FTD->getTemplatedDecl();
+
+  // Decl must be a function.
+  auto *FD = dyn_cast<FunctionDecl>(ADecl);
+  if (!FD) {
+    Diag(ADecl->getLocation(), diag::err_omp_function_expected)
+        << VariantId << SR;
+    return DeclGroupPtrTy();
+  }
+
+  auto &&HasMultiVersionAttributes = [](const FunctionDecl *FD) {
+    return FD->hasAttrs() &&
+           (FD->hasAttr<CPUDispatchAttr>() || FD->hasAttr<CPUSpecificAttr>() ||
+            FD->hasAttr<TargetAttr>());
+  };
+  // OpenMP is not compatible with CPU-specific attributes.
+  if (HasMultiVersionAttributes(FD)) {
+    Diag(FD->getLocation(), diag::err_omp_declare_variant_incompat_attributes)
+        << SR;
+    return DG;
+  }
+
+  // Allow #pragma omp declare variant only if the function is not used.
+  if (FD->isUsed(false)) {
+    Diag(SR.getBegin(), diag::err_omp_declare_variant_after_used)
+        << FD->getLocation();
+    return DG;
+  }
+
+  // The VariantRef must point to function.
+  if (!VariantRef) {
+    Diag(SR.getBegin(), diag::err_omp_function_expected) << VariantId;
+    return DG;
+  }
+
+  // Do not check templates, wait until instantiation.
+  if (VariantRef->isTypeDependent() || VariantRef->isValueDependent() ||
+      VariantRef->containsUnexpandedParameterPack() ||
+      VariantRef->isInstantiationDependent() || FD->isDependentContext())
+    return DG;
+
+  // Convert VariantRef expression to the type of the original function to
+  // resolve possible conflicts.
+  ExprResult VariantRefCast;
+  if (LangOpts.CPlusPlus) {
+    QualType FnPtrType;
+    auto *Method = dyn_cast<CXXMethodDecl>(FD);
+    if (Method && !Method->isStatic()) {
+      const Type *ClassType =
+          Context.getTypeDeclType(Method->getParent()).getTypePtr();
+      FnPtrType = Context.getMemberPointerType(FD->getType(), ClassType);
+      ExprResult ER;
+      {
+        // Build adrr_of unary op to correctly handle type checks for member
+        // functions.
+        Sema::TentativeAnalysisScope Trap(*this);
+        ER = CreateBuiltinUnaryOp(VariantRef->getBeginLoc(), UO_AddrOf,
+                                  VariantRef);
+      }
+      if (!ER.isUsable()) {
+        Diag(VariantRef->getExprLoc(), diag::err_omp_function_expected)
+            << VariantId << VariantRef->getSourceRange();
+        return DG;
+      }
+      VariantRef = ER.get();
+    } else {
+      FnPtrType = Context.getPointerType(FD->getType());
+    }
+    ImplicitConversionSequence ICS =
+        TryImplicitConversion(VariantRef, FnPtrType.getUnqualifiedType(),
+                              /*SuppressUserConversions=*/false,
+                              /*AllowExplicit=*/false,
+                              /*InOverloadResolution=*/false,
+                              /*CStyle=*/false,
+                              /*AllowObjCWritebackConversion=*/false);
+    if (ICS.isFailure()) {
+      Diag(VariantRef->getExprLoc(),
+           diag::err_omp_declare_variant_incompat_types)
+          << VariantRef->getType() << FnPtrType << VariantRef->getSourceRange();
+      return DG;
+    }
+    VariantRefCast = PerformImplicitConversion(
+        VariantRef, FnPtrType.getUnqualifiedType(), AA_Converting);
+    if (!VariantRefCast.isUsable())
+      return DG;
+    // Drop previously built artificial addr_of unary op for member functions.
+    if (Method && !Method->isStatic()) {
+      Expr *PossibleAddrOfVariantRef = VariantRefCast.get();
+      if (auto *UO = dyn_cast<UnaryOperator>(
+              PossibleAddrOfVariantRef->IgnoreImplicit()))
+        VariantRefCast = UO->getSubExpr();
+    }
+  } else {
+    VariantRefCast = VariantRef;
+  }
+
+  ExprResult ER = CheckPlaceholderExpr(VariantRefCast.get());
+  if (!ER.isUsable() ||
+      !ER.get()->IgnoreParenImpCasts()->getType()->isFunctionType()) {
+    Diag(VariantRef->getExprLoc(), diag::err_omp_function_expected)
+        << VariantId << VariantRef->getSourceRange();
+    return DG;
+  }
+
+  // The VariantRef must point to function.
+  auto *DRE = dyn_cast<DeclRefExpr>(ER.get()->IgnoreParenImpCasts());
+  if (!DRE) {
+    Diag(VariantRef->getExprLoc(), diag::err_omp_function_expected)
+        << VariantId << VariantRef->getSourceRange();
+    return DG;
+  }
+  auto *NewFD = dyn_cast_or_null<FunctionDecl>(DRE->getDecl());
+  if (!NewFD) {
+    Diag(VariantRef->getExprLoc(), diag::err_omp_function_expected)
+        << VariantId << VariantRef->getSourceRange();
+    return DG;
+  }
+
+  enum DoesntSupport {
+    VirtFuncs = 1,
+    Constructors = 3,
+    Destructors = 4,
+    DeletedFuncs = 5,
+    DefaultedFuncs = 6,
+    ConstexprFuncs = 7,
+    ConstevalFuncs = 8,
+  };
+  if (const auto *CXXFD = dyn_cast<CXXMethodDecl>(FD)) {
+    if (CXXFD->isVirtual()) {
+      Diag(FD->getLocation(), diag::err_omp_declare_variant_doesnt_support)
+          << VirtFuncs;
+      return DG;
+    }
+
+    if (isa<CXXConstructorDecl>(FD)) {
+      Diag(FD->getLocation(), diag::err_omp_declare_variant_doesnt_support)
+          << Constructors;
+      return DG;
+    }
+
+    if (isa<CXXDestructorDecl>(FD)) {
+      Diag(FD->getLocation(), diag::err_omp_declare_variant_doesnt_support)
+          << Destructors;
+      return DG;
+    }
+  }
+
+  if (FD->isDeleted()) {
+    Diag(FD->getLocation(), diag::err_omp_declare_variant_doesnt_support)
+        << DeletedFuncs;
+    return DG;
+  }
+
+  if (FD->isDefaulted()) {
+    Diag(FD->getLocation(), diag::err_omp_declare_variant_doesnt_support)
+        << DefaultedFuncs;
+    return DG;
+  }
+
+  if (FD->isConstexpr()) {
+    Diag(FD->getLocation(), diag::err_omp_declare_variant_doesnt_support)
+        << (NewFD->isConsteval() ? ConstevalFuncs : ConstexprFuncs);
+    return DG;
+  }
+
+  // Check general compatibility.
+  if (areMultiversionVariantFunctionsCompatible(
+          FD, NewFD, PDiag(diag::err_omp_declare_variant_noproto),
+          PartialDiagnosticAt(
+              SR.getBegin(),
+              PDiag(diag::note_omp_declare_variant_specified_here) << SR),
+          PartialDiagnosticAt(
+              VariantRef->getExprLoc(),
+              PDiag(diag::err_omp_declare_variant_doesnt_support)),
+          PartialDiagnosticAt(VariantRef->getExprLoc(),
+                              PDiag(diag::err_omp_declare_variant_diff)
+                                  << FD->getLocation()),
+          /*TemplatesSupported=*/true, /*ConstexprSupported=*/false))
+    return DG;
+
+  return DG;
+}
+
 StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef<OMPClause *> Clauses,
                                               Stmt *AStmt,
                                               SourceLocation StartLoc,
@@ -9895,6 +10097,7 @@
     case OMPD_declare_reduction:
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_teams:
@@ -9963,6 +10166,7 @@
     case OMPD_declare_reduction:
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_teams:
@@ -10032,6 +10236,7 @@
     case OMPD_declare_reduction:
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -10098,6 +10303,7 @@
     case OMPD_declare_reduction:
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -10165,6 +10371,7 @@
     case OMPD_declare_reduction:
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -10231,6 +10438,7 @@
     case OMPD_declare_reduction:
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -10296,6 +10504,7 @@
     case OMPD_declare_reduction:
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
+    case OMPD_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd: