Implement AST import support for class template specializations.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@120523 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/AST/ASTImporter.cpp b/lib/AST/ASTImporter.cpp
index dc0aeaa..0628fea 100644
--- a/lib/AST/ASTImporter.cpp
+++ b/lib/AST/ASTImporter.cpp
@@ -69,7 +69,7 @@
     QualType VisitEnumType(EnumType *T);
     // FIXME: TemplateTypeParmType
     // FIXME: SubstTemplateTypeParmType
-    // FIXME: TemplateSpecializationType
+    QualType VisitTemplateSpecializationType(TemplateSpecializationType *T);
     QualType VisitElaboratedType(ElaboratedType *T);
     // FIXME: DependentNameType
     // FIXME: DependentTemplateSpecializationType
@@ -84,8 +84,13 @@
     void ImportDeclarationNameLoc(const DeclarationNameInfo &From,
                                   DeclarationNameInfo& To);
     void ImportDeclContext(DeclContext *FromDC);
+    bool ImportDefinition(RecordDecl *From, RecordDecl *To);
     TemplateParameterList *ImportTemplateParameterList(
                                                  TemplateParameterList *Params);
+    TemplateArgument ImportTemplateArgument(const TemplateArgument &From);
+    bool ImportTemplateArguments(const TemplateArgument *FromArgs,
+                                 unsigned NumFromArgs,
+                               llvm::SmallVectorImpl<TemplateArgument> &ToArgs);
     bool IsStructuralMatch(RecordDecl *FromRecord, RecordDecl *ToRecord);
     bool IsStructuralMatch(EnumDecl *FromEnum, EnumDecl *ToRecord);
     bool IsStructuralMatch(ClassTemplateDecl *From, ClassTemplateDecl *To);
@@ -117,6 +122,8 @@
     Decl *VisitNonTypeTemplateParmDecl(NonTypeTemplateParmDecl *D);
     Decl *VisitTemplateTemplateParmDecl(TemplateTemplateParmDecl *D);
     Decl *VisitClassTemplateDecl(ClassTemplateDecl *D);
+    Decl *VisitClassTemplateSpecializationDecl(
+                                            ClassTemplateSpecializationDecl *D);
                             
     // Importing statements
     Stmt *VisitStmt(Stmt *S);
@@ -267,7 +274,49 @@
 static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
                                      const TemplateArgument &Arg1,
                                      const TemplateArgument &Arg2) {
-  // FIXME: Implement!
+  if (Arg1.getKind() != Arg2.getKind())
+    return false;
+
+  switch (Arg1.getKind()) {
+  case TemplateArgument::Null:
+    return true;
+      
+  case TemplateArgument::Type:
+    return Context.IsStructurallyEquivalent(Arg1.getAsType(), Arg2.getAsType());
+      
+  case TemplateArgument::Integral:
+    if (!Context.IsStructurallyEquivalent(Arg1.getIntegralType(), 
+                                          Arg2.getIntegralType()))
+      return false;
+    
+    return IsSameValue(*Arg1.getAsIntegral(), *Arg2.getAsIntegral());
+      
+  case TemplateArgument::Declaration:
+    return Context.IsStructurallyEquivalent(Arg1.getAsDecl(), Arg2.getAsDecl());
+      
+  case TemplateArgument::Template:
+    return IsStructurallyEquivalent(Context, 
+                                    Arg1.getAsTemplate(), 
+                                    Arg2.getAsTemplate());
+      
+  case TemplateArgument::Expression:
+    return IsStructurallyEquivalent(Context, 
+                                    Arg1.getAsExpr(), Arg2.getAsExpr());
+      
+  case TemplateArgument::Pack:
+    if (Arg1.pack_size() != Arg2.pack_size())
+      return false;
+      
+    for (unsigned I = 0, N = Arg1.pack_size(); I != N; ++I)
+      if (!IsStructurallyEquivalent(Context, 
+                                    Arg1.pack_begin()[I],
+                                    Arg2.pack_begin()[I]))
+        return false;
+      
+    return true;
+  }
+  
+  llvm_unreachable("Invalid template argument kind");
   return true;
 }
 
