Teach code completion to adjust its completion priorities based on the
type that we expect to see at a given point in the grammar, e.g., when
initializing a variable, returning a result, or calling a function. We
don't prune the candidate set at all, just adjust priorities to favor
things that should type-check, using an ultra-simplified type system.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@105128 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp
index 76da135..5d52d8b 100644
--- a/lib/Parse/ParseDecl.cpp
+++ b/lib/Parse/ParseDecl.cpp
@@ -556,6 +556,13 @@
         Actions.ActOnCXXEnterDeclInitializer(CurScope, ThisDecl);
       }
 
+      if (Tok.is(tok::code_completion)) {
+        Actions.CodeCompleteInitializer(CurScope, ThisDecl);
+        ConsumeCodeCompletionToken();
+        SkipUntil(tok::comma, true, true);
+        return ThisDecl;
+      }
+      
       OwningExprResult Init(ParseInitializer());
 
       if (getLang().CPlusPlus && D.getCXXScopeSpec().isSet()) {
diff --git a/lib/Parse/ParseExpr.cpp b/lib/Parse/ParseExpr.cpp
index 0059a9a..7be1a19 100644
--- a/lib/Parse/ParseExpr.cpp
+++ b/lib/Parse/ParseExpr.cpp
@@ -343,6 +343,14 @@
       }
     }
     
+    // Code completion for the right-hand side of an assignment expression
+    // goes through a special hook that takes the left-hand side into account.
+    if (Tok.is(tok::code_completion) && NextTokPrec == prec::Assignment) {
+      Actions.CodeCompleteAssignmentRHS(CurScope, LHS.get());
+      ConsumeCodeCompletionToken();
+      return ExprError();
+    }
+    
     // Parse another leaf here for the RHS of the operator.
     // ParseCastExpression works here because all RHS expressions in C have it
     // as a prefix, at least. However, in C++, an assignment-expression could
