blob: d2797b9dfbaed49ea7cc3e944b020c7f7070330e [file] [log] [blame]
Martin v. Löwis5e37bae2008-03-19 04:43:46 +00001#!/usr/bin/env python2.5
2# Copyright 2006 Google, Inc. All Rights Reserved.
3# Licensed to PSF under a Contributor Agreement.
4
5"""Refactoring framework.
6
7Used as a main program, this can refactor any number of files and/or
8recursively descend down directories. Imported as a module, this
9provides infrastructure to write your own refactoring tool.
10"""
11
12__author__ = "Guido van Rossum <guido@python.org>"
13
14
15# Python imports
16import os
17import sys
18import difflib
19import optparse
20import logging
Martin v. Löwis6780a9d2008-05-02 21:30:20 +000021from collections import defaultdict
22from itertools import chain
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000023
24# Local imports
25from .pgen2 import driver
26from .pgen2 import tokenize
27
28from . import pytree
29from . import patcomp
30from . import fixes
31from . import pygram
32
Benjamin Petersone6078232008-06-15 02:31:05 +000033def main(fixer_dir, args=None):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000034 """Main program.
35
Benjamin Petersone6078232008-06-15 02:31:05 +000036 Args:
37 fixer_dir: directory where fixer modules are located.
38 args: optional; a list of command line arguments. If omitted,
39 sys.argv[1:] is used.
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000040
41 Returns a suggested exit status (0, 1, 2).
42 """
43 # Set up option parser
44 parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...")
45 parser.add_option("-d", "--doctests_only", action="store_true",
46 help="Fix up doctests only")
47 parser.add_option("-f", "--fix", action="append", default=[],
48 help="Each FIX specifies a transformation; default all")
49 parser.add_option("-l", "--list-fixes", action="store_true",
50 help="List available transformations (fixes/fix_*.py)")
51 parser.add_option("-p", "--print-function", action="store_true",
52 help="Modify the grammar so that print() is a function")
53 parser.add_option("-v", "--verbose", action="store_true",
54 help="More verbose logging")
55 parser.add_option("-w", "--write", action="store_true",
56 help="Write back modified files")
57
58 # Parse command line arguments
59 options, args = parser.parse_args(args)
60 if options.list_fixes:
61 print "Available transformations for the -f/--fix option:"
Benjamin Petersone6078232008-06-15 02:31:05 +000062 for fixname in get_all_fix_names(fixer_dir):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000063 print fixname
64 if not args:
65 return 0
66 if not args:
67 print >>sys.stderr, "At least one file or directory argument required."
68 print >>sys.stderr, "Use --help to show usage."
69 return 2
70
Amaury Forgeot d'Arcbae17a82008-03-29 12:42:54 +000071 # Set up logging handler
72 if sys.version_info < (2, 4):
73 hdlr = logging.StreamHandler()
74 fmt = logging.Formatter('%(name)s: %(message)s')
75 hdlr.setFormatter(fmt)
76 logging.root.addHandler(hdlr)
77 else:
78 logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
79
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000080 # Initialize the refactoring tool
Benjamin Petersone6078232008-06-15 02:31:05 +000081 rt = RefactoringTool(fixer_dir, options)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000082
83 # Refactor all files and directories passed as arguments
84 if not rt.errors:
85 rt.refactor_args(args)
86 rt.summarize()
87
88 # Return error status (0 if rt.errors is zero)
89 return int(bool(rt.errors))
90
91
Benjamin Petersone6078232008-06-15 02:31:05 +000092def get_all_fix_names(fixer_dir):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000093 """Return a sorted list of all available fix names."""
94 fix_names = []
Benjamin Petersone6078232008-06-15 02:31:05 +000095 names = os.listdir(fixer_dir)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000096 names.sort()
97 for name in names:
98 if name.startswith("fix_") and name.endswith(".py"):
99 fix_names.append(name[4:-3])
100 fix_names.sort()
101 return fix_names
102
Martin v. Löwis6780a9d2008-05-02 21:30:20 +0000103def get_head_types(pat):
104 """ Accepts a pytree Pattern Node and returns a set
105 of the pattern types which will match first. """
106
107 if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
108 # NodePatters must either have no type and no content
109 # or a type and content -- so they don't get any farther
110 # Always return leafs
111 return set([pat.type])
112
113 if isinstance(pat, pytree.NegatedPattern):
114 if pat.content:
115 return get_head_types(pat.content)
116 return set([None]) # Negated Patterns don't have a type
117
118 if isinstance(pat, pytree.WildcardPattern):
119 # Recurse on each node in content
120 r = set()
121 for p in pat.content:
122 for x in p:
123 r.update(get_head_types(x))
124 return r
125
126 raise Exception("Oh no! I don't understand pattern %s" %(pat))
127
128def get_headnode_dict(fixer_list):
129 """ Accepts a list of fixers and returns a dictionary
130 of head node type --> fixer list. """
131 head_nodes = defaultdict(list)
132 for fixer in fixer_list:
133 if not fixer.pattern:
134 head_nodes[None].append(fixer)
135 continue
136 for t in get_head_types(fixer.pattern):
137 head_nodes[t].append(fixer)
138 return head_nodes
139
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000140
141class RefactoringTool(object):
142
Benjamin Petersone6078232008-06-15 02:31:05 +0000143 def __init__(self, fixer_dir, options):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000144 """Initializer.
145
Benjamin Petersone6078232008-06-15 02:31:05 +0000146 Args:
147 fixer_dir: directory in which to find fixer modules.
148 options: an optparse.Values instance.
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000149 """
Benjamin Petersone6078232008-06-15 02:31:05 +0000150 self.fixer_dir = fixer_dir
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000151 self.options = options
152 self.errors = []
153 self.logger = logging.getLogger("RefactoringTool")
154 self.fixer_log = []
155 if self.options.print_function:
156 del pygram.python_grammar.keywords["print"]
157 self.driver = driver.Driver(pygram.python_grammar,
158 convert=pytree.convert,
159 logger=self.logger)
160 self.pre_order, self.post_order = self.get_fixers()
Martin v. Löwis6780a9d2008-05-02 21:30:20 +0000161
162 self.pre_order = get_headnode_dict(self.pre_order)
163 self.post_order = get_headnode_dict(self.post_order)
164
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000165 self.files = [] # List of files that were or should be modified
166
167 def get_fixers(self):
168 """Inspects the options to load the requested patterns and handlers.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000169
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000170 Returns:
171 (pre_order, post_order), where pre_order is the list of fixers that
172 want a pre-order AST traversal, and post_order is the list that want
173 post-order traversal.
174 """
Benjamin Petersond068ad52008-08-19 16:41:34 +0000175 if os.path.isabs(self.fixer_dir):
176 fixer_pkg = os.path.relpath(self.fixer_dir, os.path.join(os.path.dirname(__file__), '..'))
177 else:
178 fixer_pkg = self.fixer_dir
179 fixer_pkg = fixer_pkg.replace(os.path.sep, ".")
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000180 pre_order_fixers = []
181 post_order_fixers = []
182 fix_names = self.options.fix
183 if not fix_names or "all" in fix_names:
Benjamin Petersone6078232008-06-15 02:31:05 +0000184 fix_names = get_all_fix_names(self.fixer_dir)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000185 for fix_name in fix_names:
186 try:
Benjamin Petersone6078232008-06-15 02:31:05 +0000187 mod = __import__(fixer_pkg + ".fix_" + fix_name, {}, {}, ["*"])
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000188 except ImportError:
189 self.log_error("Can't find transformation %s", fix_name)
190 continue
191 parts = fix_name.split("_")
192 class_name = "Fix" + "".join([p.title() for p in parts])
193 try:
194 fix_class = getattr(mod, class_name)
195 except AttributeError:
196 self.log_error("Can't find fixes.fix_%s.%s",
197 fix_name, class_name)
198 continue
199 try:
200 fixer = fix_class(self.options, self.fixer_log)
201 except Exception, err:
202 self.log_error("Can't instantiate fixes.fix_%s.%s()",
203 fix_name, class_name, exc_info=True)
204 continue
205 if fixer.explicit and fix_name not in self.options.fix:
206 self.log_message("Skipping implicit fixer: %s", fix_name)
207 continue
208
209 if self.options.verbose:
210 self.log_message("Adding transformation: %s", fix_name)
211 if fixer.order == "pre":
212 pre_order_fixers.append(fixer)
213 elif fixer.order == "post":
214 post_order_fixers.append(fixer)
215 else:
216 raise ValueError("Illegal fixer order: %r" % fixer.order)
Martin v. Löwisbaf267c2008-03-22 00:01:12 +0000217
218 pre_order_fixers.sort(key=lambda x: x.run_order)
219 post_order_fixers.sort(key=lambda x: x.run_order)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000220 return (pre_order_fixers, post_order_fixers)
221
222 def log_error(self, msg, *args, **kwds):
223 """Increments error count and log a message."""
224 self.errors.append((msg, args, kwds))
225 self.logger.error(msg, *args, **kwds)
226
227 def log_message(self, msg, *args):
228 """Hook to log a message."""
229 if args:
230 msg = msg % args
231 self.logger.info(msg)
232
233 def refactor_args(self, args):
234 """Refactors files and directories from an argument list."""
235 for arg in args:
236 if arg == "-":
237 self.refactor_stdin()
238 elif os.path.isdir(arg):
239 self.refactor_dir(arg)
240 else:
241 self.refactor_file(arg)
242
243 def refactor_dir(self, arg):
244 """Descends down a directory and refactor every Python file found.
245
246 Python files are assumed to have a .py extension.
247
248 Files and subdirectories starting with '.' are skipped.
249 """
250 for dirpath, dirnames, filenames in os.walk(arg):
251 if self.options.verbose:
252 self.log_message("Descending into %s", dirpath)
253 dirnames.sort()
254 filenames.sort()
255 for name in filenames:
256 if not name.startswith(".") and name.endswith("py"):
257 fullname = os.path.join(dirpath, name)
258 self.refactor_file(fullname)
259 # Modify dirnames in-place to remove subdirs with leading dots
260 dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
261
262 def refactor_file(self, filename):
263 """Refactors a file."""
264 try:
265 f = open(filename)
266 except IOError, err:
267 self.log_error("Can't open %s: %s", filename, err)
268 return
269 try:
270 input = f.read() + "\n" # Silence certain parse errors
271 finally:
272 f.close()
273 if self.options.doctests_only:
274 if self.options.verbose:
275 self.log_message("Refactoring doctests in %s", filename)
276 output = self.refactor_docstring(input, filename)
277 if output != input:
278 self.write_file(output, filename, input)
279 elif self.options.verbose:
280 self.log_message("No doctest changes in %s", filename)
281 else:
282 tree = self.refactor_string(input, filename)
283 if tree and tree.was_changed:
284 # The [:-1] is to take off the \n we added earlier
285 self.write_file(str(tree)[:-1], filename)
286 elif self.options.verbose:
287 self.log_message("No changes in %s", filename)
288
289 def refactor_string(self, data, name):
290 """Refactor a given input string.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000291
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000292 Args:
293 data: a string holding the code to be refactored.
294 name: a human-readable name for use in error/log messages.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000295
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000296 Returns:
297 An AST corresponding to the refactored input stream; None if
298 there were errors during the parse.
299 """
300 try:
301 tree = self.driver.parse_string(data,1)
302 except Exception, err:
303 self.log_error("Can't parse %s: %s: %s",
304 name, err.__class__.__name__, err)
305 return
306 if self.options.verbose:
307 self.log_message("Refactoring %s", name)
308 self.refactor_tree(tree, name)
309 return tree
310
311 def refactor_stdin(self):
312 if self.options.write:
313 self.log_error("Can't write changes back to stdin")
314 return
315 input = sys.stdin.read()
316 if self.options.doctests_only:
317 if self.options.verbose:
318 self.log_message("Refactoring doctests in stdin")
319 output = self.refactor_docstring(input, "<stdin>")
320 if output != input:
321 self.write_file(output, "<stdin>", input)
322 elif self.options.verbose:
323 self.log_message("No doctest changes in stdin")
324 else:
325 tree = self.refactor_string(input, "<stdin>")
326 if tree and tree.was_changed:
327 self.write_file(str(tree), "<stdin>", input)
328 elif self.options.verbose:
329 self.log_message("No changes in stdin")
330
331 def refactor_tree(self, tree, name):
332 """Refactors a parse tree (modifying the tree in place).
Martin v. Löwisab41b372008-03-19 05:22:42 +0000333
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000334 Args:
335 tree: a pytree.Node instance representing the root of the tree
336 to be refactored.
337 name: a human-readable name for this tree.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000338
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000339 Returns:
340 True if the tree was modified, False otherwise.
341 """
Martin v. Löwis6780a9d2008-05-02 21:30:20 +0000342 # Two calls to chain are required because pre_order.values()
343 # will be a list of lists of fixers:
344 # [[<fixer ...>, <fixer ...>], [<fixer ...>]]
345 all_fixers = chain(chain(*self.pre_order.values()),\
346 chain(*self.post_order.values()))
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000347 for fixer in all_fixers:
348 fixer.start_tree(tree, name)
349
350 self.traverse_by(self.pre_order, tree.pre_order())
351 self.traverse_by(self.post_order, tree.post_order())
352
353 for fixer in all_fixers:
354 fixer.finish_tree(tree, name)
355 return tree.was_changed
356
357 def traverse_by(self, fixers, traversal):
358 """Traverse an AST, applying a set of fixers to each node.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000359
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000360 This is a helper method for refactor_tree().
Martin v. Löwisab41b372008-03-19 05:22:42 +0000361
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000362 Args:
363 fixers: a list of fixer instances.
364 traversal: a generator that yields AST nodes.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000365
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000366 Returns:
367 None
368 """
369 if not fixers:
370 return
371 for node in traversal:
Martin v. Löwis6780a9d2008-05-02 21:30:20 +0000372 for fixer in fixers[node.type] + fixers[None]:
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000373 results = fixer.match(node)
374 if results:
375 new = fixer.transform(node, results)
376 if new is not None and (new != node or
377 str(new) != str(node)):
378 node.replace(new)
379 node = new
380
381 def write_file(self, new_text, filename, old_text=None):
382 """Writes a string to a file.
383
384 If there are no changes, this is a no-op.
385
386 Otherwise, it first shows a unified diff between the old text
387 and the new text, and then rewrites the file; the latter is
388 only done if the write option is set.
389 """
390 self.files.append(filename)
391 if old_text is None:
392 try:
393 f = open(filename, "r")
394 except IOError, err:
395 self.log_error("Can't read %s: %s", filename, err)
396 return
397 try:
398 old_text = f.read()
399 finally:
400 f.close()
401 if old_text == new_text:
402 if self.options.verbose:
403 self.log_message("No changes to %s", filename)
404 return
405 diff_texts(old_text, new_text, filename)
406 if not self.options.write:
407 if self.options.verbose:
408 self.log_message("Not writing changes to %s", filename)
409 return
410 backup = filename + ".bak"
411 if os.path.lexists(backup):
412 try:
413 os.remove(backup)
414 except os.error, err:
415 self.log_message("Can't remove backup %s", backup)
416 try:
417 os.rename(filename, backup)
418 except os.error, err:
419 self.log_message("Can't rename %s to %s", filename, backup)
420 try:
421 f = open(filename, "w")
422 except os.error, err:
423 self.log_error("Can't create %s: %s", filename, err)
424 return
425 try:
426 try:
427 f.write(new_text)
428 except os.error, err:
429 self.log_error("Can't write %s: %s", filename, err)
430 finally:
431 f.close()
432 if self.options.verbose:
433 self.log_message("Wrote changes to %s", filename)
434
435 PS1 = ">>> "
436 PS2 = "... "
437
438 def refactor_docstring(self, input, filename):
439 """Refactors a docstring, looking for doctests.
440
441 This returns a modified version of the input string. It looks
442 for doctests, which start with a ">>>" prompt, and may be
443 continued with "..." prompts, as long as the "..." is indented
444 the same as the ">>>".
445
446 (Unfortunately we can't use the doctest module's parser,
447 since, like most parsers, it is not geared towards preserving
448 the original source.)
449 """
450 result = []
451 block = None
452 block_lineno = None
453 indent = None
454 lineno = 0
455 for line in input.splitlines(True):
456 lineno += 1
457 if line.lstrip().startswith(self.PS1):
458 if block is not None:
459 result.extend(self.refactor_doctest(block, block_lineno,
460 indent, filename))
461 block_lineno = lineno
462 block = [line]
463 i = line.find(self.PS1)
464 indent = line[:i]
465 elif (indent is not None and
466 (line.startswith(indent + self.PS2) or
467 line == indent + self.PS2.rstrip() + "\n")):
468 block.append(line)
469 else:
470 if block is not None:
471 result.extend(self.refactor_doctest(block, block_lineno,
472 indent, filename))
473 block = None
474 indent = None
475 result.append(line)
476 if block is not None:
477 result.extend(self.refactor_doctest(block, block_lineno,
478 indent, filename))
479 return "".join(result)
480
481 def refactor_doctest(self, block, lineno, indent, filename):
482 """Refactors one doctest.
483
484 A doctest is given as a block of lines, the first of which starts
485 with ">>>" (possibly indented), while the remaining lines start
486 with "..." (identically indented).
487
488 """
489 try:
490 tree = self.parse_block(block, lineno, indent)
491 except Exception, err:
492 if self.options.verbose:
493 for line in block:
494 self.log_message("Source: %s", line.rstrip("\n"))
495 self.log_error("Can't parse docstring in %s line %s: %s: %s",
496 filename, lineno, err.__class__.__name__, err)
497 return block
498 if self.refactor_tree(tree, filename):
499 new = str(tree).splitlines(True)
500 # Undo the adjustment of the line numbers in wrap_toks() below.
501 clipped, new = new[:lineno-1], new[lineno-1:]
502 assert clipped == ["\n"] * (lineno-1), clipped
503 if not new[-1].endswith("\n"):
504 new[-1] += "\n"
505 block = [indent + self.PS1 + new.pop(0)]
506 if new:
507 block += [indent + self.PS2 + line for line in new]
508 return block
509
510 def summarize(self):
511 if self.options.write:
512 were = "were"
513 else:
514 were = "need to be"
515 if not self.files:
516 self.log_message("No files %s modified.", were)
517 else:
518 self.log_message("Files that %s modified:", were)
519 for file in self.files:
520 self.log_message(file)
521 if self.fixer_log:
522 self.log_message("Warnings/messages while refactoring:")
523 for message in self.fixer_log:
524 self.log_message(message)
525 if self.errors:
526 if len(self.errors) == 1:
527 self.log_message("There was 1 error:")
528 else:
529 self.log_message("There were %d errors:", len(self.errors))
530 for msg, args, kwds in self.errors:
531 self.log_message(msg, *args, **kwds)
532
533 def parse_block(self, block, lineno, indent):
534 """Parses a block into a tree.
535
536 This is necessary to get correct line number / offset information
537 in the parser diagnostics and embedded into the parse tree.
538 """
539 return self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
540
541 def wrap_toks(self, block, lineno, indent):
542 """Wraps a tokenize stream to systematically modify start/end."""
543 tokens = tokenize.generate_tokens(self.gen_lines(block, indent).next)
544 for type, value, (line0, col0), (line1, col1), line_text in tokens:
545 line0 += lineno - 1
546 line1 += lineno - 1
547 # Don't bother updating the columns; this is too complicated
548 # since line_text would also have to be updated and it would
549 # still break for tokens spanning lines. Let the user guess
550 # that the column numbers for doctests are relative to the
551 # end of the prompt string (PS1 or PS2).
552 yield type, value, (line0, col0), (line1, col1), line_text
553
554
555 def gen_lines(self, block, indent):
556 """Generates lines as expected by tokenize from a list of lines.
557
558 This strips the first len(indent + self.PS1) characters off each line.
559 """
560 prefix1 = indent + self.PS1
561 prefix2 = indent + self.PS2
562 prefix = prefix1
563 for line in block:
564 if line.startswith(prefix):
565 yield line[len(prefix):]
566 elif line == prefix.rstrip() + "\n":
567 yield "\n"
568 else:
569 raise AssertionError("line=%r, prefix=%r" % (line, prefix))
570 prefix = prefix2
571 while True:
572 yield ""
573
574
575def diff_texts(a, b, filename):
576 """Prints a unified diff of two strings."""
577 a = a.splitlines()
578 b = b.splitlines()
579 for line in difflib.unified_diff(a, b, filename, filename,
580 "(original)", "(refactored)",
581 lineterm=""):
582 print line
583
584
585if __name__ == "__main__":
Martin v. Löwisab41b372008-03-19 05:22:42 +0000586 sys.exit(main())