Merged revisions 66174-66175,66177 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

................
  r66174 | benjamin.peterson | 2008-09-02 19:21:32 -0500 (Tue, 02 Sep 2008) | 15 lines

  Merged revisions 66173 via svnmerge from
  svn+ssh://pythondev@svn.python.org/sandbox/trunk/2to3/lib2to3

  ........
    r66173 | benjamin.peterson | 2008-09-02 18:57:48 -0500 (Tue, 02 Sep 2008) | 8 lines

    A little 2to3 refactoring #3637

    This moves command line logic from refactor.py to a new file called
    main.py.  RefactoringTool now merely deals with the actual fixers and
    refactoring; options processing for example is abstracted out.

    This patch was reviewed by Gregory P. Smith.
  ........
................
  r66175 | benjamin.peterson | 2008-09-02 20:53:28 -0500 (Tue, 02 Sep 2008) | 1 line

  update 2to3 script from 2to3 trunk
................
  r66177 | benjamin.peterson | 2008-09-02 21:14:03 -0500 (Tue, 02 Sep 2008) | 9 lines

  Merged revisions 66176 via svnmerge from
  svn+ssh://pythondev@svn.python.org/sandbox/trunk/2to3/lib2to3

  ........
    r66176 | benjamin.peterson | 2008-09-02 21:04:06 -0500 (Tue, 02 Sep 2008) | 1 line

    fix typo
  ........
................
diff --git a/Lib/lib2to3/fixer_base.py b/Lib/lib2to3/fixer_base.py
index 682b215..adbbe77 100644
--- a/Lib/lib2to3/fixer_base.py
+++ b/Lib/lib2to3/fixer_base.py
@@ -47,8 +47,8 @@
         """Initializer.  Subclass may override.
 
         Args:
-            options: an optparse.Values instance which can be used
-                to inspect the command line options.
+            options: an dict containing the options passed to RefactoringTool
+            that could be used to customize the fixer through the command line.
             log: a list to append warnings and other messages to.
         """
         self.options = options
diff --git a/Lib/lib2to3/main.py b/Lib/lib2to3/main.py
new file mode 100644
index 0000000..c092886
--- /dev/null
+++ b/Lib/lib2to3/main.py
@@ -0,0 +1,86 @@
+"""
+Main program for 2to3.
+"""
+
+import sys
+import os
+import logging
+import optparse
+
+from . import refactor
+
+
+def main(fixer_pkg, args=None):
+    """Main program.
+
+    Args:
+        fixer_pkg: the name of a package where the fixers are located.
+        args: optional; a list of command line arguments. If omitted,
+              sys.argv[1:] is used.
+
+    Returns a suggested exit status (0, 1, 2).
+    """
+    # Set up option parser
+    parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...")
+    parser.add_option("-d", "--doctests_only", action="store_true",
+                      help="Fix up doctests only")
+    parser.add_option("-f", "--fix", action="append", default=[],
+                      help="Each FIX specifies a transformation; default all")
+    parser.add_option("-l", "--list-fixes", action="store_true",
+                      help="List available transformations (fixes/fix_*.py)")
+    parser.add_option("-p", "--print-function", action="store_true",
+                      help="Modify the grammar so that print() is a function")
+    parser.add_option("-v", "--verbose", action="store_true",
+                      help="More verbose logging")
+    parser.add_option("-w", "--write", action="store_true",
+                      help="Write back modified files")
+
+    # Parse command line arguments
+    refactor_stdin = False
+    options, args = parser.parse_args(args)
+    if options.list_fixes:
+        print "Available transformations for the -f/--fix option:"
+        for fixname in refactor.get_all_fix_names(fixer_pkg):
+            print fixname
+        if not args:
+            return 0
+    if not args:
+        print >>sys.stderr, "At least one file or directory argument required."
+        print >>sys.stderr, "Use --help to show usage."
+        return 2
+    if "-" in args:
+        refactor_stdin = True
+        if options.write:
+            print >>sys.stderr, "Can't write to stdin."
+            return 2
+
+    # Set up logging handler
+    level = logging.DEBUG if options.verbose else logging.INFO
+    logging.basicConfig(format='%(name)s: %(message)s', level=level)
+
+    # Initialize the refactoring tool
+    rt_opts = {"print_function" : options.print_function}
+    avail_names = refactor.get_fixers_from_package(fixer_pkg)
+    explicit = []
+    if options.fix:
+        explicit = [fixer_pkg + ".fix_" + fix
+                    for fix in options.fix if fix != "all"]
+        fixer_names = avail_names if "all" in options.fix else explicit
+    else:
+        fixer_names = avail_names
+    rt = refactor.RefactoringTool(fixer_names, rt_opts, explicit=explicit)
+
+    # Refactor all files and directories passed as arguments
+    if not rt.errors:
+        if refactor_stdin:
+            rt.refactor_stdin()
+        else:
+            rt.refactor(args, options.write, options.doctests_only)
+        rt.summarize()
+
+    # Return error status (0 if rt.errors is zero)
+    return int(bool(rt.errors))
+
+
+if __name__ == "__main__":
+    sys.exit(main())
diff --git a/Lib/lib2to3/refactor.py b/Lib/lib2to3/refactor.py
index f524351..c318045 100755
--- a/Lib/lib2to3/refactor.py
+++ b/Lib/lib2to3/refactor.py
@@ -16,8 +16,8 @@
 import os
 import sys
 import difflib
