- Add location info to category/protocol AST's
- Rewrite categories.



git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@43501 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/AST/Decl.cpp b/AST/Decl.cpp
index 106537f..5187982 100644
--- a/AST/Decl.cpp
+++ b/AST/Decl.cpp
@@ -345,7 +345,7 @@
                                   unsigned numInsMembers,
                                   ObjcMethodDecl **clsMethods,
                                   unsigned numClsMembers,
-                                  SourceLocation AtEndLoc) {
+                                  SourceLocation endLoc) {
   NumInstanceMethods = numInsMembers;
   if (numInsMembers) {
     InstanceMethods = new ObjcMethodDecl*[numInsMembers];
@@ -356,6 +356,7 @@
     ClassMethods = new ObjcMethodDecl*[numClsMembers];
     memcpy(ClassMethods, clsMethods, numClsMembers*sizeof(ObjcMethodDecl*));
   }
+  AtEndLoc = endLoc;
 }
 
 /// addMethods - Insert instance and methods declarations into
@@ -365,7 +366,7 @@
                                   unsigned numInsMembers,
                                   ObjcMethodDecl **clsMethods,
                                   unsigned numClsMembers,
-                                  SourceLocation AtEndLoc) {
+                                  SourceLocation endLoc) {
   NumInstanceMethods = numInsMembers;
   if (numInsMembers) {
     InstanceMethods = new ObjcMethodDecl*[numInsMembers];
@@ -376,6 +377,7 @@
     ClassMethods = new ObjcMethodDecl*[numClsMembers];
     memcpy(ClassMethods, clsMethods, numClsMembers*sizeof(ObjcMethodDecl*));
   }
+  AtEndLoc = endLoc;
 }
 
 /// addMethods - Insert instance and methods declarations into
diff --git a/Driver/RewriteTest.cpp b/Driver/RewriteTest.cpp
index be09eed..8904210 100644
--- a/Driver/RewriteTest.cpp
+++ b/Driver/RewriteTest.cpp
@@ -60,6 +60,8 @@
     void RewriteTabs();
     void RewriteForwardClassDecl(ObjcClassDecl *Dcl);
     void RewriteInterfaceDecl(ObjcInterfaceDecl *Dcl);
+    void RewriteCategoryDecl(ObjcCategoryDecl *Dcl);
+    void RewriteMethods(int nMethods, ObjcMethodDecl **Methods);
     
     // Expression Rewriting.
     Stmt *RewriteFunctionBody(Stmt *S);
@@ -122,6 +124,8 @@
       SelGetUidFunctionDecl = FD;
   } else if (ObjcInterfaceDecl *MD = dyn_cast<ObjcInterfaceDecl>(D)) {
     RewriteInterfaceDecl(MD);
+  } else if (ObjcCategoryDecl *CD = dyn_cast<ObjcCategoryDecl>(D)) {
+    RewriteCategoryDecl(CD);
   }
   // If we have a decl in the main file, see if we should rewrite it.
   if (SM->getDecomposedFileLoc(Loc).first == MainFileID)
@@ -264,6 +268,31 @@
                       typedefString.c_str(), typedefString.size());
 }
 
+void RewriteTest::RewriteMethods(int nMethods, ObjcMethodDecl **Methods) {
+  for (int i = 0; i < nMethods; i++) {
+    ObjcMethodDecl *Method = Methods[i];
+    SourceLocation Loc = Method->getLocStart();
+
+    Rewrite.ReplaceText(Loc, 0, "// ", 3);
+    
+    // FIXME: handle methods that are declared across multiple lines.
+  }
+}
+
+void RewriteTest::RewriteCategoryDecl(ObjcCategoryDecl *CatDecl) {
+  SourceLocation LocStart = CatDecl->getLocStart();
+  
+  // FIXME: handle category headers that are declared across multiple lines.
+  Rewrite.ReplaceText(LocStart, 0, "// ", 3);
+  
+  RewriteMethods(CatDecl->getNumInstanceMethods(),
+                 CatDecl->getInstanceMethods());
+  RewriteMethods(CatDecl->getNumClassMethods(),
+                 CatDecl->getClassMethods());
+  // Lastly, comment out the @end.
+  Rewrite.ReplaceText(CatDecl->getAtEndLoc(), 0, "// ", 3);
+}
+
 void RewriteTest::RewriteInterfaceDecl(ObjcInterfaceDecl *ClassDecl) {
 
   SourceLocation LocStart = ClassDecl->getLocStart();
@@ -280,28 +309,11 @@
   Rewrite.ReplaceText(LocStart, endBuf-startBuf, 
                       ResultStr.c_str(), ResultStr.size());
   
-  int nInstanceMethods = ClassDecl->getNumInstanceMethods();
-  ObjcMethodDecl **instanceMethods = ClassDecl->getInstanceMethods();
+  RewriteMethods(ClassDecl->getNumInstanceMethods(),
+                 ClassDecl->getInstanceMethods());
+  RewriteMethods(ClassDecl->getNumClassMethods(),
+                 ClassDecl->getClassMethods());
   
-  for (int i = 0; i < nInstanceMethods; i++) {
-    ObjcMethodDecl *instanceMethod = instanceMethods[i];
-    SourceLocation Loc = instanceMethod->getLocStart();
-
-    Rewrite.ReplaceText(Loc, 0, "// ", 3);
-    
-    // FIXME: handle methods that are declared across multiple lines.
-  }
-  int nClassMethods = ClassDecl->getNumClassMethods();
-  ObjcMethodDecl **classMethods = ClassDecl->getClassMethods();
-  
-  for (int i = 0; i < nClassMethods; i++) {
-    ObjcMethodDecl *classMethod = classMethods[i];
-    SourceLocation Loc = classMethod->getLocStart();
-
-    Rewrite.ReplaceText(Loc, 0, "// ", 3);
-    
-    // FIXME: handle methods that are declared across multiple lines.
-  }
   // Lastly, comment out the @end.
   Rewrite.ReplaceText(ClassDecl->getAtEndLoc(), 0, "// ", 3);
 }
