blob: 01e848e03976deb78657be7a1ebf9187fec4fe03 [file] [log] [blame]
Brian Gesiak5e0a9462017-06-29 18:56:25 +00001import sys
2import multiprocessing
3
4
5_current = None
6_total = None
7
8
9def _init(current, total):
10 global _current
11 global _total
12 _current = current
13 _total = total
14
15
16def _wrapped_func(func_and_args):
17 func, argument, should_print_progress = func_and_args
18
19 if should_print_progress:
20 with _current.get_lock():
21 _current.value += 1
22 sys.stdout.write('\r\t{} of {}'.format(_current.value, _total.value))
23
24 return func(argument)
25
26
27def pmap(func, iterable, processes, should_print_progress, *args, **kwargs):
28 """
29 A parallel map function that reports on its progress.
30
31 Applies `func` to every item of `iterable` and return a list of the
32 results. If `processes` is greater than one, a process pool is used to run
33 the functions in parallel. `should_print_progress` is a boolean value that
34 indicates whether a string 'N of M' should be printed to indicate how many
35 of the functions have finished being run.
36 """
37 global _current
38 global _total
39 _current = multiprocessing.Value('i', 0)
40 _total = multiprocessing.Value('i', len(iterable))
41
42 func_and_args = [(func, arg, should_print_progress,) for arg in iterable]
43 if processes <= 1:
44 result = map(_wrapped_func, func_and_args, *args, **kwargs)
45 else:
46 pool = multiprocessing.Pool(initializer=_init,
47 initargs=(_current, _total,),
48 processes=processes)
49 result = pool.map(_wrapped_func, func_and_args, *args, **kwargs)
50
51 if should_print_progress:
52 sys.stdout.write('\r')
53 return result