blob: 34442030454804b0a50e3a37a19e6a3df3570526 [file] [log] [blame]
Greg Bedwell90d141a2018-04-18 10:27:45 +00001#!/usr/bin/env python2.7
2
3"""A test case update script.
4
5This script is a utility to update LLVM 'llvm-mca' based test cases with new
6FileCheck patterns.
7"""
8
9import argparse
10from collections import defaultdict
11import difflib
12import glob
13import os
14import sys
15import warnings
16
17from UpdateTestChecks import common
18
19
20COMMENT_CHAR = '#'
21ADVERT_PREFIX = '{} NOTE: Assertions have been autogenerated by '.format(
22 COMMENT_CHAR)
23ADVERT = '{}utils/{}'.format(ADVERT_PREFIX, os.path.basename(__file__))
24
25
26class Error(Exception):
27 """ Generic Error to be raised without printing a traceback.
28 """
29 pass
30
31
32def _warn(msg):
33 """ Log a user warning to stderr.
34 """
35 warnings.warn(msg, Warning, stacklevel=2)
36
37
38def _configure_warnings(args):
39 warnings.resetwarnings()
40 if args.w:
41 warnings.simplefilter('ignore')
42 if args.Werror:
43 warnings.simplefilter('error')
44
45
46def _showwarning(message, category, filename, lineno, file=None, line=None):
47 """ Version of warnings.showwarning that won't attempt to print out the
48 line at the location of the warning if the line text is not explicitly
49 specified.
50 """
51 if file is None:
52 file = sys.stderr
53 if line is None:
54 line = ''
55 file.write(warnings.formatwarning(message, category, filename, lineno, line))
56
57
58def _parse_args():
59 parser = argparse.ArgumentParser(description=__doc__)
60 parser.add_argument('-v', '--verbose',
61 action='store_true',
62 help='show verbose output')
63 parser.add_argument('-w',
64 action='store_true',
65 help='suppress warnings')
66 parser.add_argument('-Werror',
67 action='store_true',
68 help='promote warnings to errors')
69 parser.add_argument('--llvm-mca-binary',
70 metavar='<path>',
71 default='llvm-mca',
72 help='the binary to use to generate the test case '
73 '(default: llvm-mca)')
74 parser.add_argument('tests',
75 metavar='<test-path>',
76 nargs='+')
77 args = parser.parse_args()
78
79 _configure_warnings(args)
80
81 if os.path.basename(args.llvm_mca_binary) != 'llvm-mca':
82 _warn('unexpected binary name: {}'.format(args.llvm_mca_binary))
83
84 return args
85
86
87def _find_run_lines(input_lines, args):
88 raw_lines = [m.group(1)
89 for m in [common.RUN_LINE_RE.match(l) for l in input_lines]
90 if m]
91 run_lines = [raw_lines[0]] if len(raw_lines) > 0 else []
92 for l in raw_lines[1:]:
93 if run_lines[-1].endswith(r'\\'):
94 run_lines[-1] = run_lines[-1].rstrip('\\') + ' ' + l
95 else:
96 run_lines.append(l)
97
98 if args.verbose:
99 sys.stderr.write('Found {} RUN line{}:\n'.format(
100 len(run_lines), '' if len(run_lines) == 1 else 's'))
101 for line in run_lines:
102 sys.stderr.write(' RUN: {}\n'.format(line))
103
104 return run_lines
105
106
107def _get_run_infos(run_lines, args):
108 run_infos = []
109 for run_line in run_lines:
110 try:
111 (tool_cmd, filecheck_cmd) = tuple([cmd.strip()
112 for cmd in run_line.split('|', 1)])
113 except ValueError:
114 _warn('could not split tool and filecheck commands: {}'.format(run_line))
115 continue
116
117 tool_basename = os.path.basename(args.llvm_mca_binary)
118
119 if not tool_cmd.startswith(tool_basename + ' '):
120 _warn('skipping non-{} RUN line: {}'.format(tool_basename, run_line))
121 continue
122
123 if not filecheck_cmd.startswith('FileCheck '):
124 _warn('skipping non-FileCheck RUN line: {}'.format(run_line))
125 continue
126
127 tool_cmd_args = tool_cmd[len(tool_basename):].strip()
128 tool_cmd_args = tool_cmd_args.replace('< %s', '').replace('%s', '').strip()
129
130 check_prefixes = [item
131 for m in common.CHECK_PREFIX_RE.finditer(filecheck_cmd)
132 for item in m.group(1).split(',')]
133 if not check_prefixes:
134 check_prefixes = ['CHECK']
135
136 run_infos.append((check_prefixes, tool_cmd_args))
137
138 return run_infos
139
140
141def _get_block_infos(run_infos, test_path, args): # noqa
142 """ For each run line, run the tool with the specified args and collect the
143 output. We use the concept of 'blocks' for uniquing, where a block is
144 a series of lines of text with no more than one newline character between
145 each one. For example:
146
147 This
148 is
149 one
150 block
151
152 This is
153 another block
154
155 This is yet another block
156
157 We then build up a 'block_infos' structure containing a dict where the
158 text of each block is the key and a list of the sets of prefixes that may
159 generate that particular block. This then goes through a series of
160 transformations to minimise the amount of CHECK lines that need to be
161 written by taking advantage of common prefixes.
162 """
163
164 def _block_key(tool_args, prefixes):
165 """ Get a hashable key based on the current tool_args and prefixes.
166 """
167 return ' '.join([tool_args] + prefixes)
168
169 all_blocks = {}
170 max_block_len = 0
171
172 # Run the tool for each run line to generate all of the blocks.
173 for prefixes, tool_args in run_infos:
174 key = _block_key(tool_args, prefixes)
175 raw_tool_output = common.invoke_tool(args.llvm_mca_binary,
176 tool_args,
177 test_path)
178
179 # Replace any lines consisting of purely whitespace with empty lines.
180 raw_tool_output = '\n'.join(line if line.strip() else ''
181 for line in raw_tool_output.splitlines())
182
183 # Split blocks, stripping all trailing whitespace, but keeping preceding
184 # whitespace except for newlines so that columns will line up visually.
185 all_blocks[key] = [b.lstrip('\n').rstrip()
186 for b in raw_tool_output.split('\n\n')]
187 max_block_len = max(max_block_len, len(all_blocks[key]))
188
189 # If necessary, pad the lists of blocks with empty blocks so that they are
190 # all the same length.
191 for key in all_blocks:
192 len_to_pad = max_block_len - len(all_blocks[key])
193 all_blocks[key] += [''] * len_to_pad
194
195 # Create the block_infos structure where it is a nested dict in the form of:
196 # block number -> block text -> list of prefix sets
197 block_infos = defaultdict(lambda: defaultdict(list))
198 for prefixes, tool_args in run_infos:
199 key = _block_key(tool_args, prefixes)
200 for block_num, block_text in enumerate(all_blocks[key]):
201 block_infos[block_num][block_text].append(set(prefixes))
202
203 # Now go through the block_infos structure and attempt to smartly prune the
204 # number of prefixes per block to the minimal set possible to output.
205 for block_num in range(len(block_infos)):
206
207 # When there are multiple block texts for a block num, remove any
208 # prefixes that are common to more than one of them.
209 # E.g. [ [{ALL,FOO}] , [{ALL,BAR}] ] -> [ [{FOO}] , [{BAR}] ]
210 all_sets = [s for s in block_infos[block_num].values()]
211 pruned_sets = []
212
213 for i, setlist in enumerate(all_sets):
214 other_set_values = set([elem for j, setlist2 in enumerate(all_sets)
215 for set_ in setlist2 for elem in set_
216 if i != j])
217 pruned_sets.append([s - other_set_values for s in setlist])
218
219 for i, block_text in enumerate(block_infos[block_num]):
220
221 # When a block text matches multiple sets of prefixes, try removing any
222 # prefixes that aren't common to all of them.
223 # E.g. [ {ALL,FOO} , {ALL,BAR} ] -> [{ALL}]
224 common_values = pruned_sets[i][0].copy()
225 for s in pruned_sets[i][1:]:
226 common_values &= s
227 if common_values:
228 pruned_sets[i] = [common_values]
229
230 # Everything should be uniqued as much as possible by now. Apply the
231 # newly pruned sets to the block_infos structure.
232 # If there are any blocks of text that still match multiple prefixes,
233 # output a warning.
234 current_set = set()
235 for s in pruned_sets[i]:
236 s = sorted(list(s))
237 if s:
238 current_set.add(s[0])
239 if len(s) > 1:
240 _warn('Multiple prefixes generating same output: {} '
241 '(discarding {})'.format(','.join(s), ','.join(s[1:])))
242
243 block_infos[block_num][block_text] = sorted(list(current_set))
244
245 return block_infos
246
247
248def _write_output(test_path, input_lines, prefix_list, block_infos, # noqa
249 args):
250 prefix_set = set([prefix for prefixes, _ in prefix_list
251 for prefix in prefixes])
252 not_prefix_set = set()
253
254 output_lines = []
255 for input_line in input_lines:
256 if input_line.startswith(ADVERT_PREFIX):
257 continue
258
259 if input_line.startswith(COMMENT_CHAR):
260 m = common.CHECK_RE.match(input_line)
261 try:
262 prefix = m.group(1)
263 except AttributeError:
264 prefix = None
265
266 if '{}-NOT:'.format(prefix) in input_line:
267 not_prefix_set.add(prefix)
268
269 if prefix not in prefix_set or prefix in not_prefix_set:
270 output_lines.append(input_line)
271 continue
272
273 if common.should_add_line_to_output(input_line, prefix_set):
274 # This input line of the function body will go as-is into the output.
275 # Except make leading whitespace uniform: 2 spaces.
276 input_line = common.SCRUB_LEADING_WHITESPACE_RE.sub(r' ', input_line)
277
278 # Skip empty lines if the previous output line is also empty.
279 if input_line or output_lines[-1]:
280 output_lines.append(input_line)
281 else:
282 continue
283
284 # Add a blank line before the new checks if required.
285 if output_lines[-1]:
286 output_lines.append('')
287
288 output_check_lines = []
289 for block_num in range(len(block_infos)):
290 for block_text in sorted(block_infos[block_num]):
291 if not block_text:
292 continue
293
294 if block_infos[block_num][block_text]:
295 lines = block_text.split('\n')
296 for prefix in block_infos[block_num][block_text]:
297 if prefix in not_prefix_set:
298 _warn('not writing for prefix {0} due to presence of "{0}-NOT:" '
299 'in input file.'.format(prefix))
300 continue
301
302 output_check_lines.append(
303 '{} {}: {}'.format(COMMENT_CHAR, prefix, lines[0]).rstrip())
304 for line in lines[1:]:
305 output_check_lines.append(
306 '{} {}-NEXT: {}'.format(COMMENT_CHAR, prefix, line).rstrip())
307 output_check_lines.append('')
308
309 if output_check_lines:
310 output_lines.insert(0, ADVERT)
311 output_lines.extend(output_check_lines)
312
313 if input_lines == output_lines:
314 sys.stderr.write(' [unchanged]\n')
315 return
316
317 diff = list(difflib.Differ().compare(input_lines, output_lines))
318 sys.stderr.write(
319 ' [{} lines total ({} added, {} removed)]\n'.format(
320 len(output_lines),
321 len([l for l in diff if l[0] == '+']),
322 len([l for l in diff if l[0] == '-'])))
323
324 if args.verbose:
325 sys.stderr.write(
326 'Writing {} lines to {}...\n\n'.format(len(output_lines), test_path))
327
328 with open(test_path, 'wb') as f:
329 for line in output_lines:
330 f.write('{}\n'.format(line.rstrip()).encode())
331
332
333def main():
334 args = _parse_args()
335 test_paths = [test for pattern in args.tests for test in glob.glob(pattern)]
336 for test_path in test_paths:
337 sys.stderr.write('Test: {}\n'.format(test_path))
338
339 # Call this per test. By default each warning will only be written once
340 # per source location. Reset the warning filter so that now each warning
341 # will be written once per source location per test.
342 _configure_warnings(args)
343
344 if args.verbose:
345 sys.stderr.write(
346 'Scanning for RUN lines in test file: {}\n'.format(test_path))
347
348 if not os.path.isfile(test_path):
349 raise Error('could not find test file: {}'.format(test_path))
350
351 with open(test_path) as f:
352 input_lines = [l.rstrip() for l in f]
353
354 run_lines = _find_run_lines(input_lines, args)
355 run_infos = _get_run_infos(run_lines, args)
356 block_infos = _get_block_infos(run_infos, test_path, args)
357 _write_output(test_path, input_lines, run_infos, block_infos, args)
358
359 return 0
360
361
362if __name__ == '__main__':
363 try:
364 warnings.showwarning = _showwarning
365 sys.exit(main())
366 except Error as e:
367 sys.stdout.write('error: {}\n'.format(e))
368 sys.exit(1)