blob: f7a3b15989ec92a602208ec935037016a526dc6f [file] [log] [blame]
Martin v. Löwisef04c442008-03-19 05:04:44 +00001#!/usr/bin/env python2.5
2# Copyright 2006 Google, Inc. All Rights Reserved.
3# Licensed to PSF under a Contributor Agreement.
4
5"""Refactoring framework.
6
7Used as a main program, this can refactor any number of files and/or
8recursively descend down directories. Imported as a module, this
9provides infrastructure to write your own refactoring tool.
10"""
11
12__author__ = "Guido van Rossum <guido@python.org>"
13
14
15# Python imports
16import os
17import sys
18import difflib
19import optparse
20import logging
21
22# Local imports
23from .pgen2 import driver
24from .pgen2 import tokenize
25
26from . import pytree
27from . import patcomp
28from . import fixes
29from . import pygram
30
31if sys.version_info < (2, 4):
32 hdlr = logging.StreamHandler()
33 fmt = logging.Formatter('%(name)s: %(message)s')
34 hdlr.setFormatter(fmt)
35 logging.root.addHandler(hdlr)
36else:
37 logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
38
39
40def main(args=None):
41 """Main program.
42
43 Call without arguments to use sys.argv[1:] as the arguments; or
44 call with a list of arguments (excluding sys.argv[0]).
45
46 Returns a suggested exit status (0, 1, 2).
47 """
48 # Set up option parser
49 parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...")
50 parser.add_option("-d", "--doctests_only", action="store_true",
51 help="Fix up doctests only")
52 parser.add_option("-f", "--fix", action="append", default=[],
53 help="Each FIX specifies a transformation; default all")
54 parser.add_option("-l", "--list-fixes", action="store_true",
55 help="List available transformations (fixes/fix_*.py)")
56 parser.add_option("-p", "--print-function", action="store_true",
57 help="Modify the grammar so that print() is a function")
58 parser.add_option("-v", "--verbose", action="store_true",
59 help="More verbose logging")
60 parser.add_option("-w", "--write", action="store_true",
61 help="Write back modified files")
62
63 # Parse command line arguments
64 options, args = parser.parse_args(args)
65 if options.list_fixes:
66 print "Available transformations for the -f/--fix option:"
67 for fixname in get_all_fix_names():
68 print fixname
69 if not args:
70 return 0
71 if not args:
72 print >>sys.stderr, "At least one file or directory argument required."
73 print >>sys.stderr, "Use --help to show usage."
74 return 2
75
76 # Initialize the refactoring tool
77 rt = RefactoringTool(options)
78
79 # Refactor all files and directories passed as arguments
80 if not rt.errors:
81 rt.refactor_args(args)
82 rt.summarize()
83
84 # Return error status (0 if rt.errors is zero)
85 return int(bool(rt.errors))
86
87
88def get_all_fix_names():
89 """Return a sorted list of all available fix names."""
90 fix_names = []
91 names = os.listdir(os.path.dirname(fixes.__file__))
92 names.sort()
93 for name in names:
94 if name.startswith("fix_") and name.endswith(".py"):
95 fix_names.append(name[4:-3])
96 fix_names.sort()
97 return fix_names
98
99
100class RefactoringTool(object):
101
102 def __init__(self, options):
103 """Initializer.
104
105 The argument is an optparse.Values instance.
106 """
107 self.options = options
108 self.errors = []
109 self.logger = logging.getLogger("RefactoringTool")
110 self.fixer_log = []
111 if self.options.print_function:
112 del pygram.python_grammar.keywords["print"]
113 self.driver = driver.Driver(pygram.python_grammar,
114 convert=pytree.convert,
115 logger=self.logger)
116 self.pre_order, self.post_order = self.get_fixers()
117 self.files = [] # List of files that were or should be modified
118
119 def get_fixers(self):
120 """Inspects the options to load the requested patterns and handlers.
121
122 Returns:
123 (pre_order, post_order), where pre_order is the list of fixers that
124 want a pre-order AST traversal, and post_order is the list that want
125 post-order traversal.
126 """
127 pre_order_fixers = []
128 post_order_fixers = []
129 fix_names = self.options.fix
130 if not fix_names or "all" in fix_names:
131 fix_names = get_all_fix_names()
132 for fix_name in fix_names:
133 try:
134 mod = __import__("lib2to3.fixes.fix_" + fix_name, {}, {}, ["*"])
135 except ImportError:
136 self.log_error("Can't find transformation %s", fix_name)
137 continue
138 parts = fix_name.split("_")
139 class_name = "Fix" + "".join([p.title() for p in parts])
140 try:
141 fix_class = getattr(mod, class_name)
142 except AttributeError:
143 self.log_error("Can't find fixes.fix_%s.%s",
144 fix_name, class_name)
145 continue
146 try:
147 fixer = fix_class(self.options, self.fixer_log)
148 except Exception, err:
149 self.log_error("Can't instantiate fixes.fix_%s.%s()",
150 fix_name, class_name, exc_info=True)
151 continue
152 if fixer.explicit and fix_name not in self.options.fix:
153 self.log_message("Skipping implicit fixer: %s", fix_name)
154 continue
155
156 if self.options.verbose:
157 self.log_message("Adding transformation: %s", fix_name)
158 if fixer.order == "pre":
159 pre_order_fixers.append(fixer)
160 elif fixer.order == "post":
161 post_order_fixers.append(fixer)
162 else:
163 raise ValueError("Illegal fixer order: %r" % fixer.order)
164 return (pre_order_fixers, post_order_fixers)
165
166 def log_error(self, msg, *args, **kwds):
167 """Increments error count and log a message."""
168 self.errors.append((msg, args, kwds))
169 self.logger.error(msg, *args, **kwds)
170
171 def log_message(self, msg, *args):
172 """Hook to log a message."""
173 if args:
174 msg = msg % args
175 self.logger.info(msg)
176
177 def refactor_args(self, args):
178 """Refactors files and directories from an argument list."""
179 for arg in args:
180 if arg == "-":
181 self.refactor_stdin()
182 elif os.path.isdir(arg):
183 self.refactor_dir(arg)
184 else:
185 self.refactor_file(arg)
186
187 def refactor_dir(self, arg):
188 """Descends down a directory and refactor every Python file found.
189
190 Python files are assumed to have a .py extension.
191
192 Files and subdirectories starting with '.' are skipped.
193 """
194 for dirpath, dirnames, filenames in os.walk(arg):
195 if self.options.verbose:
196 self.log_message("Descending into %s", dirpath)
197 dirnames.sort()
198 filenames.sort()
199 for name in filenames:
200 if not name.startswith(".") and name.endswith("py"):
201 fullname = os.path.join(dirpath, name)
202 self.refactor_file(fullname)
203 # Modify dirnames in-place to remove subdirs with leading dots
204 dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
205
206 def refactor_file(self, filename):
207 """Refactors a file."""
208 try:
209 f = open(filename)
210 except IOError, err:
211 self.log_error("Can't open %s: %s", filename, err)
212 return
213 try:
214 input = f.read() + "\n" # Silence certain parse errors
215 finally:
216 f.close()
217 if self.options.doctests_only:
218 if self.options.verbose:
219 self.log_message("Refactoring doctests in %s", filename)
220 output = self.refactor_docstring(input, filename)
221 if output != input:
222 self.write_file(output, filename, input)
223 elif self.options.verbose:
224 self.log_message("No doctest changes in %s", filename)
225 else:
226 tree = self.refactor_string(input, filename)
227 if tree and tree.was_changed:
228 # The [:-1] is to take off the \n we added earlier
229 self.write_file(str(tree)[:-1], filename)
230 elif self.options.verbose:
231 self.log_message("No changes in %s", filename)
232
233 def refactor_string(self, data, name):
234 """Refactor a given input string.
235
236 Args:
237 data: a string holding the code to be refactored.
238 name: a human-readable name for use in error/log messages.
239
240 Returns:
241 An AST corresponding to the refactored input stream; None if
242 there were errors during the parse.
243 """
244 try:
245 tree = self.driver.parse_string(data,1)
246 except Exception, err:
247 self.log_error("Can't parse %s: %s: %s",
248 name, err.__class__.__name__, err)
249 return
250 if self.options.verbose:
251 self.log_message("Refactoring %s", name)
252 self.refactor_tree(tree, name)
253 return tree
254
255 def refactor_stdin(self):
256 if self.options.write:
257 self.log_error("Can't write changes back to stdin")
258 return
259 input = sys.stdin.read()
260 if self.options.doctests_only:
261 if self.options.verbose:
262 self.log_message("Refactoring doctests in stdin")
263 output = self.refactor_docstring(input, "<stdin>")
264 if output != input:
265 self.write_file(output, "<stdin>", input)
266 elif self.options.verbose:
267 self.log_message("No doctest changes in stdin")
268 else:
269 tree = self.refactor_string(input, "<stdin>")
270 if tree and tree.was_changed:
271 self.write_file(str(tree), "<stdin>", input)
272 elif self.options.verbose:
273 self.log_message("No changes in stdin")
274
275 def refactor_tree(self, tree, name):
276 """Refactors a parse tree (modifying the tree in place).
277
278 Args:
279 tree: a pytree.Node instance representing the root of the tree
280 to be refactored.
281 name: a human-readable name for this tree.
282
283 Returns:
284 True if the tree was modified, False otherwise.
285 """
286 all_fixers = self.pre_order + self.post_order
287 for fixer in all_fixers:
288 fixer.start_tree(tree, name)
289
290 self.traverse_by(self.pre_order, tree.pre_order())
291 self.traverse_by(self.post_order, tree.post_order())
292
293 for fixer in all_fixers:
294 fixer.finish_tree(tree, name)
295 return tree.was_changed
296
297 def traverse_by(self, fixers, traversal):
298 """Traverse an AST, applying a set of fixers to each node.
299
300 This is a helper method for refactor_tree().
301
302 Args:
303 fixers: a list of fixer instances.
304 traversal: a generator that yields AST nodes.
305
306 Returns:
307 None
308 """
309 if not fixers:
310 return
311 for node in traversal:
312 for fixer in fixers:
313 results = fixer.match(node)
314 if results:
315 new = fixer.transform(node, results)
316 if new is not None and (new != node or
317 str(new) != str(node)):
318 node.replace(new)
319 node = new
320
321 def write_file(self, new_text, filename, old_text=None):
322 """Writes a string to a file.
323
324 If there are no changes, this is a no-op.
325
326 Otherwise, it first shows a unified diff between the old text
327 and the new text, and then rewrites the file; the latter is
328 only done if the write option is set.
329 """
330 self.files.append(filename)
331 if old_text is None:
332 try:
333 f = open(filename, "r")
334 except IOError, err:
335 self.log_error("Can't read %s: %s", filename, err)
336 return
337 try:
338 old_text = f.read()
339 finally:
340 f.close()
341 if old_text == new_text:
342 if self.options.verbose:
343 self.log_message("No changes to %s", filename)
344 return
345 diff_texts(old_text, new_text, filename)
346 if not self.options.write:
347 if self.options.verbose:
348 self.log_message("Not writing changes to %s", filename)
349 return
350 backup = filename + ".bak"
351 if os.path.lexists(backup):
352 try:
353 os.remove(backup)
354 except os.error, err:
355 self.log_message("Can't remove backup %s", backup)
356 try:
357 os.rename(filename, backup)
358 except os.error, err:
359 self.log_message("Can't rename %s to %s", filename, backup)
360 try:
361 f = open(filename, "w")
362 except os.error, err:
363 self.log_error("Can't create %s: %s", filename, err)
364 return
365 try:
366 try:
367 f.write(new_text)
368 except os.error, err:
369 self.log_error("Can't write %s: %s", filename, err)
370 finally:
371 f.close()
372 if self.options.verbose:
373 self.log_message("Wrote changes to %s", filename)
374
375 PS1 = ">>> "
376 PS2 = "... "
377
378 def refactor_docstring(self, input, filename):
379 """Refactors a docstring, looking for doctests.
380
381 This returns a modified version of the input string. It looks
382 for doctests, which start with a ">>>" prompt, and may be
383 continued with "..." prompts, as long as the "..." is indented
384 the same as the ">>>".
385
386 (Unfortunately we can't use the doctest module's parser,
387 since, like most parsers, it is not geared towards preserving
388 the original source.)
389 """
390 result = []
391 block = None
392 block_lineno = None
393 indent = None
394 lineno = 0
395 for line in input.splitlines(True):
396 lineno += 1
397 if line.lstrip().startswith(self.PS1):
398 if block is not None:
399 result.extend(self.refactor_doctest(block, block_lineno,
400 indent, filename))
401 block_lineno = lineno
402 block = [line]
403 i = line.find(self.PS1)
404 indent = line[:i]
405 elif (indent is not None and
406 (line.startswith(indent + self.PS2) or
407 line == indent + self.PS2.rstrip() + "\n")):
408 block.append(line)
409 else:
410 if block is not None:
411 result.extend(self.refactor_doctest(block, block_lineno,
412 indent, filename))
413 block = None
414 indent = None
415 result.append(line)
416 if block is not None:
417 result.extend(self.refactor_doctest(block, block_lineno,
418 indent, filename))
419 return "".join(result)
420
421 def refactor_doctest(self, block, lineno, indent, filename):
422 """Refactors one doctest.
423
424 A doctest is given as a block of lines, the first of which starts
425 with ">>>" (possibly indented), while the remaining lines start
426 with "..." (identically indented).
427
428 """
429 try:
430 tree = self.parse_block(block, lineno, indent)
431 except Exception, err:
432 if self.options.verbose:
433 for line in block:
434 self.log_message("Source: %s", line.rstrip("\n"))
435 self.log_error("Can't parse docstring in %s line %s: %s: %s",
436 filename, lineno, err.__class__.__name__, err)
437 return block
438 if self.refactor_tree(tree, filename):
439 new = str(tree).splitlines(True)
440 # Undo the adjustment of the line numbers in wrap_toks() below.
441 clipped, new = new[:lineno-1], new[lineno-1:]
442 assert clipped == ["\n"] * (lineno-1), clipped
443 if not new[-1].endswith("\n"):
444 new[-1] += "\n"
445 block = [indent + self.PS1 + new.pop(0)]
446 if new:
447 block += [indent + self.PS2 + line for line in new]
448 return block
449
450 def summarize(self):
451 if self.options.write:
452 were = "were"
453 else:
454 were = "need to be"
455 if not self.files:
456 self.log_message("No files %s modified.", were)
457 else:
458 self.log_message("Files that %s modified:", were)
459 for file in self.files:
460 self.log_message(file)
461 if self.fixer_log:
462 self.log_message("Warnings/messages while refactoring:")
463 for message in self.fixer_log:
464 self.log_message(message)
465 if self.errors:
466 if len(self.errors) == 1:
467 self.log_message("There was 1 error:")
468 else:
469 self.log_message("There were %d errors:", len(self.errors))
470 for msg, args, kwds in self.errors:
471 self.log_message(msg, *args, **kwds)
472
473 def parse_block(self, block, lineno, indent):
474 """Parses a block into a tree.
475
476 This is necessary to get correct line number / offset information
477 in the parser diagnostics and embedded into the parse tree.
478 """
479 return self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
480
481 def wrap_toks(self, block, lineno, indent):
482 """Wraps a tokenize stream to systematically modify start/end."""
483 tokens = tokenize.generate_tokens(self.gen_lines(block, indent).next)
484 for type, value, (line0, col0), (line1, col1), line_text in tokens:
485 line0 += lineno - 1
486 line1 += lineno - 1
487 # Don't bother updating the columns; this is too complicated
488 # since line_text would also have to be updated and it would
489 # still break for tokens spanning lines. Let the user guess
490 # that the column numbers for doctests are relative to the
491 # end of the prompt string (PS1 or PS2).
492 yield type, value, (line0, col0), (line1, col1), line_text
493
494
495 def gen_lines(self, block, indent):
496 """Generates lines as expected by tokenize from a list of lines.
497
498 This strips the first len(indent + self.PS1) characters off each line.
499 """
500 prefix1 = indent + self.PS1
501 prefix2 = indent + self.PS2
502 prefix = prefix1
503 for line in block:
504 if line.startswith(prefix):
505 yield line[len(prefix):]
506 elif line == prefix.rstrip() + "\n":
507 yield "\n"
508 else:
509 raise AssertionError("line=%r, prefix=%r" % (line, prefix))
510 prefix = prefix2
511 while True:
512 yield ""
513
514
515def diff_texts(a, b, filename):
516 """Prints a unified diff of two strings."""
517 a = a.splitlines()
518 b = b.splitlines()
519 for line in difflib.unified_diff(a, b, filename, filename,
520 "(original)", "(refactored)",
521 lineterm=""):
522 print line
523
524
525if __name__ == "__main__":
526 sys.exit(main())