Add a new attribute 'enable_if' which can be used to control overload resolution based on the values of the function arguments at the call site.

llvm-svn: 198996
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 945525b..6555a3d 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -827,6 +827,34 @@
                                Attr.getAttributeSpellingListIndex()));
 }
 
+static void handleEnableIfAttr(Sema &S, Decl *D, const AttributeList &Attr) {
+  Expr *Cond = Attr.getArgAsExpr(0);
+  if (!Cond->isTypeDependent()) {
+    ExprResult Converted = S.PerformContextuallyConvertToBool(Cond);
+    if (Converted.isInvalid())
+      return;
+    Cond = Converted.take();
+  }
+
+  StringRef Msg;
+  if (!S.checkStringLiteralArgumentAttr(Attr, 1, Msg))
+    return;
+
+  SmallVector<PartialDiagnosticAt, 8> Diags;
+  if (!Cond->isValueDependent() &&
+      !Expr::isPotentialConstantExprUnevaluated(Cond, cast<FunctionDecl>(D),
+                                                Diags)) {
+    S.Diag(Attr.getLoc(), diag::err_enable_if_never_constant_expr);
+    for (int I = 0, N = Diags.size(); I != N; ++I)
+      S.Diag(Diags[I].first, Diags[I].second);
+    return;
+  }
+
+  D->addAttr(::new (S.Context)
+             EnableIfAttr(Attr.getRange(), S.Context, Cond, Msg,
+                          Attr.getAttributeSpellingListIndex()));
+}
+
 static void handleConsumableAttr(Sema &S, Decl *D, const AttributeList &Attr) {
   ConsumableAttr::ConsumedState DefaultState;
 
@@ -3990,6 +4018,7 @@
     handleAttrWithMessage<DeprecatedAttr>(S, D, Attr);
     break;
   case AttributeList::AT_Destructor:  handleDestructorAttr  (S, D, Attr); break;
+  case AttributeList::AT_EnableIf:    handleEnableIfAttr    (S, D, Attr); break;
   case AttributeList::AT_ExtVectorType:
     handleExtVectorTypeAttr(S, scope, D, Attr);
     break;
diff --git a/clang/lib/Sema/SemaDeclCXX.cpp b/clang/lib/Sema/SemaDeclCXX.cpp
index 28f038a..820f57f 100644
--- a/clang/lib/Sema/SemaDeclCXX.cpp
+++ b/clang/lib/Sema/SemaDeclCXX.cpp
@@ -6079,6 +6079,18 @@
   PopDeclContext();
 }
 
+/// This is used to implement the constant expression evaluation part of the
+/// attribute enable_if extension. There is nothing in standard C++ which would
+/// require reentering parameters.
+void Sema::ActOnReenterCXXMethodParameter(Scope *S, ParmVarDecl *Param) {
+  if (!Param)
+    return;
+
+  S->AddDecl(Param);
+  if (Param->getDeclName())
+    IdResolver.AddDecl(Param);
+}
+
 /// ActOnStartDelayedCXXMethodDeclaration - We have completed
 /// parsing a top-level (non-nested) C++ class, and we are now
 /// parsing those parts of the given Method declaration that could
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ca261cd..f479dc8 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -4474,6 +4474,21 @@
   else if (isa<MemberExpr>(NakedFn))
     NDecl = cast<MemberExpr>(NakedFn)->getMemberDecl();
 
+  if (FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(NDecl)) {
+    if (FD->hasAttr<EnableIfAttr>()) {
+      if (const EnableIfAttr *Attr = CheckEnableIf(FD, ArgExprs, true)) {
+        Diag(Fn->getLocStart(),
+             isa<CXXMethodDecl>(FD) ?
+                 diag::err_ovl_no_viable_member_function_in_call :
+                 diag::err_ovl_no_viable_function_in_call)
+          << FD << FD->getSourceRange();
+        Diag(FD->getLocation(),
+             diag::note_ovl_candidate_disabled_by_enable_if_attr)
+            << Attr->getCond()->getSourceRange() << Attr->getMessage();
+      }
+    }
+  }
+
   return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc,
                                ExecConfig, IsExecConfig);
 }
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 1333748..6032ed3 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -1008,8 +1008,8 @@
       isa<FunctionNoProtoType>(NewQType.getTypePtr()))
     return false;
 
