Hide the specializations folding sets of ClassTemplateDecl as an implementation detail (InsertPos
leaks though) and add methods to its interface for adding/finding specializations.

Simplifies its users a bit and we no longer need to replace specializations in the folding set with
their redeclarations. We just return the most recent redeclarations.

As a bonus, it fixes http://llvm.org/PR7670.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@108832 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/AST/DeclTemplate.cpp b/lib/AST/DeclTemplate.cpp
index 9e1d79d..a75c1c0 100644
--- a/lib/AST/DeclTemplate.cpp
+++ b/lib/AST/DeclTemplate.cpp
@@ -171,6 +171,28 @@
   Decl::Destroy(C);
 }
 
+ClassTemplateSpecializationDecl *
+ClassTemplateDecl::findSpecialization(const TemplateArgument *Args,
+                                      unsigned NumArgs, void *&InsertPos) {
+  llvm::FoldingSetNodeID ID;
+  ClassTemplateSpecializationDecl::Profile(ID, Args, NumArgs, getASTContext());
+  ClassTemplateSpecializationDecl *D
+      = getSpecializations().FindNodeOrInsertPos(ID, InsertPos);
+  return D ? D->getMostRecentDeclaration() : 0;
+}
+
+ClassTemplatePartialSpecializationDecl *
+ClassTemplateDecl::findPartialSpecialization(const TemplateArgument *Args,
+                                             unsigned NumArgs,
+                                             void *&InsertPos) {
+  llvm::FoldingSetNodeID ID;
+  ClassTemplatePartialSpecializationDecl::Profile(ID, Args, NumArgs,
+                                                  getASTContext());
+  ClassTemplatePartialSpecializationDecl *D
+      = getPartialSpecializations().FindNodeOrInsertPos(ID, InsertPos);
+  return D ? D->getMostRecentDeclaration() : 0;
+}
+
 void ClassTemplateDecl::getPartialSpecializations(
           llvm::SmallVectorImpl<ClassTemplatePartialSpecializationDecl *> &PS) {
   llvm::FoldingSet<ClassTemplatePartialSpecializationDecl> &PartialSpecs
@@ -181,7 +203,7 @@
        P = PartialSpecs.begin(), PEnd = PartialSpecs.end();
        P != PEnd; ++P) {
     assert(!PS[P->getSequenceNumber()]);
-    PS[P->getSequenceNumber()] = &*P;
+    PS[P->getSequenceNumber()] = P->getMostRecentDeclaration();
   }
 }
 
@@ -194,7 +216,22 @@
                           PEnd = getPartialSpecializations().end();
        P != PEnd; ++P) {
     if (Context.hasSameType(P->getInjectedSpecializationType(), T))
-      return &*P;
+      return P->getMostRecentDeclaration();
+  }
+
+  return 0;
+}
+
+ClassTemplatePartialSpecializationDecl *
+ClassTemplateDecl::findPartialSpecInstantiatedFromMember(
+                                    ClassTemplatePartialSpecializationDecl *D) {
+  Decl *DCanon = D->getCanonicalDecl();
+  for (llvm::FoldingSet<ClassTemplatePartialSpecializationDecl>::iterator
+            P = getPartialSpecializations().begin(),
+         PEnd = getPartialSpecializations().end();
+       P != PEnd; ++P) {
+    if (P->getInstantiatedFromMember()->getCanonicalDecl() == DCanon)
+      return P->getMostRecentDeclaration();
   }
 
   return 0;
diff --git a/lib/Sema/SemaTemplate.cpp b/lib/Sema/SemaTemplate.cpp
index 4cb9433..c8b5338 100644
--- a/lib/Sema/SemaTemplate.cpp
+++ b/lib/Sema/SemaTemplate.cpp
@@ -1477,14 +1477,10 @@
                = dyn_cast<ClassTemplateDecl>(Template)) {
     // Find the class template specialization declaration that
     // corresponds to these arguments.
-    llvm::FoldingSetNodeID ID;
-    ClassTemplateSpecializationDecl::Profile(ID,
-                                             Converted.getFlatArguments(),
-                                             Converted.flatSize(),
-                                             Context);
     void *InsertPos = 0;
     ClassTemplateSpecializationDecl *Decl
-      = ClassTemplate->getSpecializations().FindNodeOrInsertPos(ID, InsertPos);
+      = ClassTemplate->findSpecialization(Converted.getFlatArguments(),
+                                          Converted.flatSize(), InsertPos);
     if (!Decl) {
       // This is the first time we have referenced this class template
       // specialization. Create the canonical declaration and add it to
@@ -1495,7 +1491,7 @@
                                                 ClassTemplate->getLocation(),
                                                 ClassTemplate,
                                                 Converted, 0);
-      ClassTemplate->getSpecializations().InsertNode(Decl, InsertPos);
+      ClassTemplate->AddSpecialization(Decl, InsertPos);
       Decl->setLexicalDeclContext(CurContext);
     }
 
