blob: 77eb24971ed24cb25508d55199b264d3880207f0 [file] [log] [blame]
Georg Brandl0c77a822008-06-10 16:37:50 +00001"""
2 ast
3 ~~~
4
5 The `ast` module helps Python applications to process trees of the Python
6 abstract syntax grammar. The abstract syntax itself might change with
7 each Python release; this module helps to find out programmatically what
8 the current grammar looks like and allows modifications of it.
9
10 An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11 a flag to the `compile()` builtin function or by using the `parse()`
12 function from this module. The result will be a tree of objects whose
13 classes all inherit from `ast.AST`.
14
15 A modified abstract syntax tree can be compiled into a Python code object
16 using the built-in `compile()` function.
17
18 Additionally various helper functions are provided that make working with
19 the trees simpler. The main intention of the helper functions and this
20 module in general is to provide an easy to use interface for libraries
21 that work tightly with the python syntax (template engines for example).
22
23
24 :copyright: Copyright 2008 by Armin Ronacher.
25 :license: Python License.
26"""
Pablo Galindo27fc3b62019-11-24 23:02:40 +000027import sys
Georg Brandl0c77a822008-06-10 16:37:50 +000028from _ast import *
29
30
Guido van Rossum495da292019-03-07 12:38:08 -080031def parse(source, filename='<unknown>', mode='exec', *,
Guido van Rossum10b55c12019-06-11 17:23:12 -070032 type_comments=False, feature_version=None):
Georg Brandl0c77a822008-06-10 16:37:50 +000033 """
Terry Reedyfeac6242011-01-24 21:36:03 +000034 Parse the source into an AST node.
35 Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
Guido van Rossumdcfcd142019-01-31 03:40:27 -080036 Pass type_comments=True to get back type comments where the syntax allows.
Georg Brandl0c77a822008-06-10 16:37:50 +000037 """
Guido van Rossumdcfcd142019-01-31 03:40:27 -080038 flags = PyCF_ONLY_AST
39 if type_comments:
40 flags |= PyCF_TYPE_COMMENTS
Guido van Rossum10b55c12019-06-11 17:23:12 -070041 if isinstance(feature_version, tuple):
42 major, minor = feature_version # Should be a 2-tuple.
43 assert major == 3
44 feature_version = minor
45 elif feature_version is None:
46 feature_version = -1
47 # Else it should be an int giving the minor version for 3.x.
Guido van Rossum495da292019-03-07 12:38:08 -080048 return compile(source, filename, mode, flags,
Victor Stinnerefdf6ca2019-06-12 02:52:16 +020049 _feature_version=feature_version)
Georg Brandl0c77a822008-06-10 16:37:50 +000050
51
52def literal_eval(node_or_string):
53 """
54 Safely evaluate an expression node or a string containing a Python
55 expression. The string or node provided may only consist of the following
Éric Araujo2a83cc62011-04-17 19:10:27 +020056 Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
57 sets, booleans, and None.
Georg Brandl0c77a822008-06-10 16:37:50 +000058 """
Georg Brandl0c77a822008-06-10 16:37:50 +000059 if isinstance(node_or_string, str):
60 node_or_string = parse(node_or_string, mode='eval')
61 if isinstance(node_or_string, Expression):
62 node_or_string = node_or_string.body
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020063 def _convert_num(node):
64 if isinstance(node, Constant):
Serhiy Storchaka3f228112018-09-27 17:42:37 +030065 if type(node.value) in (int, float, complex):
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020066 return node.value
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020067 raise ValueError('malformed node or string: ' + repr(node))
68 def _convert_signed_num(node):
69 if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
70 operand = _convert_num(node.operand)
71 if isinstance(node.op, UAdd):
72 return + operand
73 else:
74 return - operand
75 return _convert_num(node)
Georg Brandl0c77a822008-06-10 16:37:50 +000076 def _convert(node):
Victor Stinnerf2c1aa12016-01-26 00:40:57 +010077 if isinstance(node, Constant):
78 return node.value
Georg Brandl0c77a822008-06-10 16:37:50 +000079 elif isinstance(node, Tuple):
80 return tuple(map(_convert, node.elts))
81 elif isinstance(node, List):
82 return list(map(_convert, node.elts))
Georg Brandl492f3fc2010-07-11 09:41:21 +000083 elif isinstance(node, Set):
84 return set(map(_convert, node.elts))
Georg Brandl0c77a822008-06-10 16:37:50 +000085 elif isinstance(node, Dict):
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020086 return dict(zip(map(_convert, node.keys),
87 map(_convert, node.values)))
Victor Stinnerf2c1aa12016-01-26 00:40:57 +010088 elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020089 left = _convert_signed_num(node.left)
90 right = _convert_num(node.right)
91 if isinstance(left, (int, float)) and isinstance(right, complex):
Victor Stinnerf2c1aa12016-01-26 00:40:57 +010092 if isinstance(node.op, Add):
93 return left + right
94 else:
95 return left - right
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020096 return _convert_signed_num(node)
Georg Brandl0c77a822008-06-10 16:37:50 +000097 return _convert(node_or_string)
98
99
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300100def dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
Georg Brandl0c77a822008-06-10 16:37:50 +0000101 """
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300102 Return a formatted dump of the tree in node. This is mainly useful for
103 debugging purposes. If annotate_fields is true (by default),
104 the returned string will show the names and the values for fields.
105 If annotate_fields is false, the result string will be more compact by
106 omitting unambiguous field names. Attributes such as line
Benjamin Petersondcf97b92008-07-02 17:30:14 +0000107 numbers and column offsets are not dumped by default. If this is wanted,
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300108 include_attributes can be set to true. If indent is a non-negative
109 integer or string, then the tree will be pretty-printed with that indent
110 level. None (the default) selects the single line representation.
Georg Brandl0c77a822008-06-10 16:37:50 +0000111 """
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300112 def _format(node, level=0):
113 if indent is not None:
114 level += 1
115 prefix = '\n' + indent * level
116 sep = ',\n' + indent * level
117 else:
118 prefix = ''
119 sep = ', '
Georg Brandl0c77a822008-06-10 16:37:50 +0000120 if isinstance(node, AST):
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300121 args = []
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300122 allsimple = True
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300123 keywords = annotate_fields
124 for field in node._fields:
125 try:
126 value = getattr(node, field)
127 except AttributeError:
128 keywords = True
129 else:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300130 value, simple = _format(value, level)
131 allsimple = allsimple and simple
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300132 if keywords:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300133 args.append('%s=%s' % (field, value))
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300134 else:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300135 args.append(value)
Georg Brandl0c77a822008-06-10 16:37:50 +0000136 if include_attributes and node._attributes:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300137 for attr in node._attributes:
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300138 try:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300139 value = getattr(node, attr)
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300140 except AttributeError:
141 pass
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300142 else:
143 value, simple = _format(value, level)
144 allsimple = allsimple and simple
145 args.append('%s=%s' % (attr, value))
146 if allsimple and len(args) <= 3:
147 return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args
148 return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False
Georg Brandl0c77a822008-06-10 16:37:50 +0000149 elif isinstance(node, list):
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300150 if not node:
151 return '[]', True
152 return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False
153 return repr(node), True
154
Georg Brandl0c77a822008-06-10 16:37:50 +0000155 if not isinstance(node, AST):
156 raise TypeError('expected AST, got %r' % node.__class__.__name__)
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300157 if indent is not None and not isinstance(indent, str):
158 indent = ' ' * indent
159 return _format(node)[0]
Georg Brandl0c77a822008-06-10 16:37:50 +0000160
161
162def copy_location(new_node, old_node):
163 """
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000164 Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
165 attributes) from *old_node* to *new_node* if possible, and return *new_node*.
Georg Brandl0c77a822008-06-10 16:37:50 +0000166 """
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000167 for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
Georg Brandl0c77a822008-06-10 16:37:50 +0000168 if attr in old_node._attributes and attr in new_node._attributes \
169 and hasattr(old_node, attr):
170 setattr(new_node, attr, getattr(old_node, attr))
171 return new_node
172
173
174def fix_missing_locations(node):
175 """
176 When you compile a node tree with compile(), the compiler expects lineno and
177 col_offset attributes for every node that supports them. This is rather
178 tedious to fill in for generated nodes, so this helper adds these attributes
179 recursively where not already set, by setting them to the values of the
180 parent node. It works recursively starting at *node*.
181 """
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000182 def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
Georg Brandl0c77a822008-06-10 16:37:50 +0000183 if 'lineno' in node._attributes:
184 if not hasattr(node, 'lineno'):
185 node.lineno = lineno
186 else:
187 lineno = node.lineno
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000188 if 'end_lineno' in node._attributes:
189 if not hasattr(node, 'end_lineno'):
190 node.end_lineno = end_lineno
191 else:
192 end_lineno = node.end_lineno
Georg Brandl0c77a822008-06-10 16:37:50 +0000193 if 'col_offset' in node._attributes:
194 if not hasattr(node, 'col_offset'):
195 node.col_offset = col_offset
196 else:
197 col_offset = node.col_offset
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000198 if 'end_col_offset' in node._attributes:
199 if not hasattr(node, 'end_col_offset'):
200 node.end_col_offset = end_col_offset
201 else:
202 end_col_offset = node.end_col_offset
Georg Brandl0c77a822008-06-10 16:37:50 +0000203 for child in iter_child_nodes(node):
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000204 _fix(child, lineno, col_offset, end_lineno, end_col_offset)
205 _fix(node, 1, 0, 1, 0)
Georg Brandl0c77a822008-06-10 16:37:50 +0000206 return node
207
208
209def increment_lineno(node, n=1):
210 """
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000211 Increment the line number and end line number of each node in the tree
212 starting at *node* by *n*. This is useful to "move code" to a different
213 location in a file.
Georg Brandl0c77a822008-06-10 16:37:50 +0000214 """
Georg Brandl0c77a822008-06-10 16:37:50 +0000215 for child in walk(node):
216 if 'lineno' in child._attributes:
217 child.lineno = getattr(child, 'lineno', 0) + n
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000218 if 'end_lineno' in child._attributes:
219 child.end_lineno = getattr(child, 'end_lineno', 0) + n
Georg Brandl0c77a822008-06-10 16:37:50 +0000220 return node
221
222
223def iter_fields(node):
224 """
225 Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
226 that is present on *node*.
227 """
228 for field in node._fields:
229 try:
230 yield field, getattr(node, field)
231 except AttributeError:
232 pass
233
234
235def iter_child_nodes(node):
236 """
237 Yield all direct child nodes of *node*, that is, all fields that are nodes
238 and all items of fields that are lists of nodes.
239 """
240 for name, field in iter_fields(node):
241 if isinstance(field, AST):
242 yield field
243 elif isinstance(field, list):
244 for item in field:
245 if isinstance(item, AST):
246 yield item
247
248
249def get_docstring(node, clean=True):
250 """
251 Return the docstring for the given node or None if no docstring can
252 be found. If the node provided does not have docstrings a TypeError
253 will be raised.
Matthias Bussonnier41cea702017-02-23 22:44:19 -0800254
255 If *clean* is `True`, all tabs are expanded to spaces and any whitespace
256 that can be uniformly removed from the second line onwards is removed.
Georg Brandl0c77a822008-06-10 16:37:50 +0000257 """
Yury Selivanov2f07a662015-07-23 08:54:35 +0300258 if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
Georg Brandl0c77a822008-06-10 16:37:50 +0000259 raise TypeError("%r can't have docstrings" % node.__class__.__name__)
Serhiy Storchaka08f127a2018-06-15 11:05:15 +0300260 if not(node.body and isinstance(node.body[0], Expr)):
Serhiy Storchaka73cbe7a2018-05-29 12:04:55 +0300261 return None
262 node = node.body[0].value
263 if isinstance(node, Str):
264 text = node.s
265 elif isinstance(node, Constant) and isinstance(node.value, str):
266 text = node.value
267 else:
268 return None
Serhiy Storchaka08f127a2018-06-15 11:05:15 +0300269 if clean:
Victor Stinnerf2c1aa12016-01-26 00:40:57 +0100270 import inspect
271 text = inspect.cleandoc(text)
272 return text
Georg Brandl0c77a822008-06-10 16:37:50 +0000273
274
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000275def _splitlines_no_ff(source):
276 """Split a string into lines ignoring form feed and other chars.
277
278 This mimics how the Python parser splits source code.
279 """
280 idx = 0
281 lines = []
282 next_line = ''
283 while idx < len(source):
284 c = source[idx]
285 next_line += c
286 idx += 1
287 # Keep \r\n together
288 if c == '\r' and idx < len(source) and source[idx] == '\n':
289 next_line += '\n'
290 idx += 1
291 if c in '\r\n':
292 lines.append(next_line)
293 next_line = ''
294
295 if next_line:
296 lines.append(next_line)
297 return lines
298
299
300def _pad_whitespace(source):
301 """Replace all chars except '\f\t' in a line with spaces."""
302 result = ''
303 for c in source:
304 if c in '\f\t':
305 result += c
306 else:
307 result += ' '
308 return result
309
310
311def get_source_segment(source, node, *, padded=False):
312 """Get source code segment of the *source* that generated *node*.
313
314 If some location information (`lineno`, `end_lineno`, `col_offset`,
315 or `end_col_offset`) is missing, return None.
316
317 If *padded* is `True`, the first line of a multi-line statement will
318 be padded with spaces to match its original position.
319 """
320 try:
321 lineno = node.lineno - 1
322 end_lineno = node.end_lineno - 1
323 col_offset = node.col_offset
324 end_col_offset = node.end_col_offset
325 except AttributeError:
326 return None
327
328 lines = _splitlines_no_ff(source)
329 if end_lineno == lineno:
330 return lines[lineno].encode()[col_offset:end_col_offset].decode()
331
332 if padded:
333 padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
334 else:
335 padding = ''
336
337 first = padding + lines[lineno].encode()[col_offset:].decode()
338 last = lines[end_lineno].encode()[:end_col_offset].decode()
339 lines = lines[lineno+1:end_lineno]
340
341 lines.insert(0, first)
342 lines.append(last)
343 return ''.join(lines)
344
345
Georg Brandl0c77a822008-06-10 16:37:50 +0000346def walk(node):
347 """
Georg Brandl619e7ba2011-01-09 07:38:51 +0000348 Recursively yield all descendant nodes in the tree starting at *node*
349 (including *node* itself), in no specified order. This is useful if you
350 only want to modify nodes in place and don't care about the context.
Georg Brandl0c77a822008-06-10 16:37:50 +0000351 """
352 from collections import deque
353 todo = deque([node])
354 while todo:
355 node = todo.popleft()
356 todo.extend(iter_child_nodes(node))
357 yield node
358
359
360class NodeVisitor(object):
361 """
362 A node visitor base class that walks the abstract syntax tree and calls a
363 visitor function for every node found. This function may return a value
364 which is forwarded by the `visit` method.
365
366 This class is meant to be subclassed, with the subclass adding visitor
367 methods.
368
369 Per default the visitor functions for the nodes are ``'visit_'`` +
370 class name of the node. So a `TryFinally` node visit function would
371 be `visit_TryFinally`. This behavior can be changed by overriding
372 the `visit` method. If no visitor function exists for a node
373 (return value `None`) the `generic_visit` visitor is used instead.
374
375 Don't use the `NodeVisitor` if you want to apply changes to nodes during
376 traversing. For this a special visitor exists (`NodeTransformer`) that
377 allows modifications.
378 """
379
380 def visit(self, node):
381 """Visit a node."""
382 method = 'visit_' + node.__class__.__name__
383 visitor = getattr(self, method, self.generic_visit)
384 return visitor(node)
385
386 def generic_visit(self, node):
387 """Called if no explicit visitor function exists for a node."""
388 for field, value in iter_fields(node):
389 if isinstance(value, list):
390 for item in value:
391 if isinstance(item, AST):
392 self.visit(item)
393 elif isinstance(value, AST):
394 self.visit(value)
395
Serhiy Storchakac3ea41e2019-08-26 10:13:19 +0300396 def visit_Constant(self, node):
397 value = node.value
398 type_name = _const_node_type_names.get(type(value))
399 if type_name is None:
400 for cls, name in _const_node_type_names.items():
401 if isinstance(value, cls):
402 type_name = name
403 break
404 if type_name is not None:
405 method = 'visit_' + type_name
406 try:
407 visitor = getattr(self, method)
408 except AttributeError:
409 pass
410 else:
411 import warnings
412 warnings.warn(f"{method} is deprecated; add visit_Constant",
413 DeprecationWarning, 2)
414 return visitor(node)
415 return self.generic_visit(node)
416
Georg Brandl0c77a822008-06-10 16:37:50 +0000417
418class NodeTransformer(NodeVisitor):
419 """
420 A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
421 allows modification of nodes.
422
423 The `NodeTransformer` will walk the AST and use the return value of the
424 visitor methods to replace or remove the old node. If the return value of
425 the visitor method is ``None``, the node will be removed from its location,
426 otherwise it is replaced with the return value. The return value may be the
427 original node in which case no replacement takes place.
428
429 Here is an example transformer that rewrites all occurrences of name lookups
430 (``foo``) to ``data['foo']``::
431
432 class RewriteName(NodeTransformer):
433
434 def visit_Name(self, node):
435 return copy_location(Subscript(
436 value=Name(id='data', ctx=Load()),
437 slice=Index(value=Str(s=node.id)),
438 ctx=node.ctx
439 ), node)
440
441 Keep in mind that if the node you're operating on has child nodes you must
442 either transform the child nodes yourself or call the :meth:`generic_visit`
443 method for the node first.
444
445 For nodes that were part of a collection of statements (that applies to all
446 statement nodes), the visitor may also return a list of nodes rather than
447 just a single node.
448
449 Usually you use the transformer like this::
450
451 node = YourTransformer().visit(node)
452 """
453
454 def generic_visit(self, node):
455 for field, old_value in iter_fields(node):
Georg Brandl0c77a822008-06-10 16:37:50 +0000456 if isinstance(old_value, list):
457 new_values = []
458 for value in old_value:
459 if isinstance(value, AST):
460 value = self.visit(value)
461 if value is None:
462 continue
463 elif not isinstance(value, AST):
464 new_values.extend(value)
465 continue
466 new_values.append(value)
467 old_value[:] = new_values
468 elif isinstance(old_value, AST):
469 new_node = self.visit(old_value)
470 if new_node is None:
471 delattr(node, field)
472 else:
473 setattr(node, field, new_node)
474 return node
Serhiy Storchaka3f228112018-09-27 17:42:37 +0300475
476
477# The following code is for backward compatibility.
478# It will be removed in future.
479
480def _getter(self):
481 return self.value
482
483def _setter(self, value):
484 self.value = value
485
486Constant.n = property(_getter, _setter)
487Constant.s = property(_getter, _setter)
488
489class _ABC(type):
490
491 def __instancecheck__(cls, inst):
492 if not isinstance(inst, Constant):
493 return False
494 if cls in _const_types:
495 try:
496 value = inst.value
497 except AttributeError:
498 return False
499 else:
Anthony Sottile74176222019-01-18 11:30:28 -0800500 return (
501 isinstance(value, _const_types[cls]) and
502 not isinstance(value, _const_types_not.get(cls, ()))
503 )
Serhiy Storchaka3f228112018-09-27 17:42:37 +0300504 return type.__instancecheck__(cls, inst)
505
506def _new(cls, *args, **kwargs):
507 if cls in _const_types:
508 return Constant(*args, **kwargs)
509 return Constant.__new__(cls, *args, **kwargs)
510
511class Num(Constant, metaclass=_ABC):
512 _fields = ('n',)
513 __new__ = _new
514
515class Str(Constant, metaclass=_ABC):
516 _fields = ('s',)
517 __new__ = _new
518
519class Bytes(Constant, metaclass=_ABC):
520 _fields = ('s',)
521 __new__ = _new
522
523class NameConstant(Constant, metaclass=_ABC):
524 __new__ = _new
525
526class Ellipsis(Constant, metaclass=_ABC):
527 _fields = ()
528
529 def __new__(cls, *args, **kwargs):
530 if cls is Ellipsis:
531 return Constant(..., *args, **kwargs)
532 return Constant.__new__(cls, *args, **kwargs)
533
534_const_types = {
535 Num: (int, float, complex),
536 Str: (str,),
537 Bytes: (bytes,),
538 NameConstant: (type(None), bool),
539 Ellipsis: (type(...),),
540}
Anthony Sottile74176222019-01-18 11:30:28 -0800541_const_types_not = {
542 Num: (bool,),
543}
Serhiy Storchakac3ea41e2019-08-26 10:13:19 +0300544_const_node_type_names = {
545 bool: 'NameConstant', # should be before int
546 type(None): 'NameConstant',
547 int: 'Num',
548 float: 'Num',
549 complex: 'Num',
550 str: 'Str',
551 bytes: 'Bytes',
552 type(...): 'Ellipsis',
553}
Serhiy Storchaka832e8642019-09-09 23:36:13 +0300554
Pablo Galindo27fc3b62019-11-24 23:02:40 +0000555# Large float and imaginary literals get turned into infinities in the AST.
556# We unparse those infinities to INFSTR.
557_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
558
559class _Unparser(NodeVisitor):
560 """Methods in this class recursively traverse an AST and
561 output source code for the abstract syntax; original formatting
562 is disregarded."""
563
564 def __init__(self):
565 self._source = []
566 self._buffer = []
567 self._indent = 0
568
569 def interleave(self, inter, f, seq):
570 """Call f on each item in seq, calling inter() in between."""
571 seq = iter(seq)
572 try:
573 f(next(seq))
574 except StopIteration:
575 pass
576 else:
577 for x in seq:
578 inter()
579 f(x)
580
581 def fill(self, text=""):
582 """Indent a piece of text and append it, according to the current
583 indentation level"""
584 self.write("\n" + " " * self._indent + text)
585
586 def write(self, text):
587 """Append a piece of text"""
588 self._source.append(text)
589
590 def buffer_writer(self, text):
591 self._buffer.append(text)
592
593 @property
594 def buffer(self):
595 value = "".join(self._buffer)
596 self._buffer.clear()
597 return value
598
Pablo Galindoded88882019-11-25 11:49:17 +0000599 class _Block:
Pablo Galindo27fc3b62019-11-24 23:02:40 +0000600 """A context manager for preparing the source for blocks. It adds
601 the character':', increases the indentation on enter and decreases
602 the indentation on exit."""
Pablo Galindoded88882019-11-25 11:49:17 +0000603 def __init__(self, unparser):
604 self.unparser = unparser
605
606 def __enter__(self):
607 self.unparser.write(":")
608 self.unparser._indent += 1
609
610 def __exit__(self, exc_type, exc_value, traceback):
611 self.unparser._indent -= 1
612
613 def block(self):
614 return self._Block(self)
Pablo Galindo27fc3b62019-11-24 23:02:40 +0000615
616 def traverse(self, node):
617 if isinstance(node, list):
618 for item in node:
619 self.traverse(item)
620 else:
621 super().visit(node)
622
623 def visit(self, node):
624 """Outputs a source code string that, if converted back to an ast
625 (using ast.parse) will generate an AST equivalent to *node*"""
626 self._source = []
627 self.traverse(node)
628 return "".join(self._source)
629
630 def visit_Module(self, node):
631 for subnode in node.body:
632 self.traverse(subnode)
633
634 def visit_Expr(self, node):
635 self.fill()
636 self.traverse(node.value)
637
638 def visit_NamedExpr(self, node):
639 self.write("(")
640 self.traverse(node.target)
641 self.write(" := ")
642 self.traverse(node.value)
643 self.write(")")
644
645 def visit_Import(self, node):
646 self.fill("import ")
647 self.interleave(lambda: self.write(", "), self.traverse, node.names)
648
649 def visit_ImportFrom(self, node):
650 self.fill("from ")
651 self.write("." * node.level)
652 if node.module:
653 self.write(node.module)
654 self.write(" import ")
655 self.interleave(lambda: self.write(", "), self.traverse, node.names)
656
657 def visit_Assign(self, node):
658 self.fill()
659 for target in node.targets:
660 self.traverse(target)
661 self.write(" = ")
662 self.traverse(node.value)
663
664 def visit_AugAssign(self, node):
665 self.fill()
666 self.traverse(node.target)
667 self.write(" " + self.binop[node.op.__class__.__name__] + "= ")
668 self.traverse(node.value)
669
670 def visit_AnnAssign(self, node):
671 self.fill()
672 if not node.simple and isinstance(node.target, Name):
673 self.write("(")
674 self.traverse(node.target)
675 if not node.simple and isinstance(node.target, Name):
676 self.write(")")
677 self.write(": ")
678 self.traverse(node.annotation)
679 if node.value:
680 self.write(" = ")
681 self.traverse(node.value)
682
683 def visit_Return(self, node):
684 self.fill("return")
685 if node.value:
686 self.write(" ")
687 self.traverse(node.value)
688
689 def visit_Pass(self, node):
690 self.fill("pass")
691
692 def visit_Break(self, node):
693 self.fill("break")
694
695 def visit_Continue(self, node):
696 self.fill("continue")
697
698 def visit_Delete(self, node):
699 self.fill("del ")
700 self.interleave(lambda: self.write(", "), self.traverse, node.targets)
701
702 def visit_Assert(self, node):
703 self.fill("assert ")
704 self.traverse(node.test)
705 if node.msg:
706 self.write(", ")
707 self.traverse(node.msg)
708
709 def visit_Global(self, node):
710 self.fill("global ")
711 self.interleave(lambda: self.write(", "), self.write, node.names)
712
713 def visit_Nonlocal(self, node):
714 self.fill("nonlocal ")
715 self.interleave(lambda: self.write(", "), self.write, node.names)
716
717 def visit_Await(self, node):
718 self.write("(")
719 self.write("await")
720 if node.value:
721 self.write(" ")
722 self.traverse(node.value)
723 self.write(")")
724
725 def visit_Yield(self, node):
726 self.write("(")
727 self.write("yield")
728 if node.value:
729 self.write(" ")
730 self.traverse(node.value)
731 self.write(")")
732
733 def visit_YieldFrom(self, node):
734 self.write("(")
735 self.write("yield from")
736 if node.value:
737 self.write(" ")
738 self.traverse(node.value)
739 self.write(")")
740
741 def visit_Raise(self, node):
742 self.fill("raise")
743 if not node.exc:
744 if node.cause:
745 raise ValueError(f"Node can't use cause without an exception.")
746 return
747 self.write(" ")
748 self.traverse(node.exc)
749 if node.cause:
750 self.write(" from ")
751 self.traverse(node.cause)
752
753 def visit_Try(self, node):
754 self.fill("try")
755 with self.block():
756 self.traverse(node.body)
757 for ex in node.handlers:
758 self.traverse(ex)
759 if node.orelse:
760 self.fill("else")
761 with self.block():
762 self.traverse(node.orelse)
763 if node.finalbody:
764 self.fill("finally")
765 with self.block():
766 self.traverse(node.finalbody)
767
768 def visit_ExceptHandler(self, node):
769 self.fill("except")
770 if node.type:
771 self.write(" ")
772 self.traverse(node.type)
773 if node.name:
774 self.write(" as ")
775 self.write(node.name)
776 with self.block():
777 self.traverse(node.body)
778
779 def visit_ClassDef(self, node):
780 self.write("\n")
781 for deco in node.decorator_list:
782 self.fill("@")
783 self.traverse(deco)
784 self.fill("class " + node.name)
785 self.write("(")
786 comma = False
787 for e in node.bases:
788 if comma:
789 self.write(", ")
790 else:
791 comma = True
792 self.traverse(e)
793 for e in node.keywords:
794 if comma:
795 self.write(", ")
796 else:
797 comma = True
798 self.traverse(e)
799 self.write(")")
800
801 with self.block():
802 self.traverse(node.body)
803
804 def visit_FunctionDef(self, node):
805 self.__FunctionDef_helper(node, "def")
806
807 def visit_AsyncFunctionDef(self, node):
808 self.__FunctionDef_helper(node, "async def")
809
810 def __FunctionDef_helper(self, node, fill_suffix):
811 self.write("\n")
812 for deco in node.decorator_list:
813 self.fill("@")
814 self.traverse(deco)
815 def_str = fill_suffix + " " + node.name + "("
816 self.fill(def_str)
817 self.traverse(node.args)
818 self.write(")")
819 if node.returns:
820 self.write(" -> ")
821 self.traverse(node.returns)
822 with self.block():
823 self.traverse(node.body)
824
825 def visit_For(self, node):
826 self.__For_helper("for ", node)
827
828 def visit_AsyncFor(self, node):
829 self.__For_helper("async for ", node)
830
831 def __For_helper(self, fill, node):
832 self.fill(fill)
833 self.traverse(node.target)
834 self.write(" in ")
835 self.traverse(node.iter)
836 with self.block():
837 self.traverse(node.body)
838 if node.orelse:
839 self.fill("else")
840 with self.block():
841 self.traverse(node.orelse)
842
843 def visit_If(self, node):
844 self.fill("if ")
845 self.traverse(node.test)
846 with self.block():
847 self.traverse(node.body)
848 # collapse nested ifs into equivalent elifs.
849 while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
850 node = node.orelse[0]
851 self.fill("elif ")
852 self.traverse(node.test)
853 with self.block():
854 self.traverse(node.body)
855 # final else
856 if node.orelse:
857 self.fill("else")
858 with self.block():
859 self.traverse(node.orelse)
860
861 def visit_While(self, node):
862 self.fill("while ")
863 self.traverse(node.test)
864 with self.block():
865 self.traverse(node.body)
866 if node.orelse:
867 self.fill("else")
868 with self.block():
869 self.traverse(node.orelse)
870
871 def visit_With(self, node):
872 self.fill("with ")
873 self.interleave(lambda: self.write(", "), self.traverse, node.items)
874 with self.block():
875 self.traverse(node.body)
876
877 def visit_AsyncWith(self, node):
878 self.fill("async with ")
879 self.interleave(lambda: self.write(", "), self.traverse, node.items)
880 with self.block():
881 self.traverse(node.body)
882
883 def visit_JoinedStr(self, node):
884 self.write("f")
885 self._fstring_JoinedStr(node, self.buffer_writer)
886 self.write(repr(self.buffer))
887
888 def visit_FormattedValue(self, node):
889 self.write("f")
890 self._fstring_FormattedValue(node, self.buffer_writer)
891 self.write(repr(self.buffer))
892
893 def _fstring_JoinedStr(self, node, write):
894 for value in node.values:
895 meth = getattr(self, "_fstring_" + type(value).__name__)
896 meth(value, write)
897
898 def _fstring_Constant(self, node, write):
899 if not isinstance(node.value, str):
900 raise ValueError("Constants inside JoinedStr should be a string.")
901 value = node.value.replace("{", "{{").replace("}", "}}")
902 write(value)
903
904 def _fstring_FormattedValue(self, node, write):
905 write("{")
906 expr = type(self)().visit(node.value).rstrip("\n")
907 if expr.startswith("{"):
908 write(" ") # Separate pair of opening brackets as "{ {"
909 write(expr)
910 if node.conversion != -1:
911 conversion = chr(node.conversion)
912 if conversion not in "sra":
913 raise ValueError("Unknown f-string conversion.")
914 write(f"!{conversion}")
915 if node.format_spec:
916 write(":")
917 meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
918 meth(node.format_spec, write)
919 write("}")
920
921 def visit_Name(self, node):
922 self.write(node.id)
923
924 def _write_constant(self, value):
925 if isinstance(value, (float, complex)):
926 # Substitute overflowing decimal literal for AST infinities.
927 self.write(repr(value).replace("inf", _INFSTR))
928 else:
929 self.write(repr(value))
930
931 def visit_Constant(self, node):
932 value = node.value
933 if isinstance(value, tuple):
934 self.write("(")
935 if len(value) == 1:
936 self._write_constant(value[0])
937 self.write(",")
938 else:
939 self.interleave(lambda: self.write(", "), self._write_constant, value)
940 self.write(")")
941 elif value is ...:
942 self.write("...")
943 else:
944 if node.kind == "u":
945 self.write("u")
946 self._write_constant(node.value)
947
948 def visit_List(self, node):
949 self.write("[")
950 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
951 self.write("]")
952
953 def visit_ListComp(self, node):
954 self.write("[")
955 self.traverse(node.elt)
956 for gen in node.generators:
957 self.traverse(gen)
958 self.write("]")
959
960 def visit_GeneratorExp(self, node):
961 self.write("(")
962 self.traverse(node.elt)
963 for gen in node.generators:
964 self.traverse(gen)
965 self.write(")")
966
967 def visit_SetComp(self, node):
968 self.write("{")
969 self.traverse(node.elt)
970 for gen in node.generators:
971 self.traverse(gen)
972 self.write("}")
973
974 def visit_DictComp(self, node):
975 self.write("{")
976 self.traverse(node.key)
977 self.write(": ")
978 self.traverse(node.value)
979 for gen in node.generators:
980 self.traverse(gen)
981 self.write("}")
982
983 def visit_comprehension(self, node):
984 if node.is_async:
985 self.write(" async for ")
986 else:
987 self.write(" for ")
988 self.traverse(node.target)
989 self.write(" in ")
990 self.traverse(node.iter)
991 for if_clause in node.ifs:
992 self.write(" if ")
993 self.traverse(if_clause)
994
995 def visit_IfExp(self, node):
996 self.write("(")
997 self.traverse(node.body)
998 self.write(" if ")
999 self.traverse(node.test)
1000 self.write(" else ")
1001 self.traverse(node.orelse)
1002 self.write(")")
1003
1004 def visit_Set(self, node):
1005 if not node.elts:
1006 raise ValueError("Set node should has at least one item")
1007 self.write("{")
1008 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1009 self.write("}")
1010
1011 def visit_Dict(self, node):
1012 self.write("{")
1013
1014 def write_key_value_pair(k, v):
1015 self.traverse(k)
1016 self.write(": ")
1017 self.traverse(v)
1018
1019 def write_item(item):
1020 k, v = item
1021 if k is None:
1022 # for dictionary unpacking operator in dicts {**{'y': 2}}
1023 # see PEP 448 for details
1024 self.write("**")
1025 self.traverse(v)
1026 else:
1027 write_key_value_pair(k, v)
1028
1029 self.interleave(
1030 lambda: self.write(", "), write_item, zip(node.keys, node.values)
1031 )
1032 self.write("}")
1033
1034 def visit_Tuple(self, node):
1035 self.write("(")
1036 if len(node.elts) == 1:
1037 elt = node.elts[0]
1038 self.traverse(elt)
1039 self.write(",")
1040 else:
1041 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1042 self.write(")")
1043
1044 unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
1045
1046 def visit_UnaryOp(self, node):
1047 self.write("(")
1048 self.write(self.unop[node.op.__class__.__name__])
1049 self.write(" ")
1050 self.traverse(node.operand)
1051 self.write(")")
1052
1053 binop = {
1054 "Add": "+",
1055 "Sub": "-",
1056 "Mult": "*",
1057 "MatMult": "@",
1058 "Div": "/",
1059 "Mod": "%",
1060 "LShift": "<<",
1061 "RShift": ">>",
1062 "BitOr": "|",
1063 "BitXor": "^",
1064 "BitAnd": "&",
1065 "FloorDiv": "//",
1066 "Pow": "**",
1067 }
1068
1069 def visit_BinOp(self, node):
1070 self.write("(")
1071 self.traverse(node.left)
1072 self.write(" " + self.binop[node.op.__class__.__name__] + " ")
1073 self.traverse(node.right)
1074 self.write(")")
1075
1076 cmpops = {
1077 "Eq": "==",
1078 "NotEq": "!=",
1079 "Lt": "<",
1080 "LtE": "<=",
1081 "Gt": ">",
1082 "GtE": ">=",
1083 "Is": "is",
1084 "IsNot": "is not",
1085 "In": "in",
1086 "NotIn": "not in",
1087 }
1088
1089 def visit_Compare(self, node):
1090 self.write("(")
1091 self.traverse(node.left)
1092 for o, e in zip(node.ops, node.comparators):
1093 self.write(" " + self.cmpops[o.__class__.__name__] + " ")
1094 self.traverse(e)
1095 self.write(")")
1096
1097 boolops = {And: "and", Or: "or"}
1098
1099 def visit_BoolOp(self, node):
1100 self.write("(")
1101 s = " %s " % self.boolops[node.op.__class__]
1102 self.interleave(lambda: self.write(s), self.traverse, node.values)
1103 self.write(")")
1104
1105 def visit_Attribute(self, node):
1106 self.traverse(node.value)
1107 # Special case: 3.__abs__() is a syntax error, so if node.value
1108 # is an integer literal then we need to either parenthesize
1109 # it or add an extra space to get 3 .__abs__().
1110 if isinstance(node.value, Constant) and isinstance(node.value.value, int):
1111 self.write(" ")
1112 self.write(".")
1113 self.write(node.attr)
1114
1115 def visit_Call(self, node):
1116 self.traverse(node.func)
1117 self.write("(")
1118 comma = False
1119 for e in node.args:
1120 if comma:
1121 self.write(", ")
1122 else:
1123 comma = True
1124 self.traverse(e)
1125 for e in node.keywords:
1126 if comma:
1127 self.write(", ")
1128 else:
1129 comma = True
1130 self.traverse(e)
1131 self.write(")")
1132
1133 def visit_Subscript(self, node):
1134 self.traverse(node.value)
1135 self.write("[")
1136 self.traverse(node.slice)
1137 self.write("]")
1138
1139 def visit_Starred(self, node):
1140 self.write("*")
1141 self.traverse(node.value)
1142
1143 def visit_Ellipsis(self, node):
1144 self.write("...")
1145
1146 def visit_Index(self, node):
1147 self.traverse(node.value)
1148
1149 def visit_Slice(self, node):
1150 if node.lower:
1151 self.traverse(node.lower)
1152 self.write(":")
1153 if node.upper:
1154 self.traverse(node.upper)
1155 if node.step:
1156 self.write(":")
1157 self.traverse(node.step)
1158
1159 def visit_ExtSlice(self, node):
1160 self.interleave(lambda: self.write(", "), self.traverse, node.dims)
1161
1162 def visit_arg(self, node):
1163 self.write(node.arg)
1164 if node.annotation:
1165 self.write(": ")
1166 self.traverse(node.annotation)
1167
1168 def visit_arguments(self, node):
1169 first = True
1170 # normal arguments
1171 all_args = node.posonlyargs + node.args
1172 defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults
1173 for index, elements in enumerate(zip(all_args, defaults), 1):
1174 a, d = elements
1175 if first:
1176 first = False
1177 else:
1178 self.write(", ")
1179 self.traverse(a)
1180 if d:
1181 self.write("=")
1182 self.traverse(d)
1183 if index == len(node.posonlyargs):
1184 self.write(", /")
1185
1186 # varargs, or bare '*' if no varargs but keyword-only arguments present
1187 if node.vararg or node.kwonlyargs:
1188 if first:
1189 first = False
1190 else:
1191 self.write(", ")
1192 self.write("*")
1193 if node.vararg:
1194 self.write(node.vararg.arg)
1195 if node.vararg.annotation:
1196 self.write(": ")
1197 self.traverse(node.vararg.annotation)
1198
1199 # keyword-only arguments
1200 if node.kwonlyargs:
1201 for a, d in zip(node.kwonlyargs, node.kw_defaults):
1202 if first:
1203 first = False
1204 else:
1205 self.write(", ")
1206 self.traverse(a),
1207 if d:
1208 self.write("=")
1209 self.traverse(d)
1210
1211 # kwargs
1212 if node.kwarg:
1213 if first:
1214 first = False
1215 else:
1216 self.write(", ")
1217 self.write("**" + node.kwarg.arg)
1218 if node.kwarg.annotation:
1219 self.write(": ")
1220 self.traverse(node.kwarg.annotation)
1221
1222 def visit_keyword(self, node):
1223 if node.arg is None:
1224 self.write("**")
1225 else:
1226 self.write(node.arg)
1227 self.write("=")
1228 self.traverse(node.value)
1229
1230 def visit_Lambda(self, node):
1231 self.write("(")
1232 self.write("lambda ")
1233 self.traverse(node.args)
1234 self.write(": ")
1235 self.traverse(node.body)
1236 self.write(")")
1237
1238 def visit_alias(self, node):
1239 self.write(node.name)
1240 if node.asname:
1241 self.write(" as " + node.asname)
1242
1243 def visit_withitem(self, node):
1244 self.traverse(node.context_expr)
1245 if node.optional_vars:
1246 self.write(" as ")
1247 self.traverse(node.optional_vars)
1248
1249def unparse(ast_obj):
1250 unparser = _Unparser()
1251 return unparser.visit(ast_obj)
1252
Serhiy Storchaka832e8642019-09-09 23:36:13 +03001253
1254def main():
1255 import argparse
1256
1257 parser = argparse.ArgumentParser(prog='python -m ast')
1258 parser.add_argument('infile', type=argparse.FileType(mode='rb'), nargs='?',
1259 default='-',
1260 help='the file to parse; defaults to stdin')
1261 parser.add_argument('-m', '--mode', default='exec',
1262 choices=('exec', 'single', 'eval', 'func_type'),
1263 help='specify what kind of code must be parsed')
1264 parser.add_argument('-a', '--include-attributes', action='store_true',
1265 help='include attributes such as line numbers and '
1266 'column offsets')
1267 args = parser.parse_args()
1268
1269 with args.infile as infile:
1270 source = infile.read()
1271 tree = parse(source, args.infile.name, args.mode, type_comments=True)
1272 print(dump(tree, include_attributes=args.include_attributes, indent=3))
1273
1274if __name__ == '__main__':
1275 main()