[OPENMP50]Initial support for use_device_addr clause.

Summary:
Added parsing/sema analysis/serialization support for use_device_addr
clauses.

Reviewers: jdoerfert

Subscribers: yaxunl, guansong, arphaman, sstefan1, llvm-commits, cfe-commits, caomhin

Tags: #clang, #llvm

Differential Revision: https://reviews.llvm.org/D80404
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 14c4c78..fa1c80f 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -136,6 +136,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
@@ -227,6 +228,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
@@ -1198,6 +1200,53 @@
   return new (Mem) OMPUseDevicePtrClause(Sizes);
 }
 
+OMPUseDeviceAddrClause *
+OMPUseDeviceAddrClause::Create(const ASTContext &C, const OMPVarListLocTy &Locs,
+                               ArrayRef<Expr *> Vars,
+                               ArrayRef<ValueDecl *> Declarations,
+                               MappableExprComponentListsRef ComponentLists) {
+  OMPMappableExprListSizeTy Sizes;
+  Sizes.NumVars = Vars.size();
+  Sizes.NumUniqueDeclarations = getUniqueDeclarationsTotalNumber(Declarations);
+  Sizes.NumComponentLists = ComponentLists.size();
+  Sizes.NumComponents = getComponentsTotalNumber(ComponentLists);
+
+  // We need to allocate:
+  // 3 x NumVars x Expr* - we have an original list expression for each clause
+  // list entry and an equal number of private copies and inits.
+  // NumUniqueDeclarations x ValueDecl* - unique base declarations associated
+  // with each component list.
+  // (NumUniqueDeclarations + NumComponentLists) x unsigned - we specify the
+  // number of lists for each unique declaration and the size of each component
+  // list.
+  // NumComponents x MappableComponent - the total of all the components in all
+  // the lists.
+  void *Mem = C.Allocate(
+      totalSizeToAlloc<Expr *, ValueDecl *, unsigned,
+                       OMPClauseMappableExprCommon::MappableComponent>(
+          Sizes.NumVars, Sizes.NumUniqueDeclarations,
+          Sizes.NumUniqueDeclarations + Sizes.NumComponentLists,
+          Sizes.NumComponents));
+
+  auto *Clause = new (Mem) OMPUseDeviceAddrClause(Locs, Sizes);
+
+  Clause->setVarRefs(Vars);
+  Clause->setClauseInfo(Declarations, ComponentLists);
+  return Clause;
+}
+
+OMPUseDeviceAddrClause *
+OMPUseDeviceAddrClause::CreateEmpty(const ASTContext &C,
+                                    const OMPMappableExprListSizeTy &Sizes) {
+  void *Mem = C.Allocate(
+      totalSizeToAlloc<Expr *, ValueDecl *, unsigned,
+                       OMPClauseMappableExprCommon::MappableComponent>(
+          Sizes.NumVars, Sizes.NumUniqueDeclarations,
+          Sizes.NumUniqueDeclarations + Sizes.NumComponentLists,
+          Sizes.NumComponents));
+  return new (Mem) OMPUseDeviceAddrClause(Sizes);
+}
+
 OMPIsDevicePtrClause *
 OMPIsDevicePtrClause::Create(const ASTContext &C, const OMPVarListLocTy &Locs,
                              ArrayRef<Expr *> Vars,
@@ -1934,6 +1983,15 @@
   }
 }
 