@@ -702,6 +751,33 @@
     return false;
   }
   
+  // If both declarations are class template specializations, we know
+  // the ODR applies, so check the template and template arguments.
+  ClassTemplateSpecializationDecl *Spec1
+    = dyn_cast<ClassTemplateSpecializationDecl>(D1);
+  ClassTemplateSpecializationDecl *Spec2
+    = dyn_cast<ClassTemplateSpecializationDecl>(D2);
+  if (Spec1 && Spec2) {
+    // Check that the specialized templates are the same.
+    if (!IsStructurallyEquivalent(Context, Spec1->getSpecializedTemplate(),
+                                  Spec2->getSpecializedTemplate()))
+      return false;
+    
+    // Check that the template arguments are the same.
+    if (Spec1->getTemplateArgs().size() != Spec2->getTemplateArgs().size())
+      return false;
+    
+    for (unsigned I = 0, N = Spec1->getTemplateArgs().size(); I != N; ++I)
+      if (!IsStructurallyEquivalent(Context, 
+                                    Spec1->getTemplateArgs().get(I),
+                                    Spec2->getTemplateArgs().get(I)))
+        return false;
+  }  
+  // If one is a class template specialization and the other is not, these
+  // structures are diferent.
+  else if (Spec1 || Spec2)
+    return false;
+
   // Compare the definitions of these two records. If either or both are
   // incomplete, we assume that they are equivalent.
   D1 = D1->getDefinition();
@@ -1450,6 +1526,30 @@
   return Importer.getToContext().getTagDeclType(ToDecl);
 }
 
+QualType ASTNodeImporter::VisitTemplateSpecializationType(
+                                                TemplateSpecializationType *T) {
+  TemplateName ToTemplate = Importer.Import(T->getTemplateName());
+  if (ToTemplate.isNull())
+    return QualType();
+  
+  llvm::SmallVector<TemplateArgument, 2> ToTemplateArgs;
+  if (ImportTemplateArguments(T->getArgs(), T->getNumArgs(), ToTemplateArgs))
+    return QualType();
+  
+  QualType ToCanonType;
+  if (!QualType(T, 0).isCanonical()) {
+    QualType FromCanonType 
+      = Importer.getFromContext().getCanonicalType(QualType(T, 0));
+    ToCanonType =Importer.Import(FromCanonType);
+    if (ToCanonType.isNull())
+      return QualType();
+  }
+  return Importer.getToContext().getTemplateSpecializationType(ToTemplate, 
+                                                         ToTemplateArgs.data(), 
+                                                         ToTemplateArgs.size(),
+                                                               ToCanonType);
+}
+
 QualType ASTNodeImporter::VisitElaboratedType(ElaboratedType *T) {
   NestedNameSpecifier *ToQualifier = 0;
   // Note: the qualifier in an ElaboratedType is optional.
@@ -1576,6 +1676,43 @@
     Importer.Import(*From);
 }
 
