[Concept] Associated Constraints Infrastructure

Add code to correctly calculate the associated constraints of a template (no enforcement yet).
D41284 on Phabricator.

llvm-svn: 374938
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index cd6ea7f..1a3bde0 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -738,7 +738,7 @@
                                            cast<TemplateTemplateParmDecl>(*P)));
   }
 
-  assert(!TTP->getRequiresClause() &&
+  assert(!TTP->getTemplateParameters()->getRequiresClause() &&
          "Unexpected requires-clause on template template-parameter");
   Expr *const CanonRequiresClause = nullptr;
 
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index ccfc292..7e013c6 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -70,6 +70,8 @@
   }
   if (RequiresClause) {
     *getTrailingObjects<Expr *>() = RequiresClause;
+    if (RequiresClause->containsUnexpandedParameterPack())
+      ContainsUnexpandedParameterPack = true;
   }
 }
 
@@ -136,6 +138,18 @@
   }
 }
 
+void TemplateParameterList::
+getAssociatedConstraints(llvm::SmallVectorImpl<const Expr *> &AC) const {
+  // TODO: Concepts: Collect immediately-introduced constraints.
+  if (HasRequiresClause)
+    AC.push_back(getRequiresClause());
+}
+
+bool TemplateParameterList::hasAssociatedConstraints() const {
+  // TODO: Concepts: Regard immediately-introduced constraints.
+  return HasRequiresClause;
+}
+
 namespace clang {
 
 void *allocateDefaultArgStorageChain(const ASTContext &C) {
@@ -145,6 +159,28 @@
 } // namespace clang
 
 //===----------------------------------------------------------------------===//
+// TemplateDecl Implementation
+//===----------------------------------------------------------------------===//
+
+TemplateDecl::TemplateDecl(Kind DK, DeclContext *DC, SourceLocation L,
+                           DeclarationName Name, TemplateParameterList *Params,
+                           NamedDecl *Decl)
+    : NamedDecl(DK, DC, L, Name), TemplatedDecl(Decl), TemplateParams(Params) {}
+
+void TemplateDecl::anchor() {}
+
+void TemplateDecl::
+getAssociatedConstraints(llvm::SmallVectorImpl<const Expr *> &AC) const {
+  // TODO: Concepts: Append function trailing requires clause.
+  TemplateParams->getAssociatedConstraints(AC);
+}
+
+bool TemplateDecl::hasAssociatedConstraints() const {
+  // TODO: Concepts: Regard function trailing requires clause.
+  return TemplateParams->hasAssociatedConstraints();
+}
+
+//===----------------------------------------------------------------------===//
 // RedeclarableTemplateDecl Implementation
 //===----------------------------------------------------------------------===//
 
@@ -344,19 +380,10 @@
                                              SourceLocation L,
                                              DeclarationName Name,
                                              TemplateParameterList *Params,
-                                             NamedDecl *Decl,
-                                             Expr *AssociatedConstraints) {
+                                             NamedDecl *Decl) {
   AdoptTemplateParameterList(Params, cast<DeclContext>(Decl));
 
-  if (!AssociatedConstraints) {
-    return new (C, DC) ClassTemplateDecl(C, DC, L, Name, Params, Decl);
-  }
-
-  auto *const CTDI = new (C) ConstrainedTemplateDeclInfo;
-  auto *const New =
-      new (C, DC) ClassTemplateDecl(CTDI, C, DC, L, Name, Params, Decl);
-  New->setAssociatedConstraints(AssociatedConstraints);
-  return New;
+  return new (C, DC) ClassTemplateDecl(C, DC, L, Name, Params, Decl);
 }
 
 ClassTemplateDecl *ClassTemplateDecl::CreateDeserialized(ASTContext &C,
@@ -708,12 +735,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// TemplateDecl Implementation
-//===----------------------------------------------------------------------===//
-
-void TemplateDecl::anchor() {}
-
-//===----------------------------------------------------------------------===//
 // ClassTemplateSpecializationDecl Implementation
 //===----------------------------------------------------------------------===//
 
diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp
index 3131609..848ccf5 100644
--- a/clang/lib/Sema/SemaConcept.cpp
+++ b/clang/lib/Sema/SemaConcept.cpp
@@ -122,4 +122,4 @@
   IsSatisfied = EvalResult.Val.getInt().getBoolValue();
 
   return false;
-}
+}
\ No newline at end of file
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index cb756eb..2871511 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -45,27 +45,7 @@
   return SourceRange(Ps[0]->getTemplateLoc(), Ps[N-1]->getRAngleLoc());
 }
 
