[OPENMP50]Add initial support for 'affinity' clause.

Summary:
Added parsing/sema/serialization support for affinity clause in task
directives.

Reviewers: jdoerfert

Subscribers: yaxunl, guansong, arphaman, llvm-commits, cfe-commits, caomhin

Tags: #clang, #llvm

Differential Revision: https://reviews.llvm.org/D80148
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index d4d398f..14c4c78 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -151,6 +151,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     break;
   }
 
@@ -241,6 +242,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     break;
   }
 
@@ -1368,6 +1370,25 @@
   return new (Mem) OMPUsesAllocatorsClause(N);
 }
 
+OMPAffinityClause *
+OMPAffinityClause::Create(const ASTContext &C, SourceLocation StartLoc,
+                          SourceLocation LParenLoc, SourceLocation ColonLoc,
+                          SourceLocation EndLoc, Expr *Modifier,
+                          ArrayRef<Expr *> Locators) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(Locators.size() + 1));
+  auto *Clause = new (Mem)
+      OMPAffinityClause(StartLoc, LParenLoc, ColonLoc, EndLoc, Locators.size());
+  Clause->setModifier(Modifier);
+  Clause->setVarRefs(Locators);
+  return Clause;
+}
+
+OMPAffinityClause *OMPAffinityClause::CreateEmpty(const ASTContext &C,
+                                                  unsigned N) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N + 1));
+  return new (Mem) OMPAffinityClause(N);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenMP clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -1969,6 +1990,21 @@
   OS << ")";
 }
 
+void OMPClausePrinter::VisitOMPAffinityClause(OMPAffinityClause *Node) {
+  if (Node->varlist_empty())
+    return;
+  OS << "affinity";
+  char StartSym = '(';
+  if (Expr *Modifier = Node->getModifier()) {
+    OS << "(";
+    Modifier->printPretty(OS, nullptr, Policy);
+    OS << " :";
+    StartSym = ' ';
+  }
+  VisitOMPClauseList(Node, StartSym);
+  OS << ")";
+}
+
 void OMPTraitInfo::getAsVariantMatchInfo(ASTContext &ASTCtx,
                                          VariantMatchInfo &VMI) const {
   for (const OMPTraitSet &Set : Sets) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 501c07b..bd2eeb6 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -809,6 +809,12 @@
       Profiler->VisitStmt(D.AllocatorTraits);
   }
 }
+void OMPClauseProfiler::VisitOMPAffinityClause(const OMPAffinityClause *C) {
+  if (const Expr *Modifier = C->getModifier())
+    Profiler->VisitStmt(Modifier);
+  for (const Expr *E : C->varlists())
+    Profiler->VisitStmt(E);
+}
 void OMPClauseProfiler::VisitOMPOrderClause(const OMPOrderClause *C) {}
 } // namespace
 
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 841d76b..8dddb66 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -175,6 +175,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     break;
   }
   llvm_unreachable("Invalid OpenMP simple clause kind");
@@ -422,6 +423,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     break;
   }
   llvm_unreachable("Invalid OpenMP simple clause kind");
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index ae30944..f91098d 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -4734,6 +4734,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     llvm_unreachable("Clause is not allowed in 'omp atomic'.");
   }
 }
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 19acfe1..bd40e6b 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2668,6 +2668,7 @@
   case OMPC_nontemporal:
   case OMPC_inclusive:
   case OMPC_exclusive:
+  case OMPC_affinity:
     Clause = ParseOpenMPVarListClause(DKind, CKind, WrongDirective);
     break;
   case OMPC_uses_allocators:
@@ -3275,7 +3276,7 @@
                          getOpenMPClauseName(Kind).data()))
     return true;
 
-  bool DependWithIterator = false;
+  bool HasIterator = false;
   bool NeedRParenForLinear = false;
   BalancedDelimiterTracker LinearT(*this, tok::l_paren,
                                   tok::annot_pragma_openmp_end);
@@ -3321,7 +3322,7 @@
         // iterators-definition ]
         // where iterator-specifier is [ iterator-type ] identifier =
         // range-specification
-        DependWithIterator = true;
+        HasIterator = true;
         EnterScope(Scope::OpenMPDirectiveScope | Scope::DeclScope);
         ExprResult IteratorRes = ParseOpenMPIteratorsExpr();
         Data.DepModOrTailExpr = IteratorRes.get();
@@ -3440,12 +3441,24 @@
           ConsumeToken();
       }
     }