+bool ASTNodeImporter::ImportDefinition(RecordDecl *From, RecordDecl *To) {
+  if (To->getDefinition())
+    return false;
+  
+  To->startDefinition();
+  
+  // Add base classes.
+  if (CXXRecordDecl *ToCXX = dyn_cast<CXXRecordDecl>(To)) {
+    CXXRecordDecl *FromCXX = cast<CXXRecordDecl>(From);
+    
+    llvm::SmallVector<CXXBaseSpecifier *, 4> Bases;
+    for (CXXRecordDecl::base_class_iterator 
+                  Base1 = FromCXX->bases_begin(),
+            FromBaseEnd = FromCXX->bases_end();
+         Base1 != FromBaseEnd;
+         ++Base1) {
+      QualType T = Importer.Import(Base1->getType());
+      if (T.isNull())
+        return false;
+      
+      Bases.push_back(
+                    new (Importer.getToContext()) 
+                      CXXBaseSpecifier(Importer.Import(Base1->getSourceRange()),
+                                       Base1->isVirtual(),
+                                       Base1->isBaseOfClass(),
+                                       Base1->getAccessSpecifierAsWritten(),
+                                       Importer.Import(Base1->getTypeSourceInfo())));
+    }
+    if (!Bases.empty())
+      ToCXX->setBases(Bases.data(), Bases.size());
+  }
+  
+  ImportDeclContext(From);
+  To->completeDefinition();
+  return true;
+}
+
 TemplateParameterList *ASTNodeImporter::ImportTemplateParameterList(
                                                 TemplateParameterList *Params) {
   llvm::SmallVector<NamedDecl *, 4> ToParams;
@@ -1597,6 +1734,75 @@
                                        Importer.Import(Params->getRAngleLoc()));
 }
 
+TemplateArgument 
+ASTNodeImporter::ImportTemplateArgument(const TemplateArgument &From) {
+  switch (From.getKind()) {
+  case TemplateArgument::Null:
+    return TemplateArgument();
+     
+  case TemplateArgument::Type: {
+    QualType ToType = Importer.Import(From.getAsType());
+    if (ToType.isNull())
+      return TemplateArgument();
+    return TemplateArgument(ToType);
+  }
+      
+  case TemplateArgument::Integral: {
+    QualType ToType = Importer.Import(From.getIntegralType());
+    if (ToType.isNull())
+      return TemplateArgument();
+    return TemplateArgument(*From.getAsIntegral(), ToType);
+  }
+
+  case TemplateArgument::Declaration:
+    if (Decl *To = Importer.Import(From.getAsDecl()))
+      return TemplateArgument(To);
+    return TemplateArgument();
+      
+  case TemplateArgument::Template: {
+    TemplateName ToTemplate = Importer.Import(From.getAsTemplate());
+    if (ToTemplate.isNull())
+      return TemplateArgument();
+    
+    return TemplateArgument(ToTemplate);
+  }
+      
+  case TemplateArgument::Expression:
+    if (Expr *ToExpr = Importer.Import(From.getAsExpr()))
+      return TemplateArgument(ToExpr);
+    return TemplateArgument();
+      
+  case TemplateArgument::Pack: {
+    llvm::SmallVector<TemplateArgument, 2> ToPack;
+    ToPack.reserve(From.pack_size());
+    if (ImportTemplateArguments(From.pack_begin(), From.pack_size(), ToPack))
+      return TemplateArgument();
+    
+    TemplateArgument *ToArgs 
+      = new (Importer.getToContext()) TemplateArgument[ToPack.size()];
+    std::copy(ToPack.begin(), ToPack.end(), ToArgs);
+    return TemplateArgument(ToArgs, ToPack.size());
+  }
+  }
+  
+  llvm_unreachable("Invalid template argument kind");
+  return TemplateArgument();
+}
+
+bool ASTNodeImporter::ImportTemplateArguments(const TemplateArgument *FromArgs,
+                                              unsigned NumFromArgs,
+                              llvm::SmallVectorImpl<TemplateArgument> &ToArgs) {
+  for (unsigned I = 0; I != NumFromArgs; ++I) {
+    TemplateArgument To = ImportTemplateArgument(FromArgs[I]);
+    if (To.isNull() && !FromArgs[I].isNull())
+      return true;
+    
+    ToArgs.push_back(To);
+  }
+  
+  return false;
+}
+
 bool ASTNodeImporter::IsStructuralMatch(RecordDecl *FromRecord, 
                                         RecordDecl *ToRecord) {
   StructuralEquivalenceContext Ctx(Importer.getFromContext(),
@@ -1939,38 +2145,8 @@
   
   Importer.Imported(D, D2);
 
-  if (D->isDefinition()) {
-    D2->startDefinition();
-
-    // Add base classes.
-    if (CXXRecordDecl *D2CXX = dyn_cast<CXXRecordDecl>(D2)) {
-      CXXRecordDecl *D1CXX = cast<CXXRecordDecl>(D);
-
-      llvm::SmallVector<CXXBaseSpecifier *, 4> Bases;
-      for (CXXRecordDecl::base_class_iterator 
-                Base1 = D1CXX->bases_begin(),
-             FromBaseEnd = D1CXX->bases_end();
-           Base1 != FromBaseEnd;
-           ++Base1) {
-        QualType T = Importer.Import(Base1->getType());
-        if (T.isNull())
-          return 0;
-          
-        Bases.push_back(
-          new (Importer.getToContext()) 
-                CXXBaseSpecifier(Importer.Import(Base1->getSourceRange()),
-                                 Base1->isVirtual(),
-                                 Base1->isBaseOfClass(),
-                                 Base1->getAccessSpecifierAsWritten(),
-                                 Importer.Import(Base1->getTypeSourceInfo())));
-      }
-      if (!Bases.empty())
-        D2CXX->setBases(Bases.data(), Bases.size());
-    }
-
-    ImportDeclContext(D);
-    D2->completeDefinition();
-  }
+  if (D->isDefinition() && ImportDefinition(D, D2))
+    return 0;
   
   return D2;
 }
@@ -3166,6 +3342,100 @@
   return D2;
 }
 