diff --git a/Parse/ParseObjc.cpp b/Parse/ParseObjc.cpp
index b2323a2..ea48704 100644
--- a/Parse/ParseObjc.cpp
+++ b/Parse/ParseObjc.cpp
@@ -157,7 +157,8 @@
     
     DeclTy *CategoryType = Actions.ActOnStartCategoryInterface(atLoc, 
                                      nameId, nameLoc, categoryId, categoryLoc,
-                                     &ProtocolRefs[0], ProtocolRefs.size());
+                                     &ProtocolRefs[0], ProtocolRefs.size(),
+                                     endProtoLoc);
     
     ParseObjCInterfaceDeclList(CategoryType, tok::objc_not_keyword);
 
diff --git a/Sema/Sema.h b/Sema/Sema.h
index 2d82ab1..9db00fb 100644
--- a/Sema/Sema.h
+++ b/Sema/Sema.h
@@ -489,7 +489,8 @@
 		    SourceLocation AtInterfaceLoc,
                     IdentifierInfo *ClassName, SourceLocation ClassLoc,
                     IdentifierInfo *CategoryName, SourceLocation CategoryLoc,
-                    IdentifierInfo **ProtoRefNames, unsigned NumProtoRefs);
+                    IdentifierInfo **ProtoRefNames, unsigned NumProtoRefs,
+                    SourceLocation EndProtoLoc);
   
   virtual DeclTy *ActOnStartClassImplementation(
 		    SourceLocation AtClassImplLoc,
diff --git a/Sema/SemaDecl.cpp b/Sema/SemaDecl.cpp
index 8d3887a..e125ce5 100644
--- a/Sema/SemaDecl.cpp
+++ b/Sema/SemaDecl.cpp
@@ -1081,16 +1081,18 @@
     ObjcProtocols[ProtocolName] = PDecl;
   }    
   
-  /// Check then save referenced protocols
-  for (unsigned int i = 0; i != NumProtoRefs; i++) {
-    ObjcProtocolDecl* RefPDecl = ObjcProtocols[ProtoRefNames[i]];
-    if (!RefPDecl || RefPDecl->isForwardDecl())
-      Diag(ProtocolLoc, diag::err_undef_protocolref,
-           ProtoRefNames[i]->getName(),
-           ProtocolName->getName());
-    PDecl->setReferencedProtocols((int)i, RefPDecl);
+  if (NumProtoRefs) {
+    /// Check then save referenced protocols
+    for (unsigned int i = 0; i != NumProtoRefs; i++) {
+      ObjcProtocolDecl* RefPDecl = ObjcProtocols[ProtoRefNames[i]];
+      if (!RefPDecl || RefPDecl->isForwardDecl())
+        Diag(ProtocolLoc, diag::err_undef_protocolref,
+             ProtoRefNames[i]->getName(),
+             ProtocolName->getName());
+      PDecl->setReferencedProtocols((int)i, RefPDecl);
+    }
+    PDecl->setLocEnd(EndProtoLoc);
   }
-
   return PDecl;
 }
 
@@ -1137,7 +1139,8 @@
                       SourceLocation AtInterfaceLoc,
                       IdentifierInfo *ClassName, SourceLocation ClassLoc,
                       IdentifierInfo *CategoryName, SourceLocation CategoryLoc,
