[OPENMP5.0]Add support for device_type clause in declare target
construct.

OpenMP 5.0 introduced new clause for declare target directive, device_type clause, which may accept values host, nohost, and any. Host means
that the function must be emitted only for the host, nohost - only for
the device, and any - for both, device and the host.

llvm-svn: 369775
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 1695683..9d3d3a5 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -1556,32 +1556,102 @@
          !S.isInOpenMPDeclareTargetContext();
 }
 
+namespace {
+/// Status of the function emission on the host/device.
+enum class FunctionEmissionStatus {
+  Emitted,
+  Discarded,
+  Unknown,
+};
+} // anonymous namespace
+
 /// Do we know that we will eventually codegen the given function?
-static bool isKnownEmitted(Sema &S, FunctionDecl *FD) {
+static FunctionEmissionStatus isKnownDeviceEmitted(Sema &S, FunctionDecl *FD) {
   assert(S.LangOpts.OpenMP && S.LangOpts.OpenMPIsDevice &&
          "Expected OpenMP device compilation.");
   // Templates are emitted when they're instantiated.
   if (FD->isDependentContext())
-    return false;
+    return FunctionEmissionStatus::Discarded;
 
-  if (OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(
-          FD->getCanonicalDecl()))
-    return true;
+  Optional<OMPDeclareTargetDeclAttr::DevTypeTy> DevTy =
+      OMPDeclareTargetDeclAttr::getDeviceType(FD->getCanonicalDecl());
+  if (DevTy.hasValue())
+    return (*DevTy == OMPDeclareTargetDeclAttr::DT_Host)
+               ? FunctionEmissionStatus::Discarded
+               : FunctionEmissionStatus::Emitted;
 
   // Otherwise, the function is known-emitted if it's in our set of
   // known-emitted functions.
-  return S.DeviceKnownEmittedFns.count(FD) > 0;
+  return (S.DeviceKnownEmittedFns.count(FD) > 0)
+             ? FunctionEmissionStatus::Emitted
+             : FunctionEmissionStatus::Unknown;
 }
 
 Sema::DeviceDiagBuilder Sema::diagIfOpenMPDeviceCode(SourceLocation Loc,
                                                      unsigned DiagID) {
   assert(LangOpts.OpenMP && LangOpts.OpenMPIsDevice &&
          "Expected OpenMP device compilation.");
-  return DeviceDiagBuilder((isOpenMPDeviceDelayedContext(*this) &&
-                            !isKnownEmitted(*this, getCurFunctionDecl()))
-                               ? DeviceDiagBuilder::K_Deferred
-                               : DeviceDiagBuilder::K_Immediate,
-                           Loc, DiagID, getCurFunctionDecl(), *this);
+  FunctionEmissionStatus FES =
+      isKnownDeviceEmitted(*this, getCurFunctionDecl());
+  DeviceDiagBuilder::Kind Kind = DeviceDiagBuilder::K_Nop;
+  switch (FES) {
+  case FunctionEmissionStatus::Emitted:
+    Kind = DeviceDiagBuilder::K_Immediate;
+    break;
+  case FunctionEmissionStatus::Unknown:
+    Kind = isOpenMPDeviceDelayedContext(*this) ? DeviceDiagBuilder::K_Deferred
+                                               : DeviceDiagBuilder::K_Immediate;
+    break;
+  case FunctionEmissionStatus::Discarded:
+    Kind = DeviceDiagBuilder::K_Nop;
+    break;
+  }
+
+  return DeviceDiagBuilder(Kind, Loc, DiagID, getCurFunctionDecl(), *this);
+}
+
+/// Do we know that we will eventually codegen the given function?
+static FunctionEmissionStatus isKnownHostEmitted(Sema &S, FunctionDecl *FD) {
+  assert(S.LangOpts.OpenMP && !S.LangOpts.OpenMPIsDevice &&
+         "Expected OpenMP host compilation.");
+  // In OpenMP 4.5 all the functions are host functions.
+  if (S.LangOpts.OpenMP <= 45)
+    return FunctionEmissionStatus::Emitted;
+
+  Optional<OMPDeclareTargetDeclAttr::DevTypeTy> DevTy =
+      OMPDeclareTargetDeclAttr::getDeviceType(FD->getCanonicalDecl());
+  if (DevTy.hasValue())
+    return (*DevTy == OMPDeclareTargetDeclAttr::DT_NoHost)
+               ? FunctionEmissionStatus::Discarded
+               : FunctionEmissionStatus::Emitted;
+
+  // Otherwise, the function is known-emitted if it's in our set of
+  // known-emitted functions.
+  return (S.DeviceKnownEmittedFns.count(FD) > 0)
+             ? FunctionEmissionStatus::Emitted
+             : FunctionEmissionStatus::Unknown;
+}
+
+Sema::DeviceDiagBuilder Sema::diagIfOpenMPHostCode(SourceLocation Loc,
+                                                   unsigned DiagID) {
+  assert(LangOpts.OpenMP && !LangOpts.OpenMPIsDevice &&
+         "Expected OpenMP host compilation.");
+  FunctionEmissionStatus FES =
+      isKnownHostEmitted(*this, getCurFunctionDecl());
+  DeviceDiagBuilder::Kind Kind = DeviceDiagBuilder::K_Nop;
+  switch (FES) {
+  case FunctionEmissionStatus::Emitted:
+    Kind = DeviceDiagBuilder::K_Immediate;
+    break;
+  case FunctionEmissionStatus::Unknown:
+    Kind = DeviceDiagBuilder::K_Deferred;
+    break;
+  case FunctionEmissionStatus::Discarded:
+    Kind = DeviceDiagBuilder::K_Nop;
+    break;
+  }
+
+  return DeviceDiagBuilder(Kind, Loc, DiagID, getCurFunctionDecl(), *this);
 }
 
 void Sema::checkOpenMPDeviceFunction(SourceLocation Loc, FunctionDecl *Callee,
@@ -1589,21 +1659,75 @@
   assert(LangOpts.OpenMP && LangOpts.OpenMPIsDevice &&
          "Expected OpenMP device compilation.");
   assert(Callee && "Callee may not be null.");
+  Callee = Callee->getMostRecentDecl();
   FunctionDecl *Caller = getCurFunctionDecl();
 
+  // host only function are not available on the device.
+  if (Caller &&
+      (isKnownDeviceEmitted(*this, Caller) == FunctionEmissionStatus::Emitted ||
+       (!isOpenMPDeviceDelayedContext(*this) &&
+        isKnownDeviceEmitted(*this, Caller) ==
+            FunctionEmissionStatus::Unknown)) &&
+      isKnownDeviceEmitted(*this, Callee) ==
+          FunctionEmissionStatus::Discarded) {
+    StringRef HostDevTy =
+        getOpenMPSimpleClauseTypeName(OMPC_device_type, OMPC_DEVICE_TYPE_host);
+    Diag(Loc, diag::err_omp_wrong_device_function_call) << HostDevTy << 0;
+    Diag(Callee->getAttr<OMPDeclareTargetDeclAttr>()->getLocation(),
+         diag::note_omp_marked_device_type_here)
+        << HostDevTy;
+    return;
+  }
   // If the caller is known-emitted, mark the callee as known-emitted.
   // Otherwise, mark the call in our call graph so we can traverse it later.
   if ((CheckForDelayedContext && !isOpenMPDeviceDelayedContext(*this)) ||
       (!Caller && !CheckForDelayedContext) ||
-      (Caller && isKnownEmitted(*this, Caller)))
+      (Caller &&
+       isKnownDeviceEmitted(*this, Caller) == FunctionEmissionStatus::Emitted))
     markKnownEmitted(*this, Caller, Callee, Loc,
                      [CheckForDelayedContext](Sema &S, FunctionDecl *FD) {
-                       return CheckForDelayedContext && isKnownEmitted(S, FD);
+                       return CheckForDelayedContext &&
+                              isKnownDeviceEmitted(S, FD) ==
+                                  FunctionEmissionStatus::Emitted;
                      });
   else if (Caller)
     DeviceCallGraph[Caller].insert({Callee, Loc});
 }
 
