Several improvements to template argument deduction:
  - Once we have deduced template arguments for a class template partial
    specialization, we use exactly those template arguments for instantiating
    the definition of the class template partial specialization.
  - Added template argument deduction for non-type template parameters.
  - Added template argument deduction for dependently-sized array types.

With these changes, we can now implement, e.g., the remove_reference
type trait. Also, Daniel's Ackermann template metaprogram now compiles
properly.



git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@72909 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/clang/AST/DeclTemplate.h b/include/clang/AST/DeclTemplate.h
index 556ee28..c477b76 100644
--- a/include/clang/AST/DeclTemplate.h
+++ b/include/clang/AST/DeclTemplate.h
@@ -203,6 +203,9 @@
 
   /// Get the position of the template parameter within its parameter list.
   unsigned getPosition() const { return Position; }
+  
+  /// Get the index of the template parameter within its parameter list.
+  unsigned getIndex() const { return Position; }
 };
 
 /// TemplateTypeParmDecl - Declaration of a template type parameter,
@@ -299,7 +302,8 @@
 
   using TemplateParmPosition::getDepth;
   using TemplateParmPosition::getPosition;
-
+  using TemplateParmPosition::getIndex;
+    
   /// \brief Determine whether this template parameter has a default
   /// argument.
   bool hasDefaultArgument() const { return DefaultArgument; }
@@ -350,7 +354,8 @@
 
   using TemplateParmPosition::getDepth;
   using TemplateParmPosition::getPosition;
-
+  using TemplateParmPosition::getIndex;
+    
   /// \brief Determine whether this template parameter has a default
   /// argument.
   bool hasDefaultArgument() const { return DefaultArgument; }
@@ -523,6 +528,12 @@
     return QualType::getFromOpaquePtr(Integer.Type);
   }
 
+  void setIntegralType(QualType T) {
+    assert(Kind == Integral && 
+           "Cannot set the integral type of a non-integral template argument");
+    Integer.Type = T.getAsOpaquePtr();
+  };
+
   /// \brief Retrieve the template argument as an expression.
   Expr *getAsExpr() const {
     if (Kind != Expression)
diff --git a/lib/Sema/Sema.h b/lib/Sema/Sema.h
index 3969da8..d8cc01e 100644
--- a/lib/Sema/Sema.h
+++ b/lib/Sema/Sema.h
@@ -2019,8 +2019,9 @@
                              const IdentifierInfo &II,
                              SourceRange Range);
 
-  bool DeduceTemplateArguments(ClassTemplatePartialSpecializationDecl *Partial,
-                               const TemplateArgumentList &TemplateArgs);
+  TemplateArgumentList *
+  DeduceTemplateArguments(ClassTemplatePartialSpecializationDecl *Partial,
+                          const TemplateArgumentList &TemplateArgs);
                              
   //===--------------------------------------------------------------------===//
   // C++ Template Instantiation
@@ -2227,7 +2228,7 @@
 
   QualType InstantiateType(QualType T, const TemplateArgumentList &TemplateArgs,
                            SourceLocation Loc, DeclarationName Entity);
-
+  
   OwningExprResult InstantiateExpr(Expr *E, 
                                    const TemplateArgumentList &TemplateArgs);
 
diff --git a/lib/Sema/SemaTemplateDeduction.cpp b/lib/Sema/SemaTemplateDeduction.cpp
index 82b027c..87968bf 100644
--- a/lib/Sema/SemaTemplateDeduction.cpp
+++ b/lib/Sema/SemaTemplateDeduction.cpp
@@ -20,6 +20,86 @@
 #include "llvm/Support/Compiler.h"
 using namespace clang;
 
+/// \brief If the given expression is of a form that permits the deduction
+/// of a non-type template parameter, return the declaration of that
+/// non-type template parameter.
+static NonTypeTemplateParmDecl *getDeducedParameterFromExpr(Expr *E) {
+  if (ImplicitCastExpr *IC = dyn_cast<ImplicitCastExpr>(E))
+    E = IC->getSubExpr();
+  
+  if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E))
+    return dyn_cast<NonTypeTemplateParmDecl>(DRE->getDecl());
+  
+  return 0;
+}
+
+/// \brief Deduce the value of the given non-type template parameter 
+/// from the given constant.
+///
+/// \returns true if deduction succeeded, false otherwise.
+static bool DeduceNonTypeTemplateArgument(ASTContext &Context, 
+                                          NonTypeTemplateParmDecl *NTTP, 
+                                          llvm::APInt Value,
+                             llvm::SmallVectorImpl<TemplateArgument> &Deduced) {
+  assert(NTTP->getDepth() == 0 && 
+         "Cannot deduce non-type template argument with depth > 0");
+  
+  if (Deduced[NTTP->getIndex()].isNull()) {
+    Deduced[NTTP->getIndex()] = TemplateArgument(SourceLocation(), 
+                                                 llvm::APSInt(Value),
+                                                 NTTP->getType());
+    return true;
+  }
+  
+  if (Deduced[NTTP->getIndex()].getKind() != TemplateArgument::Integral)
+    return false;
+  
+  // If the template argument was previously deduced to a negative value, 
+  // then our deduction fails.
+  const llvm::APSInt *PrevValuePtr = Deduced[NTTP->getIndex()].getAsIntegral();
+  assert(PrevValuePtr && "Not an integral template argument?");
+  if (PrevValuePtr->isSigned() && PrevValuePtr->isNegative())
+    return false;
+  
+  llvm::APInt PrevValue = *PrevValuePtr;
+  if (Value.getBitWidth() > PrevValue.getBitWidth())
+    PrevValue.zext(Value.getBitWidth());
+  else if (Value.getBitWidth() < PrevValue.getBitWidth())
+    Value.zext(PrevValue.getBitWidth());
+  return Value == PrevValue;
+}
+
+/// \brief Deduce the value of the given non-type template parameter 
+/// from the given type- or value-dependent expression.
+///
+/// \returns true if deduction succeeded, false otherwise.
+
+static bool DeduceNonTypeTemplateArgument(ASTContext &Context, 
+                                          NonTypeTemplateParmDecl *NTTP,
+                                          Expr *Value,
+                            llvm::SmallVectorImpl<TemplateArgument> &Deduced) {
+  assert(NTTP->getDepth() == 0 && 
+         "Cannot deduce non-type template argument with depth > 0");
+  assert((Value->isTypeDependent() || Value->isValueDependent()) &&
+         "Expression template argument must be type- or value-dependent.");
+  
+  if (Deduced[NTTP->getIndex()].isNull()) {
+    // FIXME: Clone the Value?
+    Deduced[NTTP->getIndex()] = TemplateArgument(Value);
+    return true;
+  }
+  
+  if (Deduced[NTTP->getIndex()].getKind() == TemplateArgument::Integral) {
+    // Okay, we deduced a constant in one case and a dependent expression 
+    // in another case. FIXME: Later, we will check that instantiating the 
+    // dependent expression gives us the constant value.
+    return true;
+  }
+  
+  // FIXME: Compare the expressions for equality!
+  return true;
+}
+
 static bool DeduceTemplateArguments(ASTContext &Context, QualType Param, 
                                     QualType Arg,
                              llvm::SmallVectorImpl<TemplateArgument> &Deduced) {
@@ -33,8 +113,14 @@
   if (!Param->isDependentType())
     return Param == Arg;
 
-  // FIXME: Use a visitor or switch to handle all of the kinds of
-  // types that the parameter may be.
+  // C++ [temp.deduct.type]p9:
+  //
+  //   A template type argument T, a template template argument TT or a 
+  //   template non-type argument i can be deduced if P and A have one of 
+  //   the following forms:
+  //
+  //     T
+  //     cv-list T
   if (const TemplateTypeParmType *TemplateTypeParm 
         = Param->getAsTemplateTypeParmType()) {
     // The argument type can not be less qualified than the parameter
@@ -67,6 +153,12 @@
     return false;
 
   switch (Param->getTypeClass()) {
+    // No deduction possible for these types
+    case Type::Builtin:
+      return false;
+      
+      
+    //     T *
     case Type::Pointer: {
       const PointerType *PointerArg = Arg->getAsPointerType();
       if (!PointerArg)
@@ -78,6 +170,7 @@
                                      Deduced);
     }
       
+    //     T &
     case Type::LValueReference: {
       const LValueReferenceType *ReferenceArg = Arg->getAsLValueReferenceType();
       if (!ReferenceArg)
@@ -89,6 +182,7 @@
                                      Deduced);
     }
 
+    //     T && [C++0x]
     case Type::RValueReference: {
       const RValueReferenceType *ReferenceArg = Arg->getAsRValueReferenceType();
       if (!ReferenceArg)
@@ -100,6 +194,7 @@
                                      Deduced);
     }
       
+    //     T [] (implied, but not stated explicitly)
     case Type::IncompleteArray: {
       const IncompleteArrayType *IncompleteArrayArg = 
         Context.getAsIncompleteArrayType(Arg);
@@ -111,7 +206,8 @@
                                      IncompleteArrayArg->getElementType(),
                                      Deduced);
     }
-    
+
+    //     T [integer-constant]
     case Type::ConstantArray: {
       const ConstantArrayType *ConstantArrayArg = 
         Context.getAsConstantArrayType(Arg);
@@ -129,6 +225,46 @@
                                      Deduced);
     }
 
+    //     type [i]
+    case Type::DependentSizedArray: {
+      const ArrayType *ArrayArg = dyn_cast<ArrayType>(Arg);
+      if (!ArrayArg)
+        return false;
+      
+      // Check the element type of the arrays
+      const DependentSizedArrayType *DependentArrayParm
+        = cast<DependentSizedArrayType>(Param);
+      if (!DeduceTemplateArguments(Context,
+                                   DependentArrayParm->getElementType(),
+                                   ArrayArg->getElementType(),
+                                   Deduced))
+        return false;
+          
+      // Determine the array bound is something we can deduce.
+      NonTypeTemplateParmDecl *NTTP 
+        = getDeducedParameterFromExpr(DependentArrayParm->getSizeExpr());
+      if (!NTTP)
+        return true;
+      
+      // We can perform template argument deduction for the given non-type 
+      // template parameter.
+      assert(NTTP->getDepth() == 0 && 
+             "Cannot deduce non-type template argument at depth > 0");
+      if (const ConstantArrayType *ConstantArrayArg 
+            = dyn_cast<ConstantArrayType>(ArrayArg))
+        return DeduceNonTypeTemplateArgument(Context, NTTP, 
+                                             ConstantArrayArg->getSize(),
+                                             Deduced);
+      if (const DependentSizedArrayType *DependentArrayArg
+            = dyn_cast<DependentSizedArrayType>(ArrayArg))
+        return DeduceNonTypeTemplateArgument(Context, NTTP,
+                                             DependentArrayArg->getSizeExpr(),
+                                             Deduced);
+      
+      // Incomplete type does not match a dependently-sized array type
+      return false;
+    }
+      
     default:
       break;
   }
@@ -141,16 +277,53 @@
 DeduceTemplateArguments(ASTContext &Context, const TemplateArgument &Param,
                         const TemplateArgument &Arg,
                         llvm::SmallVectorImpl<TemplateArgument> &Deduced) {
-  assert(Param.getKind() == Arg.getKind() &&
-         "Template argument kind mismatch during deduction");
   switch (Param.getKind()) {
+  case TemplateArgument::Null:
+    assert(false && "Null template argument in parameter list");
+    break;
+      
   case TemplateArgument::Type: 
+    assert(Arg.getKind() == TemplateArgument::Type && "Type/value mismatch");
     return DeduceTemplateArguments(Context, Param.getAsType(), 
                                    Arg.getAsType(), Deduced);
 
-  default:
+  case TemplateArgument::Declaration:
+    // FIXME: Implement this check
+    assert(false && "Unimplemented template argument deduction case");
     return false;
+      
+  case TemplateArgument::Integral:
+    if (Arg.getKind() == TemplateArgument::Integral) {
+      // FIXME: Zero extension + sign checking here?
+      return *Param.getAsIntegral() == *Arg.getAsIntegral();
+    }
+    if (Arg.getKind() == TemplateArgument::Expression)
+      return false;
+
+    assert(false && "Type/value mismatch");
+    return false;
+      
+  case TemplateArgument::Expression: {
+    if (NonTypeTemplateParmDecl *NTTP 
+          = getDeducedParameterFromExpr(Param.getAsExpr())) {
+      if (Arg.getKind() == TemplateArgument::Integral)
+        // FIXME: Sign problems here
+        return DeduceNonTypeTemplateArgument(Context, NTTP, 
+                                             *Arg.getAsIntegral(), Deduced);
+      if (Arg.getKind() == TemplateArgument::Expression)
+        return DeduceNonTypeTemplateArgument(Context, NTTP, Arg.getAsExpr(),
+                                             Deduced);
+      
+      assert(false && "Type/value mismatch");
+      return false;
+    }
+    
+    // Can't deduce anything, but that's okay.
+    return true;
   }
+  }
+      
+  return true;
 }
 
 static bool 
@@ -167,11 +340,50 @@
 }
 
 
-bool 
+TemplateArgumentList * 
 Sema::DeduceTemplateArguments(ClassTemplatePartialSpecializationDecl *Partial,
                               const TemplateArgumentList &TemplateArgs) {
+  // Deduce the template arguments for the partial specialization
   llvm::SmallVector<TemplateArgument, 4> Deduced;
   Deduced.resize(Partial->getTemplateParameters()->size());
-  return ::DeduceTemplateArguments(Context, Partial->getTemplateArgs(), 
-                                  TemplateArgs, Deduced);
+  if (! ::DeduceTemplateArguments(Context, Partial->getTemplateArgs(), 
+                                  TemplateArgs, Deduced))
+    return 0;
+  
+  // FIXME: Substitute the deduced template arguments into the template
+  // arguments of the class template partial specialization; the resulting
+  // template arguments should match TemplateArgs exactly.
+  
+  for (unsigned I = 0, N = Deduced.size(); I != N; ++I) {
+    TemplateArgument &Arg = Deduced[I];
+
+    // FIXME: If this template argument was not deduced, but the corresponding
+    // template parameter has a default argument, instantiate the default
+    // argument.
+    if (Arg.isNull()) // FIXME: Result->Destroy(Context);
+      return 0;
+    
+    if (Arg.getKind() == TemplateArgument::Integral) {
+      // FIXME: Instantiate the type, but we need some context!
+      const NonTypeTemplateParmDecl *Parm 
+        = cast<NonTypeTemplateParmDecl>(Partial->getTemplateParameters()
+                                          ->getParam(I));
+      //      QualType T = InstantiateType(Parm->getType(), *Result,
+      //                                   Parm->getLocation(), Parm->getDeclName());
+      //      if (T.isNull()) // FIXME: Result->Destroy(Context);
+      //        return 0;
+      QualType T = Parm->getType();
+      
+      // FIXME: Make sure we didn't overflow our data type!
+      llvm::APSInt &Value = *Arg.getAsIntegral();
+      unsigned AllowedBits = Context.getTypeSize(T);
+      if (Value.getBitWidth() != AllowedBits)
+        Value.extOrTrunc(AllowedBits);
+      Value.setIsSigned(T->isSignedIntegerType());
+      Arg.setIntegralType(T);
+    }
+  }
+  
+  return new (Context) TemplateArgumentList(Context, Deduced.data(),
+                                            Deduced.size(), /*CopyArgs=*/true);
 }
diff --git a/lib/Sema/SemaTemplateInstantiate.cpp b/lib/Sema/SemaTemplateInstantiate.cpp
index 0400b4c..562749e 100644
--- a/lib/Sema/SemaTemplateInstantiate.cpp
+++ b/lib/Sema/SemaTemplateInstantiate.cpp
@@ -833,21 +833,23 @@
 
   // Determine whether any class template partial specializations
   // match the given template arguments.
-  llvm::SmallVector<ClassTemplatePartialSpecializationDecl *, 4> Matched;
+  typedef std::pair<ClassTemplatePartialSpecializationDecl *,
+                    TemplateArgumentList *> MatchResult;
+  llvm::SmallVector<MatchResult, 4> Matched;
   for (llvm::FoldingSet<ClassTemplatePartialSpecializationDecl>::iterator 
          Partial = Template->getPartialSpecializations().begin(),
          PartialEnd = Template->getPartialSpecializations().end();
        Partial != PartialEnd;
        ++Partial) {
-    if (DeduceTemplateArguments(&*Partial, ClassTemplateSpec->getTemplateArgs()))
-      Matched.push_back(&*Partial);
+    if (TemplateArgumentList *Deduced 
+          = DeduceTemplateArguments(&*Partial, 
+                                    ClassTemplateSpec->getTemplateArgs()))
+      Matched.push_back(std::make_pair(&*Partial, Deduced));
   }
 
   if (Matched.size() == 1) {
-    Pattern = Matched[0];
-    // FIXME: set TemplateArgs to the template arguments of the
-    // partial specialization, instantiated with the deduced template
-    // arguments.
+    Pattern = Matched[0].first;
+    TemplateArgs = Matched[0].second;
   } else if (Matched.size() > 1) {
     // FIXME: Implement partial ordering of class template partial
     // specializations.
@@ -860,9 +862,17 @@
                         ExplicitInstantiation? TSK_ExplicitInstantiation 
                                              : TSK_ImplicitInstantiation);
 
