Insert #pragma once when rewriting a header file.


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@46155 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/Driver/RewriteTest.cpp b/Driver/RewriteTest.cpp
index 8d3035b..fb52dfe 100644
--- a/Driver/RewriteTest.cpp
+++ b/Driver/RewriteTest.cpp
@@ -65,6 +65,9 @@
     ObjCMethodDecl *CurMethodDecl;
     RecordDecl *SuperStructDecl;
     
+    // Needed for header files being rewritten
+    bool IsHeader;
+    
     static const int OBJC_ABI_VERSION =7 ;
   public:
     void Initialize(ASTContext &context) {
@@ -96,7 +99,8 @@
       Rewrite.setSourceMgr(Context->getSourceManager());
       // declaring objc_selector outside the parameter list removes a silly
       // scope related warning...
-      const char *s = "struct objc_selector; struct objc_class;\n"
+      const char *s = "#pragma once\n"
+                      "struct objc_selector; struct objc_class;\n"
                       "#ifndef OBJC_SUPER\n"
                       "struct objc_super { struct objc_object *o; "
                       "struct objc_object *superClass; };\n"
@@ -136,15 +140,23 @@
                       "unsigned long extra[5];\n};\n"
                       "#define __FASTENUMERATIONSTATE\n"
                       "#endif\n";
-                      
-      Rewrite.InsertText(SourceLocation::getFileLoc(MainFileID, 0), 
-                         s, strlen(s));
+      if (IsHeader) {
+        // insert the whole string when rewriting a header file
+        Rewrite.InsertText(SourceLocation::getFileLoc(MainFileID, 0), 
+                           s, strlen(s));
+      }
+      else {
+        // Not rewriting header, exclude the #pragma once pragma
+        const char *p = s + strlen("#pragma once\n");
+        Rewrite.InsertText(SourceLocation::getFileLoc(MainFileID, 0), 
+                           p, strlen(p));
+      }
     }
 
     // Top Level Driver code.
     virtual void HandleTopLevelDecl(Decl *D);
     void HandleDeclInMainFile(Decl *D);
-    RewriteTest(Diagnostic &D) : Diags(D) {}
+    RewriteTest(bool isHeader, Diagnostic &D) : Diags(D) {IsHeader = isHeader;}
     ~RewriteTest();
 
     // Syntactic Rewriting.
@@ -226,8 +238,23 @@
   };
 }
 
-ASTConsumer *clang::CreateCodeRewriterTest(Diagnostic &Diags) {
-  return new RewriteTest(Diags);
+static bool IsHeaderFile(const std::string &Filename) {
+  std::string::size_type DotPos = Filename.rfind('.');
+  
+  if (DotPos == std::string::npos) {
+    // no file extension
+    return false; 
+  }
+  
+  std::string Ext = std::string(Filename.begin()+DotPos+1, Filename.end());
+  // C header: .h
+  // C++ header: .hh or .H;
+  return Ext == "h" || Ext == "hh" || Ext == "H";
+}    
+
+ASTConsumer *clang::CreateCodeRewriterTest(const std::string& InFile,
+                                           Diagnostic &Diags) {
+  return new RewriteTest(IsHeaderFile(InFile), Diags);
 }
 
 //===----------------------------------------------------------------------===//