-import optparse
 import logging
+import operator
 from collections import defaultdict
 from itertools import chain
 
@@ -30,68 +30,19 @@
 from . import fixes
 from . import pygram
 
-def main(fixer_dir, args=None):
-    """Main program.
 
-    Args:
-        fixer_dir: directory where fixer modules are located.
-        args: optional; a list of command line arguments. If omitted,
-              sys.argv[1:] is used.
-
-    Returns a suggested exit status (0, 1, 2).
-    """
-    # Set up option parser
-    parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...")
-    parser.add_option("-d", "--doctests_only", action="store_true",
-                      help="Fix up doctests only")
-    parser.add_option("-f", "--fix", action="append", default=[],
-                      help="Each FIX specifies a transformation; default all")
-    parser.add_option("-l", "--list-fixes", action="store_true",
-                      help="List available transformations (fixes/fix_*.py)")
-    parser.add_option("-p", "--print-function", action="store_true",
-                      help="Modify the grammar so that print() is a function")
-    parser.add_option("-v", "--verbose", action="store_true",
-                      help="More verbose logging")
-    parser.add_option("-w", "--write", action="store_true",
-                      help="Write back modified files")
-
-    # Parse command line arguments
-    options, args = parser.parse_args(args)
-    if options.list_fixes:
-        print("Available transformations for the -f/--fix option:")
-        for fixname in get_all_fix_names(fixer_dir):
-            print(fixname)
-        if not args:
-            return 0
-    if not args:
-        print("At least one file or directory argument required.", file=sys.stderr)
-        print("Use --help to show usage.", file=sys.stderr)
-        return 2
-
-    # Set up logging handler
-    logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
-
-    # Initialize the refactoring tool
-    rt = RefactoringTool(fixer_dir, options)
-
-    # Refactor all files and directories passed as arguments
-    if not rt.errors:
-        rt.refactor_args(args)
-        rt.summarize()
-
-    # Return error status (0 if rt.errors is zero)
-    return int(bool(rt.errors))
-
-
-def get_all_fix_names(fixer_dir):
-    """Return a sorted list of all available fix names."""
+def get_all_fix_names(fixer_pkg, remove_prefix=True):
+    """Return a sorted list of all available fix names in the given package."""
+    pkg = __import__(fixer_pkg, [], [], ["*"])
+    fixer_dir = os.path.dirname(pkg.__file__)
     fix_names = []
     names = os.listdir(fixer_dir)
     names.sort()
     for name in names:
         if name.startswith("fix_") and name.endswith(".py"):