-  return InstantiateClass(ClassTemplateSpec->getLocation(),
-                          ClassTemplateSpec, Pattern, *TemplateArgs,
-                          ExplicitInstantiation);
+  bool Result = InstantiateClass(ClassTemplateSpec->getLocation(),
+                                 ClassTemplateSpec, Pattern, *TemplateArgs,
+                                 ExplicitInstantiation);
+  
+  for (unsigned I = 0, N = Matched.size(); I != N; ++I) {
+    // FIXME: Implement TemplateArgumentList::Destroy!
+    //    if (Matched[I].first != Pattern)
+    //      Matched[I].second->Destroy(Context);
+  }
+  
+  return Result;
 }
 
 /// \brief Instantiate the definitions of all of the member of the
diff --git a/lib/Sema/SemaTemplateInstantiateExpr.cpp b/lib/Sema/SemaTemplateInstantiateExpr.cpp
index a6b9703..5ba42f2 100644
--- a/lib/Sema/SemaTemplateInstantiateExpr.cpp
+++ b/lib/Sema/SemaTemplateInstantiateExpr.cpp
@@ -119,12 +119,13 @@
                                           T->isWideCharType(),
                                           T, 
                                        E->getSourceRange().getBegin()));
-    else if (T->isBooleanType())
+    if (T->isBooleanType())
       return SemaRef.Owned(new (SemaRef.Context) CXXBoolLiteralExpr(
                                           Arg.getAsIntegral()->getBoolValue(),
                                                  T, 
                                        E->getSourceRange().getBegin()));
 
+    assert(Arg.getAsIntegral()->getBitWidth() == SemaRef.Context.getIntWidth(T));
     return SemaRef.Owned(new (SemaRef.Context) IntegerLiteral(
                                                  *Arg.getAsIntegral(),
                                                  T, 
diff --git a/test/SemaTemplate/ackermann.cpp b/test/SemaTemplate/ackermann.cpp
new file mode 100644
index 0000000..48fbbbb
--- /dev/null
+++ b/test/SemaTemplate/ackermann.cpp
@@ -0,0 +1,37 @@
+// RUN: clang-cc -fsyntax-only -ftemplate-depth=1000 -verify %s
+
+// template<unsigned M, unsigned N>
+// struct Ackermann {
+//   enum {
+//     value = M ? (N ? Ackermann<M-1, Ackermann<M-1, N-1> >::value
+//                    : Ackermann<M-1, 1>::value)
+//               : N + 1
+//   };
+// };
+
+template<unsigned M, unsigned N>
+struct Ackermann {
+ enum {
+   value = Ackermann<M-1, Ackermann<M, N-1>::value >::value
+ };
+};
+
+template<unsigned M> struct Ackermann<M, 0> {
+ enum {
+   value = Ackermann<M-1, 1>::value
+ };
+};
+
+template<unsigned N> struct Ackermann<0, N> {
+ enum {
+   value = N + 1
+ };
+};
+
+template<> struct Ackermann<0, 0> {
+ enum {
+   value = 1
+ };
+};
+
+int g0[Ackermann<3, 8>::value == 2045 ? 1 : -1];
diff --git a/test/SemaTemplate/temp_class_spec.cpp b/test/SemaTemplate/temp_class_spec.cpp
index d516f01..8cb46cf 100644
--- a/test/SemaTemplate/temp_class_spec.cpp
+++ b/test/SemaTemplate/temp_class_spec.cpp
@@ -51,6 +51,19 @@
 int is_same3[is_same<int_ptr, int>::value? -1 : 1];
 
 template<typename T>
+struct remove_reference {
+  typedef T type;
+};
+
+template<typename T>
+struct remove_reference<T&> {
+  typedef T type;
+};
+
+int remove_ref0[is_same<remove_reference<int>::type, int>::value? 1 : -1];
+int remove_ref1[is_same<remove_reference<int&>::type, int>::value? 1 : -1];
+                
+template<typename T>
 struct is_incomplete_array {
   static const bool value = false;
 };
@@ -79,3 +92,13 @@
 int array_with_4_elements1[is_array_with_4_elements<int[1]>::value ? -1 : 1];
 int array_with_4_elements2[is_array_with_4_elements<int[4]>::value ? 1 : -1];
 int array_with_4_elements3[is_array_with_4_elements<int[4][2]>::value ? 1 : -1];
+
+template<typename T>
+struct get_array_size;
+
+template<typename T, unsigned N>
+struct get_array_size<T[N]> {
+  static const unsigned value = N;
+};
+
+int array_size0[get_array_size<int[12]>::value == 12? 1 : -1];