blob: bd95f3d87a1657f44b9861dac99ff80f32e9ce88 [file] [log] [blame]
Martin v. Löwis5e37bae2008-03-19 04:43:46 +00001#!/usr/bin/env python2.5
2# Copyright 2006 Google, Inc. All Rights Reserved.
3# Licensed to PSF under a Contributor Agreement.
4
5"""Refactoring framework.
6
7Used as a main program, this can refactor any number of files and/or
8recursively descend down directories. Imported as a module, this
9provides infrastructure to write your own refactoring tool.
10"""
11
12__author__ = "Guido van Rossum <guido@python.org>"
13
14
15# Python imports
16import os
17import sys
18import difflib
19import optparse
20import logging
Martin v. Löwis6780a9d2008-05-02 21:30:20 +000021from collections import defaultdict
22from itertools import chain
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000023
24# Local imports
25from .pgen2 import driver
26from .pgen2 import tokenize
27
28from . import pytree
29from . import patcomp
30from . import fixes
31from . import pygram
32
Benjamin Petersone6078232008-06-15 02:31:05 +000033def main(fixer_dir, args=None):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000034 """Main program.
35
Benjamin Petersone6078232008-06-15 02:31:05 +000036 Args:
37 fixer_dir: directory where fixer modules are located.
38 args: optional; a list of command line arguments. If omitted,
39 sys.argv[1:] is used.
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000040
41 Returns a suggested exit status (0, 1, 2).
42 """
43 # Set up option parser
44 parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...")
45 parser.add_option("-d", "--doctests_only", action="store_true",
46 help="Fix up doctests only")
47 parser.add_option("-f", "--fix", action="append", default=[],
48 help="Each FIX specifies a transformation; default all")
49 parser.add_option("-l", "--list-fixes", action="store_true",
50 help="List available transformations (fixes/fix_*.py)")
51 parser.add_option("-p", "--print-function", action="store_true",
52 help="Modify the grammar so that print() is a function")
53 parser.add_option("-v", "--verbose", action="store_true",
54 help="More verbose logging")
55 parser.add_option("-w", "--write", action="store_true",
56 help="Write back modified files")
57
58 # Parse command line arguments
59 options, args = parser.parse_args(args)
60 if options.list_fixes:
61 print "Available transformations for the -f/--fix option:"
Benjamin Petersone6078232008-06-15 02:31:05 +000062 for fixname in get_all_fix_names(fixer_dir):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000063 print fixname
64 if not args:
65 return 0
66 if not args:
67 print >>sys.stderr, "At least one file or directory argument required."
68 print >>sys.stderr, "Use --help to show usage."
69 return 2
70
Amaury Forgeot d'Arcbae17a82008-03-29 12:42:54 +000071 # Set up logging handler
72 if sys.version_info < (2, 4):
73 hdlr = logging.StreamHandler()
74 fmt = logging.Formatter('%(name)s: %(message)s')
75 hdlr.setFormatter(fmt)
76 logging.root.addHandler(hdlr)
77 else:
78 logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
79
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000080 # Initialize the refactoring tool
Benjamin Petersone6078232008-06-15 02:31:05 +000081 rt = RefactoringTool(fixer_dir, options)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000082
83 # Refactor all files and directories passed as arguments
84 if not rt.errors:
85 rt.refactor_args(args)
86 rt.summarize()
87
88 # Return error status (0 if rt.errors is zero)
89 return int(bool(rt.errors))
90
91
Benjamin Petersone6078232008-06-15 02:31:05 +000092def get_all_fix_names(fixer_dir):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000093 """Return a sorted list of all available fix names."""
94 fix_names = []
Benjamin Petersone6078232008-06-15 02:31:05 +000095 names = os.listdir(fixer_dir)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000096 names.sort()
97 for name in names:
98 if name.startswith("fix_") and name.endswith(".py"):
99 fix_names.append(name[4:-3])
100 fix_names.sort()
101 return fix_names
102
Martin v. Löwis6780a9d2008-05-02 21:30:20 +0000103def get_head_types(pat):
104 """ Accepts a pytree Pattern Node and returns a set
105 of the pattern types which will match first. """
106
107 if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
108 # NodePatters must either have no type and no content
109 # or a type and content -- so they don't get any farther
110 # Always return leafs
111 return set([pat.type])
112
113 if isinstance(pat, pytree.NegatedPattern):
114 if pat.content:
115 return get_head_types(pat.content)
116 return set([None]) # Negated Patterns don't have a type
117
118 if isinstance(pat, pytree.WildcardPattern):
119 # Recurse on each node in content
120 r = set()
121 for p in pat.content:
122 for x in p:
123 r.update(get_head_types(x))
124 return r
125
126 raise Exception("Oh no! I don't understand pattern %s" %(pat))
127
128def get_headnode_dict(fixer_list):
129 """ Accepts a list of fixers and returns a dictionary
130 of head node type --> fixer list. """
131 head_nodes = defaultdict(list)
132 for fixer in fixer_list:
133 if not fixer.pattern:
134 head_nodes[None].append(fixer)
135 continue
136 for t in get_head_types(fixer.pattern):
137 head_nodes[t].append(fixer)
138 return head_nodes
139
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000140
141class RefactoringTool(object):
142
Benjamin Petersone6078232008-06-15 02:31:05 +0000143 def __init__(self, fixer_dir, options):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000144 """Initializer.
145
Benjamin Petersone6078232008-06-15 02:31:05 +0000146 Args:
147 fixer_dir: directory in which to find fixer modules.
148 options: an optparse.Values instance.
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000149 """
Benjamin Petersone6078232008-06-15 02:31:05 +0000150 self.fixer_dir = fixer_dir
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000151 self.options = options
152 self.errors = []
153 self.logger = logging.getLogger("RefactoringTool")
154 self.fixer_log = []
155 if self.options.print_function:
156 del pygram.python_grammar.keywords["print"]
157 self.driver = driver.Driver(pygram.python_grammar,
158 convert=pytree.convert,
159 logger=self.logger)
160 self.pre_order, self.post_order = self.get_fixers()
Martin v. Löwis6780a9d2008-05-02 21:30:20 +0000161
162 self.pre_order = get_headnode_dict(self.pre_order)
163 self.post_order = get_headnode_dict(self.post_order)
164
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000165 self.files = [] # List of files that were or should be modified
166
167 def get_fixers(self):
168 """Inspects the options to load the requested patterns and handlers.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000169
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000170 Returns:
171 (pre_order, post_order), where pre_order is the list of fixers that
172 want a pre-order AST traversal, and post_order is the list that want
173 post-order traversal.
174 """
Benjamin Petersond068ad52008-08-19 16:41:34 +0000175 if os.path.isabs(self.fixer_dir):
176 fixer_pkg = os.path.relpath(self.fixer_dir, os.path.join(os.path.dirname(__file__), '..'))
177 else:
178 fixer_pkg = self.fixer_dir
179 fixer_pkg = fixer_pkg.replace(os.path.sep, ".")
Benjamin Petersonbd7bda42008-08-19 21:07:15 +0000180 if os.path.altsep:
181 fixer_pkg = self.fixer_dir.replace(os.path.altsep, ".")
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000182 pre_order_fixers = []
183 post_order_fixers = []
184 fix_names = self.options.fix
185 if not fix_names or "all" in fix_names:
Benjamin Petersone6078232008-06-15 02:31:05 +0000186 fix_names = get_all_fix_names(self.fixer_dir)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000187 for fix_name in fix_names:
188 try:
Benjamin Petersone6078232008-06-15 02:31:05 +0000189 mod = __import__(fixer_pkg + ".fix_" + fix_name, {}, {}, ["*"])
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000190 except ImportError:
191 self.log_error("Can't find transformation %s", fix_name)
192 continue
193 parts = fix_name.split("_")
194 class_name = "Fix" + "".join([p.title() for p in parts])
195 try:
196 fix_class = getattr(mod, class_name)
197 except AttributeError:
198 self.log_error("Can't find fixes.fix_%s.%s",
199 fix_name, class_name)
200 continue
201 try:
202 fixer = fix_class(self.options, self.fixer_log)
203 except Exception, err:
204 self.log_error("Can't instantiate fixes.fix_%s.%s()",
205 fix_name, class_name, exc_info=True)
206 continue
207 if fixer.explicit and fix_name not in self.options.fix:
208 self.log_message("Skipping implicit fixer: %s", fix_name)
209 continue
210
211 if self.options.verbose:
212 self.log_message("Adding transformation: %s", fix_name)
213 if fixer.order == "pre":
214 pre_order_fixers.append(fixer)
215 elif fixer.order == "post":
216 post_order_fixers.append(fixer)
217 else:
218 raise ValueError("Illegal fixer order: %r" % fixer.order)
Martin v. Löwisbaf267c2008-03-22 00:01:12 +0000219
220 pre_order_fixers.sort(key=lambda x: x.run_order)
221 post_order_fixers.sort(key=lambda x: x.run_order)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000222 return (pre_order_fixers, post_order_fixers)
223
224 def log_error(self, msg, *args, **kwds):
225 """Increments error count and log a message."""
226 self.errors.append((msg, args, kwds))
227 self.logger.error(msg, *args, **kwds)
228
229 def log_message(self, msg, *args):
230 """Hook to log a message."""
231 if args:
232 msg = msg % args
233 self.logger.info(msg)
234
235 def refactor_args(self, args):
236 """Refactors files and directories from an argument list."""
237 for arg in args:
238 if arg == "-":
239 self.refactor_stdin()
240 elif os.path.isdir(arg):
241 self.refactor_dir(arg)
242 else:
243 self.refactor_file(arg)
244
245 def refactor_dir(self, arg):
246 """Descends down a directory and refactor every Python file found.
247
248 Python files are assumed to have a .py extension.
249
250 Files and subdirectories starting with '.' are skipped.
251 """
252 for dirpath, dirnames, filenames in os.walk(arg):
253 if self.options.verbose:
254 self.log_message("Descending into %s", dirpath)
255 dirnames.sort()
256 filenames.sort()
257 for name in filenames:
258 if not name.startswith(".") and name.endswith("py"):
259 fullname = os.path.join(dirpath, name)
260 self.refactor_file(fullname)
261 # Modify dirnames in-place to remove subdirs with leading dots
262 dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
263
264 def refactor_file(self, filename):
265 """Refactors a file."""
266 try:
267 f = open(filename)
268 except IOError, err:
269 self.log_error("Can't open %s: %s", filename, err)
270 return
271 try:
272 input = f.read() + "\n" # Silence certain parse errors
273 finally:
274 f.close()
275 if self.options.doctests_only:
276 if self.options.verbose:
277 self.log_message("Refactoring doctests in %s", filename)
278 output = self.refactor_docstring(input, filename)
279 if output != input:
280 self.write_file(output, filename, input)
281 elif self.options.verbose:
282 self.log_message("No doctest changes in %s", filename)
283 else:
284 tree = self.refactor_string(input, filename)
285 if tree and tree.was_changed:
286 # The [:-1] is to take off the \n we added earlier
287 self.write_file(str(tree)[:-1], filename)
288 elif self.options.verbose:
289 self.log_message("No changes in %s", filename)
290
291 def refactor_string(self, data, name):
292 """Refactor a given input string.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000293
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000294 Args:
295 data: a string holding the code to be refactored.
296 name: a human-readable name for use in error/log messages.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000297
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000298 Returns:
299 An AST corresponding to the refactored input stream; None if
300 there were errors during the parse.
301 """
302 try:
303 tree = self.driver.parse_string(data,1)
304 except Exception, err:
305 self.log_error("Can't parse %s: %s: %s",
306 name, err.__class__.__name__, err)
307 return
308 if self.options.verbose:
309 self.log_message("Refactoring %s", name)
310 self.refactor_tree(tree, name)
311 return tree
312
313 def refactor_stdin(self):
314 if self.options.write:
315 self.log_error("Can't write changes back to stdin")
316 return
317 input = sys.stdin.read()
318 if self.options.doctests_only:
319 if self.options.verbose:
320 self.log_message("Refactoring doctests in stdin")
321 output = self.refactor_docstring(input, "<stdin>")
322 if output != input:
323 self.write_file(output, "<stdin>", input)
324 elif self.options.verbose:
325 self.log_message("No doctest changes in stdin")
326 else:
327 tree = self.refactor_string(input, "<stdin>")
328 if tree and tree.was_changed:
329 self.write_file(str(tree), "<stdin>", input)
330 elif self.options.verbose:
331 self.log_message("No changes in stdin")
332
333 def refactor_tree(self, tree, name):
334 """Refactors a parse tree (modifying the tree in place).
Martin v. Löwisab41b372008-03-19 05:22:42 +0000335
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000336 Args:
337 tree: a pytree.Node instance representing the root of the tree
338 to be refactored.
339 name: a human-readable name for this tree.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000340
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000341 Returns:
342 True if the tree was modified, False otherwise.
343 """
Martin v. Löwis6780a9d2008-05-02 21:30:20 +0000344 # Two calls to chain are required because pre_order.values()
345 # will be a list of lists of fixers:
346 # [[<fixer ...>, <fixer ...>], [<fixer ...>]]
347 all_fixers = chain(chain(*self.pre_order.values()),\
348 chain(*self.post_order.values()))
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000349 for fixer in all_fixers:
350 fixer.start_tree(tree, name)
351
352 self.traverse_by(self.pre_order, tree.pre_order())
353 self.traverse_by(self.post_order, tree.post_order())
354
355 for fixer in all_fixers:
356 fixer.finish_tree(tree, name)
357 return tree.was_changed
358
359 def traverse_by(self, fixers, traversal):
360 """Traverse an AST, applying a set of fixers to each node.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000361
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000362 This is a helper method for refactor_tree().
Martin v. Löwisab41b372008-03-19 05:22:42 +0000363
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000364 Args:
365 fixers: a list of fixer instances.
366 traversal: a generator that yields AST nodes.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000367
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000368 Returns:
369 None
370 """
371 if not fixers:
372 return
373 for node in traversal:
Martin v. Löwis6780a9d2008-05-02 21:30:20 +0000374 for fixer in fixers[node.type] + fixers[None]:
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000375 results = fixer.match(node)
376 if results:
377 new = fixer.transform(node, results)
378 if new is not None and (new != node or
379 str(new) != str(node)):
380 node.replace(new)
381 node = new
382
383 def write_file(self, new_text, filename, old_text=None):
384 """Writes a string to a file.
385
386 If there are no changes, this is a no-op.
387
388 Otherwise, it first shows a unified diff between the old text
389 and the new text, and then rewrites the file; the latter is
390 only done if the write option is set.
391 """
392 self.files.append(filename)
393 if old_text is None:
394 try:
395 f = open(filename, "r")
396 except IOError, err:
397 self.log_error("Can't read %s: %s", filename, err)
398 return
399 try:
400 old_text = f.read()
401 finally:
402 f.close()
403 if old_text == new_text:
404 if self.options.verbose:
405 self.log_message("No changes to %s", filename)
406 return
407 diff_texts(old_text, new_text, filename)
408 if not self.options.write:
409 if self.options.verbose:
410 self.log_message("Not writing changes to %s", filename)
411 return
412 backup = filename + ".bak"
413 if os.path.lexists(backup):
414 try:
415 os.remove(backup)
416 except os.error, err:
417 self.log_message("Can't remove backup %s", backup)
418 try:
419 os.rename(filename, backup)
420 except os.error, err:
421 self.log_message("Can't rename %s to %s", filename, backup)
422 try:
423 f = open(filename, "w")
424 except os.error, err:
425 self.log_error("Can't create %s: %s", filename, err)
426 return
427 try:
428 try:
429 f.write(new_text)
430 except os.error, err:
431 self.log_error("Can't write %s: %s", filename, err)
432 finally:
433 f.close()
434 if self.options.verbose:
435 self.log_message("Wrote changes to %s", filename)
436
437 PS1 = ">>> "
438 PS2 = "... "
439
440 def refactor_docstring(self, input, filename):
441 """Refactors a docstring, looking for doctests.
442
443 This returns a modified version of the input string. It looks
444 for doctests, which start with a ">>>" prompt, and may be
445 continued with "..." prompts, as long as the "..." is indented
446 the same as the ">>>".
447
448 (Unfortunately we can't use the doctest module's parser,
449 since, like most parsers, it is not geared towards preserving
450 the original source.)
451 """
452 result = []
453 block = None
454 block_lineno = None
455 indent = None
456 lineno = 0
457 for line in input.splitlines(True):
458 lineno += 1
459 if line.lstrip().startswith(self.PS1):
460 if block is not None:
461 result.extend(self.refactor_doctest(block, block_lineno,
462 indent, filename))
463 block_lineno = lineno
464 block = [line]
465 i = line.find(self.PS1)
466 indent = line[:i]
467 elif (indent is not None and
468 (line.startswith(indent + self.PS2) or
469 line == indent + self.PS2.rstrip() + "\n")):
470 block.append(line)
471 else:
472 if block is not None:
473 result.extend(self.refactor_doctest(block, block_lineno,
474 indent, filename))
475 block = None
476 indent = None
477 result.append(line)
478 if block is not None:
479 result.extend(self.refactor_doctest(block, block_lineno,
480 indent, filename))
481 return "".join(result)
482
483 def refactor_doctest(self, block, lineno, indent, filename):
484 """Refactors one doctest.
485
486 A doctest is given as a block of lines, the first of which starts
487 with ">>>" (possibly indented), while the remaining lines start
488 with "..." (identically indented).
489
490 """
491 try:
492 tree = self.parse_block(block, lineno, indent)
493 except Exception, err:
494 if self.options.verbose:
495 for line in block:
496 self.log_message("Source: %s", line.rstrip("\n"))
497 self.log_error("Can't parse docstring in %s line %s: %s: %s",
498 filename, lineno, err.__class__.__name__, err)
499 return block
500 if self.refactor_tree(tree, filename):
501 new = str(tree).splitlines(True)
502 # Undo the adjustment of the line numbers in wrap_toks() below.
503 clipped, new = new[:lineno-1], new[lineno-1:]
504 assert clipped == ["\n"] * (lineno-1), clipped
505 if not new[-1].endswith("\n"):
506 new[-1] += "\n"
507 block = [indent + self.PS1 + new.pop(0)]
508 if new:
509 block += [indent + self.PS2 + line for line in new]
510 return block
511
512 def summarize(self):
513 if self.options.write:
514 were = "were"
515 else:
516 were = "need to be"
517 if not self.files:
518 self.log_message("No files %s modified.", were)
519 else:
520 self.log_message("Files that %s modified:", were)
521 for file in self.files:
522 self.log_message(file)
523 if self.fixer_log:
524 self.log_message("Warnings/messages while refactoring:")
525 for message in self.fixer_log:
526 self.log_message(message)
527 if self.errors:
528 if len(self.errors) == 1:
529 self.log_message("There was 1 error:")
530 else:
531 self.log_message("There were %d errors:", len(self.errors))
532 for msg, args, kwds in self.errors:
533 self.log_message(msg, *args, **kwds)
534
535 def parse_block(self, block, lineno, indent):
536 """Parses a block into a tree.
537
538 This is necessary to get correct line number / offset information
539 in the parser diagnostics and embedded into the parse tree.
540 """
541 return self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
542
543 def wrap_toks(self, block, lineno, indent):
544 """Wraps a tokenize stream to systematically modify start/end."""
545 tokens = tokenize.generate_tokens(self.gen_lines(block, indent).next)
546 for type, value, (line0, col0), (line1, col1), line_text in tokens:
547 line0 += lineno - 1
548 line1 += lineno - 1
549 # Don't bother updating the columns; this is too complicated
550 # since line_text would also have to be updated and it would
551 # still break for tokens spanning lines. Let the user guess
552 # that the column numbers for doctests are relative to the
553 # end of the prompt string (PS1 or PS2).
554 yield type, value, (line0, col0), (line1, col1), line_text
555
556
557 def gen_lines(self, block, indent):
558 """Generates lines as expected by tokenize from a list of lines.
559
560 This strips the first len(indent + self.PS1) characters off each line.
561 """
562 prefix1 = indent + self.PS1
563 prefix2 = indent + self.PS2
564 prefix = prefix1
565 for line in block:
566 if line.startswith(prefix):
567 yield line[len(prefix):]
568 elif line == prefix.rstrip() + "\n":
569 yield "\n"
570 else:
571 raise AssertionError("line=%r, prefix=%r" % (line, prefix))
572 prefix = prefix2
573 while True:
574 yield ""
575
576
577def diff_texts(a, b, filename):
578 """Prints a unified diff of two strings."""
579 a = a.splitlines()
580 b = b.splitlines()
581 for line in difflib.unified_diff(a, b, filename, filename,
582 "(original)", "(refactored)",
583 lineterm=""):
584 print line
585
586
587if __name__ == "__main__":
Martin v. Löwisab41b372008-03-19 05:22:42 +0000588 sys.exit(main())