-            fix_names.append(name[4:-3])
-    fix_names.sort()
+            if remove_prefix:
+                name = name[4:]
+            fix_names.append(name[:-3])
     return fix_names
 
 def get_head_types(pat):
@@ -131,22 +82,36 @@
             head_nodes[t].append(fixer)
     return head_nodes
 
+def get_fixers_from_package(pkg_name):
+    """
+    Return the fully qualified names for fixers in the package pkg_name.
+    """
+    return [pkg_name + "." + fix_name
+            for fix_name in get_all_fix_names(pkg_name, False)]
+
 
 class RefactoringTool(object):
 
-    def __init__(self, fixer_dir, options):
+    _default_options = {"print_function": False}
+
+    def __init__(self, fixer_names, options=None, explicit=[]):
         """Initializer.
 
         Args:
-            fixer_dir: directory in which to find fixer modules.
-            options: an optparse.Values instance.
+            fixer_names: a list of fixers to import
+            options: an dict with configuration.
+            explicit: a list of fixers to run even if they are explicit.
         """
-        self.fixer_dir = fixer_dir
-        self.options = options
+        self.fixers = fixer_names
+        self.explicit = explicit
+        self.options = self._default_options.copy()
+        if options is not None:
+            self.options.update(options)
         self.errors = []
         self.logger = logging.getLogger("RefactoringTool")
         self.fixer_log = []