-namespace clang {
-/// [temp.constr.decl]p2: A template's associated constraints are
-/// defined as a single constraint-expression derived from the introduced
-/// constraint-expressions [ ... ].
-///
-/// \param Params The template parameter list and optional requires-clause.
-///
-/// \param FD The underlying templated function declaration for a function
-/// template.
-static Expr *formAssociatedConstraints(TemplateParameterList *Params,
-                                       FunctionDecl *FD);
-}
-
-static Expr *clang::formAssociatedConstraints(TemplateParameterList *Params,
-                                              FunctionDecl *FD) {
-  // FIXME: Concepts: collect additional introduced constraint-expressions
-  assert(!FD && "Cannot collect constraints from function declaration yet.");
-  return Params->getRequiresClause();
-}
-
-/// Determine whether the declaration found is acceptable as the name
+/// \brief Determine whether the declaration found is acceptable as the name
 /// of a template and, if so, return that template declaration. Otherwise,
 /// returns null.
 ///
@@ -1533,9 +1513,6 @@
     }
   }
 
-  // TODO Memory management; associated constraints are not always stored.
-  Expr *const CurAC = formAssociatedConstraints(TemplateParams, nullptr);
-
   if (PrevClassTemplate) {
     // Ensure that the template parameter lists are compatible. Skip this check
     // for a friend in a dependent context: the template parameter list itself
@@ -1547,30 +1524,6 @@
                                         TPL_TemplateMatch))
       return true;
 
-    // Check for matching associated constraints on redeclarations.
-    const Expr *const PrevAC = PrevClassTemplate->getAssociatedConstraints();
-    const bool RedeclACMismatch = [&] {
-      if (!(CurAC || PrevAC))
-        return false; // Nothing to check; no mismatch.
-      if (CurAC && PrevAC) {
-        llvm::FoldingSetNodeID CurACInfo, PrevACInfo;
-        CurAC->Profile(CurACInfo, Context, /*Canonical=*/true);
-        PrevAC->Profile(PrevACInfo, Context, /*Canonical=*/true);
-        if (CurACInfo == PrevACInfo)
-          return false; // All good; no mismatch.
-      }
-      return true;
-    }();
-
-    if (RedeclACMismatch) {
-      Diag(CurAC ? CurAC->getBeginLoc() : NameLoc,
-           diag::err_template_different_associated_constraints);
-      Diag(PrevAC ? PrevAC->getBeginLoc() : PrevClassTemplate->getLocation(),
-           diag::note_template_prev_declaration)
-          << /*declaration*/ 0;
-      return true;
-    }
-
     // C++ [temp.class]p4:
     //   In a redeclaration, partial specialization, explicit
     //   specialization or explicit instantiation of a class template,
@@ -1674,15 +1627,10 @@
     AddMsStructLayoutForRecord(NewClass);
   }
 
-  // Attach the associated constraints when the declaration will not be part of
-  // a decl chain.
-  Expr *const ACtoAttach =
-      PrevClassTemplate && ShouldAddRedecl ? nullptr : CurAC;
-
   ClassTemplateDecl *NewTemplate
     = ClassTemplateDecl::Create(Context, SemanticContext, NameLoc,
                                 DeclarationName(Name), TemplateParams,
-                                NewClass, ACtoAttach);
+                                NewClass);
 
   if (ShouldAddRedecl)
     NewTemplate->setPreviousDecl(PrevClassTemplate);