-  const FunctionProtoType* OldType = cast<FunctionProtoType>(OldQType);
-  const FunctionProtoType* NewType = cast<FunctionProtoType>(NewQType);
+  const FunctionProtoType *OldType = cast<FunctionProtoType>(OldQType);
+  const FunctionProtoType *NewType = cast<FunctionProtoType>(NewQType);
 
   // The signature of a function includes the types of its
   // parameters (C++ 1.3.10), which includes the presence or absence
@@ -1085,6 +1085,22 @@
       return true;
   }
 
+  // enable_if attributes are an order-sensitive part of the signature.
+  for (specific_attr_iterator<EnableIfAttr>
+         NewI = New->specific_attr_begin<EnableIfAttr>(),
+         NewE = New->specific_attr_end<EnableIfAttr>(),
+         OldI = Old->specific_attr_begin<EnableIfAttr>(),
+         OldE = Old->specific_attr_end<EnableIfAttr>();
+       NewI != NewE || OldI != OldE; ++NewI, ++OldI) {
+    if (NewI == NewE || OldI == OldE)
+      return true;
+    llvm::FoldingSetNodeID NewID, OldID;
+    NewI->getCond()->Profile(NewID, Context, true);
+    OldI->getCond()->Profile(OldID, Context, true);
+    if (!(NewID == OldID))
+      return true;
+  }
+
   // The signatures match; this is not an overload.
   return false;
 }
@@ -5452,11 +5468,11 @@
 Sema::AddOverloadCandidate(FunctionDecl *Function,
                            DeclAccessPair FoundDecl,
                            ArrayRef<Expr *> Args,
-                           OverloadCandidateSet& CandidateSet,
+                           OverloadCandidateSet &CandidateSet,
                            bool SuppressUserConversions,
                            bool PartialOverloading,
                            bool AllowExplicit) {
-  const FunctionProtoType* Proto
+  const FunctionProtoType *Proto
     = dyn_cast<FunctionProtoType>(Function->getType()->getAs<FunctionType>());
   assert(Proto && "Functions without a prototype cannot be overloaded");
   assert(!Function->getDescribedFunctionTemplate() &&
@@ -5568,7 +5584,7 @@
       if (Candidate.Conversions[ArgIdx].isBad()) {
         Candidate.Viable = false;
         Candidate.FailureKind = ovl_fail_bad_conversion;
-        break;
+        return;
       }
     } else {
       // (C++ 13.3.2p2): For the purposes of overload resolution, any
@@ -5577,6 +5593,77 @@
       Candidate.Conversions[ArgIdx].setEllipsis();
     }
   }
