Add some plumbing to rewrite message expressions (still under construction).



git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@43274 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/Driver/RewriteTest.cpp b/Driver/RewriteTest.cpp
index 7cf1342..42b5940 100644
--- a/Driver/RewriteTest.cpp
+++ b/Driver/RewriteTest.cpp
@@ -16,6 +16,7 @@
 #include "clang/AST/AST.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/Basic/SourceManager.h"
+#include "clang/Basic/IdentifierTable.h"
 using namespace clang;
 
 
@@ -28,12 +29,17 @@
     SourceLocation LastIncLoc;
     llvm::SmallVector<ObjcImplementationDecl *, 8> ClassImplementation;
     llvm::SmallVector<ObjcCategoryImplDecl *, 8> CategoryImplementation;
+    
+    FunctionDecl *MsgSendFunctionDecl;
+    FunctionDecl *GetClassFunctionDecl;
+    
     static const int OBJC_ABI_VERSION =7 ;
   public:
     void Initialize(ASTContext &context, unsigned mainFileID) {
       Context = &context;
       SM = &Context->SourceMgr;
       MainFileID = mainFileID;
+      MsgSendFunctionDecl = 0;
       Rewrite.setSourceMgr(Context->SourceMgr);
     }
     
@@ -45,6 +51,7 @@
     void RewriteFunctionBody(Stmt *S);
     void RewriteAtEncode(ObjCEncodeExpr *Exp);
     void RewriteForwardClassDecl(ObjcClassDecl *Dcl);
+    void RewriteMessageExpr(ObjCMessageExpr *Exp);
     
     void WriteObjcClassMetaData(ObjcImplementationDecl *IDecl);
     void WriteObjcMetaData();
@@ -65,6 +72,13 @@
   // If this is for a builtin, ignore it.
   if (Loc.isInvalid()) return;
 
+  // Look for built-in declarations that we need to refer during the rewrite.
+  if (FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
+    if (FD->getIdentifier() == &Context->Idents.get("objc_msgSend"))
+      MsgSendFunctionDecl = FD;
+    else if (FD->getIdentifier() == &Context->Idents.get("objc_getClass"))
+      GetClassFunctionDecl = FD;
+  }
   if (SM->getDecomposedFileLoc(Loc).first == MainFileID)
     return HandleDeclInMainFile(D);
 
@@ -113,15 +127,19 @@
 
 
 void RewriteTest::RewriteFunctionBody(Stmt *S) {
-  // Handle specific things.
-  if (ObjCEncodeExpr *AtEncode = dyn_cast<ObjCEncodeExpr>(S))
-    return RewriteAtEncode(AtEncode);
-  
   // Otherwise, just rewrite all children.
   for (Stmt::child_iterator CI = S->child_begin(), E = S->child_end();
        CI != E; ++CI)
     if (*CI)
       RewriteFunctionBody(*CI);
+      
+  // Handle specific things.
+  if (ObjCEncodeExpr *AtEncode = dyn_cast<ObjCEncodeExpr>(S))
+    return RewriteAtEncode(AtEncode);
+    
+  if (ObjCMessageExpr *MessExpr = dyn_cast<ObjCMessageExpr>(S))
+    return RewriteMessageExpr(MessExpr);
+  
 }
  
 void RewriteTest::RewriteAtEncode(ObjCEncodeExpr *Exp) {
@@ -133,6 +151,33 @@
   delete Replacement;
 }
 
+
+void RewriteTest::RewriteMessageExpr(ObjCMessageExpr *Exp) {
+  assert(MsgSendFunctionDecl && "Can't find objc_msgSend() decl");
+  //Exp->dumpPretty();
+  //printf("\n");
+  
+  // Synthesize a call to objc_msgSend().
+  
+  // Get the type, we will need to reference it in a couple spots.
+  QualType msgSendType = MsgSendFunctionDecl->getType();
+  
+  // Create a reference to the objc_msgSend() declaration.
+  DeclRefExpr *DRE = new DeclRefExpr(MsgSendFunctionDecl, msgSendType,
+                                     SourceLocation());
+                                     
+  // Now, we cast the reference to a pointer to the objc_msgSend type.
+  QualType pToFunc = Context->getPointerType(msgSendType);                                  
+  ImplicitCastExpr *ICE = new ImplicitCastExpr(pToFunc, DRE);
+  
+  const FunctionType *FT = msgSendType->getAsFunctionType();
+  CallExpr *CE = new CallExpr(ICE, 0, 0, FT->getResultType(), 
+                              SourceLocation());
+  Rewrite.ReplaceStmt(Exp, CE);
+  //Exp->dump();
+  //CE->dump();
+}
+
 void RewriteTest::RewriteForwardClassDecl(ObjcClassDecl *ClassDecl) {
   int numDecls = ClassDecl->getNumForwardDecls();
   ObjcInterfaceDecl **ForwardDecls = ClassDecl->getForwardDecls();