Merged revisions 72494 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

................
  r72494 | benjamin.peterson | 2009-05-08 20:01:14 -0500 (Fri, 08 May 2009) | 21 lines

  Merged revisions 72491-72493 via svnmerge from
  svn+ssh://pythondev@svn.python.org/sandbox/trunk/2to3/lib2to3

  ........
    r72491 | benjamin.peterson | 2009-05-08 19:33:27 -0500 (Fri, 08 May 2009) | 7 lines

    make 2to3 use unicode internally on 2.x

    This started out as a fix for #2660, but became this large refactoring
    when I realized the dire state this was in. 2to3 now uses
    tokenize.detect_encoding to decode the files correctly into unicode.
  ........
    r72492 | benjamin.peterson | 2009-05-08 19:35:38 -0500 (Fri, 08 May 2009) | 1 line

    remove compat code
  ........
    r72493 | benjamin.peterson | 2009-05-08 19:54:15 -0500 (Fri, 08 May 2009) | 1 line

    add a test for \r\n newlines
  ........
................
diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py
index b679db4..82a98d1 100755
--- a/Lib/lib2to3/refactor.py
+++ b/Lib/lib2to3/refactor.py
@@ -22,8 +22,7 @@
 from itertools import chain
 
 # Local imports
-from .pgen2 import driver
-from .pgen2 import tokenize
+from .pgen2 import driver, tokenize
 
 from . import pytree
 from . import patcomp
@@ -87,6 +86,25 @@
     return [pkg_name + "." + fix_name
             for fix_name in get_all_fix_names(pkg_name, False)]
 
+def _identity(obj):
+    return obj
+
+if sys.version_info < (3, 0):
+    import codecs
+    _open_with_encoding = codecs.open
+    # codecs.open doesn't translate newlines sadly.
+    def _from_system_newlines(input):
+        return input.replace("\r\n", "\n")
+    def _to_system_newlines(input):
+        if os.linesep != "\n":
+            return input.replace("\n", os.linesep)
+        else:
+            return input
+else:
+    _open_with_encoding = open
+    _from_system_newlines = _identity
+    _to_system_newlines = _identity
+
 
 class FixerError(Exception):
     """A fixer could not be loaded."""
@@ -213,29 +231,42 @@
             # Modify dirnames in-place to remove subdirs with leading dots
             dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
 
-    def refactor_file(self, filename, write=False, doctests_only=False):
-        """Refactors a file."""
+    def _read_python_source(self, filename):
+        """
+        Do our best to decode a Python source file correctly.
+        """
         try:
-            f = open(filename)
+            f = open(filename, "rb")
         except IOError as err:
             self.log_error("Can't open %s: %s", filename, err)
-            return
+            return None, None
         try:
-            input = f.read() + "\n" # Silence certain parse errors
+            encoding = tokenize.detect_encoding(f.readline)[0]
         finally:
             f.close()
+        with _open_with_encoding(filename, "r", encoding=encoding) as f:
+            return _from_system_newlines(f.read()), encoding
+
+    def refactor_file(self, filename, write=False, doctests_only=False):
+        """Refactors a file."""
+        input, encoding = self._read_python_source(filename)
+        if input is None:
+            # Reading the file failed.
+            return
+        input += "\n" # Silence certain parse errors
         if doctests_only:
             self.log_debug("Refactoring doctests in %s", filename)
             output = self.refactor_docstring(input, filename)
             if output != input:
-                self.processed_file(output, filename, input, write=write)
+                self.processed_file(output, filename, input, write, encoding)
             else:
                 self.log_debug("No doctest changes in %s", filename)
         else:
             tree = self.refactor_string(input, filename)
             if tree and tree.was_changed:
                 # The [:-1] is to take off the \n we added earlier
-                self.processed_file(str(tree)[:-1], filename, write=write)
+                self.processed_file(str(tree)[:-1], filename,
+                                    write=write, encoding=encoding)
             else:
                 self.log_debug("No changes in %s", filename)
 
@@ -321,31 +352,26 @@
                         node.replace(new)
                         node = new
 
-    def processed_file(self, new_text, filename, old_text=None, write=False):
+    def processed_file(self, new_text, filename, old_text=None, write=False,
+                       encoding=None):
         """
         Called when a file has been refactored, and there are changes.
         """
         self.files.append(filename)
         if old_text is None:
-            try:
-                f = open(filename, "r")
-            except IOError as err:
-                self.log_error("Can't read %s: %s", filename, err)
+            old_text = self._read_python_source(filename)[0]
+            if old_text is None:
                 return
-            try:
-                old_text = f.read()
-            finally:
-                f.close()
         if old_text == new_text:
             self.log_debug("No changes to %s", filename)
             return
         self.print_output(diff_texts(old_text, new_text, filename))
         if write:
-            self.write_file(new_text, filename, old_text)
+            self.write_file(new_text, filename, old_text, encoding)
         else:
             self.log_debug("Not writing changes to %s", filename)
 
-    def write_file(self, new_text, filename, old_text):
+    def write_file(self, new_text, filename, old_text, encoding=None):
         """Writes a string to a file.
 
         It first shows a unified diff between the old text and the new text, and
@@ -353,12 +379,12 @@
         set.
         """
         try:
-            f = open(filename, "w")
+            f = _open_with_encoding(filename, "w", encoding=encoding)
         except os.error as err:
             self.log_error("Can't create %s: %s", filename, err)
             return
         try:
-            f.write(new_text)
+            f.write(_to_system_newlines(new_text))
         except os.error as err:
             self.log_error("Can't write %s: %s", filename, err)
         finally: