Implement template argument deduction for pack expansions whose
pattern is a template argument, which involves repeatedly deducing
template arguments using the pattern of the pack expansion, then
bundling the resulting deductions into an argument pack.

We can now handle a variety of simple list-handling metaprograms using
variadic templates. See, e.g., the new "count" metaprogram.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@122439 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaTemplateDeduction.cpp b/lib/Sema/SemaTemplateDeduction.cpp
index 80ba1a3..d5a036d 100644
--- a/lib/Sema/SemaTemplateDeduction.cpp
+++ b/lib/Sema/SemaTemplateDeduction.cpp
@@ -21,6 +21,7 @@
 #include "clang/AST/StmtVisitor.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
+#include "llvm/ADT/BitVector.h"
 #include <algorithm>
 
 namespace clang {
@@ -947,6 +948,24 @@
   return ArgIdx < NumArgs;
 }
 
+/// \brief Retrieve the depth and index of an unexpanded parameter pack.
+static std::pair<unsigned, unsigned> 
+getDepthAndIndex(UnexpandedParameterPack UPP) {
+  if (const TemplateTypeParmType *TTP
+                          = UPP.first.dyn_cast<const TemplateTypeParmType *>())
+    return std::make_pair(TTP->getDepth(), TTP->getIndex());
+  
+  if (TemplateTypeParmDecl *TTP = UPP.first.dyn_cast<TemplateTypeParmDecl *>())
+    return std::make_pair(TTP->getDepth(), TTP->getIndex());
+  
+  if (NonTypeTemplateParmDecl *NTTP 
+                              = UPP.first.dyn_cast<NonTypeTemplateParmDecl *>())
+    return std::make_pair(NTTP->getDepth(), NTTP->getIndex());
+  
+  TemplateTemplateParmDecl *TTP = UPP.first.get<TemplateTemplateParmDecl *>();
+  return std::make_pair(TTP->getDepth(), TTP->getIndex());
+}
+
 static Sema::TemplateDeductionResult
 DeduceTemplateArguments(Sema &S,
                         TemplateParameterList *TemplateParams,
@@ -955,18 +974,32 @@
                         TemplateDeductionInfo &Info,
                     llvm::SmallVectorImpl<DeducedTemplateArgument> &Deduced,
                         bool NumberOfArgumentsMustMatch) {
+  // C++0x [temp.deduct.type]p9:
+  //   If the template argument list of P contains a pack expansion that is not 
+  //   the last template argument, the entire template argument list is a 
+  //   non-deduced context.
+  // FIXME: Implement this.
+
+
+  // C++0x [temp.deduct.type]p9:
+  //   If P has a form that contains <T> or <i>, then each argument Pi of the 
+  //   respective template argument list P is compared with the corresponding 
+  //   argument Ai of the corresponding template argument list of A.
   unsigned ArgIdx = 0, ParamIdx = 0;
   for (; hasTemplateArgumentForDeduction(Params, ParamIdx, NumParams); 
        ++ParamIdx) {
+    // FIXME: Variadic templates.
+    // What do we do if the argument is a pack expansion?
+    
     if (!Params[ParamIdx].isPackExpansion()) {
-      // The simple case: deduce template arguments by matching P and A.
+      // The simple case: deduce template arguments by matching Pi and Ai.
       
       // Check whether we have enough arguments.
       if (!hasTemplateArgumentForDeduction(Args, ArgIdx, NumArgs))
         return NumberOfArgumentsMustMatch? Sema::TDK_TooFewArguments 
                                          : Sema::TDK_Success;
       
-      // Perform deduction for this P/A pair.
+      // Perform deduction for this Pi/Ai pair.
       if (Sema::TemplateDeductionResult Result
           = DeduceTemplateArguments(S, TemplateParams,
                                     Params[ParamIdx], Args[ArgIdx],
@@ -978,12 +1011,107 @@
       continue;
     }
     
-    // FIXME: Variadic templates. 
-    // The parameter is a pack expansion, so we'll
-    // need to repeatedly unify arguments against the parameter, capturing
-    // the bindings for each expanded parameter pack.
-    S.Diag(Info.getLocation(), diag::err_pack_expansion_deduction);
-    return Sema::TDK_TooManyArguments;
+    // The parameter is a pack expansion.
+    
+    // C++0x [temp.deduct.type]p9:
+    //   If Pi is a pack expansion, then the pattern of Pi is compared with 
+    //   each remaining argument in the template argument list of A. Each 
+    //   comparison deduces template arguments for subsequent positions in the 
+    //   template parameter packs expanded by Pi.
+    TemplateArgument Pattern = Params[ParamIdx].getPackExpansionPattern();
+    
+    // Compute the set of template parameter indices that correspond to
+    // parameter packs expanded by the pack expansion.
+    llvm::SmallVector<unsigned, 2> PackIndices;
+    {
+      llvm::BitVector SawIndices(TemplateParams->size());
+      llvm::SmallVector<UnexpandedParameterPack, 2> Unexpanded;
+      S.collectUnexpandedParameterPacks(Pattern, Unexpanded);
+      for (unsigned I = 0, N = Unexpanded.size(); I != N; ++I) {
+        unsigned Depth, Index;
+        llvm::tie(Depth, Index) = getDepthAndIndex(Unexpanded[I]);
+        if (Depth == 0 && !SawIndices[Index]) {
+          SawIndices[Index] = true;
+          PackIndices.push_back(Index);
+        }
+      }
+    }
+    assert(!PackIndices.empty() && "Pack expansion without unexpanded packs?");
+        
+    // FIXME: If there are no remaining arguments, we can bail out early
+    // and set any deduced parameter packs to an empty argument pack.
+    // The latter part of this is a (minor) correctness issue.
+    
+    // Save the deduced template arguments for each parameter pack expanded
+    // by this pack expansion, then clear out the deduction.
+    llvm::SmallVector<DeducedTemplateArgument, 2> 
+      SavedPacks(PackIndices.size());
+    for (unsigned I = 0, N = PackIndices.size(); I != N; ++I) {
+      SavedPacks[I] = Deduced[PackIndices[I]];
+      Deduced[PackIndices[I]] = DeducedTemplateArgument();
+    }
+
+    // Keep track of the deduced template arguments for each parameter pack
+    // expanded by this pack expansion (the outer index) and for each 
+    // template argument (the inner SmallVectors).
+    llvm::SmallVector<llvm::SmallVector<DeducedTemplateArgument, 4>, 2>
+      NewlyDeducedPacks(PackIndices.size());
+    bool HasAnyArguments = false;
+    while (hasTemplateArgumentForDeduction(Args, ArgIdx, NumArgs)) {
+      HasAnyArguments = true;
+      
+      // Deduce template arguments from the pattern.
+      if (Sema::TemplateDeductionResult Result 
+            = DeduceTemplateArguments(S, TemplateParams, Pattern, Args[ArgIdx],
+                                      Info, Deduced))
+        return Result;
+      
+      // Capture the deduced template arguments for each parameter pack expanded
+      // by this pack expansion, add them to the list of arguments we've deduced
+      // for that pack, then clear out the deduced argument.
+      for (unsigned I = 0, N = PackIndices.size(); I != N; ++I) {
+        DeducedTemplateArgument &DeducedArg = Deduced[PackIndices[I]];
+        if (!DeducedArg.isNull()) {
+          NewlyDeducedPacks[I].push_back(DeducedArg);
+          DeducedArg = DeducedTemplateArgument();
+        }
+      }
+      
+      ++ArgIdx;
+    }
+    
+    // Build argument packs for each of the parameter packs expanded by this
+    // pack expansion.
+    for (unsigned I = 0, N = PackIndices.size(); I != N; ++I) {
+      if (HasAnyArguments && NewlyDeducedPacks[I].empty()) {
+        // We were not able to deduce anything for this parameter pack,
+        // so just restore the saved argument pack.
+        Deduced[PackIndices[I]] = SavedPacks[I];
+        continue;
+      }
+      
+      if (!SavedPacks[I].isNull()) {
+        // FIXME: Check against the existing argument pack.
+        S.Diag(Info.getLocation(), diag::err_pack_expansion_deduction_compare);
+        return Sema::TDK_TooFewArguments;
+      }
+      
+      if (NewlyDeducedPacks[I].empty()) {
+        // If we deduced an empty argument pack, create it now.
+        Deduced[PackIndices[I]]
+          = DeducedTemplateArgument(TemplateArgument(0, 0));
+        continue;
+      }
+                
+      TemplateArgument *ArgumentPack
+        = new (S.Context) TemplateArgument [NewlyDeducedPacks[I].size()];
+      std::copy(NewlyDeducedPacks[I].begin(), NewlyDeducedPacks[I].end(),
+                ArgumentPack);
+      Deduced[PackIndices[I]]
+        = DeducedTemplateArgument(TemplateArgument(ArgumentPack,
+                                                   NewlyDeducedPacks[I].size()),
+                            NewlyDeducedPacks[I][0].wasDeducedFromArrayBound());
+    }
   }
   
   // If there is an argument remaining, then we had too many arguments.
@@ -1087,12 +1215,16 @@
   // C++ [temp.deduct.type]p2:
   //   [...] or if any template argument remains neither deduced nor
   //   explicitly specified, template argument deduction fails.
+  // FIXME: Variadic templates Empty parameter packs?
   llvm::SmallVector<TemplateArgument, 4> Builder;
   for (unsigned I = 0, N = Deduced.size(); I != N; ++I) {
     if (Deduced[I].isNull()) {
+      unsigned ParamIdx = I;
+      if (ParamIdx >= Partial->getTemplateParameters()->size())
+        ParamIdx = Partial->getTemplateParameters()->size() - 1;
       Decl *Param
         = const_cast<NamedDecl *>(
-                                Partial->getTemplateParameters()->getParam(I));
+                          Partial->getTemplateParameters()->getParam(ParamIdx));
       Info.Param = makeTemplateParameter(Param);
       return Sema::TDK_Incomplete;
     }
@@ -1118,22 +1250,23 @@
   ClassTemplateDecl *ClassTemplate = Partial->getSpecializedTemplate();
   const TemplateArgumentLoc *PartialTemplateArgs
     = Partial->getTemplateArgsAsWritten();
-  unsigned N = Partial->getNumTemplateArgsAsWritten();
 
   // Note that we don't provide the langle and rangle locations.
   TemplateArgumentListInfo InstArgs;
 
-  for (unsigned I = 0; I != N; ++I) {
-    Decl *Param = const_cast<NamedDecl *>(
-                    ClassTemplate->getTemplateParameters()->getParam(I));
-    TemplateArgumentLoc InstArg;
-    if (S.Subst(PartialTemplateArgs[I], InstArg,
-                MultiLevelTemplateArgumentList(*DeducedArgumentList))) {
-      Info.Param = makeTemplateParameter(Param);
-      Info.FirstArg = PartialTemplateArgs[I].getArgument();
-      return Sema::TDK_SubstitutionFailure;
-    }
-    InstArgs.addArgument(InstArg);
+  if (S.Subst(PartialTemplateArgs,
+              Partial->getNumTemplateArgsAsWritten(),
+              InstArgs, MultiLevelTemplateArgumentList(*DeducedArgumentList))) {
+    unsigned ArgIdx = InstArgs.size(), ParamIdx = ArgIdx;
+    if (ParamIdx >= Partial->getTemplateParameters()->size())
+      ParamIdx = Partial->getTemplateParameters()->size() - 1;
+
+    Decl *Param
+      = const_cast<NamedDecl *>(
+                          Partial->getTemplateParameters()->getParam(ParamIdx));
+    Info.Param = makeTemplateParameter(Param);
+    Info.FirstArg = PartialTemplateArgs[ArgIdx].getArgument();
+    return Sema::TDK_SubstitutionFailure;
   }
 
   llvm::SmallVector<TemplateArgument, 4> ConvertedInstArgs;