blob: 47e47026db0862fb6d27abf7a5234d5235107562 [file] [log] [blame]
Shih-wei Liaoea285162010-06-04 12:34:56 -07001"""Utilities for enumeration of finite and countably infinite sets.
2"""
3###
4# Countable iteration
5
6# Simplifies some calculations
7class Aleph0(int):
8 _singleton = None
9 def __new__(type):
10 if type._singleton is None:
11 type._singleton = int.__new__(type)
12 return type._singleton
13 def __repr__(self): return '<aleph0>'
14 def __str__(self): return 'inf'
15
16 def __cmp__(self, b):
17 return 1
18
19 def __sub__(self, b):
20 raise ValueError,"Cannot subtract aleph0"
21 __rsub__ = __sub__
22
23 def __add__(self, b):
24 return self
25 __radd__ = __add__
26
27 def __mul__(self, b):
28 if b == 0: return b
29 return self
30 __rmul__ = __mul__
31
32 def __floordiv__(self, b):
33 if b == 0: raise ZeroDivisionError
34 return self
35 __rfloordiv__ = __floordiv__
36 __truediv__ = __floordiv__
37 __rtuediv__ = __floordiv__
38 __div__ = __floordiv__
39 __rdiv__ = __floordiv__
40
41 def __pow__(self, b):
42 if b == 0: return 1
43 return self
44aleph0 = Aleph0()
45
46def base(line):
47 return line*(line+1)//2
48
49def pairToN((x,y)):
50 line,index = x+y,y
51 return base(line)+index
52
53def getNthPairInfo(N):
54 # Avoid various singularities
55 if N==0:
56 return (0,0)
57
58 # Gallop to find bounds for line
59 line = 1
60 next = 2
61 while base(next)<=N:
62 line = next
63 next = line << 1
64
65 # Binary search for starting line
66 lo = line
67 hi = line<<1
68 while lo + 1 != hi:
69 #assert base(lo) <= N < base(hi)
70 mid = (lo + hi)>>1
71 if base(mid)<=N:
72 lo = mid
73 else:
74 hi = mid
75
76 line = lo
77 return line, N - base(line)
78
79def getNthPair(N):
80 line,index = getNthPairInfo(N)
81 return (line - index, index)
82
83def getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False):
84 """getNthPairBounded(N, W, H) -> (x, y)
85
86 Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
87
88 if W <= 0 or H <= 0:
89 raise ValueError,"Invalid bounds"
90 elif N >= W*H:
91 raise ValueError,"Invalid input (out of bounds)"
92
93 # Simple case...
94 if W is aleph0 and H is aleph0:
95 return getNthPair(N)
96
97 # Otherwise simplify by assuming W < H
98 if H < W:
99 x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod)
100 return y,x
101
102 if useDivmod:
103 return N%W,N//W
104 else:
105 # Conceptually we want to slide a diagonal line across a
106 # rectangle. This gives more interesting results for large
107 # bounds than using divmod.
108
109 # If in lower left, just return as usual
110 cornerSize = base(W)
111 if N < cornerSize:
112 return getNthPair(N)
113
114 # Otherwise if in upper right, subtract from corner
115 if H is not aleph0:
116 M = W*H - N - 1
117 if M < cornerSize:
118 x,y = getNthPair(M)
119 return (W-1-x,H-1-y)
120
121 # Otherwise, compile line and index from number of times we
122 # wrap.
123 N = N - cornerSize
124 index,offset = N%W,N//W
125 # p = (W-1, 1+offset) + (-1,1)*index
126 return (W-1-index, 1+offset+index)
127def getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded):
128 x,y = GNP(N,W,H,useDivmod)
129 assert 0 <= x < W and 0 <= y < H
130 return x,y
131
132def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
133 """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
134
135 Return the N-th W-tuple, where for 0 <= x_i < H."""
136
137 if useLeftToRight:
138 elts = [None]*W
139 for i in range(W):
140 elts[i],N = getNthPairBounded(N, H)
141 return tuple(elts)
142 else:
143 if W==0:
144 return ()
145 elif W==1:
146 return (N,)
147 elif W==2:
148 return getNthPairBounded(N, H, H)
149 else:
150 LW,RW = W//2, W - (W//2)
151 L,R = getNthPairBounded(N, H**LW, H**RW)
152 return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) +
153 getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight))
154def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
155 t = GNT(N,W,H,useLeftToRight)
156 assert len(t) == W
157 for i in t:
158 assert i < H
159 return t
160
161def getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False):
162 """getNthTuple(N, maxSize, maxElement) -> x
163
164 Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
165 y < maxElement."""
166
167 # All zero sized tuples are isomorphic, don't ya know.
168 if N == 0:
169 return ()
170 N -= 1
171 if maxElement is not aleph0:
172 if maxSize is aleph0:
173 raise NotImplementedError,'Max element size without max size unhandled'
174 bounds = [maxElement**i for i in range(1, maxSize+1)]
175 S,M = getNthPairVariableBounds(N, bounds)
176 else:
177 S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
178 return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight)
179def getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0,
180 useDivmod=False, useLeftToRight=False, GNT=getNthTuple):
181 # FIXME: maxsize is inclusive
182 t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight)
183 assert len(t) <= maxSize
184 for i in t:
185 assert i < maxElement
186 return t
187
188def getNthPairVariableBounds(N, bounds):
189 """getNthPairVariableBounds(N, bounds) -> (x, y)
190
191 Given a finite list of bounds (which may be finite or aleph0),
192 return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
193 bounds[x]."""
194
195 if not bounds:
196 raise ValueError,"Invalid bounds"
197 if not (0 <= N < sum(bounds)):
198 raise ValueError,"Invalid input (out of bounds)"
199
200 level = 0
201 active = range(len(bounds))
202 active.sort(key=lambda i: bounds[i])
203 prevLevel = 0
204 for i,index in enumerate(active):
205 level = bounds[index]
206 W = len(active) - i
207 if level is aleph0:
208 H = aleph0
209 else:
210 H = level - prevLevel
211 levelSize = W*H
212 if N<levelSize: # Found the level
213 idelta,delta = getNthPairBounded(N, W, H)
214 return active[i+idelta],prevLevel+delta
215 else:
216 N -= levelSize
217 prevLevel = level
218 else:
219 raise RuntimError,"Unexpected loop completion"
220
221def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
222 x,y = GNVP(N,bounds)
223 assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
224 return (x,y)
225
226###
227
228def testPairs():
229 W = 3
230 H = 6
231 a = [[' ' for x in range(10)] for y in range(10)]
232 b = [[' ' for x in range(10)] for y in range(10)]
233 for i in range(min(W*H,40)):
234 x,y = getNthPairBounded(i,W,H)
235 x2,y2 = getNthPairBounded(i,W,H,useDivmod=True)
236 print i,(x,y),(x2,y2)
237 a[y][x] = '%2d'%i
238 b[y2][x2] = '%2d'%i
239
240 print '-- a --'
241 for ln in a[::-1]:
242 if ''.join(ln).strip():
243 print ' '.join(ln)
244 print '-- b --'
245 for ln in b[::-1]:
246 if ''.join(ln).strip():
247 print ' '.join(ln)
248
249def testPairsVB():
250 bounds = [2,2,4,aleph0,5,aleph0]
251 a = [[' ' for x in range(15)] for y in range(15)]
252 b = [[' ' for x in range(15)] for y in range(15)]
253 for i in range(min(sum(bounds),40)):
254 x,y = getNthPairVariableBounds(i, bounds)
255 print i,(x,y)
256 a[y][x] = '%2d'%i
257
258 print '-- a --'
259 for ln in a[::-1]:
260 if ''.join(ln).strip():
261 print ' '.join(ln)
262
263###
264
265# Toggle to use checked versions of enumeration routines.
266if False:
267 getNthPairVariableBounds = getNthPairVariableBoundsChecked
268 getNthPairBounded = getNthPairBoundedChecked
269 getNthNTuple = getNthNTupleChecked
270 getNthTuple = getNthTupleChecked
271
272if __name__ == '__main__':
273 testPairs()
274
275 testPairsVB()
276