-        if self.options.print_function:
+        self.wrote = False
+        if self.options["print_function"]:
             del pygram.python_grammar.keywords["print"]
         self.driver = driver.Driver(pygram.python_grammar,
                                     convert=pytree.convert,
@@ -166,30 +131,24 @@
           want a pre-order AST traversal, and post_order is the list that want
           post-order traversal.
         """
-        if os.path.isabs(self.fixer_dir):
-            fixer_pkg = os.path.relpath(self.fixer_dir, os.path.join(os.path.dirname(__file__), '..'))
-        else:
-            fixer_pkg = self.fixer_dir
-        fixer_pkg = fixer_pkg.replace(os.path.sep, ".")
-        if os.path.altsep:
-            fixer_pkg = self.fixer_dir.replace(os.path.altsep, ".")
         pre_order_fixers = []
         post_order_fixers = []
-        fix_names = self.options.fix
-        if not fix_names or "all" in fix_names:
-            fix_names = get_all_fix_names(self.fixer_dir)
-        for fix_name in fix_names:
+        for fix_mod_path in self.fixers:
             try:
-                mod = __import__(fixer_pkg + ".fix_" + fix_name, {}, {}, ["*"])
+                mod = __import__(fix_mod_path, {}, {}, ["*"])
             except ImportError:
-                self.log_error("Can't find transformation %s", fix_name)
+                self.log_error("Can't load transformation module %s",
+                               fix_mod_path)
                 continue
+            fix_name = fix_mod_path.rsplit(".", 1)[-1]
+            if fix_name.startswith("fix_"):
+                fix_name = fix_name[4:]
             parts = fix_name.split("_")
             class_name = "Fix" + "".join([p.title() for p in parts])
             try:
                 fix_class = getattr(mod, class_name)
             except AttributeError:
-                self.log_error("Can't find fixes.fix_%s.%s",
+                self.log_error("Can't find %s.%s",
                                fix_name, class_name)
                 continue
             try:
@@ -198,12 +157,12 @@
                 self.log_error("Can't instantiate fixes.fix_%s.%s()",
                                fix_name, class_name, exc_info=True)
                 continue
-            if fixer.explicit and fix_name not in self.options.fix:
+            if fixer.explicit and self.explicit is not True and \
+                    fix_mod_path not in self.explicit:
                 self.log_message("Skipping implicit fixer: %s", fix_name)
                 continue
 
-            if self.options.verbose:
-                self.log_message("Adding transformation: %s", fix_name)
+            self.log_debug("Adding transformation: %s", fix_name)
             if fixer.order == "pre":
                 pre_order_fixers.append(fixer)
             elif fixer.order == "post":
@@ -211,8 +170,9 @@
             else:
                 raise ValueError("Illegal fixer order: %r" % fixer.order)
 
-        pre_order_fixers.sort(key=lambda x: x.run_order)
-        post_order_fixers.sort(key=lambda x: x.run_order)
+        key_func = operator.attrgetter("run_order")
+        pre_order_fixers.sort(key=key_func)
+        post_order_fixers.sort(key=key_func)
         return (pre_order_fixers, post_order_fixers)
 
     def log_error(self, msg, *args, **kwds):
@@ -226,36 +186,38 @@
             msg = msg % args
         self.logger.info(msg)
 
-    def refactor_args(self, args):
-        """Refactors files and directories from an argument list."""
-        for arg in args:
-            if arg == "-":
-                self.refactor_stdin()
-            elif os.path.isdir(arg):
-                self.refactor_dir(arg)
-            else:
-                self.refactor_file(arg)
+    def log_debug(self, msg, *args):
+        if args:
+            msg = msg % args
+        self.logger.debug(msg)
 
-    def refactor_dir(self, arg):
+    def refactor(self, items, write=False, doctests_only=False):
+        """Refactor a list of files and directories."""
+        for dir_or_file in items:
+            if os.path.isdir(dir_or_file):
+                self.refactor_dir(dir_or_file, write)
+            else:
+                self.refactor_file(dir_or_file, write)
+
+    def refactor_dir(self, dir_name, write=False, doctests_only=False):
         """Descends down a directory and refactor every Python file found.
 
         Python files are assumed to have a .py extension.
 
         Files and subdirectories starting with '.' are skipped.
         """
-        for dirpath, dirnames, filenames in os.walk(arg):
-            if self.options.verbose:
-                self.log_message("Descending into %s", dirpath)
+        for dirpath, dirnames, filenames in os.walk(dir_name):
+            self.log_debug("Descending into %s", dirpath)
             dirnames.sort()
             filenames.sort()
             for name in filenames:
                 if not name.startswith(".") and name.endswith("py"):
                     fullname = os.path.join(dirpath, name)
-                    self.refactor_file(fullname)
+                    self.refactor_file(fullname, write, doctests_only)
             # Modify dirnames in-place to remove subdirs with leading dots
             dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
 
-    def refactor_file(self, filename):
+    def refactor_file(self, filename, write=False, doctests_only=False):
         """Refactors a file."""
         try:
             f = open(filename)
@@ -266,21 +228,20 @@
             input = f.read() + "\n" # Silence certain parse errors
         finally:
             f.close()
-        if self.options.doctests_only:
-            if self.options.verbose:
-                self.log_message("Refactoring doctests in %s", filename)
+        if doctests_only:
+            self.log_debug("Refactoring doctests in %s", filename)
             output = self.refactor_docstring(input, filename)
             if output != input:
-                self.write_file(output, filename, input)
-            elif self.options.verbose:
-                self.log_message("No doctest changes in %s", filename)
+                self.processed_file(output, filename, input, write=write)
+            else:
+                self.log_debug("No doctest changes in %s", filename)
         else:
             tree = self.refactor_string(input, filename)
             if tree and tree.was_changed:
                 # The [:-1] is to take off the \n we added earlier
-                self.write_file(str(tree)[:-1], filename)
-            elif self.options.verbose:
-                self.log_message("No changes in %s", filename)
+                self.processed_file(str(tree)[:-1], filename, write=write)
+            else:
+                self.log_debug("No changes in %s", filename)
 
     def refactor_string(self, data, name):
         """Refactor a given input string.
@@ -299,30 +260,25 @@
             self.log_error("Can't parse %s: %s: %s",
                            name, err.__class__.__name__, err)
             return
-        if self.options.verbose:
-            self.log_message("Refactoring %s", name)
+        self.log_debug("Refactoring %s", name)
         self.refactor_tree(tree, name)
         return tree
 
-    def refactor_stdin(self):
-        if self.options.write:
-            self.log_error("Can't write changes back to stdin")
-            return
+    def refactor_stdin(self, doctests_only=False):
         input = sys.stdin.read()
-        if self.options.doctests_only:
-            if self.options.verbose:
-                self.log_message("Refactoring doctests in stdin")
+        if doctests_only:
+            self.log_debug("Refactoring doctests in stdin")
             output = self.refactor_docstring(input, "<stdin>")
             if output != input:
-                self.write_file(output, "<stdin>", input)
-            elif self.options.verbose:
-                self.log_message("No doctest changes in stdin")
+                self.processed_file(output, "<stdin>", input)
+            else:
+                self.log_debug("No doctest changes in stdin")
         else:
             tree = self.refactor_string(input, "<stdin>")
             if tree and tree.was_changed:
-                self.write_file(str(tree), "<stdin>", input)
-            elif self.options.verbose:
-                self.log_message("No changes in stdin")
+                self.processed_file(str(tree), "<stdin>", input)
+            else:
+                self.log_debug("No changes in stdin")
 
     def refactor_tree(self, tree, name):
         """Refactors a parse tree (modifying the tree in place).
@@ -374,14 +330,9 @@
                         node.replace(new)
                         node = new
 
-    def write_file(self, new_text, filename, old_text=None):
-        """Writes a string to a file.
-
-        If there are no changes, this is a no-op.
-
-        Otherwise, it first shows a unified diff between the old text
-        and the new text, and then rewrites the file; the latter is
-        only done if the write option is set.
+    def processed_file(self, new_text, filename, old_text=None, write=False):
+        """
+        Called when a file has been refactored, and there are changes.
         """
         self.files.append(filename)
         if old_text is None:
@@ -395,14 +346,22 @@
             finally:
                 f.close()
         if old_text == new_text:
-            if self.options.verbose:
-                self.log_message("No changes to %s", filename)
+            self.log_debug("No changes to %s", filename)
             return
         diff_texts(old_text, new_text, filename)
-        if not self.options.write:
-            if self.options.verbose:
-                self.log_message("Not writing changes to %s", filename)
+        if not write:
+            self.log_debug("Not writing changes to %s", filename)
             return
+        if write:
+            self.write_file(new_text, filename, old_text)
+
+    def write_file(self, new_text, filename, old_text=None):
+        """Writes a string to a file.
+
+        It first shows a unified diff between the old text and the new text, and
+        then rewrites the file; the latter is only done if the write option is
+        set.
+        """
         backup = filename + ".bak"
         if os.path.lexists(backup):
             try:
@@ -425,8 +384,8 @@
                 self.log_error("Can't write %s: %s", filename, err)
         finally:
             f.close()
-        if self.options.verbose:
-            self.log_message("Wrote changes to %s", filename)
+        self.log_debug("Wrote changes to %s", filename)
+        self.wrote = True
 
     PS1 = ">>> "
     PS2 = "... "
@@ -485,9 +444,9 @@
         try:
             tree = self.parse_block(block, lineno, indent)
         except Exception as err:
-            if self.options.verbose:
+            if self.log.isEnabledFor(logging.DEBUG):
                 for line in block:
-                    self.log_message("Source: %s", line.rstrip("\n"))
+                    self.log_debug("Source: %s", line.rstrip("\n"))
             self.log_error("Can't parse docstring in %s line %s: %s: %s",
                            filename, lineno, err.__class__.__name__, err)
             return block
@@ -504,7 +463,7 @@
         return block
 
     def summarize(self):
-        if self.options.write:
+        if self.wrote:
             were = "were"
         else:
             were = "need to be"
@@ -576,7 +535,3 @@
                                      "(original)", "(refactored)",
                                      lineterm=""):
         print(line)
