Add a new expression node, CXXOperatorCallExpr, which expresses a
function call created in response to the use of operator syntax that
resolves to an overloaded operator in C++, e.g., "str1 +
str2" that resolves to std::operator+(str1, str2)". We now build a
CXXOperatorCallExpr in C++ when we pick an overloaded operator. (But
only for binary operators, where we actually implement overloading)

I decided *not* to refactor the current CallExpr to make it abstract
(with FunctionCallExpr and CXXOperatorCallExpr as derived
classes). Doing so would allow us to make CXXOperatorCallExpr a little
bit smaller, at the cost of making the argument and callee accessors
virtual. We won't know if this is going to be a win until we can parse
lots of C++ code to determine how much memory we'll save by making
this change vs. the performance penalty due to the extra virtual
calls.



git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@59306 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp
index 249bebc..871dc4b 100644
--- a/lib/AST/Expr.cpp
+++ b/lib/AST/Expr.cpp
@@ -101,6 +101,15 @@
 // Postfix Operators.
 //===----------------------------------------------------------------------===//
 
+CallExpr::CallExpr(StmtClass SC, Expr *fn, Expr **args, unsigned numargs, 
+                   QualType t, SourceLocation rparenloc)
+  : Expr(SC, t), NumArgs(numargs) {
+  SubExprs = new Stmt*[numargs+1];
+  SubExprs[FN] = fn;
+  for (unsigned i = 0; i != numargs; ++i)
+    SubExprs[i+ARGS_START] = args[i];
+  RParenLoc = rparenloc;
+}
 
 CallExpr::CallExpr(Expr *fn, Expr **args, unsigned numargs, QualType t,
                    SourceLocation rparenloc)
@@ -285,6 +294,7 @@
     return getType().isVolatileQualified();
 
   case CallExprClass:
+  case CXXOperatorCallExprClass:
     // TODO: check attributes for pure/const.   "void foo() { strlen("bar"); }"
     // should warn.
     return true;
@@ -410,7 +420,8 @@
     //   An assignment expression [...] is not an lvalue.
     return LV_InvalidExpression;
   }
