[OPENMP50]Add device/kind context selector support.

Summary: Added basic parsing/sema support for device/kind context selector.

Reviewers: jdoerfert

Subscribers: rampitec, aheejin, fedor.sergeev, simoncook, guansong, s.egerton, hfinkel, kkwli0, caomhin, cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D70245
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index df5e2d0..347606f 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -20,6 +20,7 @@
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/Basic/BitmaskEnum.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/Bitcode/BitcodeReader.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GlobalValue.h"
@@ -11032,8 +11033,10 @@
 } // anonymous namespace
 
 /// Checks current context and returns true if it matches the context selector.
-template <OpenMPContextSelectorSetKind CtxSet, OpenMPContextSelectorKind Ctx>
-static bool checkContext(const OMPContextSelectorData &Data) {
+template <OpenMPContextSelectorSetKind CtxSet, OpenMPContextSelectorKind Ctx,
+          typename... Arguments>
+static bool checkContext(const OMPContextSelectorData &Data,
+                         Arguments... Params) {
   assert(Data.CtxSet != OMP_CTX_SET_unknown && Data.Ctx != OMP_CTX_unknown &&
          "Unknown context selector or context selector set.");
   return false;
@@ -11048,7 +11051,92 @@
                       [](StringRef S) { return !S.compare_lower("llvm"); });
 }
 
