[ASTImporter] Support LambdaExprs and improve template support

Also, a number of style and bug fixes was done:

 *  ASTImporterTest: added sanity check for source node
 *  ExternalASTMerger: better lookup for template specializations
 *  ASTImporter: don't add templated declarations into DeclContext
 *  ASTImporter: introduce a helper, ImportTemplateArgumentListInfo getting SourceLocations
 *  ASTImporter: proper set ParmVarDecls for imported FunctionProtoTypeLoc

Differential Revision: https://reviews.llvm.org/D42301

llvm-svn: 323519
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index 27b6ff1..aea044c 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -97,6 +97,8 @@
     typedef DesignatedInitExpr::Designator Designator;
     Designator ImportDesignator(const Designator &D);
 
+    Optional<LambdaCapture> ImportLambdaCapture(const LambdaCapture &From);
+
                         
     /// \brief What we should import from the definition.
     enum ImportDefinitionKind { 
@@ -127,16 +129,26 @@
     bool ImportDefinition(ObjCProtocolDecl *From, ObjCProtocolDecl *To,
                           ImportDefinitionKind Kind = IDK_Default);
     TemplateParameterList *ImportTemplateParameterList(
-                                                 TemplateParameterList *Params);
+        TemplateParameterList *Params);
     TemplateArgument ImportTemplateArgument(const TemplateArgument &From);
     Optional<TemplateArgumentLoc> ImportTemplateArgumentLoc(
         const TemplateArgumentLoc &TALoc);
     bool ImportTemplateArguments(const TemplateArgument *FromArgs,
                                  unsigned NumFromArgs,
-                               SmallVectorImpl<TemplateArgument> &ToArgs);
+                                 SmallVectorImpl<TemplateArgument> &ToArgs);
+
     template <typename InContainerTy>
     bool ImportTemplateArgumentListInfo(const InContainerTy &Container,
                                         TemplateArgumentListInfo &ToTAInfo);
