blob: 2b1f73c0aff74581e6efef133f6ca090c6166191 [file] [log] [blame]
Tim Petersbe4f0a72001-06-29 02:41:16 +00001from __future__ import nested_scopes
2
Tim Peters6ba5f792001-06-23 20:45:43 +00003tutorial_tests = """
Tim Peters1def3512001-06-23 20:27:04 +00004Let's try a simple generator:
5
6 >>> def f():
7 ... yield 1
8 ... yield 2
9
Tim Petersb9e9ff12001-06-24 03:44:52 +000010 >>> for i in f():
11 ... print i
12 1
13 2
Tim Peters1def3512001-06-23 20:27:04 +000014 >>> g = f()
15 >>> g.next()
16 1
17 >>> g.next()
18 2
Tim Petersea2e97a2001-06-24 07:10:02 +000019
Tim Peters2106ef02001-06-25 01:30:12 +000020"Falling off the end" stops the generator:
Tim Petersea2e97a2001-06-24 07:10:02 +000021
Tim Peters1def3512001-06-23 20:27:04 +000022 >>> g.next()
23 Traceback (most recent call last):
24 File "<stdin>", line 1, in ?
25 File "<stdin>", line 2, in g
26 StopIteration
27
Tim Petersea2e97a2001-06-24 07:10:02 +000028"return" also stops the generator:
Tim Peters1def3512001-06-23 20:27:04 +000029
30 >>> def f():
31 ... yield 1
32 ... return
33 ... yield 2 # never reached
34 ...
35 >>> g = f()
36 >>> g.next()
37 1
38 >>> g.next()
39 Traceback (most recent call last):
40 File "<stdin>", line 1, in ?
41 File "<stdin>", line 3, in f
42 StopIteration
43 >>> g.next() # once stopped, can't be resumed
44 Traceback (most recent call last):
45 File "<stdin>", line 1, in ?
46 StopIteration
47
48"raise StopIteration" stops the generator too:
49
50 >>> def f():
51 ... yield 1
52 ... return
53 ... yield 2 # never reached
54 ...
55 >>> g = f()
56 >>> g.next()
57 1
58 >>> g.next()
59 Traceback (most recent call last):
60 File "<stdin>", line 1, in ?
61 StopIteration
62 >>> g.next()
63 Traceback (most recent call last):
64 File "<stdin>", line 1, in ?
65 StopIteration
66
67However, they are not exactly equivalent:
68
69 >>> def g1():
70 ... try:
71 ... return
72 ... except:
73 ... yield 1
74 ...
75 >>> list(g1())
76 []
77
78 >>> def g2():
79 ... try:
80 ... raise StopIteration
81 ... except:
82 ... yield 42
83 >>> print list(g2())
84 [42]
85
86This may be surprising at first:
87
88 >>> def g3():
89 ... try:
90 ... return
91 ... finally:
92 ... yield 1
93 ...
94 >>> list(g3())
95 [1]
96
97Let's create an alternate range() function implemented as a generator:
98
99 >>> def yrange(n):
100 ... for i in range(n):
101 ... yield i
102 ...
103 >>> list(yrange(5))
104 [0, 1, 2, 3, 4]
105
106Generators always return to the most recent caller:
107
108 >>> def creator():
109 ... r = yrange(5)
110 ... print "creator", r.next()
111 ... return r
112 ...
113 >>> def caller():
114 ... r = creator()
115 ... for i in r:
116 ... print "caller", i
117 ...
118 >>> caller()
119 creator 0
120 caller 1
121 caller 2
122 caller 3
123 caller 4
124
125Generators can call other generators:
126
127 >>> def zrange(n):
128 ... for i in yrange(n):
129 ... yield i
130 ...
131 >>> list(zrange(5))
132 [0, 1, 2, 3, 4]
133
134"""
135
Tim Peters6ba5f792001-06-23 20:45:43 +0000136# The examples from PEP 255.
137
138pep_tests = """
139
140Specification: Return
141
142 Note that return isn't always equivalent to raising StopIteration: the
143 difference lies in how enclosing try/except constructs are treated.
144 For example,
145
146 >>> def f1():
147 ... try:
148 ... return
149 ... except:
150 ... yield 1
151 >>> print list(f1())
152 []
153
154 because, as in any function, return simply exits, but
155
156 >>> def f2():
157 ... try:
158 ... raise StopIteration
159 ... except:
160 ... yield 42
161 >>> print list(f2())
162 [42]
163
164 because StopIteration is captured by a bare "except", as is any
165 exception.
166
167Specification: Generators and Exception Propagation
168
169 >>> def f():
170 ... return 1/0
171 >>> def g():
172 ... yield f() # the zero division exception propagates
173 ... yield 42 # and we'll never get here
174 >>> k = g()
175 >>> k.next()
176 Traceback (most recent call last):
177 File "<stdin>", line 1, in ?
178 File "<stdin>", line 2, in g
179 File "<stdin>", line 2, in f
180 ZeroDivisionError: integer division or modulo by zero
181 >>> k.next() # and the generator cannot be resumed
182 Traceback (most recent call last):
183 File "<stdin>", line 1, in ?
184 StopIteration
185 >>>
186
187Specification: Try/Except/Finally
188
189 >>> def f():
190 ... try:
191 ... yield 1
192 ... try:
193 ... yield 2
194 ... 1/0
195 ... yield 3 # never get here
196 ... except ZeroDivisionError:
197 ... yield 4
198 ... yield 5
199 ... raise
200 ... except:
201 ... yield 6
202 ... yield 7 # the "raise" above stops this
203 ... except:
204 ... yield 8
205 ... yield 9
206 ... try:
207 ... x = 12
208 ... finally:
209 ... yield 10
210 ... yield 11
211 >>> print list(f())
212 [1, 2, 4, 5, 8, 9, 10, 11]
213 >>>
214
Tim Peters6ba5f792001-06-23 20:45:43 +0000215Guido's binary tree example.
216
217 >>> # A binary tree class.
218 >>> class Tree:
219 ...
220 ... def __init__(self, label, left=None, right=None):
221 ... self.label = label
222 ... self.left = left
223 ... self.right = right
224 ...
225 ... def __repr__(self, level=0, indent=" "):
226 ... s = level*indent + `self.label`
227 ... if self.left:
228 ... s = s + "\\n" + self.left.__repr__(level+1, indent)
229 ... if self.right:
230 ... s = s + "\\n" + self.right.__repr__(level+1, indent)
231 ... return s
232 ...
233 ... def __iter__(self):
234 ... return inorder(self)
235
236 >>> # Create a Tree from a list.
237 >>> def tree(list):
238 ... n = len(list)
239 ... if n == 0:
240 ... return []
241 ... i = n / 2
242 ... return Tree(list[i], tree(list[:i]), tree(list[i+1:]))
243
244 >>> # Show it off: create a tree.
245 >>> t = tree("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
246
247 >>> # A recursive generator that generates Tree leaves in in-order.
248 >>> def inorder(t):
249 ... if t:
250 ... for x in inorder(t.left):
251 ... yield x
252 ... yield t.label
253 ... for x in inorder(t.right):
254 ... yield x
255
256 >>> # Show it off: create a tree.
257 ... t = tree("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
258 ... # Print the nodes of the tree in in-order.
259 ... for x in t:
260 ... print x,
261 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z
262
263 >>> # A non-recursive generator.
264 >>> def inorder(node):
265 ... stack = []
266 ... while node:
267 ... while node.left:
268 ... stack.append(node)
269 ... node = node.left
270 ... yield node.label
271 ... while not node.right:
272 ... try:
273 ... node = stack.pop()
274 ... except IndexError:
275 ... return
276 ... yield node.label
277 ... node = node.right
278
279 >>> # Exercise the non-recursive generator.
280 >>> for x in t:
281 ... print x,
282 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z
283
284"""
285
Tim Petersb2bc6a92001-06-24 10:14:27 +0000286# Examples from Iterator-List and Python-Dev and c.l.py.
Tim Peters6ba5f792001-06-23 20:45:43 +0000287
288email_tests = """
289
290The difference between yielding None and returning it.
291
292>>> def g():
293... for i in range(3):
294... yield None
295... yield None
296... return
297>>> list(g())
298[None, None, None, None]
299
300Ensure that explicitly raising StopIteration acts like any other exception
301in try/except, not like a return.
302
303>>> def g():
304... yield 1
305... try:
306... raise StopIteration
307... except:
308... yield 2
309... yield 3
310>>> list(g())
311[1, 2, 3]
Tim Petersb9e9ff12001-06-24 03:44:52 +0000312
313A generator can't be resumed while it's already running.
314
315>>> def g():
316... i = me.next()
317... yield i
318>>> me = g()
319>>> me.next()
320Traceback (most recent call last):
321 ...
322 File "<string>", line 2, in g
323ValueError: generator already executing
Tim Petersb2bc6a92001-06-24 10:14:27 +0000324
325Next one was posted to c.l.py.
326
327>>> def gcomb(x, k):
328... "Generate all combinations of k elements from list x."
329...
330... if k > len(x):
331... return
332... if k == 0:
333... yield []
334... else:
335... first, rest = x[0], x[1:]
336... # A combination does or doesn't contain first.
337... # If it does, the remainder is a k-1 comb of rest.
338... for c in gcomb(rest, k-1):
339... c.insert(0, first)
340... yield c
341... # If it doesn't contain first, it's a k comb of rest.
342... for c in gcomb(rest, k):
343... yield c
344
345>>> seq = range(1, 5)
346>>> for k in range(len(seq) + 2):
347... print "%d-combs of %s:" % (k, seq)
348... for c in gcomb(seq, k):
349... print " ", c
3500-combs of [1, 2, 3, 4]:
351 []
3521-combs of [1, 2, 3, 4]:
353 [1]
354 [2]
355 [3]
356 [4]
3572-combs of [1, 2, 3, 4]:
358 [1, 2]
359 [1, 3]
360 [1, 4]
361 [2, 3]
362 [2, 4]
363 [3, 4]
3643-combs of [1, 2, 3, 4]:
365 [1, 2, 3]
366 [1, 2, 4]
367 [1, 3, 4]
368 [2, 3, 4]
3694-combs of [1, 2, 3, 4]:
370 [1, 2, 3, 4]
3715-combs of [1, 2, 3, 4]:
Tim Peters3e7b1a02001-06-25 19:46:25 +0000372
Tim Peterse77f2e22001-06-26 22:24:51 +0000373From the Iterators list, about the types of these things.
Tim Peters3e7b1a02001-06-25 19:46:25 +0000374
375>>> def g():
376... yield 1
377...
378>>> type(g)
379<type 'function'>
380>>> i = g()
381>>> type(i)
382<type 'generator'>
383>>> dir(i)
Tim Peterse77f2e22001-06-26 22:24:51 +0000384['gi_frame', 'gi_running', 'next']
Tim Peters3e7b1a02001-06-25 19:46:25 +0000385>>> print i.next.__doc__
386next() -- get the next value, or raise StopIteration
387>>> iter(i) is i
3881
389>>> import types
390>>> isinstance(i, types.GeneratorType)
3911
Tim Peterse77f2e22001-06-26 22:24:51 +0000392
393And more, added later.
394
395>>> i.gi_running
3960
397>>> type(i.gi_frame)
398<type 'frame'>
399>>> i.gi_running = 42
400Traceback (most recent call last):
401 ...
402TypeError: object has read-only attributes
403>>> def g():
404... yield me.gi_running
405>>> me = g()
406>>> me.gi_running
4070
408>>> me.next()
4091
410>>> me.gi_running
4110
Tim Peters6ba5f792001-06-23 20:45:43 +0000412"""
413
Tim Peters0f9da0a2001-06-23 21:01:47 +0000414# Fun tests (for sufficiently warped notions of "fun").
415
416fun_tests = """
417
418Build up to a recursive Sieve of Eratosthenes generator.
419
420>>> def firstn(g, n):
421... return [g.next() for i in range(n)]
422
423>>> def intsfrom(i):
424... while 1:
425... yield i
426... i += 1
427
428>>> firstn(intsfrom(5), 7)
429[5, 6, 7, 8, 9, 10, 11]
430
431>>> def exclude_multiples(n, ints):
432... for i in ints:
433... if i % n:
434... yield i
435
436>>> firstn(exclude_multiples(3, intsfrom(1)), 6)
437[1, 2, 4, 5, 7, 8]
438
439>>> def sieve(ints):
440... prime = ints.next()
441... yield prime
442... not_divisible_by_prime = exclude_multiples(prime, ints)
443... for p in sieve(not_divisible_by_prime):
444... yield p
445
446>>> primes = sieve(intsfrom(2))
447>>> firstn(primes, 20)
448[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]
Tim Petersb9e9ff12001-06-24 03:44:52 +0000449
Tim Petersf6ed0742001-06-27 07:17:57 +0000450
Tim Petersb9e9ff12001-06-24 03:44:52 +0000451Another famous problem: generate all integers of the form
452 2**i * 3**j * 5**k
453in increasing order, where i,j,k >= 0. Trickier than it may look at first!
454Try writing it without generators, and correctly, and without generating
4553 internal results for each result output.
456
457>>> def times(n, g):
458... for i in g:
459... yield n * i
460>>> firstn(times(10, intsfrom(1)), 10)
461[10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
462
463>>> def merge(g, h):
464... ng = g.next()
465... nh = h.next()
466... while 1:
467... if ng < nh:
468... yield ng
469... ng = g.next()
470... elif ng > nh:
471... yield nh
472... nh = h.next()
473... else:
474... yield ng
475... ng = g.next()
476... nh = h.next()
477
Tim Petersf6ed0742001-06-27 07:17:57 +0000478The following works, but is doing a whale of a lot of redundant work --
479it's not clear how to get the internal uses of m235 to share a single
480generator. Note that me_times2 (etc) each need to see every element in the
481result sequence. So this is an example where lazy lists are more natural
482(you can look at the head of a lazy list any number of times).
Tim Petersb9e9ff12001-06-24 03:44:52 +0000483
484>>> def m235():
485... yield 1
486... me_times2 = times(2, m235())
487... me_times3 = times(3, m235())
488... me_times5 = times(5, m235())
489... for i in merge(merge(me_times2,
490... me_times3),
491... me_times5):
492... yield i
493
Tim Petersf6ed0742001-06-27 07:17:57 +0000494Don't print "too many" of these -- the implementation above is extremely
495inefficient: each call of m235() leads to 3 recursive calls, and in
496turn each of those 3 more, and so on, and so on, until we've descended
497enough levels to satisfy the print stmts. Very odd: when I printed 5
498lines of results below, this managed to screw up Win98's malloc in "the
499usual" way, i.e. the heap grew over 4Mb so Win98 started fragmenting
500address space, and it *looked* like a very slow leak.
501
Tim Petersb9e9ff12001-06-24 03:44:52 +0000502>>> result = m235()
Tim Petersf6ed0742001-06-27 07:17:57 +0000503>>> for i in range(3):
Tim Petersb9e9ff12001-06-24 03:44:52 +0000504... print firstn(result, 15)
505[1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 16, 18, 20, 24]
506[25, 27, 30, 32, 36, 40, 45, 48, 50, 54, 60, 64, 72, 75, 80]
507[81, 90, 96, 100, 108, 120, 125, 128, 135, 144, 150, 160, 162, 180, 192]
Tim Petersee309272001-06-24 05:47:06 +0000508
509Heh. Here's one way to get a shared list, complete with an excruciating
510namespace renaming trick. The *pretty* part is that the times() and merge()
511functions can be reused as-is, because they only assume their stream
512arguments are iterable -- a LazyList is the same as a generator to times().
513
514>>> class LazyList:
515... def __init__(self, g):
516... self.sofar = []
517... self.fetch = g.next
518...
519... def __getitem__(self, i):
520... sofar, fetch = self.sofar, self.fetch
521... while i >= len(sofar):
522... sofar.append(fetch())
523... return sofar[i]
Tim Petersf6ed0742001-06-27 07:17:57 +0000524...
525... def clear(self):
526... self.__dict__.clear()
Tim Petersee309272001-06-24 05:47:06 +0000527
528>>> def m235():
529... yield 1
Tim Petersea2e97a2001-06-24 07:10:02 +0000530... # Gack: m235 below actually refers to a LazyList.
Tim Petersee309272001-06-24 05:47:06 +0000531... me_times2 = times(2, m235)
532... me_times3 = times(3, m235)
533... me_times5 = times(5, m235)
534... for i in merge(merge(me_times2,
535... me_times3),
536... me_times5):
537... yield i
538
Tim Petersf6ed0742001-06-27 07:17:57 +0000539Print as many of these as you like -- *this* implementation is memory-
540efficient. XXX Except that it leaks unless you clear the dict!
541
Tim Petersee309272001-06-24 05:47:06 +0000542>>> m235 = LazyList(m235())
543>>> for i in range(5):
544... print [m235[j] for j in range(15*i, 15*(i+1))]
545[1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 16, 18, 20, 24]
546[25, 27, 30, 32, 36, 40, 45, 48, 50, 54, 60, 64, 72, 75, 80]
547[81, 90, 96, 100, 108, 120, 125, 128, 135, 144, 150, 160, 162, 180, 192]
548[200, 216, 225, 240, 243, 250, 256, 270, 288, 300, 320, 324, 360, 375, 384]
549[400, 405, 432, 450, 480, 486, 500, 512, 540, 576, 600, 625, 640, 648, 675]
Tim Petersf6ed0742001-06-27 07:17:57 +0000550
551>>> m235.clear() # XXX memory leak without this
552
553
554Ye olde Fibonacci generator, LazyList style.
555
556>>> def fibgen(a, b):
557...
558... def sum(g, h):
559... while 1:
560... yield g.next() + h.next()
561...
562... def tail(g):
563... g.next() # throw first away
564... for x in g:
565... yield x
566...
567... yield a
568... yield b
569... for s in sum(iter(fib),
570... tail(iter(fib))):
571... yield s
572
573>>> fib = LazyList(fibgen(1, 2))
574>>> firstn(iter(fib), 17)
575[1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584]
576
577>>> fib.clear() # XXX memory leak without this
Tim Peters0f9da0a2001-06-23 21:01:47 +0000578"""
579
Tim Petersb6c3cea2001-06-26 03:36:28 +0000580# syntax_tests mostly provokes SyntaxErrors. Also fiddling with #if 0
581# hackery.
Tim Petersee309272001-06-24 05:47:06 +0000582
Tim Petersea2e97a2001-06-24 07:10:02 +0000583syntax_tests = """
584
585>>> def f():
586... return 22
587... yield 1
588Traceback (most recent call last):
589 ...
590SyntaxError: 'return' with argument inside generator (<string>, line 2)
591
592>>> def f():
593... yield 1
594... return 22
595Traceback (most recent call last):
596 ...
597SyntaxError: 'return' with argument inside generator (<string>, line 3)
598
599"return None" is not the same as "return" in a generator:
600
601>>> def f():
602... yield 1
603... return None
604Traceback (most recent call last):
605 ...
606SyntaxError: 'return' with argument inside generator (<string>, line 3)
607
608This one is fine:
609
610>>> def f():
611... yield 1
612... return
613
614>>> def f():
615... try:
616... yield 1
617... finally:
618... pass
619Traceback (most recent call last):
620 ...
621SyntaxError: 'yield' not allowed in a 'try' block with a 'finally' clause (<string>, line 3)
622
623>>> def f():
624... try:
625... try:
626... 1/0
627... except ZeroDivisionError:
628... yield 666 # bad because *outer* try has finally
629... except:
630... pass
631... finally:
632... pass
633Traceback (most recent call last):
634 ...
635SyntaxError: 'yield' not allowed in a 'try' block with a 'finally' clause (<string>, line 6)
636
637But this is fine:
638
639>>> def f():
640... try:
641... try:
642... yield 12
643... 1/0
644... except ZeroDivisionError:
645... yield 666
646... except:
647... try:
648... x = 12
649... finally:
650... yield 12
651... except:
652... return
653>>> list(f())
654[12, 666]
Tim Petersb6c3cea2001-06-26 03:36:28 +0000655
656>>> def f():
Tim Peters08a898f2001-06-28 01:52:22 +0000657... yield
658Traceback (most recent call last):
659SyntaxError: invalid syntax
660
661>>> def f():
662... if 0:
663... yield
664Traceback (most recent call last):
665SyntaxError: invalid syntax
666
667>>> def f():
Tim Petersb6c3cea2001-06-26 03:36:28 +0000668... if 0:
669... yield 1
670>>> type(f())
671<type 'generator'>
672
673>>> def f():
674... if "":
675... yield None
676>>> type(f())
677<type 'generator'>
678
679>>> def f():
680... return
681... try:
682... if x==4:
683... pass
684... elif 0:
685... try:
686... 1/0
687... except SyntaxError:
688... pass
689... else:
690... if 0:
691... while 12:
692... x += 1
693... yield 2 # don't blink
694... f(a, b, c, d, e)
695... else:
696... pass
697... except:
698... x = 1
699... return
700>>> type(f())
701<type 'generator'>
702
703>>> def f():
704... if 0:
705... def g():
706... yield 1
707...
708>>> type(f())
709<type 'None'>
710
711>>> def f():
712... if 0:
713... class C:
714... def __init__(self):
715... yield 1
716... def f(self):
717... yield 2
718>>> type(f())
719<type 'None'>
Tim Peters08a898f2001-06-28 01:52:22 +0000720
721>>> def f():
722... if 0:
723... return
724... if 0:
725... yield 2
726>>> type(f())
727<type 'generator'>
728
729
730>>> def f():
731... if 0:
732... lambda x: x # shouldn't trigger here
733... return # or here
734... def f(i):
735... return 2*i # or here
736... if 0:
737... return 3 # but *this* sucks (line 8)
738... if 0:
739... yield 2 # because it's a generator
740Traceback (most recent call last):
741SyntaxError: 'return' with argument inside generator (<string>, line 8)
Tim Petersea2e97a2001-06-24 07:10:02 +0000742"""
743
Tim Petersbe4f0a72001-06-29 02:41:16 +0000744# conjoin is a simple backtracking generator, named in honor of Icon's
745# "conjunction" control structure. Pass a list of no-argument functions
746# that return iterable objects. Easiest to explain by example: assume the
747# function list [x, y, z] is passed. Then conjoin acts like:
748#
749# def g():
750# values = [None] * 3
751# for values[0] in x():
752# for values[1] in y():
753# for values[2] in z():
754# yield values
755#
756# So some 3-lists of values *may* be generated, each time we successfully
757# get into the innermost loop. If an iterator fails (is exhausted) before
758# then, it "backtracks" to get the next value from the nearest enclosing
759# iterator (the one "to the left"), and starts all over again at the next
760# slot (pumps a fresh iterator). Of course this is most useful when the
761# iterators have side-effects, so that which values *can* be generated at
762# each slot depend on the values iterated at previous slots.
763
764def conjoin(gs):
765
766 values = [None] * len(gs)
767
768 def gen(i, values=values):
769 if i >= len(gs):
770 yield values
771 else:
772 for values[i] in gs[i]():
773 for x in gen(i+1):
774 yield x
775
776 for x in gen(0):
777 yield x
778
Tim Petersc468fd22001-06-30 07:29:44 +0000779# That works fine, but recursing a level and checking i against len(gs) for
780# each item produced is inefficient. By doing manual loop unrolling across
781# generator boundaries, it's possible to eliminate most of that overhead.
782# This isn't worth the bother *in general* for generators, but conjoin() is
783# a core building block for some CPU-intensive generator applications.
784
785def conjoin(gs):
786
787 n = len(gs)
788 values = [None] * n
789
790 # Do one loop nest at time recursively, until the # of loop nests
791 # remaining is divisible by 3.
792
793 def gen(i, values=values):
794 if i >= n:
795 yield values
796
797 elif (n-i) % 3:
798 ip1 = i+1
799 for values[i] in gs[i]():
800 for x in gen(ip1):
801 yield x
802
803 else:
804 for x in _gen3(i):
805 yield x
806
807 # Do three loop nests at a time, recursing only if at least three more
808 # remain. Don't call directly: this is an internal optimization for
809 # gen's use.
810
811 def _gen3(i, values=values):
812 assert i < n and (n-i) % 3 == 0
813 ip1, ip2, ip3 = i+1, i+2, i+3
814 g, g1, g2 = gs[i : ip3]
815
816 if ip3 >= n:
817 # These are the last three, so we can yield values directly.
818 for values[i] in g():
819 for values[ip1] in g1():
820 for values[ip2] in g2():
821 yield values
822
823 else:
824 # At least 6 loop nests remain; peel off 3 and recurse for the
825 # rest.
826 for values[i] in g():
827 for values[ip1] in g1():
828 for values[ip2] in g2():
829 for x in _gen3(ip3):
830 yield x
831
832 for x in gen(0):
833 yield x
834
Tim Petersbe4f0a72001-06-29 02:41:16 +0000835# A conjoin-based N-Queens solver.
836
837class Queens:
838 def __init__(self, n):
839 self.n = n
840 rangen = range(n)
841
842 # Assign a unique int to each column and diagonal.
843 # columns: n of those, range(n).
844 # NW-SE diagonals: 2n-1 of these, i-j unique and invariant along
845 # each, smallest i-j is 0-(n-1) = 1-n, so add n-1 to shift to 0-
846 # based.
847 # NE-SW diagonals: 2n-1 of these, i+j unique and invariant along
848 # each, smallest i+j is 0, largest is 2n-2.
849
850 # For each square, compute a bit vector of the columns and
851 # diagonals it covers, and for each row compute a function that
852 # generates the possiblities for the columns in that row.
853 self.rowgenerators = []
854 for i in rangen:
855 rowuses = [(1L << j) | # column ordinal
856 (1L << (n + i-j + n-1)) | # NW-SE ordinal
857 (1L << (n + 2*n-1 + i+j)) # NE-SW ordinal
858 for j in rangen]
859
860 def rowgen(rowuses=rowuses):
861 for j in rangen:
862 uses = rowuses[j]
Tim Petersc468fd22001-06-30 07:29:44 +0000863 if uses & self.used == 0:
864 self.used |= uses
865 yield j
866 self.used &= ~uses
Tim Petersbe4f0a72001-06-29 02:41:16 +0000867
868 self.rowgenerators.append(rowgen)
869
870 # Generate solutions.
871 def solve(self):
872 self.used = 0
873 for row2col in conjoin(self.rowgenerators):
874 yield row2col
875
876 def printsolution(self, row2col):
877 n = self.n
878 assert n == len(row2col)
879 sep = "+" + "-+" * n
880 print sep
881 for i in range(n):
882 squares = [" " for j in range(n)]
883 squares[row2col[i]] = "Q"
884 print "|" + "|".join(squares) + "|"
885 print sep
886
887conjoin_tests = """
888
889Generate the 3-bit binary numbers in order. This illustrates dumbest-
890possible use of conjoin, just to generate the full cross-product.
891
Tim Petersc468fd22001-06-30 07:29:44 +0000892>>> for c in conjoin([lambda: (0, 1)] * 3):
Tim Petersbe4f0a72001-06-29 02:41:16 +0000893... print c
894[0, 0, 0]
895[0, 0, 1]
896[0, 1, 0]
897[0, 1, 1]
898[1, 0, 0]
899[1, 0, 1]
900[1, 1, 0]
901[1, 1, 1]
902
Tim Petersc468fd22001-06-30 07:29:44 +0000903For efficiency in typical backtracking apps, conjoin() yields the same list
904object each time. So if you want to save away a full account of its
905generated sequence, you need to copy its results.
906
907>>> def gencopy(iterator):
908... for x in iterator:
909... yield x[:]
910
911>>> for n in range(10):
912... all = list(gencopy(conjoin([lambda: (0, 1)] * n)))
913... print n, len(all), all[0] == [0] * n, all[-1] == [1] * n
9140 1 1 1
9151 2 1 1
9162 4 1 1
9173 8 1 1
9184 16 1 1
9195 32 1 1
9206 64 1 1
9217 128 1 1
9228 256 1 1
9239 512 1 1
924
Tim Petersbe4f0a72001-06-29 02:41:16 +0000925And run an 8-queens solver.
926
927>>> q = Queens(8)
928>>> LIMIT = 2
929>>> count = 0
930>>> for row2col in q.solve():
931... count += 1
932... if count <= LIMIT:
933... print "Solution", count
934... q.printsolution(row2col)
935Solution 1
936+-+-+-+-+-+-+-+-+
937|Q| | | | | | | |
938+-+-+-+-+-+-+-+-+
939| | | | |Q| | | |
940+-+-+-+-+-+-+-+-+
941| | | | | | | |Q|
942+-+-+-+-+-+-+-+-+
943| | | | | |Q| | |
944+-+-+-+-+-+-+-+-+
945| | |Q| | | | | |
946+-+-+-+-+-+-+-+-+
947| | | | | | |Q| |
948+-+-+-+-+-+-+-+-+
949| |Q| | | | | | |
950+-+-+-+-+-+-+-+-+
951| | | |Q| | | | |
952+-+-+-+-+-+-+-+-+
953Solution 2
954+-+-+-+-+-+-+-+-+
955|Q| | | | | | | |
956+-+-+-+-+-+-+-+-+
957| | | | | |Q| | |
958+-+-+-+-+-+-+-+-+
959| | | | | | | |Q|
960+-+-+-+-+-+-+-+-+
961| | |Q| | | | | |
962+-+-+-+-+-+-+-+-+
963| | | | | | |Q| |
964+-+-+-+-+-+-+-+-+
965| | | |Q| | | | |
966+-+-+-+-+-+-+-+-+
967| |Q| | | | | | |
968+-+-+-+-+-+-+-+-+
969| | | | |Q| | | |
970+-+-+-+-+-+-+-+-+
971
972>>> print count, "solutions in all."
97392 solutions in all.
974"""
975
Tim Petersf6ed0742001-06-27 07:17:57 +0000976__test__ = {"tut": tutorial_tests,
977 "pep": pep_tests,
978 "email": email_tests,
979 "fun": fun_tests,
Tim Petersbe4f0a72001-06-29 02:41:16 +0000980 "syntax": syntax_tests,
981 "conjoin": conjoin_tests}
Tim Peters1def3512001-06-23 20:27:04 +0000982
983# Magic test name that regrtest.py invokes *after* importing this module.
984# This worms around a bootstrap problem.
985# Note that doctest and regrtest both look in sys.argv for a "-v" argument,
986# so this works as expected in both ways of running regrtest.
987def test_main():
988 import doctest, test_generators
Tim Peters2106ef02001-06-25 01:30:12 +0000989 if 0:
990 # Temporary block to help track down leaks. So far, the blame
Tim Petersf6ed0742001-06-27 07:17:57 +0000991 # fell mostly on doctest. Later: the only leaks remaining are
992 # in fun_tests, and only if you comment out the two LazyList.clear()
993 # calls.
994 for i in range(10000):
Tim Peters2106ef02001-06-25 01:30:12 +0000995 doctest.master = None
996 doctest.testmod(test_generators)
997 else:
998 doctest.testmod(test_generators)
Tim Peters1def3512001-06-23 20:27:04 +0000999
1000# This part isn't needed for regrtest, but for running the test directly.
1001if __name__ == "__main__":
1002 test_main()