+
+  if (EnableIfAttr *FailedAttr = CheckEnableIf(Function, Args)) {
+    Candidate.Viable = false;
+    Candidate.FailureKind = ovl_fail_enable_if;
+    Candidate.DeductionFailure.Data = FailedAttr;
+    return;
+  }
+}
+
+static bool IsNotEnableIfAttr(Attr *A) { return !isa<EnableIfAttr>(A); }
+
+EnableIfAttr *Sema::CheckEnableIf(FunctionDecl *Function, ArrayRef<Expr *> Args,
+                                  bool MissingImplicitThis) {
+  // FIXME: specific_attr_iterator<EnableIfAttr> iterates in reverse order, but
+  // we need to find the first failing one.
+  if (!Function->hasAttrs())
+    return 0;
+  AttrVec Attrs = Function->getAttrs();
+  AttrVec::iterator E = std::remove_if(Attrs.begin(), Attrs.end(),
+                                       IsNotEnableIfAttr);
+  if (Attrs.begin() == E)
+    return 0;
+  std::reverse(Attrs.begin(), E);
+
+  SFINAETrap Trap(*this);
+
+  // Convert the arguments.
+  SmallVector<Expr *, 16> ConvertedArgs;
+  bool InitializationFailed = false;
+  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
+    if (i == 0 && !MissingImplicitThis && isa<CXXMethodDecl>(Function) &&
+        !cast<CXXMethodDecl>(Function)->isStatic()) {
+      CXXMethodDecl *Method = cast<CXXMethodDecl>(Function);
+      ExprResult R =
+        PerformObjectArgumentInitialization(Args[0], /*Qualifier=*/0,
+                                            Method, Method);
+      if (R.isInvalid()) {
+        InitializationFailed = true;
+        break;
+      }
+      ConvertedArgs.push_back(R.take());
+    } else {
+      ExprResult R =
+        PerformCopyInitialization(InitializedEntity::InitializeParameter(
+                                                Context,
+                                                Function->getParamDecl(i)),
+                                  SourceLocation(),
+                                  Args[i]);
+      if (R.isInvalid()) {
+        InitializationFailed = true;
+        break;
+      }
+      ConvertedArgs.push_back(R.take());
+    }
+  }
+
+  if (InitializationFailed || Trap.hasErrorOccurred())
+    return cast<EnableIfAttr>(Attrs[0]);
+
+  for (AttrVec::iterator I = Attrs.begin(); I != E; ++I) {
+    APValue Result;
+    EnableIfAttr *EIA = cast<EnableIfAttr>(*I);
+    if (!EIA->getCond()->EvaluateWithSubstitution(
+            Result, Context, Function,
+            llvm::ArrayRef<const Expr*>(ConvertedArgs.data(),
+                                        ConvertedArgs.size())) ||
+        !Result.isInt() || !Result.getInt().getBoolValue()) {
+      return EIA;
+    }
+  }
+  return 0;
 }
 
 /// \brief Add all of the function declarations in the given function set to
@@ -5658,9 +5745,9 @@
                          CXXRecordDecl *ActingContext, QualType ObjectType,
                          Expr::Classification ObjectClassification,
                          ArrayRef<Expr *> Args,
-                         OverloadCandidateSet& CandidateSet,
+                         OverloadCandidateSet &CandidateSet,
                          bool SuppressUserConversions) {
-  const FunctionProtoType* Proto
+  const FunctionProtoType *Proto
     = dyn_cast<FunctionProtoType>(Method->getType()->getAs<FunctionType>());
   assert(Proto && "Methods without a prototype cannot be overloaded");
   assert(!isa<CXXConstructorDecl>(Method) &&
@@ -5747,15 +5834,22 @@
       if (Candidate.Conversions[ArgIdx + 1].isBad()) {
         Candidate.Viable = false;
         Candidate.FailureKind = ovl_fail_bad_conversion;
-        break;
+        return;
       }
     } else {
       // (C++ 13.3.2p2): For the purposes of overload resolution, any
       // argument for which there is no corresponding parameter is
-      // considered to ""match the ellipsis" (C+ 13.3.3.1.3).
+      // considered to "match the ellipsis" (C+ 13.3.3.1.3).
       Candidate.Conversions[ArgIdx + 1].setEllipsis();
     }
   }
+
+  if (EnableIfAttr *FailedAttr = CheckEnableIf(Method, Args, true)) {
+    Candidate.Viable = false;
+    Candidate.FailureKind = ovl_fail_enable_if;
+    Candidate.DeductionFailure.Data = FailedAttr;
+    return;
+  }
 }
 
 /// \brief Add a C++ member function template as a candidate to the candidate
@@ -5971,7 +6065,7 @@
     return;
   }
 
-  // We won't go through a user-define type conversion function to convert a
+  // We won't go through a user-defined type conversion function to convert a
   // derived to base as such conversions are given Conversion Rank. They only
   // go through a copy constructor. 13.3.3.1.2-p4 [over.ics.user]
   QualType FromCanon