+
+    template<typename InContainerTy>
+    bool ImportTemplateArgumentListInfo(SourceLocation FromLAngleLoc,
+                                        SourceLocation FromRAngleLoc,
+                                        const InContainerTy &Container,
+                                        TemplateArgumentListInfo &Result);
+
+    bool ImportTemplateInformation(FunctionDecl *FromFD, FunctionDecl *ToFD);
+
     bool IsStructuralMatch(RecordDecl *FromRecord, RecordDecl *ToRecord,
                            bool Complain = true);
     bool IsStructuralMatch(VarDecl *FromVar, VarDecl *ToVar,
@@ -295,6 +307,7 @@
     Expr *VisitCXXPseudoDestructorExpr(CXXPseudoDestructorExpr *E);
     Expr *VisitMemberExpr(MemberExpr *E);
     Expr *VisitCallExpr(CallExpr *E);
+    Expr *VisitLambdaExpr(LambdaExpr *LE);
     Expr *VisitInitListExpr(InitListExpr *E);
     Expr *VisitArrayInitLoopExpr(ArrayInitLoopExpr *E);
     Expr *VisitArrayInitIndexExpr(ArrayInitIndexExpr *E);
@@ -1045,7 +1058,6 @@
       = FromData.HasDeclaredCopyConstructorWithConstParam;
     ToData.HasDeclaredCopyAssignmentWithConstParam
       = FromData.HasDeclaredCopyAssignmentWithConstParam;
-    ToData.IsLambda = FromData.IsLambda;
 
     SmallVector<CXXBaseSpecifier *, 4> Bases;
     for (const auto &Base1 : FromCXX->bases()) {
@@ -1256,6 +1268,9 @@
   return false;
 }
 
+// We cannot use Optional<> pattern here and below because
+// TemplateArgumentListInfo's operator new is declared as deleted so it cannot
+// be stored in Optional.
 template <typename InContainerTy>
 bool ASTNodeImporter::ImportTemplateArgumentListInfo(
     const InContainerTy &Container, TemplateArgumentListInfo &ToTAInfo) {
@@ -1268,6 +1283,18 @@
   return false;
 }
 
+template <typename InContainerTy>
+bool ASTNodeImporter::ImportTemplateArgumentListInfo(
+    SourceLocation FromLAngleLoc, SourceLocation FromRAngleLoc,
+    const InContainerTy &Container, TemplateArgumentListInfo &Result) {
+  TemplateArgumentListInfo ToTAInfo(Importer.Import(FromLAngleLoc),
+                                    Importer.Import(FromRAngleLoc));
+  if (ImportTemplateArgumentListInfo(Container, ToTAInfo))
+    return true;
+  Result = ToTAInfo;
+  return false;
+}
+
 bool ASTNodeImporter::IsStructuralMatch(RecordDecl *FromRecord, 
                                         RecordDecl *ToRecord, bool Complain) {
   // Eliminate a potential failure point where we attempt to re-import
@@ -1918,16 +1945,16 @@
         if (DCXX->getLambdaContextDecl() && !CDecl)
           return nullptr;
         D2CXX->setLambdaMangling(DCXX->getLambdaManglingNumber(), CDecl);
-      } else if (DCXX->isInjectedClassName()) {                                                 
-        // We have to be careful to do a similar dance to the one in                            
-        // Sema::ActOnStartCXXMemberDeclarations                                                
-        CXXRecordDecl *const PrevDecl = nullptr;                                                
-        const bool DelayTypeCreation = true;                                                    
-        D2CXX = CXXRecordDecl::Create(                                                          
-            Importer.getToContext(), D->getTagKind(), DC, StartLoc, Loc,                        
-            Name.getAsIdentifierInfo(), PrevDecl, DelayTypeCreation);                           
-        Importer.getToContext().getTypeDeclType(                                                
-            D2CXX, llvm::dyn_cast<CXXRecordDecl>(DC));                                          
+      } else if (DCXX->isInjectedClassName()) {
+        // We have to be careful to do a similar dance to the one in
+        // Sema::ActOnStartCXXMemberDeclarations
+        CXXRecordDecl *const PrevDecl = nullptr;
+        const bool DelayTypeCreation = true;
+        D2CXX = CXXRecordDecl::Create(
+            Importer.getToContext(), D->getTagKind(), DC, StartLoc, Loc,
+            Name.getAsIdentifierInfo(), PrevDecl, DelayTypeCreation);
+        Importer.getToContext().getTypeDeclType(
+            D2CXX, llvm::dyn_cast<CXXRecordDecl>(DC));
       } else {
         D2CXX = CXXRecordDecl::Create(Importer.getToContext(),
                                       D->getTagKind(),
@@ -1936,6 +1963,9 @@
       }
       D2 = D2CXX;
       D2->setAccess(D->getAccess());
+      D2->setLexicalDeclContext(LexicalDC);
+      if (!DCXX->getDescribedClassTemplate())
+        LexicalDC->addDeclInternal(D2);
 
       Importer.Imported(D, D2);
 
@@ -1964,11 +1994,11 @@
     } else {
       D2 = RecordDecl::Create(Importer.getToContext(), D->getTagKind(),
                               DC, StartLoc, Loc, Name.getAsIdentifierInfo());
+      D2->setLexicalDeclContext(LexicalDC);
+      LexicalDC->addDeclInternal(D2);
     }
     
     D2->setQualifierInfo(Importer.Import(D->getQualifierLoc()));
-    D2->setLexicalDeclContext(LexicalDC);
-    LexicalDC->addDeclInternal(D2);
     if (D->isAnonymousStructOrUnion())
       D2->setAnonymousStructOrUnion(true);
     if (PrevDecl) {
@@ -2044,6 +2074,94 @@
   return ToEnumerator;
 }
 
+bool ASTNodeImporter::ImportTemplateInformation(FunctionDecl *FromFD,
+                                                FunctionDecl *ToFD) {
+  switch (FromFD->getTemplatedKind()) {
+  case FunctionDecl::TK_NonTemplate:
+  case FunctionDecl::TK_FunctionTemplate:
+    break;
+
+  case FunctionDecl::TK_MemberSpecialization: {
+    auto *InstFD = cast_or_null<FunctionDecl>(
+          Importer.Import(FromFD->getInstantiatedFromMemberFunction()));
+    if (!InstFD)
+      return true;
+
+    TemplateSpecializationKind TSK = FromFD->getTemplateSpecializationKind();
+    SourceLocation POI = Importer.Import(
+          FromFD->getMemberSpecializationInfo()->getPointOfInstantiation());
+    ToFD->setInstantiationOfMemberFunction(InstFD, TSK);
+    ToFD->getMemberSpecializationInfo()->setPointOfInstantiation(POI);
+    break;
+  }
+
+  case FunctionDecl::TK_FunctionTemplateSpecialization: {
+    auto *FTSInfo = FromFD->getTemplateSpecializationInfo();
+    auto *Template = cast_or_null<FunctionTemplateDecl>(
+        Importer.Import(FTSInfo->getTemplate()));
+    if (!Template)
+      return true;
+    TemplateSpecializationKind TSK = FTSInfo->getTemplateSpecializationKind();
+
+    // Import template arguments.
+    auto TemplArgs = FTSInfo->TemplateArguments->asArray();
+    SmallVector<TemplateArgument, 8> ToTemplArgs;
+    if (ImportTemplateArguments(TemplArgs.data(), TemplArgs.size(),
+                                ToTemplArgs))
+      return true;
+
+    TemplateArgumentList *ToTAList = TemplateArgumentList::CreateCopy(
+          Importer.getToContext(), ToTemplArgs);
+
+    TemplateArgumentListInfo ToTAInfo;
+    const auto *FromTAArgsAsWritten = FTSInfo->TemplateArgumentsAsWritten;
+    if (FromTAArgsAsWritten) {
+      if (ImportTemplateArgumentListInfo(
+              FromTAArgsAsWritten->LAngleLoc, FromTAArgsAsWritten->RAngleLoc,
+              FromTAArgsAsWritten->arguments(), ToTAInfo))
+        return true;
+    }
+
+    SourceLocation POI = Importer.Import(FTSInfo->getPointOfInstantiation());
+
+    ToFD->setFunctionTemplateSpecialization(
+        Template, ToTAList, /* InsertPos= */ nullptr,
+        TSK, FromTAArgsAsWritten ? &ToTAInfo : nullptr, POI);
+    break;
+  }
+
+  case FunctionDecl::TK_DependentFunctionTemplateSpecialization: {
+    auto *FromInfo = FromFD->getDependentSpecializationInfo();
+    UnresolvedSet<8> TemplDecls;
+    unsigned NumTemplates = FromInfo->getNumTemplates();
+    for (unsigned I = 0; I < NumTemplates; I++) {
+      if (auto *ToFTD = cast_or_null<FunctionTemplateDecl>(
+              Importer.Import(FromInfo->getTemplate(I))))
+        TemplDecls.addDecl(ToFTD);
+      else
+        return true;
+    }
+
+    // Import TemplateArgumentListInfo.
+    TemplateArgumentListInfo ToTAInfo;
+    if (ImportTemplateArgumentListInfo(
+            FromInfo->getLAngleLoc(), FromInfo->getRAngleLoc(),
+            llvm::makeArrayRef(FromInfo->getTemplateArgs(),
+                               FromInfo->getNumTemplateArgs()),
+            ToTAInfo))
+      return true;
+
+    ToFD->setDependentTemplateSpecialization(Importer.getToContext(),
+                                             TemplDecls, ToTAInfo);
+    break;
+  }
+  default:
+    llvm_unreachable("All cases should be covered!");
+  }
+
+  return false;
+}
+
 Decl *ASTNodeImporter::VisitFunctionDecl(FunctionDecl *D) {
   // Import the major distinguishing characteristics of this function.
   DeclContext *DC, *LexicalDC;
@@ -2151,15 +2269,18 @@
     Parameters.push_back(ToP);
   }
   
-  // Create the imported function.
   TypeSourceInfo *TInfo = Importer.Import(D->getTypeSourceInfo());
+  if (D->getTypeSourceInfo() && !TInfo)
+    return nullptr;
+
+  // Create the imported function.
   FunctionDecl *ToFunction = nullptr;
   SourceLocation InnerLocStart = Importer.Import(D->getInnerLocStart());
   if (CXXConstructorDecl *FromConstructor = dyn_cast<CXXConstructorDecl>(D)) {
     ToFunction = CXXConstructorDecl::Create(Importer.getToContext(),
                                             cast<CXXRecordDecl>(DC),
                                             InnerLocStart,
-                                            NameInfo, T, TInfo, 
+                                            NameInfo, T, TInfo,
                                             FromConstructor->isExplicit(),
                                             D->isInlineSpecified(), 
                                             D->isImplicit(),
@@ -2225,9 +2346,9 @@
   Importer.Imported(D, ToFunction);
 
   // Set the parameters.
-  for (unsigned I = 0, N = Parameters.size(); I != N; ++I) {
-    Parameters[I]->setOwningFunction(ToFunction);
-    ToFunction->addDeclInternal(Parameters[I]);
+  for (ParmVarDecl *Param : Parameters) {
+    Param->setOwningFunction(ToFunction);
+    ToFunction->addDeclInternal(Param);
   }
   ToFunction->setParams(Parameters);
 
@@ -2237,6 +2358,16 @@
     ToFunction->setPreviousDecl(Recent);
   }
 
+  // We need to complete creation of FunctionProtoTypeLoc manually with setting
+  // params it refers to.
+  if (TInfo) {
+    if (auto ProtoLoc =
+        TInfo->getTypeLoc().IgnoreParens().getAs<FunctionProtoTypeLoc>()) {
+      for (unsigned I = 0, N = Parameters.size(); I != N; ++I)
+        ProtoLoc.setParam(I, Parameters[I]);
+    }
+  }
+
   if (usedDifferentExceptionSpec) {
     // Update FunctionProtoType::ExtProtoInfo.
     QualType T = Importer.Import(D->getType());
@@ -2254,8 +2385,17 @@
 
   // FIXME: Other bits to merge?
 
+  // If it is a template, import all related things.
+  if (ImportTemplateInformation(D, ToFunction))
+    return nullptr;
+
   // Add this function to the lexical context.
-  LexicalDC->addDeclInternal(ToFunction);
+  // NOTE: If the function is templated declaration, it should be not added into
+  // LexicalDC. But described template is imported during import of
+  // FunctionTemplateDecl (it happens later). So, we use source declaration
+  // to determine if we should add the result function.
+  if (!D->getDescribedFunctionTemplate())
+    LexicalDC->addDeclInternal(ToFunction);
 
   if (auto *FromCXXMethod = dyn_cast<CXXMethodDecl>(D))
     ImportOverrides(cast<CXXMethodDecl>(ToFunction), FromCXXMethod);
@@ -2749,6 +2889,14 @@
   if (FromDefArg && !ToDefArg)
     return nullptr;
 
+  if (D->isObjCMethodParameter()) {
+    ToParm->setObjCMethodScopeInfo(D->getFunctionScopeIndex());
+    ToParm->setObjCDeclQualifier(D->getObjCDeclQualifier());
+  } else {
+    ToParm->setScopeInfo(D->getFunctionScopeDepth(),
+                         D->getFunctionScopeIndex());
+  }
+
   if (D->isUsed())
     ToParm->setIsUsed();
 
@@ -3850,12 +3998,12 @@
       return nullptr;
   }
 
-  CXXRecordDecl *DTemplated = D->getTemplatedDecl();
-  
+  CXXRecordDecl *FromTemplated = D->getTemplatedDecl();
+
   // Create the declaration that is being templated.
-  CXXRecordDecl *D2Templated = cast_or_null<CXXRecordDecl>(
-        Importer.Import(DTemplated));
-  if (!D2Templated)
+  CXXRecordDecl *ToTemplated = cast_or_null<CXXRecordDecl>(
+        Importer.Import(FromTemplated));
+  if (!ToTemplated)
     return nullptr;
 
   // Resolve possible cyclic import.
@@ -3863,15 +4011,15 @@
     return AlreadyImported;
 
   // Create the class template declaration itself.
-  TemplateParameterList *TemplateParams
-    = ImportTemplateParameterList(D->getTemplateParameters());
+  TemplateParameterList *TemplateParams =
+      ImportTemplateParameterList(D->getTemplateParameters());
   if (!TemplateParams)
     return nullptr;
 
   ClassTemplateDecl *D2 = ClassTemplateDecl::Create(Importer.getToContext(), DC, 
                                                     Loc, Name, TemplateParams, 
-                                                    D2Templated);
-  D2Templated->setDescribedClassTemplate(D2);    
+                                                    ToTemplated);
+  ToTemplated->setDescribedClassTemplate(D2);
   
   D2->setAccess(D->getAccess());
   D2->setLexicalDeclContext(LexicalDC);
@@ -3879,10 +4027,10 @@
   
   // Note the relationship between the class templates.
   Importer.Imported(D, D2);
-  Importer.Imported(DTemplated, D2Templated);
+  Importer.Imported(FromTemplated, ToTemplated);
 
-  if (DTemplated->isCompleteDefinition() &&
-      !D2Templated->isCompleteDefinition()) {
+  if (FromTemplated->isCompleteDefinition() &&
+      !ToTemplated->isCompleteDefinition()) {
     // FIXME: Import definition!
   }
   
@@ -3958,12 +4106,8 @@
       // Import TemplateArgumentListInfo
       TemplateArgumentListInfo ToTAInfo;
       auto &ASTTemplateArgs = *PartialSpec->getTemplateArgsAsWritten();
-      for (unsigned I = 0, E = ASTTemplateArgs.NumTemplateArgs; I < E; ++I) {
-        if (auto ToLoc = ImportTemplateArgumentLoc(ASTTemplateArgs[I]))
-          ToTAInfo.addArgument(*ToLoc);
-        else
-          return nullptr;
-      }
+      if (ImportTemplateArgumentListInfo(ASTTemplateArgs.arguments(), ToTAInfo))
+        return nullptr;
 
       QualType CanonInjType = Importer.Import(
             PartialSpec->getInjectedSpecializationType());
@@ -4901,12 +5045,8 @@
   TemplateArgumentListInfo ToTAInfo;
   TemplateArgumentListInfo *ResInfo = nullptr;
   if (E->hasExplicitTemplateArgs()) {
-    for (const auto &FromLoc : E->template_arguments()) {
-      if (auto ToTALoc = ImportTemplateArgumentLoc(FromLoc))
-        ToTAInfo.addArgument(*ToTALoc);
-      else
-        return nullptr;
-    }
+    if (ImportTemplateArgumentListInfo(E->template_arguments(), ToTAInfo))
+      return nullptr;
     ResInfo = &ToTAInfo;
   }
 
@@ -5861,11 +6001,10 @@
   if (BaseType.isNull())
     return nullptr;
 
-  TemplateArgumentListInfo ToTAInfo(Importer.Import(E->getLAngleLoc()),
-                                    Importer.Import(E->getRAngleLoc()));
-  TemplateArgumentListInfo *ResInfo = nullptr;
+  TemplateArgumentListInfo ToTAInfo, *ResInfo = nullptr;
   if (E->hasExplicitTemplateArgs()) {
-    if (ImportTemplateArgumentListInfo(E->template_arguments(), ToTAInfo))
+    if (ImportTemplateArgumentListInfo(E->getLAngleLoc(), E->getRAngleLoc(),
+                                       E->template_arguments(), ToTAInfo))
       return nullptr;
     ResInfo = &ToTAInfo;
   }
@@ -5926,11 +6065,10 @@
       return nullptr;
   }
 
-  TemplateArgumentListInfo ToTAInfo(Importer.Import(E->getLAngleLoc()),
-                                    Importer.Import(E->getRAngleLoc()));
-  TemplateArgumentListInfo *ResInfo = nullptr;
+  TemplateArgumentListInfo ToTAInfo, *ResInfo = nullptr;
   if (E->hasExplicitTemplateArgs()) {
-    if (ImportTemplateArgumentListInfo(E->template_arguments(), ToTAInfo))
+    if (ImportTemplateArgumentListInfo(E->getLAngleLoc(), E->getRAngleLoc(),
+                                       E->template_arguments(), ToTAInfo))
       return nullptr;
     ResInfo = &ToTAInfo;
   }
@@ -5981,6 +6119,73 @@
              Importer.Import(E->getRParenLoc()));
 }
 