@@ -7266,6 +7214,9 @@
                                             TemplateArgLoc);
   }
 
+  // TODO: Concepts: Match immediately-introduced-constraint for type
+  // constraints
+
   return true;
 }
 
@@ -7291,6 +7242,15 @@
     << SourceRange(Old->getTemplateLoc(), Old->getRAngleLoc());
 }
 
+static void
+DiagnoseTemplateParameterListRequiresClauseMismatch(Sema &S,
+                                                    TemplateParameterList *New,
+                                                    TemplateParameterList *Old){
+  S.Diag(New->getTemplateLoc(), diag::err_template_different_requires_clause);
+  S.Diag(Old->getTemplateLoc(),  diag::note_template_prev_declaration)
+      << /*declaration*/0;
+}
+
 /// Determine whether the given template parameter lists are
 /// equivalent.
 ///
@@ -7380,6 +7340,27 @@
     return false;
   }
 
+  if (Kind != TPL_TemplateTemplateArgumentMatch) {
+    const Expr *NewRC = New->getRequiresClause();
+    const Expr *OldRC = Old->getRequiresClause();
+    if (!NewRC != !OldRC) {
+      if (Complain)
+        DiagnoseTemplateParameterListRequiresClauseMismatch(*this, New, Old);
+      return false;
+    }
+
+    if (NewRC) {
+      llvm::FoldingSetNodeID OldRCID, NewRCID;
+      OldRC->Profile(OldRCID, Context, /*Canonical=*/true);
+      NewRC->Profile(NewRCID, Context, /*Canonical=*/true);
+      if (OldRCID != NewRCID) {
+        if (Complain)
+          DiagnoseTemplateParameterListRequiresClauseMismatch(*this, New, Old);
+        return false;
+      }
+    }
+  }
+
   return true;
 }
 
@@ -8089,10 +8070,9 @@
                                              TemplateParameterLists.front(),
                                              ConstraintExpr);
                                              
-  if (NewDecl->getAssociatedConstraints()) {
+  if (NewDecl->hasAssociatedConstraints()) {
     // C++2a [temp.concept]p4:
     // A concept shall not have associated constraints.
-    // TODO: Make a test once we have actual associated constraints.
     Diag(NameLoc, diag::err_concept_no_associated_constraints);
     NewDecl->setInvalidDecl();
   }
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 818548c..d1ad304 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -3515,14 +3515,21 @@
   if (Invalid)
     return nullptr;
 
-  // Note: we substitute into associated constraints later
-  Expr *const UninstantiatedRequiresClause = L->getRequiresClause();
+  // FIXME: Concepts: Substitution into requires clause should only happen when
+  // checking satisfaction.
+  Expr *InstRequiresClause = nullptr;
+  if (Expr *E = L->getRequiresClause()) {
+    ExprResult Res = SemaRef.SubstExpr(E, TemplateArgs);
+    if (Res.isInvalid() || !Res.isUsable()) {
+      return nullptr;
+    }
+    InstRequiresClause = Res.get();
+  }
 
   TemplateParameterList *InstL
     = TemplateParameterList::Create(SemaRef.Context, L->getTemplateLoc(),
                                     L->getLAngleLoc(), Params,
-                                    L->getRAngleLoc(),
-                                    UninstantiatedRequiresClause);
+                                    L->getRAngleLoc(), InstRequiresClause);
   return InstL;
 }
 
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 20c6cbe..d879076 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -9345,9 +9345,11 @@
   while (NumParams--)
     Params.push_back(ReadDeclAs<NamedDecl>(F, Record, Idx));
 
