[OpenMP] Add support for handling declare target to clause when unified memory is required

Summary:
This patch adds support for the handling of the variables under the declare target to clause.

The variables in this case are handled like link variables are. A pointer is created on the host and then mapped to the device. The runtime will then copy the address of the host variable in the device pointer.

Reviewers: ABataev, AlexEichenberger, caomhin

Reviewed By: ABataev

Subscribers: guansong, jdoerfert, cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D63108

llvm-svn: 363959
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 1a1e0b0..602c9a2 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -2552,16 +2552,18 @@
   return CGM.CreateRuntimeFunction(FnTy, Name);
 }
 
-Address CGOpenMPRuntime::getAddrOfDeclareTargetLink(const VarDecl *VD) {
+Address CGOpenMPRuntime::getAddrOfDeclareTargetVar(const VarDecl *VD) {
   if (CGM.getLangOpts().OpenMPSimd)
     return Address::invalid();
   llvm::Optional<OMPDeclareTargetDeclAttr::MapTypeTy> Res =
       OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(VD);
-  if (Res && *Res == OMPDeclareTargetDeclAttr::MT_Link) {
+  if (Res && (*Res == OMPDeclareTargetDeclAttr::MT_Link ||
+              (*Res == OMPDeclareTargetDeclAttr::MT_To &&
+               HasRequiresUnifiedSharedMemory))) {
     SmallString<64> PtrName;
     {
       llvm::raw_svector_ostream OS(PtrName);
-      OS << CGM.getMangledName(GlobalDecl(VD)) << "_decl_tgt_link_ptr";
+      OS << CGM.getMangledName(GlobalDecl(VD)) << "_decl_tgt_ref_ptr";
     }
     llvm::Value *Ptr = CGM.getModule().getNamedValue(PtrName);
     if (!Ptr) {
@@ -2778,7 +2780,9 @@
                                                      bool PerformInit) {
   Optional<OMPDeclareTargetDeclAttr::MapTypeTy> Res =
       OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(VD);
-  if (!Res || *Res == OMPDeclareTargetDeclAttr::MT_Link)
+  if (!Res || *Res == OMPDeclareTargetDeclAttr::MT_Link ||
+      (*Res == OMPDeclareTargetDeclAttr::MT_To &&
+       HasRequiresUnifiedSharedMemory))
     return CGM.getLangOpts().OpenMPIsDevice;
   VD = VD->getDefinition(CGM.getContext());
   if (VD && !DeclareTargetWithDefinition.insert(CGM.getMangledName(VD)).second)
@@ -4194,6 +4198,9 @@
               CE->getFlags());
       switch (Flags) {
       case OffloadEntriesInfoManagerTy::OMPTargetGlobalVarEntryTo: {
+        if (CGM.getLangOpts().OpenMPIsDevice &&
+            CGM.getOpenMPRuntime().hasRequiresUnifiedSharedMemory())
+          continue;
         if (!CE->getAddress()) {
           unsigned DiagID = CGM.getDiags().getCustomDiagID(
               DiagnosticsEngine::Error,
@@ -7452,7 +7459,10 @@
 
     // Track if the map information being generated is the first for a capture.
     bool IsCaptureFirstInfo = IsFirstComponentList;
-    bool IsLink = false; // Is this variable a "declare target link"?
+    // When the variable is on a declare target link or in a to clause with
+    // unified memory, a reference is needed to hold the host/device address
+    // of the variable.
+    bool RequiresReference = false;
 
     // Scan the components from the base to the complete expression.
     auto CI = Components.rbegin();
@@ -7482,11 +7492,14 @@
       if (const auto *VD =
               dyn_cast_or_null<VarDecl>(I->getAssociatedDeclaration())) {
         if (llvm::Optional<OMPDeclareTargetDeclAttr::MapTypeTy> Res =
-                OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(VD))
-          if (*Res == OMPDeclareTargetDeclAttr::MT_Link) {
-            IsLink = true;
-            BP = CGF.CGM.getOpenMPRuntime().getAddrOfDeclareTargetLink(VD);
+                OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(VD)) {
+          if ((*Res == OMPDeclareTargetDeclAttr::MT_Link) ||
+              (*Res == OMPDeclareTargetDeclAttr::MT_To &&
+               CGF.CGM.getOpenMPRuntime().hasRequiresUnifiedSharedMemory())) {
+            RequiresReference = true;
+            BP = CGF.CGM.getOpenMPRuntime().getAddrOfDeclareTargetVar(VD);
           }
+        }
       }
 
       // If the variable is a pointer and is being dereferenced (i.e. is not
@@ -7652,7 +7665,8 @@
           // (there is a set of entries for each capture).
           OpenMPOffloadMappingFlags Flags = getMapTypeBits(
               MapType, MapModifiers, IsImplicit,
-              !IsExpressionFirstInfo || IsLink, IsCaptureFirstInfo && !IsLink);
+              !IsExpressionFirstInfo || RequiresReference,
+              IsCaptureFirstInfo && !RequiresReference);
 
           if (!IsExpressionFirstInfo) {
             // If we have a PTR_AND_OBJ pair where the OBJ is a pointer as well,
@@ -9124,7 +9138,9 @@
   llvm::Optional<OMPDeclareTargetDeclAttr::MapTypeTy> Res =
       OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(
           cast<VarDecl>(GD.getDecl()));
-  if (!Res || *Res == OMPDeclareTargetDeclAttr::MT_Link) {
+  if (!Res || *Res == OMPDeclareTargetDeclAttr::MT_Link ||
+      (*Res == OMPDeclareTargetDeclAttr::MT_To &&
+       HasRequiresUnifiedSharedMemory)) {
     DeferredGlobalVariables.insert(cast<VarDecl>(GD.getDecl()));
     return true;
   }
@@ -9183,8 +9199,9 @@
   StringRef VarName;
   CharUnits VarSize;
   llvm::GlobalValue::LinkageTypes Linkage;
-  switch (*Res) {
-  case OMPDeclareTargetDeclAttr::MT_To:
+
+  if (*Res == OMPDeclareTargetDeclAttr::MT_To &&
+      !HasRequiresUnifiedSharedMemory) {
     Flags = OffloadEntriesInfoManagerTy::OMPTargetGlobalVarEntryTo;
     VarName = CGM.getMangledName(VD);
     if (VD->hasDefinition(CGM.getContext()) != VarDecl::DeclarationOnly) {
@@ -9207,20 +9224,27 @@
         CGM.addCompilerUsedGlobal(GVAddrRef);
       }
     }
-    break;
-  case OMPDeclareTargetDeclAttr::MT_Link:
-    Flags = OffloadEntriesInfoManagerTy::OMPTargetGlobalVarEntryLink;
+  } else {
+    assert(((*Res == OMPDeclareTargetDeclAttr::MT_Link) ||
+            (*Res == OMPDeclareTargetDeclAttr::MT_To &&
+             HasRequiresUnifiedSharedMemory)) &&
+           "Declare target attribute must link or to with unified memory.");
+    if (*Res == OMPDeclareTargetDeclAttr::MT_Link)
+      Flags = OffloadEntriesInfoManagerTy::OMPTargetGlobalVarEntryLink;
+    else
+      Flags = OffloadEntriesInfoManagerTy::OMPTargetGlobalVarEntryTo;
+
     if (CGM.getLangOpts().OpenMPIsDevice) {
       VarName = Addr->getName();
       Addr = nullptr;
     } else {
-      VarName = getAddrOfDeclareTargetLink(VD).getName();
-      Addr = cast<llvm::Constant>(getAddrOfDeclareTargetLink(VD).getPointer());
+      VarName = getAddrOfDeclareTargetVar(VD).getName();
+      Addr = cast<llvm::Constant>(getAddrOfDeclareTargetVar(VD).getPointer());
     }
     VarSize = CGM.getPointerSize();
     Linkage = llvm::GlobalValue::WeakAnyLinkage;
-    break;
   }
+
   OffloadEntriesInfoManager.registerDeviceGlobalVarEntryInfo(
       VarName, Addr, VarSize, Flags, Linkage);
 }
@@ -9239,12 +9263,15 @@
         OMPDeclareTargetDeclAttr::isDeclareTargetDeclaration(VD);
     if (!Res)
       continue;
-    if (*Res == OMPDeclareTargetDeclAttr::MT_To) {
+    if (*Res == OMPDeclareTargetDeclAttr::MT_To &&
+        !HasRequiresUnifiedSharedMemory) {
       CGM.EmitGlobal(VD);
     } else {
-      assert(*Res == OMPDeclareTargetDeclAttr::MT_Link &&
-             "Expected to or link clauses.");
-      (void)CGM.getOpenMPRuntime().getAddrOfDeclareTargetLink(VD);
+      assert((*Res == OMPDeclareTargetDeclAttr::MT_Link ||
+              (*Res == OMPDeclareTargetDeclAttr::MT_To &&
+               HasRequiresUnifiedSharedMemory)) &&
+             "Expected link clause or to clause with unified memory.");
+      (void)CGM.getOpenMPRuntime().getAddrOfDeclareTargetVar(VD);
     }
   }
 }