+Optional<LambdaCapture>
+ASTNodeImporter::ImportLambdaCapture(const LambdaCapture &From) {
+  VarDecl *Var = nullptr;
+  if (From.capturesVariable()) {
+    Var = cast_or_null<VarDecl>(Importer.Import(From.getCapturedVar()));
+    if (!Var)
+      return None;
+  }
+
+  return LambdaCapture(Importer.Import(From.getLocation()), From.isImplicit(),
+                       From.getCaptureKind(), Var,
+                       From.isPackExpansion()
+                         ? Importer.Import(From.getEllipsisLoc())
+                         : SourceLocation());
+}
+
+Expr *ASTNodeImporter::VisitLambdaExpr(LambdaExpr *LE) {
+  CXXRecordDecl *FromClass = LE->getLambdaClass();
+  auto *ToClass = dyn_cast_or_null<CXXRecordDecl>(Importer.Import(FromClass));
+  if (!ToClass)
+    return nullptr;
+
+  // NOTE: lambda classes are created with BeingDefined flag set up.
+  // It means that ImportDefinition doesn't work for them and we should fill it
+  // manually.
+  if (ToClass->isBeingDefined()) {
+    for (auto FromField : FromClass->fields()) {
+      auto *ToField = cast_or_null<FieldDecl>(Importer.Import(FromField));
+      if (!ToField)
+        return nullptr;
+    }
+  }
+
+  auto *ToCallOp = dyn_cast_or_null<CXXMethodDecl>(
+        Importer.Import(LE->getCallOperator()));
+  if (!ToCallOp)
+    return nullptr;
+
+  ToClass->completeDefinition();
+
+  unsigned NumCaptures = LE->capture_size();
+  SmallVector<LambdaCapture, 8> Captures;
+  Captures.reserve(NumCaptures);
+  for (const auto &FromCapture : LE->captures()) {
+    if (auto ToCapture = ImportLambdaCapture(FromCapture))
+      Captures.push_back(*ToCapture);
+    else
+      return nullptr;
+  }
+
+  SmallVector<Expr *, 8> InitCaptures(NumCaptures);
+  if (ImportContainerChecked(LE->capture_inits(), InitCaptures))
+    return nullptr;
+
+  return LambdaExpr::Create(Importer.getToContext(), ToClass,
+                            Importer.Import(LE->getIntroducerRange()),
+                            LE->getCaptureDefault(),
+                            Importer.Import(LE->getCaptureDefaultLoc()),
+                            Captures,
+                            LE->hasExplicitParameters(),
+                            LE->hasExplicitResultType(),
+                            InitCaptures,
+                            Importer.Import(LE->getLocEnd()),
+                            LE->containsUnexpandedParameterPack());
+}
+
+
 Expr *ASTNodeImporter::VisitInitListExpr(InitListExpr *ILE) {
   QualType T = Importer.Import(ILE->getType());
   if (T.isNull())