+void OMPClausePrinter::VisitOMPUseDeviceAddrClause(
+    OMPUseDeviceAddrClause *Node) {
+  if (!Node->varlist_empty()) {
+    OS << "use_device_addr";
+    VisitOMPClauseList(Node, '(');
+    OS << ")";
+  }
+}
+
 void OMPClausePrinter::VisitOMPIsDevicePtrClause(OMPIsDevicePtrClause *Node) {
   if (!Node->varlist_empty()) {
     OS << "is_device_ptr";
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index bd2eeb6..e573c04 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -784,6 +784,10 @@
     const OMPUseDevicePtrClause *C) {
   VisitOMPClauseList(C);
 }
+void OMPClauseProfiler::VisitOMPUseDeviceAddrClause(
+    const OMPUseDeviceAddrClause *C) {
+  VisitOMPClauseList(C);
+}
 void OMPClauseProfiler::VisitOMPIsDevicePtrClause(
     const OMPIsDevicePtrClause *C) {
   VisitOMPClauseList(C);
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 8dddb66..a000e4d 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -163,6 +163,7 @@
   case OMPC_hint:
   case OMPC_uniform:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
@@ -411,6 +412,7 @@
   case OMPC_hint:
   case OMPC_uniform:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index d12aa65..ae4e340 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -4730,6 +4730,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index bd40e6b..5161c7d 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2497,7 +2497,7 @@
 ///       in_reduction-clause | allocator-clause | allocate-clause |
 ///       acq_rel-clause | acquire-clause | release-clause | relaxed-clause |
 ///       depobj-clause | destroy-clause | detach-clause | inclusive-clause |
-///       exclusive-clause | uses_allocators-clause
+///       exclusive-clause | uses_allocators-clause | use_device_addr-clause
 ///
 OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
                                      OpenMPClauseKind CKind, bool FirstClause) {
@@ -2663,6 +2663,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_allocate:
   case OMPC_nontemporal:
@@ -3581,6 +3582,8 @@
 ///       'from' '(' [ mapper '(' mapper-identifier ')' ':' ] list ')'
 ///    use_device_ptr-clause:
 ///       'use_device_ptr' '(' list ')'
+///    use_device_addr-clause:
+///       'use_device_addr' '(' list ')'
 ///    is_device_ptr-clause:
 ///       'is_device_ptr' '(' list ')'
 ///    allocate-clause:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index e556969..a60a047 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -5408,6 +5408,7 @@
       case OMPC_to:
       case OMPC_from:
       case OMPC_use_device_ptr:
+      case OMPC_use_device_addr:
       case OMPC_is_device_ptr:
       case OMPC_nontemporal:
       case OMPC_order:
@@ -10165,12 +10166,18 @@
 
   assert(isa<CapturedStmt>(AStmt) && "Captured statement expected");
 
-  // OpenMP [2.10.1, Restrictions, p. 97]
-  // At least one map clause must appear on the directive.
-  if (!hasClauses(Clauses, OMPC_map, OMPC_use_device_ptr)) {
+  // OpenMP [2.12.2, target data Construct, Restrictions]
+  // At least one map, use_device_addr or use_device_ptr clause must appear on
+  // the directive.
+  if (!hasClauses(Clauses, OMPC_map, OMPC_use_device_ptr) &&
+      (LangOpts.OpenMP < 50 || !hasClauses(Clauses, OMPC_use_device_addr))) {
+    StringRef Expected;
+    if (LangOpts.OpenMP < 50)
+      Expected = "'map' or 'use_device_ptr'";
+    else
+      Expected = "'map', 'use_device_ptr', or 'use_device_addr'";
     Diag(StartLoc, diag::err_omp_no_clause_for_directive)
-        << "'map' or 'use_device_ptr'"
-        << getOpenMPDirectiveName(OMPD_target_data);
+        << Expected << getOpenMPDirectiveName(OMPD_target_data);
     return StmtError();
   }
 
@@ -11535,6 +11542,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
@@ -12289,6 +12297,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
@@ -12731,6 +12740,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
@@ -12956,6 +12966,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_unified_address:
   case OMPC_unified_shared_memory:
@@ -13195,6 +13206,7 @@
   case OMPC_to:
   case OMPC_from:
   case OMPC_use_device_ptr:
+  case OMPC_use_device_addr:
   case OMPC_is_device_ptr:
   case OMPC_atomic_default_mem_order:
   case OMPC_device_type:
@@ -13406,6 +13418,9 @@
   case OMPC_use_device_ptr:
     Res = ActOnOpenMPUseDevicePtrClause(VarList, Locs);
     break;
+  case OMPC_use_device_addr:
+    Res = ActOnOpenMPUseDeviceAddrClause(VarList, Locs);
+    break;
   case OMPC_is_device_ptr:
     Res = ActOnOpenMPIsDevicePtrClause(VarList, Locs);
     break;
@@ -18389,6 +18404,54 @@
       MVLI.VarBaseDeclarations, MVLI.VarComponents);
 }
 
+OMPClause *Sema::ActOnOpenMPUseDeviceAddrClause(ArrayRef<Expr *> VarList,
+                                                const OMPVarListLocTy &Locs) {
+  MappableVarListInfo MVLI(VarList);
+
+  for (Expr *RefExpr : VarList) {
+    assert(RefExpr && "NULL expr in OpenMP use_device_addr clause.");
+    SourceLocation ELoc;
+    SourceRange ERange;
+    Expr *SimpleRefExpr = RefExpr;
+    auto Res = getPrivateItem(*this, SimpleRefExpr, ELoc, ERange,
+                              /*AllowArraySection=*/true);
+    if (Res.second) {
+      // It will be analyzed later.
+      MVLI.ProcessedVarList.push_back(RefExpr);
+    }
+    ValueDecl *D = Res.first;
+    if (!D)
+      continue;
+    auto *VD = dyn_cast<VarDecl>(D);
+
+    // If required, build a capture to implement the privatization initialized
+    // with the current list item value.
+    DeclRefExpr *Ref = nullptr;
+    if (!VD)
+      Ref = buildCapture(*this, D, SimpleRefExpr, /*WithInit=*/true);
+    MVLI.ProcessedVarList.push_back(VD ? RefExpr->IgnoreParens() : Ref);
+
+    // We need to add a data sharing attribute for this variable to make sure it
+    // is correctly captured. A variable that shows up in a use_device_addr has
+    // similar properties of a first private variable.
+    DSAStack->addDSA(D, RefExpr->IgnoreParens(), OMPC_firstprivate, Ref);
+
+    // Create a mappable component for the list item. List items in this clause
+    // only need a component.
+    MVLI.VarBaseDeclarations.push_back(D);
+    MVLI.VarComponents.emplace_back();
+    MVLI.VarComponents.back().push_back(
+        OMPClauseMappableExprCommon::MappableComponent(SimpleRefExpr, D));
+  }
+
+  if (MVLI.ProcessedVarList.empty())
+    return nullptr;
+
+  return OMPUseDeviceAddrClause::Create(Context, Locs, MVLI.ProcessedVarList,
+                                        MVLI.VarBaseDeclarations,
+                                        MVLI.VarComponents);
+}
+
 OMPClause *Sema::ActOnOpenMPIsDevicePtrClause(ArrayRef<Expr *> VarList,
                                               const OMPVarListLocTy &Locs) {
   MappableVarListInfo MVLI(VarList);
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 923792f..e4c7155 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -2036,6 +2036,15 @@
     return getSema().ActOnOpenMPUseDevicePtrClause(VarList, Locs);
   }
 
+  /// Build a new OpenMP 'use_device_addr' clause.
+  ///
+  /// By default, performs semantic analysis to build the new OpenMP clause.
+  /// Subclasses may override this routine to provide different behavior.
+  OMPClause *RebuildOMPUseDeviceAddrClause(ArrayRef<Expr *> VarList,
+                                           const OMPVarListLocTy &Locs) {
+    return getSema().ActOnOpenMPUseDeviceAddrClause(VarList, Locs);
+  }
+
   /// Build a new OpenMP 'is_device_ptr' clause.
   ///
   /// By default, performs semantic analysis to build the new OpenMP clause.
@@ -9741,6 +9750,21 @@
 }
 
 template <typename Derived>
