blob: 0c485f08d32e86884bc1a6001ed6af80632cd33a [file] [log] [blame]
Martin v. Löwis5e37bae2008-03-19 04:43:46 +00001"""Utility functions, node construction macros, etc."""
2# Author: Collin Winter
3
4# Local imports
Benjamin Petersone6078232008-06-15 02:31:05 +00005from .pgen2 import token
6from .pytree import Leaf, Node
7from .pygram import python_symbols as syms
8from . import patcomp
Martin v. Löwis5e37bae2008-03-19 04:43:46 +00009
10
11###########################################################
12### Common node-construction "macros"
13###########################################################
14
15def KeywordArg(keyword, value):
16 return Node(syms.argument,
17 [keyword, Leaf(token.EQUAL, '='), value])
18
19def LParen():
20 return Leaf(token.LPAR, "(")
21
22def RParen():
23 return Leaf(token.RPAR, ")")
24
25def Assign(target, source):
26 """Build an assignment statement"""
27 if not isinstance(target, list):
28 target = [target]
29 if not isinstance(source, list):
30 source.set_prefix(" ")
31 source = [source]
32
33 return Node(syms.atom,
34 target + [Leaf(token.EQUAL, "=", prefix=" ")] + source)
35
36def Name(name, prefix=None):
37 """Return a NAME leaf"""
38 return Leaf(token.NAME, name, prefix=prefix)
39
40def Attr(obj, attr):
41 """A node tuple for obj.attr"""
42 return [obj, Node(syms.trailer, [Dot(), attr])]
43
44def Comma():
45 """A comma leaf"""
46 return Leaf(token.COMMA, ",")
47
48def Dot():
49 """A period (.) leaf"""
50 return Leaf(token.DOT, ".")
51
52def ArgList(args, lparen=LParen(), rparen=RParen()):
53 """A parenthesised argument list, used by Call()"""
Benjamin Petersone5c1d292008-09-01 17:17:22 +000054 node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
55 if args:
56 node.insert_child(1, Node(syms.arglist, args))
57 return node
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000058
Benjamin Petersone5c1d292008-09-01 17:17:22 +000059def Call(func_name, args=None, prefix=None):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +000060 """A function call"""
61 node = Node(syms.power, [func_name, ArgList(args)])
62 if prefix is not None:
63 node.set_prefix(prefix)
64 return node
65
66def Newline():
67 """A newline literal"""
68 return Leaf(token.NEWLINE, "\n")
69
70def BlankLine():
71 """A blank line"""
72 return Leaf(token.NEWLINE, "")
73
74def Number(n, prefix=None):
75 return Leaf(token.NUMBER, n, prefix=prefix)
76
77def Subscript(index_node):
78 """A numeric or string subscript"""
79 return Node(syms.trailer, [Leaf(token.LBRACE, '['),
80 index_node,
81 Leaf(token.RBRACE, ']')])
82
83def String(string, prefix=None):
84 """A string leaf"""
85 return Leaf(token.STRING, string, prefix=prefix)
86
87def ListComp(xp, fp, it, test=None):
88 """A list comprehension of the form [xp for fp in it if test].
89
90 If test is None, the "if test" part is omitted.
91 """
92 xp.set_prefix("")
93 fp.set_prefix(" ")
94 it.set_prefix(" ")
95 for_leaf = Leaf(token.NAME, "for")
96 for_leaf.set_prefix(" ")
97 in_leaf = Leaf(token.NAME, "in")
98 in_leaf.set_prefix(" ")
99 inner_args = [for_leaf, fp, in_leaf, it]
100 if test:
101 test.set_prefix(" ")
102 if_leaf = Leaf(token.NAME, "if")
103 if_leaf.set_prefix(" ")
104 inner_args.append(Node(syms.comp_if, [if_leaf, test]))
105 inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
106 return Node(syms.atom,
107 [Leaf(token.LBRACE, "["),
108 inner,
109 Leaf(token.RBRACE, "]")])
110
Martin v. Löwis966d0e02008-03-24 00:46:53 +0000111def FromImport(package_name, name_leafs):
112 """ Return an import statement in the form:
113 from package import name_leafs"""
114 # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
Benjamin Peterson699b0902008-07-16 17:01:46 +0000115 #assert package_name == '.' or '.' not in package_name, "FromImport has "\
116 # "not been tested with dotted package names -- use at your own "\
117 # "peril!"
Martin v. Löwis966d0e02008-03-24 00:46:53 +0000118
119 for leaf in name_leafs:
120 # Pull the leaves out of their old tree
121 leaf.remove()
122
123 children = [Leaf(token.NAME, 'from'),
124 Leaf(token.NAME, package_name, prefix=" "),
125 Leaf(token.NAME, 'import', prefix=" "),
126 Node(syms.import_as_names, name_leafs)]
127 imp = Node(syms.import_from, children)
128 return imp
129
130
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000131###########################################################
132### Determine whether a node represents a given literal
133###########################################################
134
135def is_tuple(node):
136 """Does the node represent a tuple literal?"""
137 if isinstance(node, Node) and node.children == [LParen(), RParen()]:
138 return True
139 return (isinstance(node, Node)
140 and len(node.children) == 3
141 and isinstance(node.children[0], Leaf)
142 and isinstance(node.children[1], Node)
143 and isinstance(node.children[2], Leaf)
144 and node.children[0].value == "("
145 and node.children[2].value == ")")
146
147def is_list(node):
148 """Does the node represent a list literal?"""
149 return (isinstance(node, Node)
150 and len(node.children) > 1
151 and isinstance(node.children[0], Leaf)
152 and isinstance(node.children[-1], Leaf)
153 and node.children[0].value == "["
154 and node.children[-1].value == "]")
155
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000156
157###########################################################
158### Misc
159###########################################################
160
Benjamin Peterson43caaa02008-12-16 03:35:28 +0000161def parenthesize(node):
162 return Node(syms.atom, [LParen(), node, RParen()])
163
Martin v. Löwis60a819d2008-04-10 02:48:01 +0000164
165consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum",
166 "min", "max"])
167
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000168def attr_chain(obj, attr):
169 """Follow an attribute chain.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000170
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000171 If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
172 use this to iterate over all objects in the chain. Iteration is
173 terminated by getattr(x, attr) is None.
Martin v. Löwisab41b372008-03-19 05:22:42 +0000174
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000175 Args:
176 obj: the starting object
177 attr: the name of the chaining attribute
Martin v. Löwisab41b372008-03-19 05:22:42 +0000178
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000179 Yields:
180 Each successive object in the chain.
181 """
182 next = getattr(obj, attr)
183 while next:
184 yield next
185 next = getattr(next, attr)
186
Martin v. Löwisab41b372008-03-19 05:22:42 +0000187p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
188 | comp_for< 'for' any 'in' node=any any* >
189 """
190p1 = """
191power<
192 ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
193 'any' | 'all' | (any* trailer< '.' 'join' >) )
194 trailer< '(' node=any ')' >
195 any*
196>
197"""
198p2 = """
199power<
200 'sorted'
201 trailer< '(' arglist<node=any any*> ')' >
202 any*
203>
204"""
205pats_built = False
206def in_special_context(node):
207 """ Returns true if node is in an environment where all that is required
208 of it is being itterable (ie, it doesn't matter if it returns a list
209 or an itterator).
210 See test_map_nochange in test_fixers.py for some examples and tests.
211 """
212 global p0, p1, p2, pats_built
213 if not pats_built:
214 p1 = patcomp.compile_pattern(p1)
215 p0 = patcomp.compile_pattern(p0)
216 p2 = patcomp.compile_pattern(p2)
217 pats_built = True
218 patterns = [p0, p1, p2]
219 for pattern, parent in zip(patterns, attr_chain(node, "parent")):
220 results = {}
221 if pattern.match(parent, results) and results["node"] is node:
222 return True
223 return False
224
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000225###########################################################
226### The following functions are to find bindings in a suite
227###########################################################
228
229def make_suite(node):
230 if node.type == syms.suite:
231 return node
232 node = node.clone()
233 parent, node.parent = node.parent, None
234 suite = Node(syms.suite, [node])
235 suite.parent = parent
236 return suite
237
Benjamin Peterson43caaa02008-12-16 03:35:28 +0000238def find_root(node):
239 """Find the top level namespace."""
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000240 # Scamper up to the top level namespace
241 while node.type != syms.file_input:
242 assert node.parent, "Tree is insane! root found before "\
243 "file_input node was found."
244 node = node.parent
Benjamin Peterson43caaa02008-12-16 03:35:28 +0000245 return node
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000246
Benjamin Peterson43caaa02008-12-16 03:35:28 +0000247def does_tree_import(package, name, node):
248 """ Returns true if name is imported from package at the
249 top level of the tree which node belongs to.
250 To cover the case of an import like 'import foo', use
251 None for the package and 'foo' for the name. """
252 binding = find_binding(name, find_root(node), package)
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000253 return bool(binding)
254
Benjamin Peterson43caaa02008-12-16 03:35:28 +0000255def is_import(node):
256 """Returns true if the node is an import statement."""
257 return node.type in (syms.import_name, syms.import_from)
258
259def touch_import(package, name, node):
260 """ Works like `does_tree_import` but adds an import statement
261 if it was not imported. """
262 def is_import_stmt(node):
263 return node.type == syms.simple_stmt and node.children and \
264 is_import(node.children[0])
265
266 root = find_root(node)
267
268 if does_tree_import(package, name, root):
269 return
270
271 add_newline_before = False
272
273 # figure out where to insert the new import. First try to find
274 # the first import and then skip to the last one.
275 insert_pos = offset = 0
276 for idx, node in enumerate(root.children):
277 if not is_import_stmt(node):
278 continue
279 for offset, node2 in enumerate(root.children[idx:]):
280 if not is_import_stmt(node2):
281 break
282 insert_pos = idx + offset
283 break
284
285 # if there are no imports where we can insert, find the docstring.
286 # if that also fails, we stick to the beginning of the file
287 if insert_pos == 0:
288 for idx, node in enumerate(root.children):
289 if node.type == syms.simple_stmt and node.children and \
290 node.children[0].type == token.STRING:
291 insert_pos = idx + 1
292 add_newline_before
293 break
294
295 if package is None:
296 import_ = Node(syms.import_name, [
297 Leaf(token.NAME, 'import'),
298 Leaf(token.NAME, name, prefix=' ')
299 ])
300 else:
301 import_ = FromImport(package, [Leaf(token.NAME, name, prefix=' ')])
302
303 children = [import_, Newline()]
304 if add_newline_before:
305 children.insert(0, Newline())
306 root.insert_child(insert_pos, Node(syms.simple_stmt, children))
307
308
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000309_def_syms = set([syms.classdef, syms.funcdef])
310def find_binding(name, node, package=None):
311 """ Returns the node which binds variable name, otherwise None.
312 If optional argument package is supplied, only imports will
313 be returned.
314 See test cases for examples."""
315 for child in node.children:
316 ret = None
317 if child.type == syms.for_stmt:
318 if _find(name, child.children[1]):
319 return child
320 n = find_binding(name, make_suite(child.children[-1]), package)
321 if n: ret = n
322 elif child.type in (syms.if_stmt, syms.while_stmt):
323 n = find_binding(name, make_suite(child.children[-1]), package)
324 if n: ret = n
325 elif child.type == syms.try_stmt:
326 n = find_binding(name, make_suite(child.children[2]), package)
327 if n:
328 ret = n
329 else:
330 for i, kid in enumerate(child.children[3:]):
331 if kid.type == token.COLON and kid.value == ":":
332 # i+3 is the colon, i+4 is the suite
333 n = find_binding(name, make_suite(child.children[i+4]), package)
334 if n: ret = n
335 elif child.type in _def_syms and child.children[1].value == name:
336 ret = child
337 elif _is_import_binding(child, name, package):
338 ret = child
339 elif child.type == syms.simple_stmt:
340 ret = find_binding(name, child, package)
341 elif child.type == syms.expr_stmt:
Martin v. Löwisab41b372008-03-19 05:22:42 +0000342 if _find(name, child.children[0]):
343 ret = child
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000344
345 if ret:
346 if not package:
347 return ret
Benjamin Peterson43caaa02008-12-16 03:35:28 +0000348 if is_import(ret):
Martin v. Löwis5e37bae2008-03-19 04:43:46 +0000349 return ret
350 return None
351
352_block_syms = set([syms.funcdef, syms.classdef, syms.trailer])
353def _find(name, node):
354 nodes = [node]
355 while nodes:
356 node = nodes.pop()
357 if node.type > 256 and node.type not in _block_syms:
358 nodes.extend(node.children)
359 elif node.type == token.NAME and node.value == name:
360 return node
361 return None
362
363def _is_import_binding(node, name, package=None):
364 """ Will reuturn node if node will import name, or node
365 will import * from package. None is returned otherwise.
366 See test cases for examples. """
367
368 if node.type == syms.import_name and not package:
369 imp = node.children[1]
370 if imp.type == syms.dotted_as_names:
371 for child in imp.children:
372 if child.type == syms.dotted_as_name:
373 if child.children[2].value == name:
374 return node
375 elif child.type == token.NAME and child.value == name:
376 return node
377 elif imp.type == syms.dotted_as_name:
378 last = imp.children[-1]
379 if last.type == token.NAME and last.value == name:
380 return node
381 elif imp.type == token.NAME and imp.value == name:
382 return node
383 elif node.type == syms.import_from:
384 # unicode(...) is used to make life easier here, because
385 # from a.b import parses to ['import', ['a', '.', 'b'], ...]
386 if package and unicode(node.children[1]).strip() != package:
387 return None
388 n = node.children[3]
389 if package and _find('as', n):
390 # See test_from_import_as for explanation
391 return None
392 elif n.type == syms.import_as_names and _find(name, n):
393 return node
394 elif n.type == syms.import_as_name:
395 child = n.children[2]
396 if child.type == token.NAME and child.value == name:
397 return node
398 elif n.type == token.NAME and n.value == name:
399 return node
400 elif package and n.type == token.STAR:
401 return node
402 return None