blob: 97914ebc66858b06d98d5a29aef1e65eae72a97d [file] [log] [blame]
Georg Brandl0c77a822008-06-10 16:37:50 +00001"""
2 ast
3 ~~~
4
5 The `ast` module helps Python applications to process trees of the Python
6 abstract syntax grammar. The abstract syntax itself might change with
7 each Python release; this module helps to find out programmatically what
8 the current grammar looks like and allows modifications of it.
9
10 An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11 a flag to the `compile()` builtin function or by using the `parse()`
12 function from this module. The result will be a tree of objects whose
13 classes all inherit from `ast.AST`.
14
15 A modified abstract syntax tree can be compiled into a Python code object
16 using the built-in `compile()` function.
17
18 Additionally various helper functions are provided that make working with
19 the trees simpler. The main intention of the helper functions and this
20 module in general is to provide an easy to use interface for libraries
21 that work tightly with the python syntax (template engines for example).
22
23
24 :copyright: Copyright 2008 by Armin Ronacher.
25 :license: Python License.
26"""
Pablo Galindo27fc3b62019-11-24 23:02:40 +000027import sys
Georg Brandl0c77a822008-06-10 16:37:50 +000028from _ast import *
Pablo Galindo27fc3b62019-11-24 23:02:40 +000029from contextlib import contextmanager
Georg Brandl0c77a822008-06-10 16:37:50 +000030
31
Guido van Rossum495da292019-03-07 12:38:08 -080032def parse(source, filename='<unknown>', mode='exec', *,
Guido van Rossum10b55c12019-06-11 17:23:12 -070033 type_comments=False, feature_version=None):
Georg Brandl0c77a822008-06-10 16:37:50 +000034 """
Terry Reedyfeac6242011-01-24 21:36:03 +000035 Parse the source into an AST node.
36 Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
Guido van Rossumdcfcd142019-01-31 03:40:27 -080037 Pass type_comments=True to get back type comments where the syntax allows.
Georg Brandl0c77a822008-06-10 16:37:50 +000038 """
Guido van Rossumdcfcd142019-01-31 03:40:27 -080039 flags = PyCF_ONLY_AST
40 if type_comments:
41 flags |= PyCF_TYPE_COMMENTS
Guido van Rossum10b55c12019-06-11 17:23:12 -070042 if isinstance(feature_version, tuple):
43 major, minor = feature_version # Should be a 2-tuple.
44 assert major == 3
45 feature_version = minor
46 elif feature_version is None:
47 feature_version = -1
48 # Else it should be an int giving the minor version for 3.x.
Guido van Rossum495da292019-03-07 12:38:08 -080049 return compile(source, filename, mode, flags,
Victor Stinnerefdf6ca2019-06-12 02:52:16 +020050 _feature_version=feature_version)
Georg Brandl0c77a822008-06-10 16:37:50 +000051
52
53def literal_eval(node_or_string):
54 """
55 Safely evaluate an expression node or a string containing a Python
56 expression. The string or node provided may only consist of the following
Éric Araujo2a83cc62011-04-17 19:10:27 +020057 Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
58 sets, booleans, and None.
Georg Brandl0c77a822008-06-10 16:37:50 +000059 """
Georg Brandl0c77a822008-06-10 16:37:50 +000060 if isinstance(node_or_string, str):
61 node_or_string = parse(node_or_string, mode='eval')
62 if isinstance(node_or_string, Expression):
63 node_or_string = node_or_string.body
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020064 def _convert_num(node):
65 if isinstance(node, Constant):
Serhiy Storchaka3f228112018-09-27 17:42:37 +030066 if type(node.value) in (int, float, complex):
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020067 return node.value
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020068 raise ValueError('malformed node or string: ' + repr(node))
69 def _convert_signed_num(node):
70 if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
71 operand = _convert_num(node.operand)
72 if isinstance(node.op, UAdd):
73 return + operand
74 else:
75 return - operand
76 return _convert_num(node)
Georg Brandl0c77a822008-06-10 16:37:50 +000077 def _convert(node):
Victor Stinnerf2c1aa12016-01-26 00:40:57 +010078 if isinstance(node, Constant):
79 return node.value
Georg Brandl0c77a822008-06-10 16:37:50 +000080 elif isinstance(node, Tuple):
81 return tuple(map(_convert, node.elts))
82 elif isinstance(node, List):
83 return list(map(_convert, node.elts))
Georg Brandl492f3fc2010-07-11 09:41:21 +000084 elif isinstance(node, Set):
85 return set(map(_convert, node.elts))
Georg Brandl0c77a822008-06-10 16:37:50 +000086 elif isinstance(node, Dict):
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020087 return dict(zip(map(_convert, node.keys),
88 map(_convert, node.values)))
Victor Stinnerf2c1aa12016-01-26 00:40:57 +010089 elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020090 left = _convert_signed_num(node.left)
91 right = _convert_num(node.right)
92 if isinstance(left, (int, float)) and isinstance(right, complex):
Victor Stinnerf2c1aa12016-01-26 00:40:57 +010093 if isinstance(node.op, Add):
94 return left + right
95 else:
96 return left - right
Serhiy Storchakad8ac4d12018-01-04 11:15:39 +020097 return _convert_signed_num(node)
Georg Brandl0c77a822008-06-10 16:37:50 +000098 return _convert(node_or_string)
99
100
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300101def dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
Georg Brandl0c77a822008-06-10 16:37:50 +0000102 """
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300103 Return a formatted dump of the tree in node. This is mainly useful for
104 debugging purposes. If annotate_fields is true (by default),
105 the returned string will show the names and the values for fields.
106 If annotate_fields is false, the result string will be more compact by
107 omitting unambiguous field names. Attributes such as line
Benjamin Petersondcf97b92008-07-02 17:30:14 +0000108 numbers and column offsets are not dumped by default. If this is wanted,
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300109 include_attributes can be set to true. If indent is a non-negative
110 integer or string, then the tree will be pretty-printed with that indent
111 level. None (the default) selects the single line representation.
Georg Brandl0c77a822008-06-10 16:37:50 +0000112 """
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300113 def _format(node, level=0):
114 if indent is not None:
115 level += 1
116 prefix = '\n' + indent * level
117 sep = ',\n' + indent * level
118 else:
119 prefix = ''
120 sep = ', '
Georg Brandl0c77a822008-06-10 16:37:50 +0000121 if isinstance(node, AST):
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300122 args = []
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300123 allsimple = True
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300124 keywords = annotate_fields
125 for field in node._fields:
126 try:
127 value = getattr(node, field)
128 except AttributeError:
129 keywords = True
130 else:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300131 value, simple = _format(value, level)
132 allsimple = allsimple and simple
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300133 if keywords:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300134 args.append('%s=%s' % (field, value))
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300135 else:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300136 args.append(value)
Georg Brandl0c77a822008-06-10 16:37:50 +0000137 if include_attributes and node._attributes:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300138 for attr in node._attributes:
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300139 try:
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300140 value = getattr(node, attr)
Serhiy Storchakae64f9482019-08-29 09:30:23 +0300141 except AttributeError:
142 pass
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300143 else:
144 value, simple = _format(value, level)
145 allsimple = allsimple and simple
146 args.append('%s=%s' % (attr, value))
147 if allsimple and len(args) <= 3:
148 return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args
149 return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False
Georg Brandl0c77a822008-06-10 16:37:50 +0000150 elif isinstance(node, list):
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300151 if not node:
152 return '[]', True
153 return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False
154 return repr(node), True
155
Georg Brandl0c77a822008-06-10 16:37:50 +0000156 if not isinstance(node, AST):
157 raise TypeError('expected AST, got %r' % node.__class__.__name__)
Serhiy Storchaka850573b2019-09-09 19:33:13 +0300158 if indent is not None and not isinstance(indent, str):
159 indent = ' ' * indent
160 return _format(node)[0]
Georg Brandl0c77a822008-06-10 16:37:50 +0000161
162
163def copy_location(new_node, old_node):
164 """
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000165 Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
166 attributes) from *old_node* to *new_node* if possible, and return *new_node*.
Georg Brandl0c77a822008-06-10 16:37:50 +0000167 """
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000168 for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
Georg Brandl0c77a822008-06-10 16:37:50 +0000169 if attr in old_node._attributes and attr in new_node._attributes \
170 and hasattr(old_node, attr):
171 setattr(new_node, attr, getattr(old_node, attr))
172 return new_node
173
174
175def fix_missing_locations(node):
176 """
177 When you compile a node tree with compile(), the compiler expects lineno and
178 col_offset attributes for every node that supports them. This is rather
179 tedious to fill in for generated nodes, so this helper adds these attributes
180 recursively where not already set, by setting them to the values of the
181 parent node. It works recursively starting at *node*.
182 """
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000183 def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
Georg Brandl0c77a822008-06-10 16:37:50 +0000184 if 'lineno' in node._attributes:
185 if not hasattr(node, 'lineno'):
186 node.lineno = lineno
187 else:
188 lineno = node.lineno
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000189 if 'end_lineno' in node._attributes:
190 if not hasattr(node, 'end_lineno'):
191 node.end_lineno = end_lineno
192 else:
193 end_lineno = node.end_lineno
Georg Brandl0c77a822008-06-10 16:37:50 +0000194 if 'col_offset' in node._attributes:
195 if not hasattr(node, 'col_offset'):
196 node.col_offset = col_offset
197 else:
198 col_offset = node.col_offset
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000199 if 'end_col_offset' in node._attributes:
200 if not hasattr(node, 'end_col_offset'):
201 node.end_col_offset = end_col_offset
202 else:
203 end_col_offset = node.end_col_offset
Georg Brandl0c77a822008-06-10 16:37:50 +0000204 for child in iter_child_nodes(node):
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000205 _fix(child, lineno, col_offset, end_lineno, end_col_offset)
206 _fix(node, 1, 0, 1, 0)
Georg Brandl0c77a822008-06-10 16:37:50 +0000207 return node
208
209
210def increment_lineno(node, n=1):
211 """
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000212 Increment the line number and end line number of each node in the tree
213 starting at *node* by *n*. This is useful to "move code" to a different
214 location in a file.
Georg Brandl0c77a822008-06-10 16:37:50 +0000215 """
Georg Brandl0c77a822008-06-10 16:37:50 +0000216 for child in walk(node):
217 if 'lineno' in child._attributes:
218 child.lineno = getattr(child, 'lineno', 0) + n
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000219 if 'end_lineno' in child._attributes:
220 child.end_lineno = getattr(child, 'end_lineno', 0) + n
Georg Brandl0c77a822008-06-10 16:37:50 +0000221 return node
222
223
224def iter_fields(node):
225 """
226 Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
227 that is present on *node*.
228 """
229 for field in node._fields:
230 try:
231 yield field, getattr(node, field)
232 except AttributeError:
233 pass
234
235
236def iter_child_nodes(node):
237 """
238 Yield all direct child nodes of *node*, that is, all fields that are nodes
239 and all items of fields that are lists of nodes.
240 """
241 for name, field in iter_fields(node):
242 if isinstance(field, AST):
243 yield field
244 elif isinstance(field, list):
245 for item in field:
246 if isinstance(item, AST):
247 yield item
248
249
250def get_docstring(node, clean=True):
251 """
252 Return the docstring for the given node or None if no docstring can
253 be found. If the node provided does not have docstrings a TypeError
254 will be raised.
Matthias Bussonnier41cea702017-02-23 22:44:19 -0800255
256 If *clean* is `True`, all tabs are expanded to spaces and any whitespace
257 that can be uniformly removed from the second line onwards is removed.
Georg Brandl0c77a822008-06-10 16:37:50 +0000258 """
Yury Selivanov2f07a662015-07-23 08:54:35 +0300259 if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
Georg Brandl0c77a822008-06-10 16:37:50 +0000260 raise TypeError("%r can't have docstrings" % node.__class__.__name__)
Serhiy Storchaka08f127a2018-06-15 11:05:15 +0300261 if not(node.body and isinstance(node.body[0], Expr)):
Serhiy Storchaka73cbe7a2018-05-29 12:04:55 +0300262 return None
263 node = node.body[0].value
264 if isinstance(node, Str):
265 text = node.s
266 elif isinstance(node, Constant) and isinstance(node.value, str):
267 text = node.value
268 else:
269 return None
Serhiy Storchaka08f127a2018-06-15 11:05:15 +0300270 if clean:
Victor Stinnerf2c1aa12016-01-26 00:40:57 +0100271 import inspect
272 text = inspect.cleandoc(text)
273 return text
Georg Brandl0c77a822008-06-10 16:37:50 +0000274
275
Ivan Levkivskyi9932a222019-01-22 11:18:22 +0000276def _splitlines_no_ff(source):
277 """Split a string into lines ignoring form feed and other chars.
278
279 This mimics how the Python parser splits source code.
280 """
281 idx = 0
282 lines = []
283 next_line = ''
284 while idx < len(source):
285 c = source[idx]
286 next_line += c
287 idx += 1
288 # Keep \r\n together
289 if c == '\r' and idx < len(source) and source[idx] == '\n':
290 next_line += '\n'
291 idx += 1
292 if c in '\r\n':
293 lines.append(next_line)
294 next_line = ''
295
296 if next_line:
297 lines.append(next_line)
298 return lines
299
300
301def _pad_whitespace(source):
302 """Replace all chars except '\f\t' in a line with spaces."""
303 result = ''
304 for c in source:
305 if c in '\f\t':
306 result += c
307 else:
308 result += ' '
309 return result
310
311
312def get_source_segment(source, node, *, padded=False):
313 """Get source code segment of the *source* that generated *node*.
314
315 If some location information (`lineno`, `end_lineno`, `col_offset`,
316 or `end_col_offset`) is missing, return None.
317
318 If *padded* is `True`, the first line of a multi-line statement will
319 be padded with spaces to match its original position.
320 """
321 try:
322 lineno = node.lineno - 1
323 end_lineno = node.end_lineno - 1
324 col_offset = node.col_offset
325 end_col_offset = node.end_col_offset
326 except AttributeError:
327 return None
328
329 lines = _splitlines_no_ff(source)
330 if end_lineno == lineno:
331 return lines[lineno].encode()[col_offset:end_col_offset].decode()
332
333 if padded:
334 padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
335 else:
336 padding = ''
337
338 first = padding + lines[lineno].encode()[col_offset:].decode()
339 last = lines[end_lineno].encode()[:end_col_offset].decode()
340 lines = lines[lineno+1:end_lineno]
341
342 lines.insert(0, first)
343 lines.append(last)
344 return ''.join(lines)
345
346
Georg Brandl0c77a822008-06-10 16:37:50 +0000347def walk(node):
348 """
Georg Brandl619e7ba2011-01-09 07:38:51 +0000349 Recursively yield all descendant nodes in the tree starting at *node*
350 (including *node* itself), in no specified order. This is useful if you
351 only want to modify nodes in place and don't care about the context.
Georg Brandl0c77a822008-06-10 16:37:50 +0000352 """
353 from collections import deque
354 todo = deque([node])
355 while todo:
356 node = todo.popleft()
357 todo.extend(iter_child_nodes(node))
358 yield node
359
360
361class NodeVisitor(object):
362 """
363 A node visitor base class that walks the abstract syntax tree and calls a
364 visitor function for every node found. This function may return a value
365 which is forwarded by the `visit` method.
366
367 This class is meant to be subclassed, with the subclass adding visitor
368 methods.
369
370 Per default the visitor functions for the nodes are ``'visit_'`` +
371 class name of the node. So a `TryFinally` node visit function would
372 be `visit_TryFinally`. This behavior can be changed by overriding
373 the `visit` method. If no visitor function exists for a node
374 (return value `None`) the `generic_visit` visitor is used instead.
375
376 Don't use the `NodeVisitor` if you want to apply changes to nodes during
377 traversing. For this a special visitor exists (`NodeTransformer`) that
378 allows modifications.
379 """
380
381 def visit(self, node):
382 """Visit a node."""
383 method = 'visit_' + node.__class__.__name__
384 visitor = getattr(self, method, self.generic_visit)
385 return visitor(node)
386
387 def generic_visit(self, node):
388 """Called if no explicit visitor function exists for a node."""
389 for field, value in iter_fields(node):
390 if isinstance(value, list):
391 for item in value:
392 if isinstance(item, AST):
393 self.visit(item)
394 elif isinstance(value, AST):
395 self.visit(value)
396
Serhiy Storchakac3ea41e2019-08-26 10:13:19 +0300397 def visit_Constant(self, node):
398 value = node.value
399 type_name = _const_node_type_names.get(type(value))
400 if type_name is None:
401 for cls, name in _const_node_type_names.items():
402 if isinstance(value, cls):
403 type_name = name
404 break
405 if type_name is not None:
406 method = 'visit_' + type_name
407 try:
408 visitor = getattr(self, method)
409 except AttributeError:
410 pass
411 else:
412 import warnings
413 warnings.warn(f"{method} is deprecated; add visit_Constant",
414 DeprecationWarning, 2)
415 return visitor(node)
416 return self.generic_visit(node)
417
Georg Brandl0c77a822008-06-10 16:37:50 +0000418
419class NodeTransformer(NodeVisitor):
420 """
421 A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
422 allows modification of nodes.
423
424 The `NodeTransformer` will walk the AST and use the return value of the
425 visitor methods to replace or remove the old node. If the return value of
426 the visitor method is ``None``, the node will be removed from its location,
427 otherwise it is replaced with the return value. The return value may be the
428 original node in which case no replacement takes place.
429
430 Here is an example transformer that rewrites all occurrences of name lookups
431 (``foo``) to ``data['foo']``::
432
433 class RewriteName(NodeTransformer):
434
435 def visit_Name(self, node):
436 return copy_location(Subscript(
437 value=Name(id='data', ctx=Load()),
438 slice=Index(value=Str(s=node.id)),
439 ctx=node.ctx
440 ), node)
441
442 Keep in mind that if the node you're operating on has child nodes you must
443 either transform the child nodes yourself or call the :meth:`generic_visit`
444 method for the node first.
445
446 For nodes that were part of a collection of statements (that applies to all
447 statement nodes), the visitor may also return a list of nodes rather than
448 just a single node.
449
450 Usually you use the transformer like this::
451
452 node = YourTransformer().visit(node)
453 """
454
455 def generic_visit(self, node):
456 for field, old_value in iter_fields(node):
Georg Brandl0c77a822008-06-10 16:37:50 +0000457 if isinstance(old_value, list):
458 new_values = []
459 for value in old_value:
460 if isinstance(value, AST):
461 value = self.visit(value)
462 if value is None:
463 continue
464 elif not isinstance(value, AST):
465 new_values.extend(value)
466 continue
467 new_values.append(value)
468 old_value[:] = new_values
469 elif isinstance(old_value, AST):
470 new_node = self.visit(old_value)
471 if new_node is None:
472 delattr(node, field)
473 else:
474 setattr(node, field, new_node)
475 return node
Serhiy Storchaka3f228112018-09-27 17:42:37 +0300476
477
478# The following code is for backward compatibility.
479# It will be removed in future.
480
481def _getter(self):
482 return self.value
483
484def _setter(self, value):
485 self.value = value
486
487Constant.n = property(_getter, _setter)
488Constant.s = property(_getter, _setter)
489
490class _ABC(type):
491
492 def __instancecheck__(cls, inst):
493 if not isinstance(inst, Constant):
494 return False
495 if cls in _const_types:
496 try:
497 value = inst.value
498 except AttributeError:
499 return False
500 else:
Anthony Sottile74176222019-01-18 11:30:28 -0800501 return (
502 isinstance(value, _const_types[cls]) and
503 not isinstance(value, _const_types_not.get(cls, ()))
504 )
Serhiy Storchaka3f228112018-09-27 17:42:37 +0300505 return type.__instancecheck__(cls, inst)
506
507def _new(cls, *args, **kwargs):
508 if cls in _const_types:
509 return Constant(*args, **kwargs)
510 return Constant.__new__(cls, *args, **kwargs)
511
512class Num(Constant, metaclass=_ABC):
513 _fields = ('n',)
514 __new__ = _new
515
516class Str(Constant, metaclass=_ABC):
517 _fields = ('s',)
518 __new__ = _new
519
520class Bytes(Constant, metaclass=_ABC):
521 _fields = ('s',)
522 __new__ = _new
523
524class NameConstant(Constant, metaclass=_ABC):
525 __new__ = _new
526
527class Ellipsis(Constant, metaclass=_ABC):
528 _fields = ()
529
530 def __new__(cls, *args, **kwargs):
531 if cls is Ellipsis:
532 return Constant(..., *args, **kwargs)
533 return Constant.__new__(cls, *args, **kwargs)
534
535_const_types = {
536 Num: (int, float, complex),
537 Str: (str,),
538 Bytes: (bytes,),
539 NameConstant: (type(None), bool),
540 Ellipsis: (type(...),),
541}
Anthony Sottile74176222019-01-18 11:30:28 -0800542_const_types_not = {
543 Num: (bool,),
544}
Serhiy Storchakac3ea41e2019-08-26 10:13:19 +0300545_const_node_type_names = {
546 bool: 'NameConstant', # should be before int
547 type(None): 'NameConstant',
548 int: 'Num',
549 float: 'Num',
550 complex: 'Num',
551 str: 'Str',
552 bytes: 'Bytes',
553 type(...): 'Ellipsis',
554}
Serhiy Storchaka832e8642019-09-09 23:36:13 +0300555
Pablo Galindo27fc3b62019-11-24 23:02:40 +0000556# Large float and imaginary literals get turned into infinities in the AST.
557# We unparse those infinities to INFSTR.
558_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
559
560class _Unparser(NodeVisitor):
561 """Methods in this class recursively traverse an AST and
562 output source code for the abstract syntax; original formatting
563 is disregarded."""
564
565 def __init__(self):
566 self._source = []
567 self._buffer = []
568 self._indent = 0
569
570 def interleave(self, inter, f, seq):
571 """Call f on each item in seq, calling inter() in between."""
572 seq = iter(seq)
573 try:
574 f(next(seq))
575 except StopIteration:
576 pass
577 else:
578 for x in seq:
579 inter()
580 f(x)
581
582 def fill(self, text=""):
583 """Indent a piece of text and append it, according to the current
584 indentation level"""
585 self.write("\n" + " " * self._indent + text)
586
587 def write(self, text):
588 """Append a piece of text"""
589 self._source.append(text)
590
591 def buffer_writer(self, text):
592 self._buffer.append(text)
593
594 @property
595 def buffer(self):
596 value = "".join(self._buffer)
597 self._buffer.clear()
598 return value
599
600 @contextmanager
601 def block(self):
602 """A context manager for preparing the source for blocks. It adds
603 the character':', increases the indentation on enter and decreases
604 the indentation on exit."""
605 self.write(":")
606 self._indent += 1
607 yield
608 self._indent -= 1
609
610 def traverse(self, node):
611 if isinstance(node, list):
612 for item in node:
613 self.traverse(item)
614 else:
615 super().visit(node)
616
617 def visit(self, node):
618 """Outputs a source code string that, if converted back to an ast
619 (using ast.parse) will generate an AST equivalent to *node*"""
620 self._source = []
621 self.traverse(node)
622 return "".join(self._source)
623
624 def visit_Module(self, node):
625 for subnode in node.body:
626 self.traverse(subnode)
627
628 def visit_Expr(self, node):
629 self.fill()
630 self.traverse(node.value)
631
632 def visit_NamedExpr(self, node):
633 self.write("(")
634 self.traverse(node.target)
635 self.write(" := ")
636 self.traverse(node.value)
637 self.write(")")
638
639 def visit_Import(self, node):
640 self.fill("import ")
641 self.interleave(lambda: self.write(", "), self.traverse, node.names)
642
643 def visit_ImportFrom(self, node):
644 self.fill("from ")
645 self.write("." * node.level)
646 if node.module:
647 self.write(node.module)
648 self.write(" import ")
649 self.interleave(lambda: self.write(", "), self.traverse, node.names)
650
651 def visit_Assign(self, node):
652 self.fill()
653 for target in node.targets:
654 self.traverse(target)
655 self.write(" = ")
656 self.traverse(node.value)
657
658 def visit_AugAssign(self, node):
659 self.fill()
660 self.traverse(node.target)
661 self.write(" " + self.binop[node.op.__class__.__name__] + "= ")
662 self.traverse(node.value)
663
664 def visit_AnnAssign(self, node):
665 self.fill()
666 if not node.simple and isinstance(node.target, Name):
667 self.write("(")
668 self.traverse(node.target)
669 if not node.simple and isinstance(node.target, Name):
670 self.write(")")
671 self.write(": ")
672 self.traverse(node.annotation)
673 if node.value:
674 self.write(" = ")
675 self.traverse(node.value)
676
677 def visit_Return(self, node):
678 self.fill("return")
679 if node.value:
680 self.write(" ")
681 self.traverse(node.value)
682
683 def visit_Pass(self, node):
684 self.fill("pass")
685
686 def visit_Break(self, node):
687 self.fill("break")
688
689 def visit_Continue(self, node):
690 self.fill("continue")
691
692 def visit_Delete(self, node):
693 self.fill("del ")
694 self.interleave(lambda: self.write(", "), self.traverse, node.targets)
695
696 def visit_Assert(self, node):
697 self.fill("assert ")
698 self.traverse(node.test)
699 if node.msg:
700 self.write(", ")
701 self.traverse(node.msg)
702
703 def visit_Global(self, node):
704 self.fill("global ")
705 self.interleave(lambda: self.write(", "), self.write, node.names)
706
707 def visit_Nonlocal(self, node):
708 self.fill("nonlocal ")
709 self.interleave(lambda: self.write(", "), self.write, node.names)
710
711 def visit_Await(self, node):
712 self.write("(")
713 self.write("await")
714 if node.value:
715 self.write(" ")
716 self.traverse(node.value)
717 self.write(")")
718
719 def visit_Yield(self, node):
720 self.write("(")
721 self.write("yield")
722 if node.value:
723 self.write(" ")
724 self.traverse(node.value)
725 self.write(")")
726
727 def visit_YieldFrom(self, node):
728 self.write("(")
729 self.write("yield from")
730 if node.value:
731 self.write(" ")
732 self.traverse(node.value)
733 self.write(")")
734
735 def visit_Raise(self, node):
736 self.fill("raise")
737 if not node.exc:
738 if node.cause:
739 raise ValueError(f"Node can't use cause without an exception.")
740 return
741 self.write(" ")
742 self.traverse(node.exc)
743 if node.cause:
744 self.write(" from ")
745 self.traverse(node.cause)
746
747 def visit_Try(self, node):
748 self.fill("try")
749 with self.block():
750 self.traverse(node.body)
751 for ex in node.handlers:
752 self.traverse(ex)
753 if node.orelse:
754 self.fill("else")
755 with self.block():
756 self.traverse(node.orelse)
757 if node.finalbody:
758 self.fill("finally")
759 with self.block():
760 self.traverse(node.finalbody)
761
762 def visit_ExceptHandler(self, node):
763 self.fill("except")
764 if node.type:
765 self.write(" ")
766 self.traverse(node.type)
767 if node.name:
768 self.write(" as ")
769 self.write(node.name)
770 with self.block():
771 self.traverse(node.body)
772
773 def visit_ClassDef(self, node):
774 self.write("\n")
775 for deco in node.decorator_list:
776 self.fill("@")
777 self.traverse(deco)
778 self.fill("class " + node.name)
779 self.write("(")
780 comma = False
781 for e in node.bases:
782 if comma:
783 self.write(", ")
784 else:
785 comma = True
786 self.traverse(e)
787 for e in node.keywords:
788 if comma:
789 self.write(", ")
790 else:
791 comma = True
792 self.traverse(e)
793 self.write(")")
794
795 with self.block():
796 self.traverse(node.body)
797
798 def visit_FunctionDef(self, node):
799 self.__FunctionDef_helper(node, "def")
800
801 def visit_AsyncFunctionDef(self, node):
802 self.__FunctionDef_helper(node, "async def")
803
804 def __FunctionDef_helper(self, node, fill_suffix):
805 self.write("\n")
806 for deco in node.decorator_list:
807 self.fill("@")
808 self.traverse(deco)
809 def_str = fill_suffix + " " + node.name + "("
810 self.fill(def_str)
811 self.traverse(node.args)
812 self.write(")")
813 if node.returns:
814 self.write(" -> ")
815 self.traverse(node.returns)
816 with self.block():
817 self.traverse(node.body)
818
819 def visit_For(self, node):
820 self.__For_helper("for ", node)
821
822 def visit_AsyncFor(self, node):
823 self.__For_helper("async for ", node)
824
825 def __For_helper(self, fill, node):
826 self.fill(fill)
827 self.traverse(node.target)
828 self.write(" in ")
829 self.traverse(node.iter)
830 with self.block():
831 self.traverse(node.body)
832 if node.orelse:
833 self.fill("else")
834 with self.block():
835 self.traverse(node.orelse)
836
837 def visit_If(self, node):
838 self.fill("if ")
839 self.traverse(node.test)
840 with self.block():
841 self.traverse(node.body)
842 # collapse nested ifs into equivalent elifs.
843 while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
844 node = node.orelse[0]
845 self.fill("elif ")
846 self.traverse(node.test)
847 with self.block():
848 self.traverse(node.body)
849 # final else
850 if node.orelse:
851 self.fill("else")
852 with self.block():
853 self.traverse(node.orelse)
854
855 def visit_While(self, node):
856 self.fill("while ")
857 self.traverse(node.test)
858 with self.block():
859 self.traverse(node.body)
860 if node.orelse:
861 self.fill("else")
862 with self.block():
863 self.traverse(node.orelse)
864
865 def visit_With(self, node):
866 self.fill("with ")
867 self.interleave(lambda: self.write(", "), self.traverse, node.items)
868 with self.block():
869 self.traverse(node.body)
870
871 def visit_AsyncWith(self, node):
872 self.fill("async with ")
873 self.interleave(lambda: self.write(", "), self.traverse, node.items)
874 with self.block():
875 self.traverse(node.body)
876
877 def visit_JoinedStr(self, node):
878 self.write("f")
879 self._fstring_JoinedStr(node, self.buffer_writer)
880 self.write(repr(self.buffer))
881
882 def visit_FormattedValue(self, node):
883 self.write("f")
884 self._fstring_FormattedValue(node, self.buffer_writer)
885 self.write(repr(self.buffer))
886
887 def _fstring_JoinedStr(self, node, write):
888 for value in node.values:
889 meth = getattr(self, "_fstring_" + type(value).__name__)
890 meth(value, write)
891
892 def _fstring_Constant(self, node, write):
893 if not isinstance(node.value, str):
894 raise ValueError("Constants inside JoinedStr should be a string.")
895 value = node.value.replace("{", "{{").replace("}", "}}")
896 write(value)
897
898 def _fstring_FormattedValue(self, node, write):
899 write("{")
900 expr = type(self)().visit(node.value).rstrip("\n")
901 if expr.startswith("{"):
902 write(" ") # Separate pair of opening brackets as "{ {"
903 write(expr)
904 if node.conversion != -1:
905 conversion = chr(node.conversion)
906 if conversion not in "sra":
907 raise ValueError("Unknown f-string conversion.")
908 write(f"!{conversion}")
909 if node.format_spec:
910 write(":")
911 meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
912 meth(node.format_spec, write)
913 write("}")
914
915 def visit_Name(self, node):
916 self.write(node.id)
917
918 def _write_constant(self, value):
919 if isinstance(value, (float, complex)):
920 # Substitute overflowing decimal literal for AST infinities.
921 self.write(repr(value).replace("inf", _INFSTR))
922 else:
923 self.write(repr(value))
924
925 def visit_Constant(self, node):
926 value = node.value
927 if isinstance(value, tuple):
928 self.write("(")
929 if len(value) == 1:
930 self._write_constant(value[0])
931 self.write(",")
932 else:
933 self.interleave(lambda: self.write(", "), self._write_constant, value)
934 self.write(")")
935 elif value is ...:
936 self.write("...")
937 else:
938 if node.kind == "u":
939 self.write("u")
940 self._write_constant(node.value)
941
942 def visit_List(self, node):
943 self.write("[")
944 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
945 self.write("]")
946
947 def visit_ListComp(self, node):
948 self.write("[")
949 self.traverse(node.elt)
950 for gen in node.generators:
951 self.traverse(gen)
952 self.write("]")
953
954 def visit_GeneratorExp(self, node):
955 self.write("(")
956 self.traverse(node.elt)
957 for gen in node.generators:
958 self.traverse(gen)
959 self.write(")")
960
961 def visit_SetComp(self, node):
962 self.write("{")
963 self.traverse(node.elt)
964 for gen in node.generators:
965 self.traverse(gen)
966 self.write("}")
967
968 def visit_DictComp(self, node):
969 self.write("{")
970 self.traverse(node.key)
971 self.write(": ")
972 self.traverse(node.value)
973 for gen in node.generators:
974 self.traverse(gen)
975 self.write("}")
976
977 def visit_comprehension(self, node):
978 if node.is_async:
979 self.write(" async for ")
980 else:
981 self.write(" for ")
982 self.traverse(node.target)
983 self.write(" in ")
984 self.traverse(node.iter)
985 for if_clause in node.ifs:
986 self.write(" if ")
987 self.traverse(if_clause)
988
989 def visit_IfExp(self, node):
990 self.write("(")
991 self.traverse(node.body)
992 self.write(" if ")
993 self.traverse(node.test)
994 self.write(" else ")
995 self.traverse(node.orelse)
996 self.write(")")
997
998 def visit_Set(self, node):
999 if not node.elts:
1000 raise ValueError("Set node should has at least one item")
1001 self.write("{")
1002 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1003 self.write("}")
1004
1005 def visit_Dict(self, node):
1006 self.write("{")
1007
1008 def write_key_value_pair(k, v):
1009 self.traverse(k)
1010 self.write(": ")
1011 self.traverse(v)
1012
1013 def write_item(item):
1014 k, v = item
1015 if k is None:
1016 # for dictionary unpacking operator in dicts {**{'y': 2}}
1017 # see PEP 448 for details
1018 self.write("**")
1019 self.traverse(v)
1020 else:
1021 write_key_value_pair(k, v)
1022
1023 self.interleave(
1024 lambda: self.write(", "), write_item, zip(node.keys, node.values)
1025 )
1026 self.write("}")
1027
1028 def visit_Tuple(self, node):
1029 self.write("(")
1030 if len(node.elts) == 1:
1031 elt = node.elts[0]
1032 self.traverse(elt)
1033 self.write(",")
1034 else:
1035 self.interleave(lambda: self.write(", "), self.traverse, node.elts)
1036 self.write(")")
1037
1038 unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
1039
1040 def visit_UnaryOp(self, node):
1041 self.write("(")
1042 self.write(self.unop[node.op.__class__.__name__])
1043 self.write(" ")
1044 self.traverse(node.operand)
1045 self.write(")")
1046
1047 binop = {
1048 "Add": "+",
1049 "Sub": "-",
1050 "Mult": "*",
1051 "MatMult": "@",
1052 "Div": "/",
1053 "Mod": "%",
1054 "LShift": "<<",
1055 "RShift": ">>",
1056 "BitOr": "|",
1057 "BitXor": "^",
1058 "BitAnd": "&",
1059 "FloorDiv": "//",
1060 "Pow": "**",
1061 }
1062
1063 def visit_BinOp(self, node):
1064 self.write("(")
1065 self.traverse(node.left)
1066 self.write(" " + self.binop[node.op.__class__.__name__] + " ")
1067 self.traverse(node.right)
1068 self.write(")")
1069
1070 cmpops = {
1071 "Eq": "==",
1072 "NotEq": "!=",
1073 "Lt": "<",
1074 "LtE": "<=",
1075 "Gt": ">",
1076 "GtE": ">=",
1077 "Is": "is",
1078 "IsNot": "is not",
1079 "In": "in",
1080 "NotIn": "not in",
1081 }
1082
1083 def visit_Compare(self, node):
1084 self.write("(")
1085 self.traverse(node.left)
1086 for o, e in zip(node.ops, node.comparators):
1087 self.write(" " + self.cmpops[o.__class__.__name__] + " ")
1088 self.traverse(e)
1089 self.write(")")
1090
1091 boolops = {And: "and", Or: "or"}
1092
1093 def visit_BoolOp(self, node):
1094 self.write("(")
1095 s = " %s " % self.boolops[node.op.__class__]
1096 self.interleave(lambda: self.write(s), self.traverse, node.values)
1097 self.write(")")
1098
1099 def visit_Attribute(self, node):
1100 self.traverse(node.value)
1101 # Special case: 3.__abs__() is a syntax error, so if node.value
1102 # is an integer literal then we need to either parenthesize
1103 # it or add an extra space to get 3 .__abs__().
1104 if isinstance(node.value, Constant) and isinstance(node.value.value, int):
1105 self.write(" ")
1106 self.write(".")
1107 self.write(node.attr)
1108
1109 def visit_Call(self, node):
1110 self.traverse(node.func)
1111 self.write("(")
1112 comma = False
1113 for e in node.args:
1114 if comma:
1115 self.write(", ")
1116 else:
1117 comma = True
1118 self.traverse(e)
1119 for e in node.keywords:
1120 if comma:
1121 self.write(", ")
1122 else:
1123 comma = True
1124 self.traverse(e)
1125 self.write(")")
1126
1127 def visit_Subscript(self, node):
1128 self.traverse(node.value)
1129 self.write("[")
1130 self.traverse(node.slice)
1131 self.write("]")
1132
1133 def visit_Starred(self, node):
1134 self.write("*")
1135 self.traverse(node.value)
1136
1137 def visit_Ellipsis(self, node):
1138 self.write("...")
1139
1140 def visit_Index(self, node):
1141 self.traverse(node.value)
1142
1143 def visit_Slice(self, node):
1144 if node.lower:
1145 self.traverse(node.lower)
1146 self.write(":")
1147 if node.upper:
1148 self.traverse(node.upper)
1149 if node.step:
1150 self.write(":")
1151 self.traverse(node.step)
1152
1153 def visit_ExtSlice(self, node):
1154 self.interleave(lambda: self.write(", "), self.traverse, node.dims)
1155
1156 def visit_arg(self, node):
1157 self.write(node.arg)
1158 if node.annotation:
1159 self.write(": ")
1160 self.traverse(node.annotation)
1161
1162 def visit_arguments(self, node):
1163 first = True
1164 # normal arguments
1165 all_args = node.posonlyargs + node.args
1166 defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults
1167 for index, elements in enumerate(zip(all_args, defaults), 1):
1168 a, d = elements
1169 if first:
1170 first = False
1171 else:
1172 self.write(", ")
1173 self.traverse(a)
1174 if d:
1175 self.write("=")
1176 self.traverse(d)
1177 if index == len(node.posonlyargs):
1178 self.write(", /")
1179
1180 # varargs, or bare '*' if no varargs but keyword-only arguments present
1181 if node.vararg or node.kwonlyargs:
1182 if first:
1183 first = False
1184 else:
1185 self.write(", ")
1186 self.write("*")
1187 if node.vararg:
1188 self.write(node.vararg.arg)
1189 if node.vararg.annotation:
1190 self.write(": ")
1191 self.traverse(node.vararg.annotation)
1192
1193 # keyword-only arguments
1194 if node.kwonlyargs:
1195 for a, d in zip(node.kwonlyargs, node.kw_defaults):
1196 if first:
1197 first = False
1198 else:
1199 self.write(", ")
1200 self.traverse(a),
1201 if d:
1202 self.write("=")
1203 self.traverse(d)
1204
1205 # kwargs
1206 if node.kwarg:
1207 if first:
1208 first = False
1209 else:
1210 self.write(", ")
1211 self.write("**" + node.kwarg.arg)
1212 if node.kwarg.annotation:
1213 self.write(": ")
1214 self.traverse(node.kwarg.annotation)
1215
1216 def visit_keyword(self, node):
1217 if node.arg is None:
1218 self.write("**")
1219 else:
1220 self.write(node.arg)
1221 self.write("=")
1222 self.traverse(node.value)
1223
1224 def visit_Lambda(self, node):
1225 self.write("(")
1226 self.write("lambda ")
1227 self.traverse(node.args)
1228 self.write(": ")
1229 self.traverse(node.body)
1230 self.write(")")
1231
1232 def visit_alias(self, node):
1233 self.write(node.name)
1234 if node.asname:
1235 self.write(" as " + node.asname)
1236
1237 def visit_withitem(self, node):
1238 self.traverse(node.context_expr)
1239 if node.optional_vars:
1240 self.write(" as ")
1241 self.traverse(node.optional_vars)
1242
1243def unparse(ast_obj):
1244 unparser = _Unparser()
1245 return unparser.visit(ast_obj)
1246
Serhiy Storchaka832e8642019-09-09 23:36:13 +03001247
1248def main():
1249 import argparse
1250
1251 parser = argparse.ArgumentParser(prog='python -m ast')
1252 parser.add_argument('infile', type=argparse.FileType(mode='rb'), nargs='?',
1253 default='-',
1254 help='the file to parse; defaults to stdin')
1255 parser.add_argument('-m', '--mode', default='exec',
1256 choices=('exec', 'single', 'eval', 'func_type'),
1257 help='specify what kind of code must be parsed')
1258 parser.add_argument('-a', '--include-attributes', action='store_true',
1259 help='include attributes such as line numbers and '
1260 'column offsets')
1261 args = parser.parse_args()
1262
1263 with args.infile as infile:
1264 source = infile.read()
1265 tree = parse(source, args.infile.name, args.mode, type_comments=True)
1266 print(dump(tree, include_attributes=args.include_attributes, indent=3))
1267
1268if __name__ == '__main__':
1269 main()