Rewrite of method definitions in categories.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@44062 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/Driver/RewriteTest.cpp b/Driver/RewriteTest.cpp
index ac1b6d3..09d58f2 100644
--- a/Driver/RewriteTest.cpp
+++ b/Driver/RewriteTest.cpp
@@ -86,7 +86,7 @@
     void RewriteTabs();
     void RewriteForwardClassDecl(ObjcClassDecl *Dcl);
     void RewriteInterfaceDecl(ObjcInterfaceDecl *Dcl);
-    void RewriteImplementationDecl(ObjcImplementationDecl *Dcl);
+    void RewriteImplementationDecl(NamedDecl *Dcl);
     void RewriteObjcMethodDecl(ObjcMethodDecl *MDecl, std::string &ResultStr);
     void RewriteCategoryDecl(ObjcCategoryDecl *Dcl);
     void RewriteProtocolDecl(ObjcProtocolDecl *Dcl);
@@ -382,7 +382,6 @@
     ResultStr += "#include <Objc/objc.h>\n";
     includeObjc = true;
   }
-  
   ResultStr += "\nstatic ";
   ResultStr += OMD->getResultType().getAsString();
   ResultStr += "\n_";
@@ -422,7 +421,8 @@
   if (OMD->isInstance()) {
     QualType selfTy = Context->getObjcInterfaceType(OMD->getClassInterface());
     selfTy = Context->getPointerType(selfTy);
-    ResultStr += "struct ";
+    if (ObjcSynthesizedStructs.count(OMD->getClassInterface()))
+      ResultStr += "struct ";
     ResultStr += selfTy.getAsString();
   }
   else
@@ -443,13 +443,25 @@
   ResultStr += ")";
   
 }
-void RewriteTest::RewriteImplementationDecl(ObjcImplementationDecl *OID) {
+void RewriteTest::RewriteImplementationDecl(NamedDecl *OID) {
+  ObjcImplementationDecl *IMD = dyn_cast<ObjcImplementationDecl>(OID);
+  ObjcCategoryImplDecl *CID = dyn_cast<ObjcCategoryImplDecl>(OID);
   
-  Rewrite.InsertText(OID->getLocStart(), "// ", 3);
+  if (IMD)
+    Rewrite.InsertText(IMD->getLocStart(), "// ", 3);
+  else
+    Rewrite.InsertText(CID->getLocStart(), "// ", 3);
   
-  for (int i = 0; i < OID->getNumInstanceMethods(); i++) {
+  int numMethods = IMD ? IMD->getNumInstanceMethods() 
+                       : CID->getNumInstanceMethods();
+  
+  for (int i = 0; i < numMethods; i++) {
     std::string ResultStr;
-    ObjcMethodDecl *OMD = OID->getInstanceMethods()[i];
+    ObjcMethodDecl *OMD;
+    if (IMD)
+      OMD = IMD->getInstanceMethods()[i];
+    else
+      OMD = CID->getInstanceMethods()[i];
     RewriteObjcMethodDecl(OMD, ResultStr);
     SourceLocation LocStart = OMD->getLocStart();
     SourceLocation LocEnd = OMD->getBody()->getLocStart();
@@ -460,9 +472,14 @@
                         ResultStr.c_str(), ResultStr.size());
   }
   
-  for (int i = 0; i < OID->getNumClassMethods(); i++) {
+  numMethods = IMD ? IMD->getNumClassMethods() : CID->getNumClassMethods();
+  for (int i = 0; i < numMethods; i++) {
     std::string ResultStr;
-    ObjcMethodDecl *OMD = OID->getClassMethods()[i];
+    ObjcMethodDecl *OMD;
+    if (IMD)
+      OMD = IMD->getClassMethods()[i];
+    else
+      OMD = CID->getClassMethods()[i];
     RewriteObjcMethodDecl(OMD, ResultStr);
     SourceLocation LocStart = OMD->getLocStart();
     SourceLocation LocEnd = OMD->getBody()->getLocStart();
@@ -472,7 +489,10 @@
     Rewrite.ReplaceText(LocStart, endBuf-startBuf,
                         ResultStr.c_str(), ResultStr.size());    
   }
-  Rewrite.InsertText(OID->getLocEnd(), "// ", 3);
+  if (IMD)
+    Rewrite.InsertText(IMD->getLocEnd(), "// ", 3);
+  else
+   Rewrite.InsertText(CID->getLocEnd(), "// ", 3); 
 }
 
 void RewriteTest::RewriteInterfaceDecl(ObjcInterfaceDecl *ClassDecl) {
@@ -1723,6 +1743,9 @@
   for (int i = 0; i < ClsDefCount; i++)
     RewriteImplementationDecl(ClassImplementation[i]);
   
+  for (int i = 0; i < CatDefCount; i++)
+    RewriteImplementationDecl(CategoryImplementation[i]);
+  
   // TODO: This is temporary until we decide how to access objc types in a
   // c program
   Result += "#include <Objc/objc.h>\n";