diff --git a/lib/Parse/ParseStmt.cpp b/lib/Parse/ParseStmt.cpp
index 1ca6a0c..f0930a0 100644
--- a/lib/Parse/ParseStmt.cpp
+++ b/lib/Parse/ParseStmt.cpp
@@ -1199,6 +1199,13 @@
 
   OwningExprResult R(Actions);
   if (Tok.isNot(tok::semi)) {
+    if (Tok.is(tok::code_completion)) {
+      Actions.CodeCompleteReturn(CurScope);
+      ConsumeCodeCompletionToken();
+      SkipUntil(tok::semi, false, true);
+      return StmtError();
+    }
+        
     R = ParseExpression();
     if (R.isInvalid()) {  // Skip to the semicolon, but don't consume it.
       SkipUntil(tok::semi, false, true);
diff --git a/lib/Sema/Sema.h b/lib/Sema/Sema.h
index 0870cf4..7338915 100644
--- a/lib/Sema/Sema.h
+++ b/lib/Sema/Sema.h
@@ -4417,6 +4417,7 @@
   //@{
   virtual void CodeCompleteOrdinaryName(Scope *S,
                                      CodeCompletionContext CompletionContext);
+  virtual void CodeCompleteExpression(Scope *S, QualType T);
   virtual void CodeCompleteMemberReferenceExpr(Scope *S, ExprTy *Base,
                                                SourceLocation OpLoc,
                                                bool IsArrow);
@@ -4424,6 +4425,10 @@
   virtual void CodeCompleteCase(Scope *S);
   virtual void CodeCompleteCall(Scope *S, ExprTy *Fn,
                                 ExprTy **Args, unsigned NumArgs);
+  virtual void CodeCompleteInitializer(Scope *S, DeclPtrTy D);
+  virtual void CodeCompleteReturn(Scope *S);
+  virtual void CodeCompleteAssignmentRHS(Scope *S, ExprTy *LHS);
+  
   virtual void CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
                                        bool EnteringContext);
   virtual void CodeCompleteUsing(Scope *S);
diff --git a/lib/Sema/SemaCodeComplete.cpp b/lib/Sema/SemaCodeComplete.cpp
index 876aecb..339c46f 100644
--- a/lib/Sema/SemaCodeComplete.cpp
+++ b/lib/Sema/SemaCodeComplete.cpp
@@ -119,6 +119,13 @@
     /// nested-name-specifiers that would otherwise be filtered out.
     bool AllowNestedNameSpecifiers;
 
+    /// \brief If set, the type that we would prefer our resulting value
+    /// declarations to have.
+    ///
+    /// Closely matching the preferred type gives a boost to a result's 
+    /// priority.
+    CanQualType PreferredType;
+    
     /// \brief A list of shadow maps, which is used to model name hiding at
     /// different levels of, e.g., the inheritance hierarchy.
     std::list<ShadowMap> ShadowMaps;
@@ -147,6 +154,11 @@
     unsigned size() const { return Results.size(); }
     bool empty() const { return Results.empty(); }
     
+    /// \brief Specify the preferred type.
+    void setPreferredType(QualType T) { 
+      PreferredType = SemaRef.Context.getCanonicalType(T); 
+    }
+    
     /// \brief Specify whether nested-name-specifiers are allowed.
     void allowNestedNameSpecifiers(bool Allow = true) {
       AllowNestedNameSpecifiers = Allow;
@@ -565,6 +577,118 @@
   Results.push_back(R);
 }
 
+enum SimplifiedTypeClass {
+  STC_Arithmetic,
+  STC_Array,
+  STC_Block,
+  STC_Function,
+  STC_ObjectiveC,
+  STC_Other,
+  STC_Pointer,
+  STC_Record,
+  STC_Void
+};
+
+/// \brief A simplified classification of types used to determine whether two
+/// types are "similar enough" when adjusting priorities.
+static SimplifiedTypeClass getSimplifiedTypeClass(CanQualType T) {
+  switch (T->getTypeClass()) {
+  case Type::Builtin:
+    switch (cast<BuiltinType>(T)->getKind()) {
+    case BuiltinType::Void:
+      return STC_Void;
+        
+    case BuiltinType::NullPtr:
+      return STC_Pointer;
+        
+    case BuiltinType::Overload:
+    case BuiltinType::Dependent:
+    case BuiltinType::UndeducedAuto:
+      return STC_Other;
+        
+    case BuiltinType::ObjCId:
+    case BuiltinType::ObjCClass:
+    case BuiltinType::ObjCSel:
+      return STC_ObjectiveC;
+        
+    default:
+      return STC_Arithmetic;
+    }
+    return STC_Other;
+      
+  case Type::Complex:
+    return STC_Arithmetic;
+    
+  case Type::Pointer:
+    return STC_Pointer;
+    
+  case Type::BlockPointer:
+    return STC_Block;
+    
+  case Type::LValueReference:
+  case Type::RValueReference:
+    return getSimplifiedTypeClass(T->getAs<ReferenceType>()->getPointeeType());
+      
+  case Type::ConstantArray:
+  case Type::IncompleteArray:
+  case Type::VariableArray:
+  case Type::DependentSizedArray:
+    return STC_Array;
+      
+  case Type::DependentSizedExtVector:
+  case Type::Vector:
+  case Type::ExtVector:
+    return STC_Arithmetic;
+      
+  case Type::FunctionProto:
+  case Type::FunctionNoProto:
+    return STC_Function;
+      
+  case Type::Record:
+    return STC_Record;
+    
+  case Type::Enum:
+    return STC_Arithmetic;
+      
+  case Type::ObjCObject:
+  case Type::ObjCInterface:
+  case Type::ObjCObjectPointer:
+    return STC_ObjectiveC;
+      
+  default:
+    return STC_Other;
+  }
+}
+ 
+/// \brief Get the type that a given expression will have if this declaration
+/// is used as an expression in its "typical" code-completion form.
+static QualType getDeclUsageType(ASTContext &C, NamedDecl *ND) {
+  ND = cast<NamedDecl>(ND->getUnderlyingDecl());
+  
+  if (TypeDecl *Type = dyn_cast<TypeDecl>(ND))
+    return C.getTypeDeclType(Type);
+  if (ObjCInterfaceDecl *Iface = dyn_cast<ObjCInterfaceDecl>(ND))
+    return C.getObjCInterfaceType(Iface);
+  
+  QualType T;
+  if (FunctionDecl *Function = dyn_cast<FunctionDecl>(ND))
+    T = Function->getResultType();
+  else if (ObjCMethodDecl *Method = dyn_cast<ObjCMethodDecl>(ND))
+    T = Method->getResultType();
+  else if (FunctionTemplateDecl *FunTmpl = dyn_cast<FunctionTemplateDecl>(ND))
+    T = FunTmpl->getTemplatedDecl()->getResultType();
+  else if (EnumConstantDecl *Enumerator = dyn_cast<EnumConstantDecl>(ND))
+    T = C.getTypeDeclType(cast<EnumDecl>(Enumerator->getDeclContext()));
+  else if (ObjCPropertyDecl *Property = dyn_cast<ObjCPropertyDecl>(ND))
+    T = Property->getType();
+  else if (ValueDecl *Value = dyn_cast<ValueDecl>(ND))
+    T = Value->getType();
+  else
+    return QualType();
+  
+  return T.getNonReferenceType();
+}
+
 void ResultBuilder::AddResult(Result R, DeclContext *CurContext, 
                               NamedDecl *Hiding, bool InBaseClass = false) {
   if (R.Kind != Result::RK_Declaration) {
@@ -618,6 +742,21 @@
   if (InBaseClass)
     R.Priority += CCD_InBaseClass;
   
+  if (!PreferredType.isNull()) {
+    if (ValueDecl *Value = dyn_cast<ValueDecl>(R.Declaration)) {
+      CanQualType T = SemaRef.Context.getCanonicalType(
+                                     getDeclUsageType(SemaRef.Context, Value));
+      // Check for exactly-matching types (modulo qualifiers).
+      if (SemaRef.Context.hasSameUnqualifiedType(PreferredType, T))
+        R.Priority /= CCF_ExactTypeMatch;
+      // Check for nearly-matching types, based on classification of each.
+      else if ((getSimplifiedTypeClass(PreferredType)
+                                                == getSimplifiedTypeClass(T)) &&
+               !(PreferredType->isEnumeralType() && T->isEnumeralType()))
+        R.Priority /= CCF_SimilarTypeMatch;
+    }
+  }
+  
   // Insert this result into the set of results.
   Results.push_back(R);
 }
@@ -755,35 +894,6 @@
     isa<ObjCPropertyDecl>(ND);
 }
 
-/// \brief Get the type that a given expression will have if this declaration
-/// is used as an expression in its "typical" code-completion form.
-static QualType getDeclUsageType(ASTContext &C, NamedDecl *ND) {
-  ND = cast<NamedDecl>(ND->getUnderlyingDecl());
-  
-  if (TypeDecl *Type = dyn_cast<TypeDecl>(ND))
-    return C.getTypeDeclType(Type);
-  if (ObjCInterfaceDecl *Iface = dyn_cast<ObjCInterfaceDecl>(ND))
-    return C.getObjCInterfaceType(Iface);
-    
-  QualType T;
-  if (FunctionDecl *Function = dyn_cast<FunctionDecl>(ND))
-    T = Function->getResultType();
-  else if (ObjCMethodDecl *Method = dyn_cast<ObjCMethodDecl>(ND))
-    T = Method->getResultType();
-  else if (FunctionTemplateDecl *FunTmpl = dyn_cast<FunctionTemplateDecl>(ND))
-    T = FunTmpl->getTemplatedDecl()->getResultType();
-  else if (EnumConstantDecl *Enumerator = dyn_cast<EnumConstantDecl>(ND))
-    T = C.getTypeDeclType(cast<EnumDecl>(Enumerator->getDeclContext()));
-  else if (ObjCPropertyDecl *Property = dyn_cast<ObjCPropertyDecl>(ND))
-    T = Property->getType();
-  else if (ValueDecl *Value = dyn_cast<ValueDecl>(ND))
-    T = Value->getType();
-  else
-    return QualType();
-  
-  return T.getNonReferenceType();
-}
-
 static bool isObjCReceiverType(ASTContext &C, QualType T) {
   T = C.getCanonicalType(T);
   switch (T->getTypeClass()) {
@@ -2131,6 +2241,31 @@
   HandleCodeCompleteResults(this, CodeCompleter, Results.data(),Results.size());
 }
 
+/// \brief Perform code-completion in an expression context when we know what
+/// type we're looking for.
+void Sema::CodeCompleteExpression(Scope *S, QualType T) {
+  typedef CodeCompleteConsumer::Result Result;
+  ResultBuilder Results(*this);
+  
+  if (WantTypesInContext(CCC_Expression, getLangOptions()))
+    Results.setFilter(&ResultBuilder::IsOrdinaryName);
+  else
+    Results.setFilter(&ResultBuilder::IsOrdinaryNonTypeName);
+  Results.setPreferredType(T.getNonReferenceType());
+  
+  CodeCompletionDeclConsumer Consumer(Results, CurContext);
+  LookupVisibleDecls(S, LookupOrdinaryName, Consumer);
+  
+  Results.EnterNewScope();
+  AddOrdinaryNameResults(CCC_Expression, S, *this, Results);
+  Results.ExitScope();
+  
+  if (CodeCompleter->includeMacros())
+    AddMacroResults(PP, Results);
+  HandleCodeCompleteResults(this, CodeCompleter, Results.data(),Results.size());
+}
+
+
 static void AddObjCProperties(ObjCContainerDecl *Container, 
                               bool AllowCategories,
                               DeclContext *CurContext,
@@ -2447,6 +2582,8 @@
     }
   }
   
+  QualType ParamType;
+  
   if (!CandidateSet.empty()) {
     // Sort the overload candidate set by placing the best overloads first.
     std::stable_sort(CandidateSet.begin(), CandidateSet.end(),
@@ -2459,14 +2596,85 @@
       if (Cand->Viable)
         Results.push_back(ResultCandidate(Cand->Function));
     }
+
+    // From the viable candidates, try to determine the type of this parameter.
+    for (unsigned I = 0, N = Results.size(); I != N; ++I) {
+      if (const FunctionType *FType = Results[I].getFunctionType())
+        if (const FunctionProtoType *Proto = dyn_cast<FunctionProtoType>(FType))
+          if (NumArgs < Proto->getNumArgs()) {
+            if (ParamType.isNull())
+              ParamType = Proto->getArgType(NumArgs);
+            else if (!Context.hasSameUnqualifiedType(
+                                            ParamType.getNonReferenceType(),
+                           Proto->getArgType(NumArgs).getNonReferenceType())) {
+              ParamType = QualType();
+              break;
+            }
+          }
+    }
+  } else {
+    // Try to determine the parameter type from the type of the expression
+    // being called.
+    QualType FunctionType = Fn->getType();
+    if (const PointerType *Ptr = FunctionType->getAs<PointerType>())
+      FunctionType = Ptr->getPointeeType();
+    else if (const BlockPointerType *BlockPtr
+                                    = FunctionType->getAs<BlockPointerType>())
+      FunctionType = BlockPtr->getPointeeType();
+    else if (const MemberPointerType *MemPtr
+                                    = FunctionType->getAs<MemberPointerType>())
+      FunctionType = MemPtr->getPointeeType();
+    
+    if (const FunctionProtoType *Proto
+                                  = FunctionType->getAs<FunctionProtoType>()) {
+      if (NumArgs < Proto->getNumArgs())
+        ParamType = Proto->getArgType(NumArgs);
+    }
   }
 
-  CodeCompleteOrdinaryName(S, CCC_Expression);
+  if (ParamType.isNull())
+    CodeCompleteOrdinaryName(S, CCC_Expression);
+  else
+    CodeCompleteExpression(S, ParamType);
+  
   if (!Results.empty())
     CodeCompleter->ProcessOverloadCandidates(*this, NumArgs, Results.data(), 
                                              Results.size());
 }
 