-
-
-if __name__ == "__main__":
-    sys.exit(main())
diff --git a/Lib/lib2to3/tests/support.py b/Lib/lib2to3/tests/support.py
index 7789033..7abf2ef 100644
--- a/Lib/lib2to3/tests/support.py
+++ b/Lib/lib2to3/tests/support.py
@@ -13,6 +13,7 @@
 
 # Local imports
 from .. import pytree
+from .. import refactor
 from ..pgen2 import driver
 
 test_dir = os.path.dirname(__file__)
@@ -38,6 +39,21 @@
 def reformat(string):
     return dedent(string) + "\n\n"
 
+def get_refactorer(fixers=None, options=None):
+    """
+    A convenience function for creating a RefactoringTool for tests.
+
+    fixers is a list of fixers for the RefactoringTool to use. By default
+    "lib2to3.fixes.*" is used. options is an optional dictionary of options to
+    be passed to the RefactoringTool.
+    """
+    if fixers is not None:
+        fixers = ["lib2to3.fixes.fix_" + fix for fix in fixers]
+    else:
+        fixers = refactor.get_fixers_from_package("lib2to3.fixes")
+    options = options or {}
+    return refactor.RefactoringTool(fixers, options, explicit=True)
+
 def all_project_files():
     for dirpath, dirnames, filenames in os.walk(proj_dir):
         for filename in filenames:
diff --git a/Lib/lib2to3/tests/test_all_fixers.py b/Lib/lib2to3/tests/test_all_fixers.py
index c36f61d..68d6306 100644
--- a/Lib/lib2to3/tests/test_all_fixers.py
+++ b/Lib/lib2to3/tests/test_all_fixers.py
@@ -19,17 +19,11 @@
 from .. import pytree
 from .. import refactor
 
-class Options:
-    def __init__(self, **kwargs):
-        for k, v in list(kwargs.items()):
-            setattr(self, k, v)
-        self.verbose = False
 
 class Test_all(support.TestCase):
     def setUp(self):
-        options = Options(fix=["all", "idioms", "ws_comma", "buffer"],
-                          print_function=False)
-        self.refactor = refactor.RefactoringTool("lib2to3/fixes", options)
+        options = {"print_function" : False}
+        self.refactor = support.get_refactorer(options=options)
 
     def test_all_project_files(self):
         for filepath in support.all_project_files():
diff --git a/Lib/lib2to3/tests/test_fixers.py b/Lib/lib2to3/tests/test_fixers.py
index ecc7f02..eb5a939 100755
--- a/Lib/lib2to3/tests/test_fixers.py
+++ b/Lib/lib2to3/tests/test_fixers.py
@@ -21,19 +21,12 @@
 from .. import fixer_util
 
 
-class Options:
-    def __init__(self, **kwargs):
-        for k, v in list(kwargs.items()):
-            setattr(self, k, v)
-
-        self.verbose = False
-
 class FixerTestCase(support.TestCase):
     def setUp(self, fix_list=None):
-        if not fix_list:
+        if fix_list is None:
             fix_list = [self.fixer]
-        options = Options(fix=fix_list, print_function=False)
-        self.refactor = refactor.RefactoringTool("lib2to3/fixes", options)
+        options = {"print_function" : False}
+        self.refactor = support.get_refactorer(fix_list, options)
         self.fixer_log = []
         self.filename = "<string>"
 
@@ -70,10 +63,10 @@
             self.failUnlessEqual(self.fixer_log, [])
 
     def assert_runs_after(self, *names):
-        fix = [self.fixer]
-        fix.extend(names)
-        options = Options(fix=fix, print_function=False)
-        r = refactor.RefactoringTool("lib2to3/fixes", options)
+        fixes = [self.fixer]
+        fixes.extend(names)
+        options = {"print_function" : False}
+        r = support.get_refactorer(fixes, options)
         (pre, post) = r.get_fixers()
         n = "fix_" + self.fixer
         if post and post[-1].__class__.__module__.endswith(n):