+void Sema::checkOpenMPHostFunction(SourceLocation Loc, FunctionDecl *Callee,
+                                   bool CheckCaller) {
+  assert(LangOpts.OpenMP && !LangOpts.OpenMPIsDevice &&
+         "Expected OpenMP host compilation.");
+  assert(Callee && "Callee may not be null.");
+  Callee = Callee->getMostRecentDecl();
+  FunctionDecl *Caller = getCurFunctionDecl();
+
+  // device only function are not available on the host.
+  if (Caller &&
+      isKnownHostEmitted(*this, Caller) == FunctionEmissionStatus::Emitted &&
+      isKnownHostEmitted(*this, Callee) == FunctionEmissionStatus::Discarded) {
+    StringRef NoHostDevTy = getOpenMPSimpleClauseTypeName(
+        OMPC_device_type, OMPC_DEVICE_TYPE_nohost);
+    Diag(Loc, diag::err_omp_wrong_device_function_call) << NoHostDevTy << 1;
+    Diag(Callee->getAttr<OMPDeclareTargetDeclAttr>()->getLocation(),
+         diag::note_omp_marked_device_type_here)
+        << NoHostDevTy;
+    return;
+  }
+  // If the caller is known-emitted, mark the callee as known-emitted.
+  // Otherwise, mark the call in our call graph so we can traverse it later.
+  if ((!CheckCaller && !Caller) ||
+      (Caller &&
+       isKnownHostEmitted(*this, Caller) == FunctionEmissionStatus::Emitted))
+    markKnownEmitted(
+        *this, Caller, Callee, Loc, [CheckCaller](Sema &S, FunctionDecl *FD) {
+          return CheckCaller &&
+                 isKnownHostEmitted(S, FD) == FunctionEmissionStatus::Emitted;
+        });
+  else if (Caller)
+    DeviceCallGraph[Caller].insert({Callee, Loc});
+}
+
 void Sema::checkOpenMPDeviceExpr(const Expr *E) {
   assert(getLangOpts().OpenMP && getLangOpts().OpenMPIsDevice &&
          "OpenMP device compilation mode is expected.");
@@ -1970,6 +2094,54 @@
 
 void Sema::DestroyDataSharingAttributesStack() { delete DSAStack; }
 
+void Sema::finalizeOpenMPDelayedAnalysis() {
+  assert(LangOpts.OpenMP && "Expected OpenMP compilation mode.");
+  // Diagnose implicit declare target functions and their callees.
+  for (const auto &CallerCallees : DeviceCallGraph) {
+    Optional<OMPDeclareTargetDeclAttr::DevTypeTy> DevTy =
+        OMPDeclareTargetDeclAttr::getDeviceType(
+            CallerCallees.getFirst()->getMostRecentDecl());
+    // Ignore host functions during device analyzis.
+    if (LangOpts.OpenMPIsDevice && DevTy &&
+        *DevTy == OMPDeclareTargetDeclAttr::DT_Host)
+      continue;
+    // Ignore nohost functions during host analyzis.
+    if (!LangOpts.OpenMPIsDevice && DevTy &&
+        *DevTy == OMPDeclareTargetDeclAttr::DT_NoHost)
+      continue;
+    for (const std::pair<CanonicalDeclPtr<FunctionDecl>, SourceLocation>
+             &Callee : CallerCallees.getSecond()) {
+      const FunctionDecl *FD = Callee.first->getMostRecentDecl();
+      Optional<OMPDeclareTargetDeclAttr::DevTypeTy> DevTy =
+          OMPDeclareTargetDeclAttr::getDeviceType(FD);
+      if (LangOpts.OpenMPIsDevice && DevTy &&
+          *DevTy == OMPDeclareTargetDeclAttr::DT_Host) {
+        // Diagnose host function called during device codegen.
+        StringRef HostDevTy = getOpenMPSimpleClauseTypeName(
+            OMPC_device_type, OMPC_DEVICE_TYPE_host);
+        Diag(Callee.second, diag::err_omp_wrong_device_function_call)
+            << HostDevTy << 0;
+        Diag(FD->getAttr<OMPDeclareTargetDeclAttr>()->getLocation(),
+             diag::note_omp_marked_device_type_here)
+            << HostDevTy;
+        continue;
+      }
+      if (!LangOpts.OpenMPIsDevice && DevTy &&
+          *DevTy == OMPDeclareTargetDeclAttr::DT_NoHost) {
+        // Diagnose nohost function called during host codegen.
+        StringRef NoHostDevTy = getOpenMPSimpleClauseTypeName(
+            OMPC_device_type, OMPC_DEVICE_TYPE_nohost);
+        Diag(Callee.second, diag::err_omp_wrong_device_function_call)
+            << NoHostDevTy << 1;
+        Diag(FD->getAttr<OMPDeclareTargetDeclAttr>()->getLocation(),
+             diag::note_omp_marked_device_type_here)
+            << NoHostDevTy;
+        continue;
+      }
+    }
+  }
+}
+
 void Sema::StartOpenMPDSABlock(OpenMPDirectiveKind DKind,
                                const DeclarationNameInfo &DirName,
                                Scope *CurScope, SourceLocation Loc) {
@@ -4415,6 +4587,7 @@
       case OMPC_reverse_offload:
       case OMPC_dynamic_allocators:
       case OMPC_atomic_default_mem_order:
+      case OMPC_device_type:
         llvm_unreachable("Unexpected clause");
       }
       for (Stmt *CC : C->children()) {
@@ -9642,6 +9815,7 @@
   case OMPC_reverse_offload:
   case OMPC_dynamic_allocators:
   case OMPC_atomic_default_mem_order:
+  case OMPC_device_type:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -10184,6 +10358,7 @@
   case OMPC_reverse_offload:
   case OMPC_dynamic_allocators:
   case OMPC_atomic_default_mem_order:
+  case OMPC_device_type:
     llvm_unreachable("Unexpected OpenMP clause.");
   }
   return CaptureRegion;
@@ -10577,6 +10752,7 @@
   case OMPC_unified_shared_memory:
   case OMPC_reverse_offload:
   case OMPC_dynamic_allocators:
+  case OMPC_device_type:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -10755,6 +10931,7 @@
   case OMPC_reverse_offload:
   case OMPC_dynamic_allocators:
   case OMPC_atomic_default_mem_order:
+  case OMPC_device_type:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -10964,6 +11141,7 @@
   case OMPC_use_device_ptr:
   case OMPC_is_device_ptr:
   case OMPC_atomic_default_mem_order:
+  case OMPC_device_type:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -11170,6 +11348,7 @@
   case OMPC_reverse_offload:
   case OMPC_dynamic_allocators:
   case OMPC_atomic_default_mem_order:
+  case OMPC_device_type:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -15333,16 +15512,15 @@
   --DeclareTargetNestingLevel;
 }
 