-bool matchesContext(const CompleteOMPContextSelectorData &ContextData) {
+/// Checks for device={kind(<kind>)} context selector.
+/// \returns true if <kind>="host" and compilation is for host.
+/// true if <kind>="nohost" and compilation is for device.
+/// true if <kind>="cpu" and compilation is for Arm, X86 or PPC CPU.
+/// true if <kind>="gpu" and compilation is for NVPTX or AMDGCN.
+/// false otherwise.
+template <>
+bool checkContext<OMP_CTX_SET_device, OMP_CTX_kind, CodeGenModule &>(
+    const OMPContextSelectorData &Data, CodeGenModule &CGM) {
+  for (StringRef Name : Data.Names) {
+    if (!Name.compare_lower("host")) {
+      if (CGM.getLangOpts().OpenMPIsDevice)
+        return false;
+      continue;
+    }
+    if (!Name.compare_lower("nohost")) {
+      if (!CGM.getLangOpts().OpenMPIsDevice)
+        return false;
+      continue;
+    }
+    switch (CGM.getTriple().getArch()) {
+    case llvm::Triple::arm:
+    case llvm::Triple::armeb:
+    case llvm::Triple::aarch64:
+    case llvm::Triple::aarch64_be:
+    case llvm::Triple::aarch64_32:
+    case llvm::Triple::ppc:
+    case llvm::Triple::ppc64:
+    case llvm::Triple::ppc64le:
+    case llvm::Triple::x86:
+    case llvm::Triple::x86_64:
+      if (Name.compare_lower("cpu"))
+        return false;
+      break;
+    case llvm::Triple::amdgcn:
+    case llvm::Triple::nvptx:
+    case llvm::Triple::nvptx64:
+      if (Name.compare_lower("gpu"))
+        return false;
+      break;
+    case llvm::Triple::UnknownArch:
+    case llvm::Triple::arc:
+    case llvm::Triple::avr:
+    case llvm::Triple::bpfel:
+    case llvm::Triple::bpfeb:
+    case llvm::Triple::hexagon:
+    case llvm::Triple::mips:
+    case llvm::Triple::mipsel:
+    case llvm::Triple::mips64:
+    case llvm::Triple::mips64el:
+    case llvm::Triple::msp430:
+    case llvm::Triple::r600:
+    case llvm::Triple::riscv32:
+    case llvm::Triple::riscv64:
+    case llvm::Triple::sparc:
+    case llvm::Triple::sparcv9:
+    case llvm::Triple::sparcel:
+    case llvm::Triple::systemz:
+    case llvm::Triple::tce:
+    case llvm::Triple::tcele:
+    case llvm::Triple::thumb:
+    case llvm::Triple::thumbeb:
+    case llvm::Triple::xcore:
+    case llvm::Triple::le32:
+    case llvm::Triple::le64:
+    case llvm::Triple::amdil:
+    case llvm::Triple::amdil64:
+    case llvm::Triple::hsail:
+    case llvm::Triple::hsail64:
+    case llvm::Triple::spir:
+    case llvm::Triple::spir64:
+    case llvm::Triple::kalimba:
+    case llvm::Triple::shave:
+    case llvm::Triple::lanai:
+    case llvm::Triple::wasm32:
+    case llvm::Triple::wasm64:
+    case llvm::Triple::renderscript32:
+    case llvm::Triple::renderscript64:
+      return false;
+    }
+  }
+  return true;
+}
+
+bool matchesContext(CodeGenModule &CGM,
+                    const CompleteOMPContextSelectorData &ContextData) {
   for (const OMPContextSelectorData &Data : ContextData) {
     switch (Data.CtxSet) {
     case OMP_CTX_SET_implementation:
@@ -11057,8 +11145,22 @@
         if (!checkContext<OMP_CTX_SET_implementation, OMP_CTX_vendor>(Data))
           return false;
         break;
+      case OMP_CTX_kind:
       case OMP_CTX_unknown:
-        llvm_unreachable("Unexpected context selector kind.");
+        llvm_unreachable(
+            "Unexpected context selector kind in implementation set.");
+      }
+      break;
+    case OMP_CTX_SET_device:
+      switch (Data.Ctx) {
+      case OMP_CTX_kind:
+        if (!checkContext<OMP_CTX_SET_device, OMP_CTX_kind, CodeGenModule &>(
+                Data, CGM))
+          return false;
+        break;
+      case OMP_CTX_vendor:
+      case OMP_CTX_unknown:
+        llvm_unreachable("Unexpected context selector kind in device set.");
       }
       break;
     case OMP_CTX_SET_unknown:
@@ -11089,8 +11191,21 @@
         Data.back().Names =
             llvm::makeArrayRef(A->implVendors_begin(), A->implVendors_end());
         break;
+      case OMP_CTX_kind:
       case OMP_CTX_unknown:
-        llvm_unreachable("Unexpected context selector kind.");
+        llvm_unreachable(
+            "Unexpected context selector kind in implementation set.");
+      }
+      break;
+    case OMP_CTX_SET_device:
+      switch (Ctx) {
+      case OMP_CTX_kind:
+        Data.back().Names =
+            llvm::makeArrayRef(A->deviceKinds_begin(), A->deviceKinds_end());
+        break;
+      case OMP_CTX_vendor:
+      case OMP_CTX_unknown:
+        llvm_unreachable("Unexpected context selector kind in device set.");
       }
       break;
     case OMP_CTX_SET_unknown:
@@ -11100,27 +11215,59 @@
   return Data;
 }
 
+static bool isStrictSubset(const CompleteOMPContextSelectorData &LHS,
+                           const CompleteOMPContextSelectorData &RHS) {
+  llvm::SmallDenseMap<std::pair<int, int>, llvm::StringSet<>, 4> RHSData;
+  for (const OMPContextSelectorData &D : RHS) {
+    auto &Pair = RHSData.FindAndConstruct(std::make_pair(D.CtxSet, D.Ctx));
+    Pair.getSecond().insert(D.Names.begin(), D.Names.end());
+  }
+  bool AllSetsAreEqual = true;
+  for (const OMPContextSelectorData &D : LHS) {
+    auto It = RHSData.find(std::make_pair(D.CtxSet, D.Ctx));
+    if (It == RHSData.end())
+      return false;
+    if (D.Names.size() > It->getSecond().size())
+      return false;
+    if (llvm::set_union(It->getSecond(), D.Names))
+      return false;
+    AllSetsAreEqual =
+        AllSetsAreEqual && (D.Names.size() == It->getSecond().size());
+  }
+
+  return LHS.size() != RHS.size() || !AllSetsAreEqual;
+}
+
 static bool greaterCtxScore(const CompleteOMPContextSelectorData &LHS,
                             const CompleteOMPContextSelectorData &RHS) {
   // Score is calculated as sum of all scores + 1.
   llvm::APSInt LHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
-  for (const OMPContextSelectorData &Data : LHS) {
-    if (Data.Score.getBitWidth() > LHSScore.getBitWidth()) {
-      LHSScore = LHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
-    } else if (Data.Score.getBitWidth() < LHSScore.getBitWidth()) {
-      LHSScore += Data.Score.extend(LHSScore.getBitWidth());
-    } else {
-      LHSScore += Data.Score;
+  bool RHSIsSubsetOfLHS = isStrictSubset(RHS, LHS);
+  if (RHSIsSubsetOfLHS) {
+    LHSScore = llvm::APSInt::get(0);
+  } else {
+    for (const OMPContextSelectorData &Data : LHS) {
+      if (Data.Score.getBitWidth() > LHSScore.getBitWidth()) {
+        LHSScore = LHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
+      } else if (Data.Score.getBitWidth() < LHSScore.getBitWidth()) {
+        LHSScore += Data.Score.extend(LHSScore.getBitWidth());
+      } else {
+        LHSScore += Data.Score;
+      }
     }
   }
   llvm::APSInt RHSScore(llvm::APInt(64, 1), /*isUnsigned=*/false);
-  for (const OMPContextSelectorData &Data : RHS) {
-    if (Data.Score.getBitWidth() > RHSScore.getBitWidth()) {
-      RHSScore = RHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
-    } else if (Data.Score.getBitWidth() < RHSScore.getBitWidth()) {
-      RHSScore += Data.Score.extend(RHSScore.getBitWidth());
-    } else {
-      RHSScore += Data.Score;
+  if (!RHSIsSubsetOfLHS && isStrictSubset(LHS, RHS)) {
+    RHSScore = llvm::APSInt::get(0);
+  } else {
+    for (const OMPContextSelectorData &Data : RHS) {
+      if (Data.Score.getBitWidth() > RHSScore.getBitWidth()) {
+        RHSScore = RHSScore.extend(Data.Score.getBitWidth()) + Data.Score;
+      } else if (Data.Score.getBitWidth() < RHSScore.getBitWidth()) {
+        RHSScore += Data.Score.extend(RHSScore.getBitWidth());
+      } else {
+        RHSScore += Data.Score;
+      }
     }
   }
   return llvm::APSInt::compareValues(LHSScore, RHSScore) >= 0;
@@ -11128,7 +11275,7 @@
 
 /// Finds the variant function that matches current context with its context
 /// selector.
-static const FunctionDecl *getDeclareVariantFunction(ASTContext &Ctx,
+static const FunctionDecl *getDeclareVariantFunction(CodeGenModule &CGM,
                                                      const FunctionDecl *FD) {
   if (!FD->hasAttrs() || !FD->hasAttr<OMPDeclareVariantAttr>())
     return FD;
@@ -11137,8 +11284,8 @@
   CompleteOMPContextSelectorData TopMostData;
   for (const auto *A : FD->specific_attrs<OMPDeclareVariantAttr>()) {
     CompleteOMPContextSelectorData Data =
-        translateAttrToContextSelectorData(Ctx, A);
-    if (!matchesContext(Data))
+        translateAttrToContextSelectorData(CGM.getContext(), A);
+    if (!matchesContext(CGM, Data))
       continue;
     // If the attribute matches the context, find the attribute with the highest
     // score.
@@ -11161,7 +11308,7 @@
   llvm::GlobalValue *Orig = CGM.GetGlobalValue(MangledName);
   if (Orig && !Orig->isDeclaration())
     return false;
-  const FunctionDecl *NewFD = getDeclareVariantFunction(CGM.getContext(), D);
+  const FunctionDecl *NewFD = getDeclareVariantFunction(CGM, D);
   // Emit original function if it does not have declare variant attribute or the
   // context does not match.
   if (NewFD == D)