+Decl *ASTNodeImporter::VisitClassTemplateSpecializationDecl(
+                                          ClassTemplateSpecializationDecl *D) {
+  // If this record has a definition in the translation unit we're coming from,
+  // but this particular declaration is not that definition, import the
+  // definition and map to that.
+  TagDecl *Definition = D->getDefinition();
+  if (Definition && Definition != D) {
+    Decl *ImportedDef = Importer.Import(Definition);
+    if (!ImportedDef)
+      return 0;
+    
+    return Importer.Imported(D, ImportedDef);
+  }
+
+  ClassTemplateDecl *ClassTemplate
+    = cast_or_null<ClassTemplateDecl>(Importer.Import(
+                                                 D->getSpecializedTemplate()));
+  if (!ClassTemplate)
+    return 0;
+  
+  // Import the context of this declaration.
+  DeclContext *DC = ClassTemplate->getDeclContext();
+  if (!DC)
+    return 0;
+  
+  DeclContext *LexicalDC = DC;
+  if (D->getDeclContext() != D->getLexicalDeclContext()) {
+    LexicalDC = Importer.ImportContext(D->getLexicalDeclContext());
+    if (!LexicalDC)
+      return 0;
+  }
+  
+  // Import the location of this declaration.
+  SourceLocation Loc = Importer.Import(D->getLocation());
+
+  // Import template arguments.
+  llvm::SmallVector<TemplateArgument, 2> TemplateArgs;
+  if (ImportTemplateArguments(D->getTemplateArgs().data(), 
+                              D->getTemplateArgs().size(),
+                              TemplateArgs))
+    return 0;
+  
+  // Try to find an existing specialization with these template arguments.
+  void *InsertPos = 0;
+  ClassTemplateSpecializationDecl *D2
+    = ClassTemplate->findSpecialization(TemplateArgs.data(), 
+                                        TemplateArgs.size(), InsertPos);
+  if (D2) {
+    // We already have a class template specialization with these template
+    // arguments.
+    
+    // FIXME: Check for specialization vs. instantiation errors.
+    
+    if (RecordDecl *FoundDef = D2->getDefinition()) {
+      if (!D->isDefinition() || IsStructuralMatch(D, FoundDef)) {
+        // The record types structurally match, or the "from" translation
+        // unit only had a forward declaration anyway; call it the same
+        // function.
+        return Importer.Imported(D, FoundDef);
+      }
+    }
+  } else {
+    // Create a new specialization.
+    D2 = ClassTemplateSpecializationDecl::Create(Importer.getToContext(), 
+                                                 D->getTagKind(), DC, 
+                                                 Loc, ClassTemplate,
+                                                 TemplateArgs.data(), 
+                                                 TemplateArgs.size(), 
+                                                 /*PrevDecl=*/0);
+    D2->setSpecializationKind(D->getSpecializationKind());
+
+    // Add this specialization to the class template.
+    ClassTemplate->AddSpecialization(D2, InsertPos);
+    
+    // Import the qualifier, if any.
+    if (D->getQualifier()) {
+      NestedNameSpecifier *NNS = Importer.Import(D->getQualifier());
+      SourceRange NNSRange = Importer.Import(D->getQualifierRange());
+      D2->setQualifierInfo(NNS, NNSRange);
+    }
+
+    
+    // Add the specialization to this context.
+    D2->setLexicalDeclContext(LexicalDC);
+    LexicalDC->addDecl(D2);
+  }
+  Importer.Imported(D, D2);
+  
+  if (D->isDefinition() && ImportDefinition(D, D2))
+    return 0;
+  
+  return D2;
+}
+
 //----------------------------------------------------------------------------
 // Import Statements
 //----------------------------------------------------------------------------