-void Sema::ActOnOpenMPDeclareTargetName(Scope *CurScope,
-                                        CXXScopeSpec &ScopeSpec,
-                                        const DeclarationNameInfo &Id,
-                                        OMPDeclareTargetDeclAttr::MapTypeTy MT,
-                                        NamedDeclSetType &SameDirectiveDecls) {
+NamedDecl *
+Sema::lookupOpenMPDeclareTargetName(Scope *CurScope, CXXScopeSpec &ScopeSpec,
+                                    const DeclarationNameInfo &Id,
+                                    NamedDeclSetType &SameDirectiveDecls) {
   LookupResult Lookup(*this, Id, LookupOrdinaryName);
   LookupParsedName(Lookup, CurScope, &ScopeSpec, true);
 
   if (Lookup.isAmbiguous())
-    return;
+    return nullptr;
   Lookup.suppressDiagnostics();
 
   if (!Lookup.isSingleResult()) {
@@ -15353,33 +15531,56 @@
       diagnoseTypo(Corrected, PDiag(diag::err_undeclared_var_use_suggest)
                                   << Id.getName());
       checkDeclIsAllowedInOpenMPTarget(nullptr, Corrected.getCorrectionDecl());
-      return;
+      return nullptr;
     }
 
     Diag(Id.getLoc(), diag::err_undeclared_var_use) << Id.getName();
-    return;
+    return nullptr;
   }
 
   NamedDecl *ND = Lookup.getAsSingle<NamedDecl>();
-  if (isa<VarDecl>(ND) || isa<FunctionDecl>(ND) ||
-      isa<FunctionTemplateDecl>(ND)) {
-    if (!SameDirectiveDecls.insert(cast<NamedDecl>(ND->getCanonicalDecl())))
-      Diag(Id.getLoc(), diag::err_omp_declare_target_multiple) << Id.getName();
-    llvm::Optional<OMPDeclareTargetDeclAttr::MapTypeTy> Res =
-        OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(
-            cast<ValueDecl>(ND));
-    if (!Res) {
-      auto *A = OMPDeclareTargetDeclAttr::CreateImplicit(Context, MT);
-      ND->addAttr(A);
-      if (ASTMutationListener *ML = Context.getASTMutationListener())
-        ML->DeclarationMarkedOpenMPDeclareTarget(ND, A);
-      checkDeclIsAllowedInOpenMPTarget(nullptr, ND, Id.getLoc());
-    } else if (*Res != MT) {
-      Diag(Id.getLoc(), diag::err_omp_declare_target_to_and_link)
-          << Id.getName();
-    }
-  } else {
+  if (!isa<VarDecl>(ND) && !isa<FunctionDecl>(ND) &&
+      !isa<FunctionTemplateDecl>(ND)) {
     Diag(Id.getLoc(), diag::err_omp_invalid_target_decl) << Id.getName();
+    return nullptr;
+  }
+  if (!SameDirectiveDecls.insert(cast<NamedDecl>(ND->getCanonicalDecl())))
+    Diag(Id.getLoc(), diag::err_omp_declare_target_multiple) << Id.getName();
+  return ND;
+}
+
+void Sema::ActOnOpenMPDeclareTargetName(
+    NamedDecl *ND, SourceLocation Loc, OMPDeclareTargetDeclAttr::MapTypeTy MT,
+    OMPDeclareTargetDeclAttr::DevTypeTy DT) {
+  assert((isa<VarDecl>(ND) || isa<FunctionDecl>(ND) ||
+          isa<FunctionTemplateDecl>(ND)) &&
+         "Expected variable, function or function template.");
+
+  // Diagnose marking after use as it may lead to incorrect diagnosis and
+  // codegen.
+  if (LangOpts.OpenMP >= 50 &&
+      (ND->isUsed(/*CheckUsedAttr=*/false) || ND->isReferenced()))
+    Diag(Loc, diag::warn_omp_declare_target_after_first_use);
+
+  Optional<OMPDeclareTargetDeclAttr::DevTypeTy> DevTy =
+      OMPDeclareTargetDeclAttr::getDeviceType(cast<ValueDecl>(ND));
+  if (DevTy.hasValue() && *DevTy != DT) {
+    Diag(Loc, diag::err_omp_device_type_mismatch)
+        << OMPDeclareTargetDeclAttr::ConvertDevTypeTyToStr(DT)
+        << OMPDeclareTargetDeclAttr::ConvertDevTypeTyToStr(*DevTy);
+    return;
+  }
+  Optional<OMPDeclareTargetDeclAttr::MapTypeTy> Res =
+      OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(cast<ValueDecl>(ND));
+  if (!Res) {
+    auto *A = OMPDeclareTargetDeclAttr::CreateImplicit(Context, MT, DT,
+                                                       SourceRange(Loc, Loc));
+    ND->addAttr(A);
+    if (ASTMutationListener *ML = Context.getASTMutationListener())
+      ML->DeclarationMarkedOpenMPDeclareTarget(ND, A);
+    checkDeclIsAllowedInOpenMPTarget(nullptr, ND, Loc);
+  } else if (*Res != MT) {
+    Diag(Loc, diag::err_omp_declare_target_to_and_link) << ND;
   }
 }
 
@@ -15453,8 +15654,14 @@
       return;
     }
     // Mark the function as must be emitted for the device.