@@ -3727,7 +3723,6 @@
 
   // Find the class template (partial) specialization declaration that
   // corresponds to these arguments.
-  llvm::FoldingSetNodeID ID;
   if (isPartialSpecialization) {
     bool MirrorsPrimaryTemplate;
     if (CheckClassTemplatePartialSpecializationArgs(
@@ -3760,30 +3755,22 @@
       Diag(TemplateNameLoc, diag::err_partial_spec_fully_specialized)
         << ClassTemplate->getDeclName();
       isPartialSpecialization = false;
-    } else {
-      // FIXME: Template parameter list matters, too
-      ClassTemplatePartialSpecializationDecl::Profile(ID,
-                                                  Converted.getFlatArguments(),
-                                                      Converted.flatSize(),
-                                                      Context);
     }
   }
-  
-  if (!isPartialSpecialization)
-    ClassTemplateSpecializationDecl::Profile(ID,
-                                             Converted.getFlatArguments(),
-                                             Converted.flatSize(),
-                                             Context);
+
   void *InsertPos = 0;
   ClassTemplateSpecializationDecl *PrevDecl = 0;
 
   if (isPartialSpecialization)
+    // FIXME: Template parameter list matters, too
     PrevDecl
-      = ClassTemplate->getPartialSpecializations().FindNodeOrInsertPos(ID,
-                                                                    InsertPos);
+      = ClassTemplate->findPartialSpecialization(Converted.getFlatArguments(),
+                                                 Converted.flatSize(),
+                                                 InsertPos);
   else
     PrevDecl
-      = ClassTemplate->getSpecializations().FindNodeOrInsertPos(ID, InsertPos);
+      = ClassTemplate->findSpecialization(Converted.getFlatArguments(),
+                                          Converted.flatSize(), InsertPos);
 
   ClassTemplateSpecializationDecl *Specialization = 0;
 
@@ -3821,7 +3808,7 @@
     ClassTemplatePartialSpecializationDecl *PrevPartial
       = cast_or_null<ClassTemplatePartialSpecializationDecl>(PrevDecl);
     unsigned SequenceNumber = PrevPartial? PrevPartial->getSequenceNumber()
-                            : ClassTemplate->getPartialSpecializations().size();
+                            : ClassTemplate->getNextPartialSpecSequenceNumber();
     ClassTemplatePartialSpecializationDecl *Partial
       = ClassTemplatePartialSpecializationDecl::Create(Context, Kind,
                                              ClassTemplate->getDeclContext(),
@@ -3840,12 +3827,8 @@
                     (TemplateParameterList**) TemplateParameterLists.release());
     }
 
-    if (PrevPartial) {
-      ClassTemplate->getPartialSpecializations().RemoveNode(PrevPartial);
-      ClassTemplate->getPartialSpecializations().GetOrInsertNode(Partial);
-    } else {
-      ClassTemplate->getPartialSpecializations().InsertNode(Partial, InsertPos);
-    }
+    if (!PrevPartial)
+      ClassTemplate->AddPartialSpecialization(Partial, InsertPos);
     Specialization = Partial;
 
     // If we are providing an explicit specialization of a member class 
@@ -3902,13 +3885,8 @@
                     (TemplateParameterList**) TemplateParameterLists.release());
     }
 
-    if (PrevDecl) {
-      ClassTemplate->getSpecializations().RemoveNode(PrevDecl);
-      ClassTemplate->getSpecializations().GetOrInsertNode(Specialization);
-    } else {
-      ClassTemplate->getSpecializations().InsertNode(Specialization,
-                                                     InsertPos);
-    }
+    if (!PrevDecl)
+      ClassTemplate->AddSpecialization(Specialization, InsertPos);
 
     CanonType = Context.getTypeDeclType(Specialization);
   }