-  // TODO: Concepts
+  bool HasRequiresClause = Record[Idx++];
+  Expr *RequiresClause = HasRequiresClause ? ReadExpr(F) : nullptr;
+
   TemplateParameterList *TemplateParams = TemplateParameterList::Create(
-      getContext(), TemplateLoc, LAngleLoc, Params, RAngleLoc, nullptr);
+      getContext(), TemplateLoc, LAngleLoc, Params, RAngleLoc, RequiresClause);
   return TemplateParams;
 }
 
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index d906286..65d6252 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -2000,7 +2000,6 @@
   DeclID PatternID = ReadDeclID();
   auto *TemplatedDecl = cast_or_null<NamedDecl>(Reader.GetDecl(PatternID));
   TemplateParameterList *TemplateParams = Record.readTemplateParameterList();
-  // FIXME handle associated constraints
   D->init(TemplatedDecl, TemplateParams);
 
   return PatternID;
@@ -2166,7 +2165,8 @@
                                     ClassTemplatePartialSpecializationDecl *D) {
   RedeclarableResult Redecl = VisitClassTemplateSpecializationDeclImpl(D);
 
-  D->TemplateParams = Record.readTemplateParameterList();
+  TemplateParameterList *Params = Record.readTemplateParameterList();
+  D->TemplateParams = Params;
   D->ArgsAsWritten = Record.readASTTemplateArgumentListInfo();
 
   // These are read/set from/to the first declaration.
@@ -2268,7 +2268,8 @@
     VarTemplatePartialSpecializationDecl *D) {
   RedeclarableResult Redecl = VisitVarTemplateSpecializationDeclImpl(D);
 
-  D->TemplateParams = Record.readTemplateParameterList();
+  TemplateParameterList *Params = Record.readTemplateParameterList();
+  D->TemplateParams = Params;
   D->ArgsAsWritten = Record.readASTTemplateArgumentListInfo();
 
   // These are read/set from/to the first declaration.
@@ -2284,6 +2285,7 @@
 
   D->setDeclaredWithTypename(Record.readInt());
 
+  // TODO: Concepts: Immediately introduced constraint
   if (Record.readInt())
     D->setDefaultArgument(GetTypeSourceInfo());
 }
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 5e9e650..aef3523 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -6070,10 +6070,16 @@
   AddSourceLocation(TemplateParams->getTemplateLoc());
   AddSourceLocation(TemplateParams->getLAngleLoc());
   AddSourceLocation(TemplateParams->getRAngleLoc());
-  // TODO: Concepts
+
   Record->push_back(TemplateParams->size());
   for (const auto &P : *TemplateParams)
     AddDeclRef(P);
+  if (const Expr *RequiresClause = TemplateParams->getRequiresClause()) {
+    Record->push_back(true);
+    AddStmt(const_cast<Expr*>(RequiresClause));
+  } else {
+    Record->push_back(false);
+  }
 }
 
 /// Emit a template argument list.
diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index 2c22587..039b57f 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -1608,7 +1608,7 @@
   VisitTypeDecl(D);
 
   Record.push_back(D->wasDeclaredWithTypename());
-
+  // TODO: Concepts - constrained parameters.
   bool OwnsDefaultArg = D->hasDefaultArgument() &&
                         !D->defaultArgumentWasInherited();
   Record.push_back(OwnsDefaultArg);
@@ -1638,6 +1638,7 @@
 
     Code = serialization::DECL_EXPANDED_NON_TYPE_TEMPLATE_PARM_PACK;
   } else {
+    // TODO: Concepts - constrained parameters.
     // Rest of NonTypeTemplateParmDecl.
     Record.push_back(D->isParameterPack());
     bool OwnsDefaultArg = D->hasDefaultArgument() &&
@@ -1667,6 +1668,7 @@
       Record.AddTemplateParameterList(D->getExpansionTemplateParameters(I));
     Code = serialization::DECL_EXPANDED_TEMPLATE_TEMPLATE_PARM_PACK;
   } else {
+    // TODO: Concepts - constrained parameters.
     // Rest of TemplateTemplateParmDecl.
     Record.push_back(D->isParameterPack());
     bool OwnsDefaultArg = D->hasDefaultArgument() &&