+void Sema::CodeCompleteInitializer(Scope *S, DeclPtrTy D) {
+  ValueDecl *VD = dyn_cast_or_null<ValueDecl>(D.getAs<Decl>());
+  if (!VD) {
+    CodeCompleteOrdinaryName(S, CCC_Expression);
+    return;
+  }
+  
+  CodeCompleteExpression(S, VD->getType());
+}
+
+void Sema::CodeCompleteReturn(Scope *S) {
+  QualType ResultType;
+  if (isa<BlockDecl>(CurContext)) {
+    if (BlockScopeInfo *BSI = getCurBlock())
+      ResultType = BSI->ReturnType;
+  } else if (FunctionDecl *Function = dyn_cast<FunctionDecl>(CurContext))
+    ResultType = Function->getResultType();
+  else if (ObjCMethodDecl *Method = dyn_cast<ObjCMethodDecl>(CurContext))
+    ResultType = Method->getResultType();
+  
+  if (ResultType.isNull())
+    CodeCompleteOrdinaryName(S, CCC_Expression);
+  else
+    CodeCompleteExpression(S, ResultType);
+}
+
+void Sema::CodeCompleteAssignmentRHS(Scope *S, ExprTy *LHS) {
+  if (LHS)
+    CodeCompleteExpression(S, static_cast<Expr *>(LHS)->getType());
+  else
+    CodeCompleteOrdinaryName(S, CCC_Expression);
+}
+
 void Sema::CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
                                    bool EnteringContext) {
   if (!SS.getScopeRep() || !CodeCompleter)