[OpenMP 5.0] Parsing/sema support for to clause with mapper modifier.

This patch implements the parsing and sema support for OpenMP to clause
with potential user-defined mappers attached. User defined mapper is a
new feature in OpenMP 5.0. A to/from clause can have an explicit or
implicit associated mapper, which instructs the compiler to generate and
use customized mapping functions. An example is shown below:

    struct S { int len; int *d; };
    #pragma omp declare mapper(id: struct S s) map(s, s.d[0:s.len])
    struct S ss;
    #pragma omp target update to(mapper(id): ss) // use the mapper with name 'id' to map ss to device

Contributed-by: <lildmh@gmail.com>

Differential Revision: https://reviews.llvm.org/D58523

llvm-svn: 354698
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 04be5b5..5a6891d 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -9802,7 +9802,8 @@
                                DepLinMapLoc, ColonLoc, VarList, Locs);
     break;
   case OMPC_to:
-    Res = ActOnOpenMPToClause(VarList, Locs);
+    Res = ActOnOpenMPToClause(VarList, ReductionOrMapperIdScopeSpec,
+                              ReductionOrMapperId, Locs);
     break;
   case OMPC_from:
     Res = ActOnOpenMPFromClause(VarList, Locs);
@@ -13140,17 +13141,23 @@
 static void checkMappableExpressionList(
     Sema &SemaRef, DSAStackTy *DSAS, OpenMPClauseKind CKind,
     MappableVarListInfo &MVLI, SourceLocation StartLoc,
+    CXXScopeSpec &MapperIdScopeSpec, DeclarationNameInfo MapperId,
+    ArrayRef<Expr *> UnresolvedMappers,
     OpenMPMapClauseKind MapType = OMPC_MAP_unknown,
-    bool IsMapTypeImplicit = false, CXXScopeSpec *MapperIdScopeSpec = nullptr,
-    const DeclarationNameInfo *MapperId = nullptr,
-    ArrayRef<Expr *> UnresolvedMappers = llvm::None) {
+    bool IsMapTypeImplicit = false) {
   // We only expect mappable expressions in 'to', 'from', and 'map' clauses.
   assert((CKind == OMPC_map || CKind == OMPC_to || CKind == OMPC_from) &&
          "Unexpected clause kind with mappable expressions!");
-  assert(
-      ((CKind == OMPC_map && MapperIdScopeSpec && MapperId) ||
-       (CKind != OMPC_map && !MapperIdScopeSpec && !MapperId)) &&
-      "Map clauses and only map clauses have user-defined mapper identifiers.");
+
+  // If the identifier of user-defined mapper is not specified, it is "default".
+  // We do not change the actual name in this clause to distinguish whether a
+  // mapper is specified explicitly, i.e., it is not explicitly specified when
+  // MapperId.getName() is empty.
+  if (!MapperId.getName() || MapperId.getName().isEmpty()) {
+    auto &DeclNames = SemaRef.getASTContext().DeclarationNames;
+    MapperId.setName(DeclNames.getIdentifier(
+        &SemaRef.getASTContext().Idents.get("default")));
+  }
 
   // Iterators to find the current unresolved mapper expression.
   auto UMIt = UnresolvedMappers.begin(), UMEnd = UnresolvedMappers.end();
@@ -13183,16 +13190,14 @@
     if (VE->isValueDependent() || VE->isTypeDependent() ||
         VE->isInstantiationDependent() ||
         VE->containsUnexpandedParameterPack()) {
-      if (CKind == OMPC_map) {
+      if (CKind != OMPC_from) {
         // Try to find the associated user-defined mapper.
         ExprResult ER = buildUserDefinedMapperRef(
-            SemaRef, DSAS->getCurScope(), *MapperIdScopeSpec, *MapperId,
+            SemaRef, DSAS->getCurScope(), MapperIdScopeSpec, MapperId,
             VE->getType().getCanonicalType(), UnresolvedMapper);
         if (ER.isInvalid())
           continue;
         MVLI.UDMapperList.push_back(ER.get());
-      } else {
-        MVLI.UDMapperList.push_back(nullptr);
       }
       // We can only analyze this information once the missing information is
       // resolved.
@@ -13225,16 +13230,14 @@
     if (const auto *TE = dyn_cast<CXXThisExpr>(BE)) {
       // Add store "this" pointer to class in DSAStackTy for future checking
       DSAS->addMappedClassesQualTypes(TE->getType());
-      if (CKind == OMPC_map) {
+      if (CKind != OMPC_from) {
         // Try to find the associated user-defined mapper.
         ExprResult ER = buildUserDefinedMapperRef(
-            SemaRef, DSAS->getCurScope(), *MapperIdScopeSpec, *MapperId,
+            SemaRef, DSAS->getCurScope(), MapperIdScopeSpec, MapperId,
             VE->getType().getCanonicalType(), UnresolvedMapper);
         if (ER.isInvalid())
           continue;
         MVLI.UDMapperList.push_back(ER.get());
-      } else {
-        MVLI.UDMapperList.push_back(nullptr);
       }
       // Skip restriction checking for variable or field declarations
       MVLI.ProcessedVarList.push_back(RE);
@@ -13352,16 +13355,16 @@
           continue;
         }
       }
+    }
 
-      // Try to find the associated user-defined mapper.
+    // Try to find the associated user-defined mapper.
+    if (CKind != OMPC_from) {
       ExprResult ER = buildUserDefinedMapperRef(
-          SemaRef, DSAS->getCurScope(), *MapperIdScopeSpec, *MapperId,
+          SemaRef, DSAS->getCurScope(), MapperIdScopeSpec, MapperId,
           Type.getCanonicalType(), UnresolvedMapper);
       if (ER.isInvalid())
         continue;
       MVLI.UDMapperList.push_back(ER.get());
-    } else {
-      MVLI.UDMapperList.push_back(nullptr);
     }
 
     // Save the current expression.
@@ -13410,17 +13413,10 @@
     ++Count;
   }
 
-  // If the identifier of user-defined mapper is not specified, it is "default".
-  if (!MapperId.getName() || MapperId.getName().isEmpty()) {
-    auto &DeclNames = getASTContext().DeclarationNames;
-    MapperId.setName(
-        DeclNames.getIdentifier(&getASTContext().Idents.get("default")));
-  }
-
   MappableVarListInfo MVLI(VarList);
   checkMappableExpressionList(*this, DSAStack, OMPC_map, MVLI, Locs.StartLoc,
-                              MapType, IsMapTypeImplicit, &MapperIdScopeSpec,
-                              &MapperId, UnresolvedMappers);
+                              MapperIdScopeSpec, MapperId, UnresolvedMappers,
+                              MapType, IsMapTypeImplicit);
 
   // We need to produce a map clause even if we don't have variables so that
   // other diagnostics related with non-existing map clauses are accurate.
@@ -14169,20 +14165,30 @@
 }
 
 OMPClause *Sema::ActOnOpenMPToClause(ArrayRef<Expr *> VarList,
-                                     const OMPVarListLocTy &Locs) {
+                                     CXXScopeSpec &MapperIdScopeSpec,
+                                     DeclarationNameInfo &MapperId,
+                                     const OMPVarListLocTy &Locs,
+                                     ArrayRef<Expr *> UnresolvedMappers) {
   MappableVarListInfo MVLI(VarList);
-  checkMappableExpressionList(*this, DSAStack, OMPC_to, MVLI, Locs.StartLoc);
+  checkMappableExpressionList(*this, DSAStack, OMPC_to, MVLI, Locs.StartLoc,
+                              MapperIdScopeSpec, MapperId, UnresolvedMappers);
   if (MVLI.ProcessedVarList.empty())
     return nullptr;
 
-  return OMPToClause::Create(Context, Locs, MVLI.ProcessedVarList,
-                             MVLI.VarBaseDeclarations, MVLI.VarComponents);
+  return OMPToClause::Create(
+      Context, Locs, MVLI.ProcessedVarList, MVLI.VarBaseDeclarations,
+      MVLI.VarComponents, MVLI.UDMapperList,
+      MapperIdScopeSpec.getWithLocInContext(Context), MapperId);
 }
 
 OMPClause *Sema::ActOnOpenMPFromClause(ArrayRef<Expr *> VarList,
                                        const OMPVarListLocTy &Locs) {
   MappableVarListInfo MVLI(VarList);
-  checkMappableExpressionList(*this, DSAStack, OMPC_from, MVLI, Locs.StartLoc);
+  CXXScopeSpec MapperIdScopeSpec;
+  DeclarationNameInfo MapperId;
+  ArrayRef<Expr *> UnresolvedMappers;
+  checkMappableExpressionList(*this, DSAStack, OMPC_from, MVLI, Locs.StartLoc,
+                              MapperIdScopeSpec, MapperId, UnresolvedMappers);
   if (MVLI.ProcessedVarList.empty())
     return nullptr;
 
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 870a960..b8f922a 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -1902,8 +1902,12 @@
   /// By default, performs semantic analysis to build the new statement.
   /// Subclasses may override this routine to provide different behavior.
   OMPClause *RebuildOMPToClause(ArrayRef<Expr *> VarList,
-                                const OMPVarListLocTy &Locs) {
-    return getSema().ActOnOpenMPToClause(VarList, Locs);
+                                CXXScopeSpec &MapperIdScopeSpec,
+                                DeclarationNameInfo &MapperId,
+                                const OMPVarListLocTy &Locs,
+                                ArrayRef<Expr *> UnresolvedMappers) {
+    return getSema().ActOnOpenMPToClause(VarList, MapperIdScopeSpec, MapperId,
+                                         Locs, UnresolvedMappers);
   }
 
   /// Build a new OpenMP 'from' clause.
@@ -8811,34 +8815,37 @@
                                              C->getLParenLoc(), C->getEndLoc());
 }
 
