Make the InjectedClassNameType the canonical type of the current instantiation
of a class template or class template partial specialization.  That is to
say, in
  template <class T> class A { ... };
or
  template <class T> class B<const T*> { ... };
make 'A<T>' and 'B<const T*>' sugar for the corresponding InjectedClassNameType
when written inside the appropriate context.  This allows us to track the
current instantiation appropriately even inside AST routines.  It also allows
us to compute a DeclContext for a type much more efficiently, at some extra
cost every time we write a template specialization (which can be optimized,
but I've left it simple in this patch).



git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@102407 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/Sema.h b/lib/Sema/Sema.h
index 3ea9bd0..212a36f 100644
--- a/lib/Sema/Sema.h
+++ b/lib/Sema/Sema.h
@@ -3047,6 +3047,7 @@
 
   QualType RebuildTypeInCurrentInstantiation(QualType T, SourceLocation Loc,
                                              DeclarationName Name);
+  void RebuildNestedNameSpecifierInCurrentInstantiation(CXXScopeSpec &SS);
 
   std::string
   getTemplateArgumentBindingsText(const TemplateParameterList *Params,
diff --git a/lib/Sema/SemaCXXScopeSpec.cpp b/lib/Sema/SemaCXXScopeSpec.cpp
index 89f8aec..10adc67 100644
--- a/lib/Sema/SemaCXXScopeSpec.cpp
+++ b/lib/Sema/SemaCXXScopeSpec.cpp
@@ -24,61 +24,17 @@
 using namespace clang;
 
 /// \brief Find the current instantiation that associated with the given type.
-static CXXRecordDecl *
-getCurrentInstantiationOf(ASTContext &Context, DeclContext *CurContext, 
-                          QualType T) {
+static CXXRecordDecl *getCurrentInstantiationOf(QualType T) {
   if (T.isNull())
     return 0;
-  
-  T = Context.getCanonicalType(T).getUnqualifiedType();
-  
-  for (DeclContext *Ctx = CurContext; Ctx; Ctx = Ctx->getLookupParent()) {
-    // If we've hit a namespace or the global scope, then the
-    // nested-name-specifier can't refer to the current instantiation.
-    if (Ctx->isFileContext())
-      return 0;
-    
-    // Skip non-class contexts.
-    CXXRecordDecl *Record = dyn_cast<CXXRecordDecl>(Ctx);
-    if (!Record)
-      continue;
-    
-    // If this record type is not dependent,
-    if (!Record->isDependentType())
-      return 0;
-    
-    // C++ [temp.dep.type]p1:
-    //
-    //   In the definition of a class template, a nested class of a
-    //   class template, a member of a class template, or a member of a
-    //   nested class of a class template, a name refers to the current
-    //   instantiation if it is
-    //     -- the injected-class-name (9) of the class template or
-    //        nested class,
-    //     -- in the definition of a primary class template, the name
-    //        of the class template followed by the template argument
-    //        list of the primary template (as described below)
-    //        enclosed in <>,
-    //     -- in the definition of a nested class of a class template,
-    //        the name of the nested class referenced as a member of
-    //        the current instantiation, or
-    //     -- in the definition of a partial specialization, the name
-    //        of the class template followed by the template argument
-    //        list of the partial specialization enclosed in <>. If
-    //        the nth template parameter is a parameter pack, the nth
-    //        template argument is a pack expansion (14.6.3) whose
-    //        pattern is the name of the parameter pack.
-    //        (FIXME: parameter packs)
-    //
-    // All of these options come down to having the
-    // nested-name-specifier type that is equivalent to the
-    // injected-class-name of one of the types that is currently in
-    // our context.
-    if (Context.getCanonicalType(Context.getTypeDeclType(Record)) == T)
-      return Record;
-  }  
-  
-  return 0;
+
+  const Type *Ty = T->getCanonicalTypeInternal().getTypePtr();
+  if (isa<RecordType>(Ty))
+    return cast<CXXRecordDecl>(cast<RecordType>(Ty)->getDecl());
+  else if (isa<InjectedClassNameType>(Ty))
+    return cast<InjectedClassNameType>(Ty)->getDecl();
+  else
+    return 0;
 }
 
 /// \brief Compute the DeclContext that is associated with the given type.
@@ -92,7 +48,7 @@
   if (const TagType *Tag = T->getAs<TagType>())
     return Tag->getDecl();
 
-  return ::getCurrentInstantiationOf(Context, CurContext, T);
+  return ::getCurrentInstantiationOf(T);
 }
 
 /// \brief Compute the DeclContext that is associated with the given
@@ -218,7 +174,7 @@
     return 0;
 
   QualType T = QualType(NNS->getAsType(), 0);
-  return ::getCurrentInstantiationOf(Context, CurContext, T);
+  return ::getCurrentInstantiationOf(T);
 }
 
 /// \brief Require that the context specified by SS be complete.
@@ -704,6 +660,11 @@
     return true;
     
   EnterDeclaratorContext(S, DC);
+
+  // Rebuild the nested name specifier for the new scope.
+  if (DC->isDependentContext())
+    RebuildNestedNameSpecifierInCurrentInstantiation(SS);
+
   return false;
 }
 
diff --git a/lib/Sema/SemaTemplate.cpp b/lib/Sema/SemaTemplate.cpp
index 958ed44..731836b 100644
--- a/lib/Sema/SemaTemplate.cpp
+++ b/lib/Sema/SemaTemplate.cpp
@@ -1344,27 +1344,25 @@
 
     // Check the template parameter list against its corresponding template-id.
     if (DependentTemplateId) {
-      TemplateDecl *Template
-        = TemplateIdsInSpecifier[Idx]->getTemplateName().getAsTemplateDecl();
+      TemplateParameterList *ExpectedTemplateParams = 0;
 
-      if (ClassTemplateDecl *ClassTemplate
-            = dyn_cast<ClassTemplateDecl>(Template)) {
-        TemplateParameterList *ExpectedTemplateParams = 0;
-        // Is this template-id naming the primary template?
-        if (Context.hasSameType(TemplateId,
-                 ClassTemplate->getInjectedClassNameSpecialization(Context)))
-          ExpectedTemplateParams = ClassTemplate->getTemplateParameters();
-        // ... or a partial specialization?
-        else if (ClassTemplatePartialSpecializationDecl *PartialSpec
-                   = ClassTemplate->findPartialSpecialization(TemplateId))
-          ExpectedTemplateParams = PartialSpec->getTemplateParameters();
-
-        if (ExpectedTemplateParams)
-          TemplateParameterListsAreEqual(ParamLists[Idx],
-                                         ExpectedTemplateParams,
-                                         true, TPL_TemplateMatch);
+      // Are there cases in (e.g.) friends where this won't match?
+      if (const InjectedClassNameType *Injected
+            = TemplateId->getAs<InjectedClassNameType>()) {
+        CXXRecordDecl *Record = Injected->getDecl();
+        if (ClassTemplatePartialSpecializationDecl *Partial =
+              dyn_cast<ClassTemplatePartialSpecializationDecl>(Record))
+          ExpectedTemplateParams = Partial->getTemplateParameters();
+        else
+          ExpectedTemplateParams = Record->getDescribedClassTemplate()
+            ->getTemplateParameters();
       }
 
+      if (ExpectedTemplateParams)
+        TemplateParameterListsAreEqual(ParamLists[Idx],
+                                       ExpectedTemplateParams,
+                                       true, TPL_TemplateMatch);
+
       CheckTemplateParameterList(ParamLists[Idx], 0, TPC_ClassTemplateMember);
     } else if (ParamLists[Idx]->size() > 0)
       Diag(ParamLists[Idx]->getTemplateLoc(),
@@ -1430,6 +1428,7 @@
          "Converted template argument list is too short!");
 
   QualType CanonType;
+  bool IsCurrentInstantiation = false;
 
   if (Name.isDependent() ||
       TemplateSpecializationType::anyDependentTemplateArguments(
@@ -1451,6 +1450,45 @@
     // In the future, we need to teach getTemplateSpecializationType to only
     // build the canonical type and return that to us.
     CanonType = Context.getCanonicalType(CanonType);
+
+    // This might work out to be a current instantiation, in which
+    // case the canonical type needs to be the InjectedClassNameType.
+    //
+    // TODO: in theory this could be a simple hashtable lookup; most
+    // changes to CurContext don't change the set of current
+    // instantiations.
+    if (isa<ClassTemplateDecl>(Template)) {
+      for (DeclContext *Ctx = CurContext; Ctx; Ctx = Ctx->getLookupParent()) {
+        // If we get out to a namespace, we're done.
+        if (Ctx->isFileContext()) break;
+
+        // If this isn't a record, keep looking.
+        CXXRecordDecl *Record = dyn_cast<CXXRecordDecl>(Ctx);
+        if (!Record) continue;
+
+        // Look for one of the two cases with InjectedClassNameTypes
+        // and check whether it's the same template.
+        if (!isa<ClassTemplatePartialSpecializationDecl>(Record) &&
+            !Record->getDescribedClassTemplate())
+          continue;
+          
+        // Fetch the injected class name type and check whether its
+        // injected type is equal to the type we just built.
+        QualType ICNT = Context.getTypeDeclType(Record);
+        QualType Injected = cast<InjectedClassNameType>(ICNT)
+          ->getInjectedSpecializationType();
+
+        if (CanonType != Injected->getCanonicalTypeInternal())
+          continue;
+
+        // If so, the canonical type of this TST is the injected
+        // class name type of the record we just found.
+        assert(ICNT.isCanonical());
+        CanonType = ICNT;
+        IsCurrentInstantiation = true;
+        break;
+      }
+    }
   } else if (ClassTemplateDecl *ClassTemplate
                = dyn_cast<ClassTemplateDecl>(Template)) {
     // Find the class template specialization declaration that
@@ -1484,7 +1522,8 @@
   // Build the fully-sugared type for this class template
   // specialization, which refers back to the class template
   // specialization we created or found.
-  return Context.getTemplateSpecializationType(Name, TemplateArgs, CanonType);
+  return Context.getTemplateSpecializationType(Name, TemplateArgs, CanonType,
+                                               IsCurrentInstantiation);
 }
 
 Action::TypeResult
@@ -5389,6 +5428,17 @@
   return Rebuilder.TransformType(T);
 }
 
+void Sema::RebuildNestedNameSpecifierInCurrentInstantiation(CXXScopeSpec &SS) {
+  if (SS.isInvalid()) return;
+
+  NestedNameSpecifier *NNS = static_cast<NestedNameSpecifier*>(SS.getScopeRep());
+  CurrentInstantiationRebuilder Rebuilder(*this, SS.getRange().getBegin(),
+                                          DeclarationName());
+  NestedNameSpecifier *Rebuilt = 
+    Rebuilder.TransformNestedNameSpecifier(NNS, SS.getRange());
+  if (Rebuilt) SS.setScopeRep(Rebuilt);
+}
+
 /// \brief Produces a formatted string that describes the binding of
 /// template parameters to template arguments.
 std::string
diff --git a/lib/Sema/SemaTemplateDeduction.cpp b/lib/Sema/SemaTemplateDeduction.cpp
index d61a767..7154d62 100644
--- a/lib/Sema/SemaTemplateDeduction.cpp
+++ b/lib/Sema/SemaTemplateDeduction.cpp
@@ -638,7 +638,8 @@
     case Type::InjectedClassName: {
       // Treat a template's injected-class-name as if the template
       // specialization type had been used.
-      Param = cast<InjectedClassNameType>(Param)->getUnderlyingType();
+      Param = cast<InjectedClassNameType>(Param)
+        ->getInjectedSpecializationType();
       assert(isa<TemplateSpecializationType>(Param) &&
              "injected class name is not a template specialization type");
       // fall through
@@ -2340,13 +2341,16 @@
   // are more constrained. We know that every template parameter is deduc
   llvm::SmallVector<DeducedTemplateArgument, 4> Deduced;
   Sema::TemplateDeductionInfo Info(Context, Loc);
+
+  QualType PT1 = PS1->getInjectedSpecializationType();
+  QualType PT2 = PS2->getInjectedSpecializationType();
   
   // Determine whether PS1 is at least as specialized as PS2
   Deduced.resize(PS2->getTemplateParameters()->size());
   bool Better1 = !DeduceTemplateArgumentsDuringPartialOrdering(*this,
                                                   PS2->getTemplateParameters(),
-                                                  Context.getTypeDeclType(PS2),
-                                                  Context.getTypeDeclType(PS1),
+                                                               PT2,
+                                                               PT1,
                                                                Info,
                                                                Deduced,
                                                                0);
@@ -2356,8 +2360,8 @@
   Deduced.resize(PS1->getTemplateParameters()->size());
   bool Better2 = !DeduceTemplateArgumentsDuringPartialOrdering(*this,
                                                   PS1->getTemplateParameters(),
-                                                  Context.getTypeDeclType(PS1),
-                                                  Context.getTypeDeclType(PS2),
+                                                               PT1,
+                                                               PT2,
                                                                Info,
                                                                Deduced,
                                                                0);
@@ -2537,6 +2541,10 @@
     break;
   }
 
+  case Type::InjectedClassName:
+    T = cast<InjectedClassNameType>(T)->getInjectedSpecializationType();
+    // fall through
+
   case Type::TemplateSpecialization: {
     const TemplateSpecializationType *Spec
       = cast<TemplateSpecializationType>(T);
diff --git a/lib/Sema/SemaTemplateInstantiateDecl.cpp b/lib/Sema/SemaTemplateInstantiateDecl.cpp
index d14ea1a..4575d47 100644
--- a/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -2500,7 +2500,7 @@
       T = Context.getTypeDeclType(Record);
       assert(isa<InjectedClassNameType>(T) &&
              "type of partial specialization is not an InjectedClassNameType");
-      T = cast<InjectedClassNameType>(T)->getUnderlyingType();
+      T = cast<InjectedClassNameType>(T)->getInjectedSpecializationType();
     }  
     
     if (!T.isNull()) {