@@ -6031,6 +6125,7 @@
         GetConversionRank(ICS.Standard.Second) != ICR_Exact_Match) {
       Candidate.Viable = false;
       Candidate.FailureKind = ovl_fail_final_conversion_not_exact;
+      return;
     }
 
     // C++0x [dcl.init.ref]p5:
@@ -6042,18 +6137,26 @@
         ICS.Standard.First == ICK_Lvalue_To_Rvalue) {
       Candidate.Viable = false;
       Candidate.FailureKind = ovl_fail_bad_final_conversion;
+      return;
     }
     break;
 
   case ImplicitConversionSequence::BadConversion:
     Candidate.Viable = false;
     Candidate.FailureKind = ovl_fail_bad_final_conversion;
-    break;
+    return;
 
   default:
     llvm_unreachable(
            "Can only end up with a standard conversion sequence or failure");
   }
+
+  if (EnableIfAttr *FailedAttr = CheckEnableIf(Conversion, ArrayRef<Expr*>())) {
+    Candidate.Viable = false;
+    Candidate.FailureKind = ovl_fail_enable_if;
+    Candidate.DeductionFailure.Data = FailedAttr;
+    return;
+  }
 }
 
 /// \brief Adds a conversion function template specialization
@@ -6191,7 +6294,7 @@
       if (Candidate.Conversions[ArgIdx + 1].isBad()) {
         Candidate.Viable = false;
         Candidate.FailureKind = ovl_fail_bad_conversion;
-        break;
+        return;
       }
     } else {
       // (C++ 13.3.2p2): For the purposes of overload resolution, any
@@ -6200,6 +6303,13 @@
       Candidate.Conversions[ArgIdx + 1].setEllipsis();
     }
   }
+
+  if (EnableIfAttr *FailedAttr = CheckEnableIf(Conversion, ArrayRef<Expr*>())) {
+    Candidate.Viable = false;
+    Candidate.FailureKind = ovl_fail_enable_if;
+    Candidate.DeductionFailure.Data = FailedAttr;
+    return;
+  }
 }
 
 /// \brief Add overload candidates for overloaded operators that are
@@ -8111,6 +8221,47 @@
     }
   }
 
+  // Check for enable_if value-based overload resolution.
+  if (Cand1.Function && Cand2.Function &&
+      (Cand1.Function->hasAttr<EnableIfAttr>() ||
+       Cand2.Function->hasAttr<EnableIfAttr>())) {
+    // FIXME: The next several lines are just
+    // specific_attr_iterator<EnableIfAttr> but going in declaration order,
+    // instead of reverse order which is how they're stored in the AST.
+    AttrVec Cand1Attrs;
+    AttrVec::iterator Cand1E = Cand1Attrs.end();
+    if (Cand1.Function->hasAttrs()) {
+      Cand1Attrs = Cand1.Function->getAttrs();
+      Cand1E = std::remove_if(Cand1Attrs.begin(), Cand1Attrs.end(),
+                              IsNotEnableIfAttr);
+      std::reverse(Cand1Attrs.begin(), Cand1E);
+    }
+
+    AttrVec Cand2Attrs;
+    AttrVec::iterator Cand2E = Cand2Attrs.end();
+    if (Cand2.Function->hasAttrs()) {
+      Cand2Attrs = Cand2.Function->getAttrs();
+      Cand2E = std::remove_if(Cand2Attrs.begin(), Cand2Attrs.end(),
+                              IsNotEnableIfAttr);
+      std::reverse(Cand2Attrs.begin(), Cand2E);
+    }
+    for (AttrVec::iterator
+         Cand1I = Cand1Attrs.begin(), Cand2I = Cand2Attrs.begin();
+         Cand1I != Cand1E || Cand2I != Cand2E; ++Cand1I, ++Cand2I) {
+      if (Cand1I == Cand1E)
+        return false;
+      if (Cand2I == Cand2E)
+        return true;
+      llvm::FoldingSetNodeID Cand1ID, Cand2ID;
+      cast<EnableIfAttr>(*Cand1I)->getCond()->Profile(Cand1ID,
+                                                      S.getASTContext(), true);
+      cast<EnableIfAttr>(*Cand2I)->getCond()->Profile(Cand2ID,
+                                                      S.getASTContext(), true);
+      if (!(Cand1ID == Cand2ID))
+        return false;
+    }
+  }
+
   return false;
 }
 