-                      IdentifierInfo **ProtoRefNames, unsigned NumProtoRefs) {
+                      IdentifierInfo **ProtoRefNames, unsigned NumProtoRefs,
+                      SourceLocation EndProtoLoc) {
   ObjcInterfaceDecl *IDecl = getObjCInterfaceDecl(ClassName);
   
   /// Check that class of this category is already completely declared.
@@ -1161,17 +1164,19 @@
   if (!CDeclChain)
     CDecl->insertNextClassCategory();
 
-  /// Check then save referenced protocols
-  for (unsigned int i = 0; i != NumProtoRefs; i++) {
-    ObjcProtocolDecl* RefPDecl = ObjcProtocols[ProtoRefNames[i]];
-    if (!RefPDecl || RefPDecl->isForwardDecl()) {
-      Diag(CategoryLoc, diag::err_undef_protocolref,
-           ProtoRefNames[i]->getName(),
-           CategoryName->getName());
+  if (NumProtoRefs) {
+    /// Check then save referenced protocols
+    for (unsigned int i = 0; i != NumProtoRefs; i++) {
+      ObjcProtocolDecl* RefPDecl = ObjcProtocols[ProtoRefNames[i]];
+      if (!RefPDecl || RefPDecl->isForwardDecl()) {
+        Diag(CategoryLoc, diag::err_undef_protocolref,
+             ProtoRefNames[i]->getName(),
+             CategoryName->getName());
+      }
+      CDecl->setCatReferencedProtocols((int)i, RefPDecl);
     }
-    CDecl->setCatReferencedProtocols((int)i, RefPDecl);
+    CDecl->setLocEnd(EndProtoLoc);
   }
-  
   return CDecl;
 }
 
diff --git a/include/clang/AST/DeclObjC.h b/include/clang/AST/DeclObjC.h
index dd6518c..9cc9e5a 100644
--- a/include/clang/AST/DeclObjC.h
+++ b/include/clang/AST/DeclObjC.h
@@ -307,6 +307,9 @@
   int NumClassMethods;  // -1 if not defined
 
   bool isForwardProtoDecl; // declared with @protocol.
+  
+  SourceLocation EndLoc; // marks the '>' or identifier.
+  SourceLocation AtEndLoc; // marks the end of the entire interface.
 public:
   ObjcProtocolDecl(SourceLocation L, unsigned numRefProtos,
                    IdentifierInfo *Id, bool FD = false)
@@ -348,6 +351,14 @@
   bool isForwardDecl() const { return isForwardProtoDecl; }
   void setForwardDecl(bool val) { isForwardProtoDecl = val; }
 
+  // Location information, modeled after the Stmt API. 
+  SourceLocation getLocStart() const { return getLocation(); } // '@'protocol
+  SourceLocation getLocEnd() const { return EndLoc; }
+  void setLocEnd(SourceLocation LE) { EndLoc = LE; };
+  
+  // We also need to record the @end location.
+  SourceLocation getAtEndLoc() const { return AtEndLoc; }
+
   static bool classof(const Decl *D) { return D->getKind() == ObjcProtocol; }
   static bool classof(const ObjcProtocolDecl *D) { return true; }
 };
@@ -459,6 +470,8 @@
   /// Next category belonging to this class
   ObjcCategoryDecl *NextClassCategory;
   
+  SourceLocation EndLoc; // marks the '>' or identifier.
+  SourceLocation AtEndLoc; // marks the end of the entire interface.
 public:
   ObjcCategoryDecl(SourceLocation L, unsigned numRefProtocol,IdentifierInfo *Id)
     : NamedDecl(ObjcCategory, L, Id),
@@ -502,6 +515,13 @@
     NextClassCategory = ClassInterface->getCategoryList();
     ClassInterface->setCategoryList(this);
   }
+  // Location information, modeled after the Stmt API. 
+  SourceLocation getLocStart() const { return getLocation(); } // '@'interface
+  SourceLocation getLocEnd() const { return EndLoc; }
+  void setLocEnd(SourceLocation LE) { EndLoc = LE; };
+  
+  // We also need to record the @end location.
+  SourceLocation getAtEndLoc() const { return AtEndLoc; }
   
   static bool classof(const Decl *D) { return D->getKind() == ObjcCategory; }
   static bool classof(const ObjcCategoryDecl *D) { return true; }
diff --git a/include/clang/Parse/Action.h b/include/clang/Parse/Action.h
index 1ea1643..2097601 100644
--- a/include/clang/Parse/Action.h
+++ b/include/clang/Parse/Action.h
@@ -499,7 +499,8 @@
     IdentifierInfo *CategoryName, 
     SourceLocation CategoryLoc,
     IdentifierInfo **ProtoRefNames, 
-    unsigned NumProtoRefs) {
+    unsigned NumProtoRefs,
+    SourceLocation EndProtoLoc) {
     return 0;
   }
   // ActOnStartClassImplementation - this action is called immdiately after