blob: 8d2e919593b84d989f4aab21d445e33530b49f61 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import io
Frank Barchard1f83cf92021-09-07 14:13:03 -070010import os
XNNPACK Teamb455b122019-09-27 18:10:33 -070011import re
12import sys
13from itertools import chain
14
15
16def key_value_pair(line):
17 key, value = line.split("=", 1)
18 # represent value as integer, if possible, otherwise as str
19 try:
20 value = int(value)
21 except ValueError:
22 pass
23 return key, value
24
25
26parser = argparse.ArgumentParser(description='XNNPACK generator')
27parser.add_argument("input", metavar="FILE", nargs=1,
28 help="Input file")
29parser.add_argument("-D", dest="defines", metavar="KEY=VALUE", nargs="*",
30 type=key_value_pair, action="append",
31 help="Predefined variables")
32parser.add_argument("-o", "--output",
33 help='Output file')
34parser.set_defaults(defines=list())
35
36
37LEADING_WHITESPACE_REGEX = re.compile(r"^\s*", flags=0)
38
39
40def extract_leading_whitespace(line):
41 match = re.match(r"\s*", line)
42 return match.group(0) if match else ""
43
44
45def escape(line):
46 output_parts = []
47 while "${" in line:
48 start_pos = line.index("${")
49 end_pos = line.index("}", start_pos + 2)
50 if start_pos != 0:
51 output_parts.append("\"" + line[:start_pos].replace("\"", "\\\"") + "\"")
52 output_parts.append("str(" + line[start_pos+2:end_pos] + ")")
53 line = line[end_pos+1:]
54 if line:
55 output_parts.append("\"" + line.replace("\"", "\\\"") + "\"")
56 return " + ".join(output_parts)
57
58
Marat Dukhane0970282019-11-13 12:01:29 -080059def preprocess(input_text, input_globals, input_path="codegen"):
XNNPACK Teamb455b122019-09-27 18:10:33 -070060 input_lines = input_text.splitlines()
61 python_lines = ["from __future__ import print_function"]
62
63 blank_lines = 0
64
65 last_line = ""
66 last_indent = ""
67
68 # List of tuples (total_index, python_indent)
69 indent_stack = [("", "")]
70
71 # Indicates whether this is the first line inside Python
72 # code block (i.e. for, while, if, elif, else)
73 python_block_start = True
74 for i, input_line in enumerate(input_lines):
75 if input_line == "":
76 blank_lines += 1
77 continue
78
79 input_indent = extract_leading_whitespace(input_line)
80 if python_block_start:
81 assert input_indent.startswith(last_indent)
82 extra_python_indent = input_indent[len(last_indent):]
83 python_indent = indent_stack[-1][1] + extra_python_indent
84 indent_stack.append((input_indent, python_indent))
85 assert input_indent.startswith(indent_stack[-1][0])
86 else:
87 while not input_indent.startswith(indent_stack[-1][0]):
88 del indent_stack[-1]
89 python_block_start = False
90
91 python_indent = indent_stack[-1][1]
92 stripped_input_line = input_line.strip()
93 if stripped_input_line.startswith("$") and not stripped_input_line.startswith("${"):
94 if stripped_input_line.endswith(":"):
95 python_block_start = True
96 while blank_lines != 0:
97 python_lines.append(python_indent + "print(file=OUT_STREAM)")
98 blank_lines -= 1
99 python_lines.append(python_indent + stripped_input_line.replace("$", ""))
100 else:
101 assert input_line.startswith(python_indent)
102 while blank_lines != 0:
103 python_lines.append(python_indent + "print(file=OUT_STREAM)")
104 blank_lines -= 1
105 python_lines.append(python_indent + "print(%s, file=OUT_STREAM)" % escape(input_line[len(python_indent):]))
106 last_line = input_line
107 last_indent = input_indent
108
109 while blank_lines != 0:
110 python_lines.append(python_indent + "print(file=OUT_STREAM)")
111 blank_lines -= 1
112
113 exec_globals = dict(input_globals)
Marat Dukhane0970282019-11-13 12:01:29 -0800114 if sys.version_info > (3, 0):
115 output_stream = io.StringIO()
116 else:
117 output_stream = io.BytesIO()
XNNPACK Teamb455b122019-09-27 18:10:33 -0700118 exec_globals["OUT_STREAM"] = output_stream
Marat Dukhane0970282019-11-13 12:01:29 -0800119 python_bytecode = compile("\n".join(python_lines), input_path, 'exec')
120 exec(python_bytecode, exec_globals)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700121
122 return output_stream.getvalue()
123
124
125PREAMBLE = """\
126// Auto-generated file. Do not edit!
127// Template: {template}
128// Generator: {generator}
129//
130"""
131
132
133def main(args):
134 options = parser.parse_args(args)
135
136 input_text = codecs.open(options.input[0], "r", encoding="utf-8").read()
137 python_globals = dict(chain(*options.defines))
Frank Barchard1f83cf92021-09-07 14:13:03 -0700138 output_text = PREAMBLE.format(template=options.input[0], generator=sys.argv[0]) + preprocess(input_text, python_globals, options.input[0])
XNNPACK Teamb455b122019-09-27 18:10:33 -0700139
Frank Barchard1f83cf92021-09-07 14:13:03 -0700140 txt_changed = True
141 if os.path.exists(options.output):
142 with codecs.open(options.output, "r", encoding="utf-8") as output_file:
143 txt_changed = output_file.read() != output_text
XNNPACK Teamb455b122019-09-27 18:10:33 -0700144
Frank Barchard1f83cf92021-09-07 14:13:03 -0700145 if txt_changed:
146 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
147 output_file.write(output_text)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700148
149if __name__ == "__main__":
150 main(sys.argv[1:])