+OMPClause *TreeTransform<Derived>::TransformOMPUseDeviceAddrClause(
+    OMPUseDeviceAddrClause *C) {
+  llvm::SmallVector<Expr *, 16> Vars;
+  Vars.reserve(C->varlist_size());
+  for (auto *VE : C->varlists()) {
+    ExprResult EVar = getDerived().TransformExpr(cast<Expr>(VE));
+    if (EVar.isInvalid())
+      return nullptr;
+    Vars.push_back(EVar.get());
+  }
+  OMPVarListLocTy Locs(C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
+  return getDerived().RebuildOMPUseDeviceAddrClause(Vars, Locs);
+}
+
+template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPIsDevicePtrClause(OMPIsDevicePtrClause *C) {
   llvm::SmallVector<Expr *, 16> Vars;
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 16bcb18..a5a1276 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -11918,6 +11918,15 @@
     C = OMPUseDevicePtrClause::CreateEmpty(Context, Sizes);
     break;
   }
+  case llvm::omp::OMPC_use_device_addr: {
+    OMPMappableExprListSizeTy Sizes;
+    Sizes.NumVars = Record.readInt();
+    Sizes.NumUniqueDeclarations = Record.readInt();
+    Sizes.NumComponentLists = Record.readInt();
+    Sizes.NumComponents = Record.readInt();
+    C = OMPUseDeviceAddrClause::CreateEmpty(Context, Sizes);
+    break;
+  }
   case llvm::omp::OMPC_is_device_ptr: {
     OMPMappableExprListSizeTy Sizes;
     Sizes.NumVars = Record.readInt();
@@ -12704,6 +12713,48 @@
   C->setComponents(Components, ListSizes);
 }
 
+void OMPClauseReader::VisitOMPUseDeviceAddrClause(OMPUseDeviceAddrClause *C) {
+  C->setLParenLoc(Record.readSourceLocation());
+  auto NumVars = C->varlist_size();
+  auto UniqueDecls = C->getUniqueDeclarationsNum();
+  auto TotalLists = C->getTotalComponentListNum();
+  auto TotalComponents = C->getTotalComponentsNum();
+
+  SmallVector<Expr *, 16> Vars;
+  Vars.reserve(NumVars);
+  for (unsigned i = 0; i != NumVars; ++i)
+    Vars.push_back(Record.readSubExpr());
+  C->setVarRefs(Vars);
+
+  SmallVector<ValueDecl *, 16> Decls;
+  Decls.reserve(UniqueDecls);
+  for (unsigned i = 0; i < UniqueDecls; ++i)
+    Decls.push_back(Record.readDeclAs<ValueDecl>());
+  C->setUniqueDecls(Decls);
+
+  SmallVector<unsigned, 16> ListsPerDecl;
+  ListsPerDecl.reserve(UniqueDecls);
+  for (unsigned i = 0; i < UniqueDecls; ++i)
+    ListsPerDecl.push_back(Record.readInt());
+  C->setDeclNumLists(ListsPerDecl);
+
+  SmallVector<unsigned, 32> ListSizes;
+  ListSizes.reserve(TotalLists);
+  for (unsigned i = 0; i < TotalLists; ++i)
+    ListSizes.push_back(Record.readInt());
+  C->setComponentListSizes(ListSizes);
+
+  SmallVector<OMPClauseMappableExprCommon::MappableComponent, 32> Components;
+  Components.reserve(TotalComponents);
+  for (unsigned i = 0; i < TotalComponents; ++i) {
+    Expr *AssociatedExpr = Record.readSubExpr();
+    auto *AssociatedDecl = Record.readDeclAs<ValueDecl>();
+    Components.push_back(OMPClauseMappableExprCommon::MappableComponent(
+        AssociatedExpr, AssociatedDecl));
+  }
+  C->setComponents(Components, ListSizes);
+}
+
 void OMPClauseReader::VisitOMPIsDevicePtrClause(OMPIsDevicePtrClause *C) {
   C->setLParenLoc(Record.readSourceLocation());
   auto NumVars = C->varlist_size();
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 1e3adb5..9d81e13 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -6625,6 +6625,26 @@
   }
 }
 
+void OMPClauseWriter::VisitOMPUseDeviceAddrClause(OMPUseDeviceAddrClause *C) {
+  Record.push_back(C->varlist_size());
+  Record.push_back(C->getUniqueDeclarationsNum());
+  Record.push_back(C->getTotalComponentListNum());
+  Record.push_back(C->getTotalComponentsNum());
+  Record.AddSourceLocation(C->getLParenLoc());
+  for (auto *E : C->varlists())
+    Record.AddStmt(E);
+  for (auto *D : C->all_decls())
+    Record.AddDeclRef(D);
+  for (auto N : C->all_num_lists())
+    Record.push_back(N);
+  for (auto N : C->all_lists_sizes())
+    Record.push_back(N);
+  for (auto &M : C->all_components()) {
+    Record.AddStmt(M.getAssociatedExpression());
+    Record.AddDeclRef(M.getAssociatedDeclaration());
+  }
+}
+
 void OMPClauseWriter::VisitOMPIsDevicePtrClause(OMPIsDevicePtrClause *C) {
   Record.push_back(C->varlist_size());
   Record.push_back(C->getUniqueDeclarationsNum());