-  } else if (Kind == OMPC_allocate) {
+  } else if (Kind == OMPC_allocate ||
+             (Kind == OMPC_affinity && Tok.is(tok::identifier) &&
+              PP.getSpelling(Tok) == "iterator")) {
     // Handle optional allocator expression followed by colon delimiter.
     ColonProtectionRAIIObject ColonRAII(*this);
     TentativeParsingAction TPA(*this);
-    ExprResult Tail =
-        Actions.CorrectDelayedTyposInExpr(ParseAssignmentExpression());
+    // OpenMP 5.0, 2.10.1, task Construct.
+    // where aff-modifier is one of the following:
+    // iterator(iterators-definition)
+    ExprResult Tail;
+    if (Kind == OMPC_allocate) {
+      Tail = ParseAssignmentExpression();
+    } else {
+      HasIterator = true;
+      EnterScope(Scope::OpenMPDirectiveScope | Scope::DeclScope);
+      Tail = ParseOpenMPIteratorsExpr();
+    }
+    Tail = Actions.CorrectDelayedTyposInExpr(Tail);
     Tail = Actions.ActOnFinishFullExpr(Tail.get(), T.getOpenLocation(),
                                        /*DiscardedValue=*/false);
     if (Tail.isUsable()) {
@@ -3454,8 +3467,7 @@
         Data.ColonLoc = ConsumeToken();
         TPA.Commit();
       } else {
-        // colon not found, no allocator specified, parse only list of
-        // variables.
+        // Colon not found, parse only list of variables.
         TPA.Revert();
       }
     } else {
@@ -3524,7 +3536,7 @@
   if (!T.consumeClose())
     Data.RLoc = T.getCloseLocation();
   // Exit from scope when the iterator is used in depend clause.
-  if (DependWithIterator)
+  if (HasIterator)
     ExitScope();
   return (Kind != OMPC_depend && Kind != OMPC_map && Vars.empty()) ||
          (MustHaveTail && !Data.DepModOrTailExpr) || InvalidReductionId ||
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index e03b926..b62fb26 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -18,6 +18,7 @@
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/DeclOpenMP.h"
+#include "clang/AST/OpenMPClause.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/StmtVisitor.h"
@@ -5414,6 +5415,7 @@
       case OMPC_inclusive:
       case OMPC_exclusive:
       case OMPC_uses_allocators:
+      case OMPC_affinity:
         continue;
       case OMPC_allocator:
       case OMPC_flush:
@@ -11547,6 +11549,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -12301,6 +12304,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     llvm_unreachable("Unexpected OpenMP clause.");
   }
   return CaptureRegion;
@@ -12740,6 +12744,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -12966,6 +12971,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -13199,6 +13205,7 @@
   case OMPC_inclusive:
   case OMPC_exclusive:
   case OMPC_uses_allocators:
+  case OMPC_affinity:
     llvm_unreachable("Clause is not allowed.");
   }
   return Res;
@@ -13415,6 +13422,10 @@
   case OMPC_exclusive:
     Res = ActOnOpenMPExclusiveClause(VarList, StartLoc, LParenLoc, EndLoc);
     break;
+  case OMPC_affinity:
+    Res = ActOnOpenMPAffinityClause(StartLoc, LParenLoc, ColonLoc, EndLoc,
+                                    DepModOrTailExpr, VarList);
+    break;
   case OMPC_if:
   case OMPC_depobj:
   case OMPC_final:
@@ -18785,3 +18796,42 @@
   return OMPUsesAllocatorsClause::Create(Context, StartLoc, LParenLoc, EndLoc,
                                          NewData);
 }
+
+OMPClause *Sema::ActOnOpenMPAffinityClause(
+    SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation ColonLoc,
+    SourceLocation EndLoc, Expr *Modifier, ArrayRef<Expr *> Locators) {
+  SmallVector<Expr *, 8> Vars;
+  for (Expr *RefExpr : Locators) {
+    assert(RefExpr && "NULL expr in OpenMP shared clause.");
+    if (isa<DependentScopeDeclRefExpr>(RefExpr) || RefExpr->isTypeDependent()) {
+      // It will be analyzed later.
+      Vars.push_back(RefExpr);
+      continue;
+    }
+
+    SourceLocation ELoc = RefExpr->getExprLoc();
+    Expr *SimpleExpr = RefExpr->IgnoreParenImpCasts();
+
+    if (!SimpleExpr->isLValue()) {
+      Diag(ELoc, diag::err_omp_expected_addressable_lvalue_or_array_item)
+          << 1 << 0 << RefExpr->getSourceRange();
+      continue;
+    }
+
+    ExprResult Res;
+    {
+      Sema::TentativeAnalysisScope Trap(*this);
+      Res = CreateBuiltinUnaryOp(ELoc, UO_AddrOf, SimpleExpr);
+    }
+    if (!Res.isUsable() && !isa<OMPArraySectionExpr>(SimpleExpr) &&
+        !isa<OMPArrayShapingExpr>(SimpleExpr)) {
+      Diag(ELoc, diag::err_omp_expected_addressable_lvalue_or_array_item)
+          << 1 << 0 << RefExpr->getSourceRange();
+      continue;
+    }
+    Vars.push_back(SimpleExpr);
+  }
+
+  return OMPAffinityClause::Create(Context, StartLoc, LParenLoc, ColonLoc,
+                                   EndLoc, Modifier, Vars);
+}
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index f8b84c2..923792f 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -2107,6 +2107,19 @@
                                                     Data);
   }
 