-    if (LangOpts.OpenMPIsDevice && Res.hasValue() && IdLoc.isValid())
+    Optional<OMPDeclareTargetDeclAttr::DevTypeTy> DevTy =
+        OMPDeclareTargetDeclAttr::getDeviceType(FD);
+    if (LangOpts.OpenMPIsDevice && Res.hasValue() && IdLoc.isValid() &&
+        *DevTy != OMPDeclareTargetDeclAttr::DT_Host)
       checkOpenMPDeviceFunction(IdLoc, FD, /*CheckForDelayedContext=*/false);
+    if (!LangOpts.OpenMPIsDevice && Res.hasValue() && IdLoc.isValid() &&
+        *DevTy != OMPDeclareTargetDeclAttr::DT_NoHost)
+      checkOpenMPHostFunction(IdLoc, FD, /*CheckCaller=*/false);
   }
   if (auto *VD = dyn_cast<ValueDecl>(D)) {
     // Problem if any with var declared with incomplete type will be reported
@@ -15467,7 +15674,8 @@
       if (isa<VarDecl>(D) || isa<FunctionDecl>(D) ||
           isa<FunctionTemplateDecl>(D)) {
         auto *A = OMPDeclareTargetDeclAttr::CreateImplicit(
-            Context, OMPDeclareTargetDeclAttr::MT_To);
+            Context, OMPDeclareTargetDeclAttr::MT_To,
+            OMPDeclareTargetDeclAttr::DT_Any, SourceRange(IdLoc, IdLoc));
         D->addAttr(A);
         if (ASTMutationListener *ML = Context.getASTMutationListener())
           ML->DeclarationMarkedOpenMPDeclareTarget(D, A);