@@ -4701,14 +4679,10 @@
 
   // Find the class template specialization declaration that
   // corresponds to these arguments.
-  llvm::FoldingSetNodeID ID;
-  ClassTemplateSpecializationDecl::Profile(ID,
-                                           Converted.getFlatArguments(),
-                                           Converted.flatSize(),
-                                           Context);
   void *InsertPos = 0;
   ClassTemplateSpecializationDecl *PrevDecl
-    = ClassTemplate->getSpecializations().FindNodeOrInsertPos(ID, InsertPos);
+    = ClassTemplate->findSpecialization(Converted.getFlatArguments(),
+                                        Converted.flatSize(), InsertPos);
 
   TemplateSpecializationKind PrevDecl_TSK
     = PrevDecl ? PrevDecl->getTemplateSpecializationKind() : TSK_Undeclared;
@@ -4761,15 +4735,9 @@
                                                 Converted, PrevDecl);
     SetNestedNameSpecifier(Specialization, SS);
 
-    if (!HasNoEffect) {
-      if (PrevDecl) {
-        // Remove the previous declaration from the folding set, since we want
-        // to introduce a new declaration.
-        ClassTemplate->getSpecializations().RemoveNode(PrevDecl);
-        ClassTemplate->getSpecializations().FindNodeOrInsertPos(ID, InsertPos);
-      }
+    if (!HasNoEffect && !PrevDecl) {
       // Insert the new specialization.
-      ClassTemplate->getSpecializations().InsertNode(Specialization, InsertPos);
+      ClassTemplate->AddSpecialization(Specialization, InsertPos);
     }
   }
 
diff --git a/lib/Sema/SemaTemplateInstantiateDecl.cpp b/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 79b5532..7e06175 100644
--- a/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -857,16 +857,7 @@
   if (!InstClassTemplate)
     return 0;
   
-  Decl *DCanon = D->getCanonicalDecl();
-  for (llvm::FoldingSet<ClassTemplatePartialSpecializationDecl>::iterator
-            P = InstClassTemplate->getPartialSpecializations().begin(),
-         PEnd = InstClassTemplate->getPartialSpecializations().end();
-       P != PEnd; ++P) {
-    if (P->getInstantiatedFromMember()->getCanonicalDecl() == DCanon)
-      return &*P;
-  }
-  
-  return 0;
+  return InstClassTemplate->findPartialSpecInstantiatedFromMember(D);
 }
 
 Decl *
@@ -1804,15 +1795,10 @@
 
   // Figure out where to insert this class template partial specialization
   // in the member template's set of class template partial specializations.
-  llvm::FoldingSetNodeID ID;
-  ClassTemplatePartialSpecializationDecl::Profile(ID,
-                                                  Converted.getFlatArguments(),
-                                                  Converted.flatSize(),
-                                                  SemaRef.Context);
   void *InsertPos = 0;
   ClassTemplateSpecializationDecl *PrevDecl
-    = ClassTemplate->getPartialSpecializations().FindNodeOrInsertPos(ID,
-                                                                     InsertPos);
+    = ClassTemplate->findPartialSpecialization(Converted.getFlatArguments(),
+                                                Converted.flatSize(), InsertPos);
   
   // Build the canonical type that describes the converted template
   // arguments of the class template partial specialization.
@@ -1871,7 +1857,7 @@
                                                      InstTemplateArgs,
                                                      CanonType,
                                                      0,
-                             ClassTemplate->getPartialSpecializations().size());
+                             ClassTemplate->getNextPartialSpecSequenceNumber());
   // Substitute the nested name specifier, if any.
   if (SubstQualifier(PartialSpec, InstPartialSpec))
     return 0;
@@ -1881,8 +1867,7 @@
   
   // Add this partial specialization to the set of class template partial
   // specializations.
-  ClassTemplate->getPartialSpecializations().InsertNode(InstPartialSpec,
-                                                        InsertPos);
+  ClassTemplate->AddPartialSpecialization(InstPartialSpec, InsertPos);
   return false;
 }