@@ -3506,6 +3776,64 @@
   return 0;
 }
 
+TemplateName ASTImporter::Import(TemplateName From) {
+  switch (From.getKind()) {
+  case TemplateName::Template:
+    if (TemplateDecl *ToTemplate
+                = cast_or_null<TemplateDecl>(Import(From.getAsTemplateDecl())))
+      return TemplateName(ToTemplate);
+      
+    return TemplateName();
+      
+  case TemplateName::OverloadedTemplate: {
+    OverloadedTemplateStorage *FromStorage = From.getAsOverloadedTemplate();
+    UnresolvedSet<2> ToTemplates;
+    for (OverloadedTemplateStorage::iterator I = FromStorage->begin(),
+                                             E = FromStorage->end();
+         I != E; ++I) {
+      if (NamedDecl *To = cast_or_null<NamedDecl>(Import(*I))) 
+        ToTemplates.addDecl(To);
+      else
+        return TemplateName();
+    }
+    return ToContext.getOverloadedTemplateName(ToTemplates.begin(), 
+                                               ToTemplates.end());
+  }
+      
+  case TemplateName::QualifiedTemplate: {
+    QualifiedTemplateName *QTN = From.getAsQualifiedTemplateName();
+    NestedNameSpecifier *Qualifier = Import(QTN->getQualifier());
+    if (!Qualifier)
+      return TemplateName();
+    
+    if (TemplateDecl *ToTemplate
+        = cast_or_null<TemplateDecl>(Import(From.getAsTemplateDecl())))
+      return ToContext.getQualifiedTemplateName(Qualifier, 
+                                                QTN->hasTemplateKeyword(), 
+                                                ToTemplate);
+    
+    return TemplateName();
+  }
+  
+  case TemplateName::DependentTemplate: {
+    DependentTemplateName *DTN = From.getAsDependentTemplateName();
+    NestedNameSpecifier *Qualifier = Import(DTN->getQualifier());
+    if (!Qualifier)
+      return TemplateName();
+    
+    if (DTN->isIdentifier()) {
+      return ToContext.getDependentTemplateName(Qualifier, 
+                                                Import(DTN->getIdentifier()));
+    }
+    
+    return ToContext.getDependentTemplateName(Qualifier, DTN->getOperator());
+  }
+  }
+  
+  llvm_unreachable("Invalid template name kind");
+  return TemplateName();
+}
+
 SourceLocation ASTImporter::Import(SourceLocation FromLoc) {
   if (FromLoc.isInvalid())
     return SourceLocation();
@@ -3623,7 +3951,7 @@
   return DeclarationName();
 }
 
-IdentifierInfo *ASTImporter::Import(IdentifierInfo *FromId) {
+IdentifierInfo *ASTImporter::Import(const IdentifierInfo *FromId) {
   if (!FromId)
     return 0;