+  /// Build a new OpenMP 'affinity' clause.
+  ///
+  /// By default, performs semantic analysis to build the new OpenMP clause.
+  /// Subclasses may override this routine to provide different behavior.
+  OMPClause *RebuildOMPAffinityClause(SourceLocation StartLoc,
+                                      SourceLocation LParenLoc,
+                                      SourceLocation ColonLoc,
+                                      SourceLocation EndLoc, Expr *Modifier,
+                                      ArrayRef<Expr *> Locators) {
+    return getSema().ActOnOpenMPAffinityClause(StartLoc, LParenLoc, ColonLoc,
+                                               EndLoc, Modifier, Locators);
+  }
+
   /// Build a new OpenMP 'order' clause.
   ///
   /// By default, performs semantic analysis to build the new OpenMP clause.
@@ -9814,6 +9827,28 @@
 }
 
 template <typename Derived>
+OMPClause *
+TreeTransform<Derived>::TransformOMPAffinityClause(OMPAffinityClause *C) {
+  SmallVector<Expr *, 4> Locators;
+  Locators.reserve(C->varlist_size());
+  ExprResult ModifierRes;
+  if (Expr *Modifier = C->getModifier()) {
+    ModifierRes = getDerived().TransformExpr(Modifier);
+    if (ModifierRes.isInvalid())
+      return nullptr;
+  }
+  for (Expr *E : C->varlists()) {
+    ExprResult Locator = getDerived().TransformExpr(E);
+    if (Locator.isInvalid())
+      continue;
+    Locators.push_back(Locator.get());
+  }
+  return getDerived().RebuildOMPAffinityClause(
+      C->getBeginLoc(), C->getLParenLoc(), C->getColonLoc(), C->getEndLoc(),
+      ModifierRes.get(), Locators);
+}
+
+template <typename Derived>
 OMPClause *TreeTransform<Derived>::TransformOMPOrderClause(OMPOrderClause *C) {
   return getDerived().RebuildOMPOrderClause(C->getKind(), C->getKindKwLoc(),
                                             C->getBeginLoc(), C->getLParenLoc(),
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 3f41646..16bcb18 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -11951,6 +11951,9 @@
   case llvm::omp::OMPC_uses_allocators:
     C = OMPUsesAllocatorsClause::CreateEmpty(Context, Record.readInt());
     break;
+  case llvm::omp::OMPC_affinity:
+    C = OMPAffinityClause::CreateEmpty(Context, Record.readInt());
+    break;
 #define OMP_CLAUSE_NO_CLASS(Enum, Str)                                         \
   case llvm::omp::Enum:                                                        \
     break;
@@ -12794,6 +12797,18 @@
   C->setAllocatorsData(Data);
 }
 
+void OMPClauseReader::VisitOMPAffinityClause(OMPAffinityClause *C) {
+  C->setLParenLoc(Record.readSourceLocation());
+  C->setModifier(Record.readSubExpr());
+  C->setColonLoc(Record.readSourceLocation());
+  unsigned NumOfLocators = C->varlist_size();
+  SmallVector<Expr *, 4> Locators;
+  Locators.reserve(NumOfLocators);
+  for (unsigned I = 0; I != NumOfLocators; ++I)
+    Locators.push_back(Record.readSubExpr());
+  C->setVarRefs(Locators);
+}
+
 void OMPClauseReader::VisitOMPOrderClause(OMPOrderClause *C) {
   C->setKind(Record.readEnum<OpenMPOrderClauseKind>());
   C->setLParenLoc(Record.readSourceLocation());
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index bd7cbfc..1e3adb5 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -6704,6 +6704,15 @@
   }
 }
 
+void OMPClauseWriter::VisitOMPAffinityClause(OMPAffinityClause *C) {
+  Record.push_back(C->varlist_size());
+  Record.AddSourceLocation(C->getLParenLoc());
+  Record.AddStmt(C->getModifier());
+  Record.AddSourceLocation(C->getColonLoc());
+  for (Expr *E : C->varlists())
+    Record.AddStmt(E);
+}
+
 void ASTRecordWriter::writeOMPTraitInfo(const OMPTraitInfo *TI) {
   writeUInt32(TI->Sets.size());
   for (const auto &Set : TI->Sets) {