blob: 82dfba1b1b50cb46db7d72df7e2898599da4b849 [file] [log] [blame]
Martin v. Löwis5e37bae2008-03-19 04:43:46 +00001"""Utility functions, node construction macros, etc."""
2# Author: Collin Winter
3
4# Local imports
5from ..pgen2 import token
6from ..pytree import Leaf, Node
7from ..pygram import python_symbols as syms
Martin v. Löwisab41b372008-03-19 05:22:42 +00008from .. import patcomp
Martin v. Löwis5e37bae2008-03-19 04:43:46 +00009
10
11###########################################################
12### Common node-construction "macros"
13###########################################################
14
15def KeywordArg(keyword, value):
16 return Node(syms.argument,
17 [keyword, Leaf(token.EQUAL, '='), value])
18
19def LParen():
20 return Leaf(token.LPAR, "(")
21
22def RParen():
23 return Leaf(token.RPAR, ")")
24
25def Assign(target, source):
26 """Build an assignment statement"""
27 if not isinstance(target, list):
28 target = [target]
29 if not isinstance(source, list):
30 source.set_prefix(" ")
31 source = [source]
32
33 return Node(syms.atom,
34 target + [Leaf(token.EQUAL, "=", prefix=" ")] + source)
35
36def Name(name, prefix=None):
37 """Return a NAME leaf"""
38 return Leaf(token.NAME, name, prefix=prefix)
39
40def Attr(obj, attr):
41 """A node tuple for obj.attr"""
42 return [obj, Node(syms.trailer, [Dot(), attr])]
43
44def Comma():
45 """A comma leaf"""
46 return Leaf(token.COMMA, ",")
47
48def Dot():
49 """A period (.) leaf"""
50 return Leaf(token.DOT, ".")
51
52def ArgList(args, lparen=LParen(), rparen=RParen()):
53 """A parenthesised argument list, used by Call()"""
54 return Node(syms.trailer,
55 [lparen.clone(),
56 Node(syms.arglist, args),
57 rparen.clone()])
58
59def Call(func_name, args, prefix=None):
60 """A function call"""
61 node = Node(syms.power, [func_name, ArgList(args)])
62 if prefix is not None:
63 node.set_prefix(prefix)
64 return node
65
66def Newline():
67 """A newline literal"""
68 return Leaf(token.NEWLINE, "\n")
69
70def BlankLine():
71 """A blank line"""
72 return Leaf(token.NEWLINE, "")
73
74def Number(n, prefix=None):
75 return Leaf(token.NUMBER, n, prefix=prefix)
76
77def Subscript(index_node):
78 """A numeric or string subscript"""
79 return Node(syms.trailer, [Leaf(token.LBRACE, '['),
80 index_node,
81 Leaf(token.RBRACE, ']')])
82
83def String(string, prefix=None):
84 """A string leaf"""
85 return Leaf(token.STRING, string, prefix=prefix)
86
87def ListComp(xp, fp, it, test=None):
88 """A list comprehension of the form [xp for fp in it if test].
89
90 If test is None, the "if test" part is omitted.
91 """
92 xp.set_prefix("")
93 fp.set_prefix(" ")
94 it.set_prefix(" ")
95 for_leaf = Leaf(token.NAME, "for")
96 for_leaf.set_prefix(" ")
97 in_leaf = Leaf(token.NAME, "in")
98 in_leaf.set_prefix(" ")
99 inner_args = [for_leaf, fp, in_leaf, it]
100 if test:
101 test.set_prefix(" ")
102 if_leaf = Leaf(token.NAME, "if")
103 if_leaf.set_prefix(" ")
104 inner_args.append(Node(syms.comp_if, [if_leaf, test]))
105 inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
106 return Node(syms.atom,
107 [Leaf(token.LBRACE, "["),
108 inner,
109 Leaf(token.RBRACE, "]")])
110
Martin v. Löwis966d0e02008-03-24 00:46:53 +0000111def FromImport(package_name, name_leafs):
112 """ Return an import statement in the form:
113 from package import name_leafs"""
114 # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
115 assert package_name == '.' or '.' not in package.name, "FromImport has "\
116 "not been tested with dotted package names -- use at your own "\
117 "peril!"
118
119 for leaf in name_leafs:
120 # Pull the leaves out of their old tree
121 leaf.remove()
122
123 children = [Leaf(token.NAME, 'from'),
124 Leaf(token.NAME, package_name, prefix=" "),
125 Leaf(token.NAME, 'import', prefix=" "),
126 Node(syms.import_as_names, name_leafs)]
127 imp = Node(syms.import_from, children)
128 return imp
129
130
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000131###########################################################
132### Determine whether a node represents a given literal
133###########################################################
134
135def is_tuple(node):
136 """Does the node represent a tuple literal?"""
137 if isinstance(node, Node) and node.children == [LParen(), RParen()]:
138 return True
139 return (isinstance(node, Node)
140 and len(node.children) == 3
141 and isinstance(node.children[0], Leaf)
142 and isinstance(node.children[1], Node)
143 and isinstance(node.children[2], Leaf)
144 and node.children[0].value == "("
145 and node.children[2].value == ")")
146
147def is_list(node):
148 """Does the node represent a list literal?"""
149 return (isinstance(node, Node)
150 and len(node.children) > 1
151 and isinstance(node.children[0], Leaf)
152 and isinstance(node.children[-1], Leaf)
153 and node.children[0].value == "["
154 and node.children[-1].value == "]")
155
156###########################################################
157### Common portability code. This allows fixers to do, eg,
158### "from .util import set" and forget about it.
159###########################################################
160
161try:
162 any = any
163except NameError:
164 def any(l):
165 for o in l:
166 if o:
167 return True
168 return False
169
170try:
171 set = set
172except NameError:
173 from sets import Set as set
174
175try:
176 reversed = reversed
177except NameError:
178 def reversed(l):
179 return l[::-1]
180
181###########################################################
182### Misc
183###########################################################
184
185def attr_chain(obj, attr):
186 """Follow an attribute chain.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000187
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000188 If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
189 use this to iterate over all objects in the chain. Iteration is
190 terminated by getattr(x, attr) is None.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000191
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000192 Args:
193 obj: the starting object
194 attr: the name of the chaining attribute
Martin v. Löwisab41b372008-03-19 05:22:42 +0000195
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000196 Yields:
197 Each successive object in the chain.
198 """
199 next = getattr(obj, attr)
200 while next:
201 yield next
202 next = getattr(next, attr)
203
Martin v. Löwisab41b372008-03-19 05:22:42 +0000204p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
205 | comp_for< 'for' any 'in' node=any any* >
206 """
207p1 = """
208power<
209 ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
210 'any' | 'all' | (any* trailer< '.' 'join' >) )
211 trailer< '(' node=any ')' >
212 any*
213>
214"""
215p2 = """
216power<
217 'sorted'
218 trailer< '(' arglist<node=any any*> ')' >
219 any*
220>
221"""
222pats_built = False
223def in_special_context(node):
224 """ Returns true if node is in an environment where all that is required
225 of it is being itterable (ie, it doesn't matter if it returns a list
226 or an itterator).
227 See test_map_nochange in test_fixers.py for some examples and tests.
228 """
229 global p0, p1, p2, pats_built
230 if not pats_built:
231 p1 = patcomp.compile_pattern(p1)
232 p0 = patcomp.compile_pattern(p0)
233 p2 = patcomp.compile_pattern(p2)
234 pats_built = True
235 patterns = [p0, p1, p2]
236 for pattern, parent in zip(patterns, attr_chain(node, "parent")):
237 results = {}
238 if pattern.match(parent, results) and results["node"] is node:
239 return True
240 return False
241
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000242###########################################################
243### The following functions are to find bindings in a suite
244###########################################################
245
246def make_suite(node):
247 if node.type == syms.suite:
248 return node
249 node = node.clone()
250 parent, node.parent = node.parent, None
251 suite = Node(syms.suite, [node])
252 suite.parent = parent
253 return suite
254
255def does_tree_import(package, name, node):
256 """ Returns true if name is imported from package at the
257 top level of the tree which node belongs to.
258 To cover the case of an import like 'import foo', use
259 Null for the package and 'foo' for the name. """
260 # Scamper up to the top level namespace
261 while node.type != syms.file_input:
262 assert node.parent, "Tree is insane! root found before "\
263 "file_input node was found."
264 node = node.parent
265
266 binding = find_binding(name, node, package)
267 return bool(binding)
268
269_def_syms = set([syms.classdef, syms.funcdef])
270def find_binding(name, node, package=None):
271 """ Returns the node which binds variable name, otherwise None.
272 If optional argument package is supplied, only imports will
273 be returned.
274 See test cases for examples."""
275 for child in node.children:
276 ret = None
277 if child.type == syms.for_stmt:
278 if _find(name, child.children[1]):
279 return child
280 n = find_binding(name, make_suite(child.children[-1]), package)
281 if n: ret = n
282 elif child.type in (syms.if_stmt, syms.while_stmt):
283 n = find_binding(name, make_suite(child.children[-1]), package)
284 if n: ret = n
285 elif child.type == syms.try_stmt:
286 n = find_binding(name, make_suite(child.children[2]), package)
287 if n:
288 ret = n
289 else:
290 for i, kid in enumerate(child.children[3:]):
291 if kid.type == token.COLON and kid.value == ":":
292 # i+3 is the colon, i+4 is the suite
293 n = find_binding(name, make_suite(child.children[i+4]), package)
294 if n: ret = n
295 elif child.type in _def_syms and child.children[1].value == name:
296 ret = child
297 elif _is_import_binding(child, name, package):
298 ret = child
299 elif child.type == syms.simple_stmt:
300 ret = find_binding(name, child, package)
301 elif child.type == syms.expr_stmt:
Martin v. Löwisab41b372008-03-19 05:22:42 +0000302 if _find(name, child.children[0]):
303 ret = child
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000304
305 if ret:
306 if not package:
307 return ret
308 if ret.type in (syms.import_name, syms.import_from):
309 return ret
310 return None
311
312_block_syms = set([syms.funcdef, syms.classdef, syms.trailer])
313def _find(name, node):
314 nodes = [node]
315 while nodes:
316 node = nodes.pop()
317 if node.type > 256 and node.type not in _block_syms:
318 nodes.extend(node.children)
319 elif node.type == token.NAME and node.value == name:
320 return node
321 return None
322
323def _is_import_binding(node, name, package=None):
324 """ Will reuturn node if node will import name, or node
325 will import * from package. None is returned otherwise.
326 See test cases for examples. """
327
328 if node.type == syms.import_name and not package:
329 imp = node.children[1]
330 if imp.type == syms.dotted_as_names:
331 for child in imp.children:
332 if child.type == syms.dotted_as_name:
333 if child.children[2].value == name:
334 return node
335 elif child.type == token.NAME and child.value == name:
336 return node
337 elif imp.type == syms.dotted_as_name:
338 last = imp.children[-1]
339 if last.type == token.NAME and last.value == name:
340 return node
341 elif imp.type == token.NAME and imp.value == name:
342 return node
343 elif node.type == syms.import_from:
344 # unicode(...) is used to make life easier here, because
345 # from a.b import parses to ['import', ['a', '.', 'b'], ...]
346 if package and unicode(node.children[1]).strip() != package:
347 return None
348 n = node.children[3]
349 if package and _find('as', n):
350 # See test_from_import_as for explanation
351 return None
352 elif n.type == syms.import_as_names and _find(name, n):
353 return node
354 elif n.type == syms.import_as_name:
355 child = n.children[2]
356 if child.type == token.NAME and child.value == name:
357 return node
358 elif n.type == token.NAME and n.value == name:
359 return node
360 elif package and n.type == token.STAR:
361 return node
362 return None