-template <typename Derived>
-OMPClause *TreeTransform<Derived>::TransformOMPMapClause(OMPMapClause *C) {
-  llvm::SmallVector<Expr *, 16> Vars;
+template <typename Derived, class T>
+bool transformOMPMappableExprListClause(
+    TreeTransform<Derived> &TT, OMPMappableExprListClause<T> *C,
+    llvm::SmallVectorImpl<Expr *> &Vars, CXXScopeSpec &MapperIdScopeSpec,
+    DeclarationNameInfo &MapperIdInfo,
+    llvm::SmallVectorImpl<Expr *> &UnresolvedMappers) {
+  // Transform expressions in the list.
   Vars.reserve(C->varlist_size());
   for (auto *VE : C->varlists()) {
-    ExprResult EVar = getDerived().TransformExpr(cast<Expr>(VE));
+    ExprResult EVar = TT.getDerived().TransformExpr(cast<Expr>(VE));
     if (EVar.isInvalid())
-      return nullptr;
+      return true;
     Vars.push_back(EVar.get());
   }
+  // Transform mapper scope specifier and identifier.
   NestedNameSpecifierLoc QualifierLoc;
   if (C->getMapperQualifierLoc()) {
-    QualifierLoc = getDerived().TransformNestedNameSpecifierLoc(
+    QualifierLoc = TT.getDerived().TransformNestedNameSpecifierLoc(
         C->getMapperQualifierLoc());
     if (!QualifierLoc)
-      return nullptr;
+      return true;
   }
-  CXXScopeSpec MapperIdScopeSpec;
   MapperIdScopeSpec.Adopt(QualifierLoc);
-  DeclarationNameInfo MapperIdInfo = C->getMapperIdInfo();
+  MapperIdInfo = C->getMapperIdInfo();
   if (MapperIdInfo.getName()) {
-    MapperIdInfo = getDerived().TransformDeclarationNameInfo(MapperIdInfo);
+    MapperIdInfo = TT.getDerived().TransformDeclarationNameInfo(MapperIdInfo);
     if (!MapperIdInfo.getName())
-      return nullptr;
+      return true;
   }
   // Build a list of all candidate OMPDeclareMapperDecls, which is provided by
   // the previous user-defined mapper lookup in dependent environment.
-  llvm::SmallVector<Expr *, 16> UnresolvedMappers;
   for (auto *E : C->mapperlists()) {
     // Transform all the decls.
     if (E) {
@@ -8846,18 +8853,31 @@
       UnresolvedSet<8> Decls;
       for (auto *D : ULE->decls()) {
         NamedDecl *InstD =
-            cast<NamedDecl>(getDerived().TransformDecl(E->getExprLoc(), D));
+            cast<NamedDecl>(TT.getDerived().TransformDecl(E->getExprLoc(), D));
         Decls.addDecl(InstD, InstD->getAccess());
       }
       UnresolvedMappers.push_back(UnresolvedLookupExpr::Create(
-          SemaRef.Context, /*NamingClass=*/nullptr,
-          MapperIdScopeSpec.getWithLocInContext(SemaRef.Context), MapperIdInfo,
-          /*ADL=*/false, ULE->isOverloaded(), Decls.begin(), Decls.end()));
+          TT.getSema().Context, /*NamingClass=*/nullptr,
+          MapperIdScopeSpec.getWithLocInContext(TT.getSema().Context),
+          MapperIdInfo, /*ADL=*/true, ULE->isOverloaded(), Decls.begin(),
+          Decls.end()));
     } else {
       UnresolvedMappers.push_back(nullptr);
     }
   }
+  return false;
+}
+
+template <typename Derived>
+OMPClause *TreeTransform<Derived>::TransformOMPMapClause(OMPMapClause *C) {
   OMPVarListLocTy Locs(C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
+  llvm::SmallVector<Expr *, 16> Vars;
+  CXXScopeSpec MapperIdScopeSpec;
+  DeclarationNameInfo MapperIdInfo;
+  llvm::SmallVector<Expr *, 16> UnresolvedMappers;
+  if (transformOMPMappableExprListClause<Derived, OMPMapClause>(
+          *this, C, Vars, MapperIdScopeSpec, MapperIdInfo, UnresolvedMappers))
+    return nullptr;
   return getDerived().RebuildOMPMapClause(
       C->getMapTypeModifiers(), C->getMapTypeModifiersLoc(), MapperIdScopeSpec,
       MapperIdInfo, C->getMapType(), C->isImplicitMapType(), C->getMapLoc(),
@@ -8942,16 +8962,16 @@
 
 template <typename Derived>
 OMPClause *TreeTransform<Derived>::TransformOMPToClause(OMPToClause *C) {
-  llvm::SmallVector<Expr *, 16> Vars;
-  Vars.reserve(C->varlist_size());
-  for (auto *VE : C->varlists()) {
-    ExprResult EVar = getDerived().TransformExpr(cast<Expr>(VE));
-    if (EVar.isInvalid())
-      return 0;
-    Vars.push_back(EVar.get());
-  }
   OMPVarListLocTy Locs(C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
-  return getDerived().RebuildOMPToClause(Vars, Locs);
+  llvm::SmallVector<Expr *, 16> Vars;
+  CXXScopeSpec MapperIdScopeSpec;
+  DeclarationNameInfo MapperIdInfo;
+  llvm::SmallVector<Expr *, 16> UnresolvedMappers;
+  if (transformOMPMappableExprListClause<Derived, OMPToClause>(
+          *this, C, Vars, MapperIdScopeSpec, MapperIdInfo, UnresolvedMappers))
+    return nullptr;
+  return getDerived().RebuildOMPToClause(Vars, MapperIdScopeSpec, MapperIdInfo,
+                                         Locs, UnresolvedMappers);
 }
 
 template <typename Derived>