@@ -8819,6 +8970,15 @@
       << (unsigned) FnKind << CalleeTarget << CallerTarget;
 }
 
+void DiagnoseFailedEnableIfAttr(Sema &S, OverloadCandidate *Cand) {
+  FunctionDecl *Callee = Cand->Function;
+  EnableIfAttr *Attr = static_cast<EnableIfAttr*>(Cand->DeductionFailure.Data);
+
+  S.Diag(Callee->getLocation(),
+         diag::note_ovl_candidate_disabled_by_enable_if_attr)
+      << Attr->getCond()->getSourceRange() << Attr->getMessage();
+}
+
 /// Generates a 'note' diagnostic for an overload candidate.  We've
 /// already generated a primary error at the call site.
 ///
@@ -8882,6 +9042,9 @@
 
   case ovl_fail_bad_target:
     return DiagnoseBadTarget(S, Cand);
+
+  case ovl_fail_enable_if:
+    return DiagnoseFailedEnableIfAttr(S, Cand);
   }
 }
 
@@ -11107,7 +11270,7 @@
         << qualsString
         << (qualsString.find(' ') == std::string::npos ? 1 : 2);
     }
-              
+
     CXXMemberCallExpr *call
       = new (Context) CXXMemberCallExpr(Context, MemExprE, Args,
                                         resultType, valueKind, RParenLoc);
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 22f13d7..6995ae7 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -129,6 +129,39 @@
   }
 }
 
+static void instantiateDependentEnableIfAttr(
+    Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
+    const EnableIfAttr *A, const Decl *Tmpl, Decl *New) {
+  Expr *Cond = 0;
+  {
+    EnterExpressionEvaluationContext Unevaluated(S, Sema::Unevaluated);
+    ExprResult Result = S.SubstExpr(A->getCond(), TemplateArgs);
+    if (Result.isInvalid())
+      return;
+    Cond = Result.takeAs<Expr>();
+  }
+  if (A->getCond()->isTypeDependent() && !Cond->isTypeDependent()) {
+    ExprResult Converted = S.PerformContextuallyConvertToBool(Cond);
+    if (Converted.isInvalid())
+      return;
+    Cond = Converted.take();
+  }
+
+  SmallVector<PartialDiagnosticAt, 8> Diags;
+  if (A->getCond()->isValueDependent() && !Cond->isValueDependent() &&
+      !Expr::isPotentialConstantExprUnevaluated(Cond, cast<FunctionDecl>(Tmpl),
+                                                Diags)) {
+    S.Diag(A->getLocation(), diag::err_enable_if_never_constant_expr);
+    for (int I = 0, N = Diags.size(); I != N; ++I)
+      S.Diag(Diags[I].first, Diags[I].second);
+    return;
+  }
+
+  EnableIfAttr *EIA = new (S.getASTContext()) EnableIfAttr(
+      A->getLocation(), S.getASTContext(), Cond, A->getMessage());
+  New->addAttr(EIA);
+}
+
 void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
                             const Decl *Tmpl, Decl *New,
                             LateInstantiatedAttrVec *LateAttrs,
@@ -144,6 +177,13 @@
       continue;
     }
 
+    const EnableIfAttr *EnableIf = dyn_cast<EnableIfAttr>(TmplAttr);
+    if (EnableIf && EnableIf->getCond()->isValueDependent()) {
+      instantiateDependentEnableIfAttr(*this, TemplateArgs, EnableIf, Tmpl,
+                                       New);
+      continue;
+    }
+
     assert(!TmplAttr->isPackExpansion());
     if (TmplAttr->isLateParsed() && LateAttrs) {
       // Late parsed attributes must be instantiated and attached after the