-  case CallExprClass: {
+  case CallExprClass: 
+  case CXXOperatorCallExprClass: {
     // C++ [expr.call]p10:
     //   A function call is an lvalue if and only if the result type
     //   is a reference.
@@ -586,7 +597,8 @@
   case CXXBoolLiteralExprClass:
   case AddrLabelExprClass:
     return true;
-  case CallExprClass: {
+  case CallExprClass: 
+  case CXXOperatorCallExprClass: {
     const CallExpr *CE = cast<CallExpr>(this);
 
     // Allow any constant foldable calls to builtins.
@@ -777,7 +789,8 @@
                                     T1.getUnqualifiedType());
     break;
   }
-  case CallExprClass: {
+  case CallExprClass: 
+  case CXXOperatorCallExprClass: {
     const CallExpr *CE = cast<CallExpr>(this);
     Result.zextOrTrunc(static_cast<uint32_t>(Ctx.getTypeSize(getType())));
     
diff --git a/lib/AST/ExprCXX.cpp b/lib/AST/ExprCXX.cpp
index ff97e68..22f30f4 100644
--- a/lib/AST/ExprCXX.cpp
+++ b/lib/AST/ExprCXX.cpp
@@ -11,6 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "clang/Basic/IdentifierTable.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/AST/ExprCXX.h"
 using namespace clang;
 
@@ -75,6 +77,52 @@
   return child_iterator();
 }
 
+OverloadedOperatorKind CXXOperatorCallExpr::getOperator() const {
+  // All simple function calls (e.g. func()) are implicitly cast to pointer to
+  // function. As a result, we try and obtain the DeclRefExpr from the 
+  // ImplicitCastExpr.
+  const ImplicitCastExpr *ICE = dyn_cast<ImplicitCastExpr>(getCallee());
+  if (!ICE) // FIXME: deal with more complex calls (e.g. (func)(), (*func)()).
+    return OO_None;
+  
+  const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(ICE->getSubExpr());
+  if (!DRE)
+    return OO_None;
+  
+  if (const FunctionDecl *FDecl = dyn_cast<FunctionDecl>(DRE->getDecl()))
+    return FDecl->getIdentifier()->getOverloadedOperatorID();  
+  else if (const OverloadedFunctionDecl *Ovl 
+             = dyn_cast<OverloadedFunctionDecl>(DRE->getDecl()))
+    return Ovl->getIdentifier()->getOverloadedOperatorID();
+  else
+    return OO_None;
+}
+
+SourceRange CXXOperatorCallExpr::getSourceRange() const {
+  OverloadedOperatorKind Kind = getOperator();
+  if (Kind == OO_PlusPlus || Kind == OO_MinusMinus) {
+    if (getNumArgs() == 1)
+      // Prefix operator
+      return SourceRange(getOperatorLoc(), 
+                         getArg(0)->getSourceRange().getEnd());
+    else
+      // Postfix operator
+      return SourceRange(getArg(0)->getSourceRange().getEnd(),
+                         getOperatorLoc());
+  } else if (Kind == OO_Call) {
+    return SourceRange(getArg(0)->getSourceRange().getBegin(), getRParenLoc());
+  } else if (Kind == OO_Subscript) {
+    return SourceRange(getArg(0)->getSourceRange().getBegin(), getRParenLoc());
+  } else if (getNumArgs() == 1) {
+    return SourceRange(getOperatorLoc(), getArg(0)->getSourceRange().getEnd());
+  } else if (getNumArgs() == 2) {
+    return SourceRange(getArg(0)->getSourceRange().getBegin(),
+                       getArg(1)->getSourceRange().getEnd());
+  } else {
+    return SourceRange();
+  }
+}
+
 //===----------------------------------------------------------------------===//
 //  Named casts
 //===----------------------------------------------------------------------===//
diff --git a/lib/AST/StmtPrinter.cpp b/lib/AST/StmtPrinter.cpp
index bbd53d0..cf6d4c0 100644
--- a/lib/AST/StmtPrinter.cpp
+++ b/lib/AST/StmtPrinter.cpp
@@ -811,6 +811,49 @@
 }
 
 // C++
+void StmtPrinter::VisitCXXOperatorCallExpr(CXXOperatorCallExpr *Node) {
+  const char *OpStrings[NUM_OVERLOADED_OPERATORS] = {
+    "",
+#define OVERLOADED_OPERATOR(Name,Spelling,Token,Unary,Binary,MemberOnly) \
+    Spelling,
+#include "clang/Basic/OperatorKinds.def"
+  };
+
+  OverloadedOperatorKind Kind = Node->getOperator();
+  if (Kind == OO_PlusPlus || Kind == OO_MinusMinus) {
+    if (Node->getNumArgs() == 1) {
+      OS << OpStrings[Kind] << ' ';
+      PrintExpr(Node->getArg(0));
+    } else {
+      PrintExpr(Node->getArg(0));
+      OS << ' ' << OpStrings[Kind];
+    }
+  } else if (Kind == OO_Call) {
+    PrintExpr(Node->getArg(0));
+    OS << '(';
+    for (unsigned ArgIdx = 1; ArgIdx < Node->getNumArgs(); ++ArgIdx) {
+      if (ArgIdx > 1)
+        OS << ", ";
+      if (!isa<CXXDefaultArgExpr>(Node->getArg(ArgIdx)))
+        PrintExpr(Node->getArg(ArgIdx));
+    }
+    OS << ')';
+  } else if (Kind == OO_Subscript) {
+    PrintExpr(Node->getArg(0));
+    OS << '[';
+    PrintExpr(Node->getArg(1));
+    OS << ']';
+  } else if (Node->getNumArgs() == 1) {
+    OS << OpStrings[Kind] << ' ';
+    PrintExpr(Node->getArg(0));
+  } else if (Node->getNumArgs() == 2) {
+    PrintExpr(Node->getArg(0));
+    OS << ' ' << OpStrings[Kind] << ' ';
+    PrintExpr(Node->getArg(1));
+  } else {
+    assert(false && "unknown overloaded operator");
+  }
+}
 
 void StmtPrinter::VisitCXXNamedCastExpr(CXXNamedCastExpr *Node) {
   OS << Node->getCastName() << '<';
diff --git a/lib/AST/StmtSerialization.cpp b/lib/AST/StmtSerialization.cpp
index 667f597..268922e 100644
--- a/lib/AST/StmtSerialization.cpp
+++ b/lib/AST/StmtSerialization.cpp
@@ -53,7 +53,7 @@
       return BreakStmt::CreateImpl(D, C);
      
     case CallExprClass:
-      return CallExpr::CreateImpl(D, C);
+      return CallExpr::CreateImpl(D, C, CallExprClass);
       
     case CaseStmtClass:
       return CaseStmt::CreateImpl(D, C);
@@ -198,6 +198,9 @@
     //    C++
     //==--------------------------------------==//
       
+    case CXXOperatorCallExprClass:
+      return CXXOperatorCallExpr::CreateImpl(D, C, CXXOperatorCallExprClass);
+
     case CXXDefaultArgExprClass:
       return CXXDefaultArgExpr::CreateImpl(D, C);      
 
@@ -361,14 +364,14 @@
   S.BatchEmitOwnedPtrs(NumArgs+1, SubExprs);  
 }
 
-CallExpr* CallExpr::CreateImpl(Deserializer& D, ASTContext& C) {
+CallExpr* CallExpr::CreateImpl(Deserializer& D, ASTContext& C, StmtClass SC) {
   QualType t = QualType::ReadVal(D);
   SourceLocation L = SourceLocation::ReadVal(D);
   unsigned NumArgs = D.ReadInt();
   Stmt** SubExprs = new Stmt*[NumArgs+1];
   D.BatchReadOwnedPtrs(NumArgs+1, SubExprs, C);
 
-  return new CallExpr(SubExprs,NumArgs,t,L);  
+  return new CallExpr(SC, SubExprs,NumArgs,t,L);  
 }
 
 void CaseStmt::EmitImpl(Serializer& S) const {