blob: 0c8de2f59a8d643f28b73cdb8792a42bf110017f [file] [log] [blame]
Tobias Grosser75805372011-04-29 06:27:02 +00001from ctypes import *
2
3isl = cdll.LoadLibrary("libisl.so")
4
5class Context:
6 defaultInstance = None
7 instances = {}
8
9 def __init__(self):
10 ptr = isl.isl_ctx_alloc()
11 self.ptr = ptr
12 Context.instances[ptr] = self
13
14 def __del__(self):
15 isl.isl_ctx_free(self)
16
17 def from_param(self):
18 return self.ptr
19
20 @staticmethod
21 def from_ptr(ptr):
22 return Context.instances[ptr]
23
24 @staticmethod
25 def getDefaultInstance():
26 if Context.defaultInstance == None:
27 Context.defaultInstance = Context()
28
29 return Context.defaultInstance
30
31class IslObject:
32 def __init__(self, string = "", ctx = None, ptr = None):
33 self.initialize_isl_methods()
34 if ptr != None:
35 self.ptr = ptr
36 self.ctx = self.get_isl_method("get_ctx")(self)
37 return
38
39 if ctx == None:
40 ctx = Context.getDefaultInstance()
41
42 self.ctx = ctx
43 self.ptr = self.get_isl_method("read_from_str")(ctx, string, -1)
44
45 def __del__(self):
46 self.get_isl_method("free")(self)
47
48 def from_param(self):
49 return self.ptr
50
51 @property
52 def context(self):
53 return self.ctx
54
55 def __repr__(self):
56 p = Printer(self.ctx)
57 self.to_printer(p)
58 return p.getString();
59
60 def __str__(self):
61 p = Printer(self.ctx)
62 self.to_printer(p)
63 return p.getString();
64
65 @staticmethod
66 def isl_name():
67 return "No isl name available"
68
69 def initialize_isl_methods(self):
70 if hasattr(self.__class__, "initialized"):
71 return
72
73 self.__class__.initalized = True
74 self.get_isl_method("read_from_str").argtypes = [Context, c_char_p, c_int]
75 self.get_isl_method("copy").argtypes = [self.__class__]
76 self.get_isl_method("copy").restype = c_int
77 self.get_isl_method("free").argtypes = [self.__class__]
78 self.get_isl_method("get_ctx").argtypes = [self.__class__]
79 self.get_isl_method("get_ctx").restype = Context.from_ptr
80 getattr(isl, "isl_printer_print_" + self.isl_name()).argtypes = [Printer, self.__class__]
81
82 def get_isl_method(self, name):
83 return getattr(isl, "isl_" + self.isl_name() + "_" + name)
84
85 def to_printer(self, printer):
86 getattr(isl, "isl_printer_print_" + self.isl_name())(printer, self)
87
88class BSet(IslObject):
89 @staticmethod
90 def from_ptr(ptr):
91 if not ptr:
92 return
93 return BSet(ptr = ptr)
94
95 @staticmethod
96 def isl_name():
97 return "basic_set"
98
99class Set(IslObject):
100 @staticmethod
101 def from_ptr(ptr):
102 if not ptr:
103 return
104 return Set(ptr = ptr)
105
106 @staticmethod
107 def isl_name():
108 return "set"
109
110class USet(IslObject):
111 @staticmethod
112 def from_ptr(ptr):
113 if not ptr:
114 return
115 return USet(ptr = ptr)
116
117 @staticmethod
118 def isl_name():
119 return "union_set"
120
121
122class BMap(IslObject):
123 @staticmethod
124 def from_ptr(ptr):
125 if not ptr:
126 return
127 return BMap(ptr = ptr)
128
129 def __mul__(self, set):
130 return self.intersect_domain(set)
131
132 @staticmethod
133 def isl_name():
134 return "basic_map"
135
136class Map(IslObject):
137 @staticmethod
138 def from_ptr(ptr):
139 if not ptr:
140 return
141 return Map(ptr = ptr)
142
143 def __mul__(self, set):
144 return self.intersect_domain(set)
145
146 @staticmethod
147 def isl_name():
148 return "map"
149
150 @staticmethod
151 def lex_lt(dim):
152 dim = isl.isl_dim_copy(dim)
153 return isl.isl_map_lex_lt(dim)
154
155 @staticmethod
156 def lex_le(dim):
157 dim = isl.isl_dim_copy(dim)
158 return isl.isl_map_lex_le(dim)
159
160 @staticmethod
161 def lex_gt(dim):
162 dim = isl.isl_dim_copy(dim)
163 return isl.isl_map_lex_gt(dim)
164
165 @staticmethod
166 def lex_ge(dim):
167 dim = isl.isl_dim_copy(dim)
168 return isl.isl_map_lex_ge(dim)
169
170class UMap(IslObject):
171 @staticmethod
172 def from_ptr(ptr):
173 if not ptr:
174 return
175 return UMap(ptr = ptr)
176
177 @staticmethod
178 def isl_name():
179 return "union_map"
180
181class Dim(IslObject):
182 @staticmethod
183 def from_ptr(ptr):
184 if not ptr:
185 return
186 return Dim(ptr = ptr)
187
188 @staticmethod
189 def isl_name():
190 return "dim"
191
192 def initialize_isl_methods(self):
193 if hasattr(self.__class__, "initialized"):
194 return
195
196 self.__class__.initalized = True
197 self.get_isl_method("copy").argtypes = [self.__class__]
198 self.get_isl_method("copy").restype = c_int
199 self.get_isl_method("free").argtypes = [self.__class__]
200 self.get_isl_method("get_ctx").argtypes = [self.__class__]
201 self.get_isl_method("get_ctx").restype = Context.from_ptr
202
203 def __repr__(self):
204 return str(self)
205
206 def __str__(self):
207
208 dimParam = isl.isl_dim_size(self, 1)
209 dimIn = isl.isl_dim_size(self, 2)
210 dimOut = isl.isl_dim_size(self, 3)
211
212 if dimIn:
213 return "<dim In:%s, Out:%s, Param:%s>" % (dimIn, dimOut, dimParam)
214
215 return "<dim Set:%s, Param:%s>" % (dimOut, dimParam)
216
217class Printer:
218 FORMAT_ISL = 0
219 FORMAT_POLYLIB = 1
220 FORMAT_POLYLIB_CONSTRAINTS = 2
221 FORMAT_OMEGA = 3
222 FORMAT_C = 4
223 FORMAT_LATEX = 5
224 FORMAT_EXT_POLYLIB = 6
225
226 def __init__(self, ctx = None):
227 if ctx == None:
228 ctx = Context.getDefaultInstance()
229
230 self.ctx = ctx
231 self.ptr = isl.isl_printer_to_str(ctx)
232
233 def setFormat(self, format):
234 self.ptr = isl.isl_printer_set_output_format(self, format);
235
236 def from_param(self):
237 return self.ptr
238
239 def __del__(self):
240 isl.isl_printer_free(self)
241
242 def getString(self):
243 return isl.isl_printer_get_str(self)
244
245functions = [
246 # Unary properties
247 ("is_empty", BSet, [BSet], c_int),
248 ("is_empty", Set, [Set], c_int),
249 ("is_empty", USet, [USet], c_int),
250 ("is_empty", BMap, [BMap], c_int),
251 ("is_empty", Map, [Map], c_int),
252 ("is_empty", UMap, [UMap], c_int),
253
254 # ("is_universe", Set, [Set], c_int),
255 # ("is_universe", Map, [Map], c_int),
256
257 ("is_single_valued", Map, [Map], c_int),
258
259 ("is_bijective", Map, [Map], c_int),
260
261 ("is_wrapping", BSet, [BSet], c_int),
262 ("is_wrapping", Set, [Set], c_int),
263
264 # Binary properties
265 ("is_equal", BSet, [BSet, BSet], c_int),
266 ("is_equal", Set, [Set, Set], c_int),
267 ("is_equal", USet, [USet, USet], c_int),
268 ("is_equal", BMap, [BMap, BMap], c_int),
269 ("is_equal", Map, [Map, Map], c_int),
270 ("is_equal", UMap, [UMap, UMap], c_int),
271
272 # is_disjoint missing
273
274 # ("is_subset", BSet, [BSet, BSet], c_int),
275 ("is_subset", Set, [Set, Set], c_int),
276 ("is_subset", USet, [USet, USet], c_int),
277 ("is_subset", BMap, [BMap, BMap], c_int),
278 ("is_subset", Map, [Map, Map], c_int),
279 ("is_subset", UMap, [UMap, UMap], c_int),
280 #("is_strict_subset", BSet, [BSet, BSet], c_int),
281 ("is_strict_subset", Set, [Set, Set], c_int),
282 ("is_strict_subset", USet, [USet, USet], c_int),
283 ("is_strict_subset", BMap, [BMap, BMap], c_int),
284 ("is_strict_subset", Map, [Map, Map], c_int),
285 ("is_strict_subset", UMap, [UMap, UMap], c_int),
286
287 # Unary Operations
288 ("complement", Set, [Set], Set),
289 ("reverse", BMap, [BMap], BMap),
290 ("reverse", Map, [Map], Map),
291 ("reverse", UMap, [UMap], UMap),
292
293 # Projection missing
294 ("range", BMap, [BMap], BSet),
295 ("range", Map, [Map], Set),
296 ("range", UMap, [UMap], USet),
297 ("domain", BMap, [BMap], BSet),
298 ("domain", Map, [Map], Set),
299 ("domain", UMap, [UMap], USet),
300
301 ("identity", Set, [Set], Map),
302 ("identity", USet, [USet], UMap),
303
304 ("deltas", BMap, [BMap], BSet),
305 ("deltas", Map, [Map], Set),
306 ("deltas", UMap, [UMap], USet),
307
308 ("coalesce", Set, [Set], Set),
309 ("coalesce", USet, [USet], USet),
310 ("coalesce", Map, [Map], Map),
311 ("coalesce", UMap, [UMap], UMap),
312
313 ("detect_equalities", BSet, [BSet], BSet),
314 ("detect_equalities", Set, [Set], Set),
315 ("detect_equalities", USet, [USet], USet),
316 ("detect_equalities", BMap, [BMap], BMap),
317 ("detect_equalities", Map, [Map], Map),
318 ("detect_equalities", UMap, [UMap], UMap),
319
320 ("convex_hull", Set, [Set], Set),
321 ("convex_hull", Map, [Map], Map),
322
323 ("simple_hull", Set, [Set], Set),
324 ("simple_hull", Map, [Map], Map),
325
326 ("affine_hull", BSet, [BSet], BSet),
327 ("affine_hull", Set, [Set], BSet),
328 ("affine_hull", USet, [USet], USet),
329 ("affine_hull", BMap, [BMap], BMap),
330 ("affine_hull", Map, [Map], BMap),
331 ("affine_hull", UMap, [UMap], UMap),
332
333 ("polyhedral_hull", Set, [Set], Set),
334 ("polyhedral_hull", USet, [USet], USet),
335 ("polyhedral_hull", Map, [Map], Map),
336 ("polyhedral_hull", UMap, [UMap], UMap),
337
338 # Power missing
339 # Transitive closure missing
340 # Reaching path lengths missing
341
342 ("wrap", BMap, [BMap], BSet),
343 ("wrap", Map, [Map], Set),
344 ("wrap", UMap, [UMap], USet),
345 ("unwrap", BSet, [BMap], BMap),
346 ("unwrap", Set, [Map], Map),
347 ("unwrap", USet, [UMap], UMap),
348
349 ("flatten", Set, [Set], Set),
350 ("flatten", Map, [Map], Map),
351 ("flatten_map", Set, [Set], Map),
352
353 # Dimension manipulation missing
354
355 # Binary Operations
356 ("intersect", BSet, [BSet, BSet], BSet),
357 ("intersect", Set, [Set, Set], Set),
358 ("intersect", USet, [USet, USet], USet),
359 ("intersect", BMap, [BMap, BMap], BMap),
360 ("intersect", Map, [Map, Map], Map),
361 ("intersect", UMap, [UMap, UMap], UMap),
362 ("intersect_domain", BMap, [BMap, BSet], BMap),
363 ("intersect_domain", Map, [Map, Set], Map),
364 ("intersect_domain", UMap, [UMap, USet], UMap),
365 ("intersect_range", BMap, [BMap, BSet], BMap),
366 ("intersect_range", Map, [Map, Set], Map),
367 ("intersect_range", UMap, [UMap, USet], UMap),
368
369 ("union", BSet, [BSet, BSet], Set),
370 ("union", Set, [Set, Set], Set),
371 ("union", USet, [USet, USet], USet),
372 ("union", BMap, [BMap, BMap], Map),
373 ("union", Map, [Map, Map], Map),
374 ("union", UMap, [UMap, UMap], UMap),
375
376 ("subtract", Set, [Set, Set], Set),
377 ("subtract", Map, [Map, Map], Map),
378 ("subtract", USet, [USet, USet], USet),
379 ("subtract", UMap, [UMap, UMap], UMap),
380
381 ("apply", BSet, [BSet, BMap], BSet),
382 ("apply", Set, [Set, Map], Set),
383 ("apply", USet, [USet, UMap], USet),
384 ("apply_domain", BMap, [BMap, BMap], BMap),
385 ("apply_domain", Map, [Map, Map], Map),
386 ("apply_domain", UMap, [UMap, UMap], UMap),
387 ("apply_range", BMap, [BMap, BMap], BMap),
388 ("apply_range", Map, [Map, Map], Map),
389 ("apply_range", UMap, [UMap, UMap], UMap),
390
391 ("gist", BSet, [BSet, BSet], BSet),
392 ("gist", Set, [Set, Set], Set),
393 ("gist", USet, [USet, USet], USet),
394 ("gist", BMap, [BMap, BMap], BMap),
395 ("gist", Map, [Map, Map], Map),
396 ("gist", UMap, [UMap, UMap], UMap),
397
398 # Lexicographic Optimizations
399 # partial_lexmin missing
400 ("lexmin", BSet, [BSet], BSet),
401 ("lexmin", Set, [Set], Set),
402 ("lexmin", USet, [USet], USet),
403 ("lexmin", BMap, [BMap], BMap),
404 ("lexmin", Map, [Map], Map),
405 ("lexmin", UMap, [UMap], UMap),
406
407 ("lexmax", BSet, [BSet], BSet),
408 ("lexmax", Set, [Set], Set),
409 ("lexmax", USet, [USet], USet),
410 ("lexmax", BMap, [BMap], BMap),
411 ("lexmax", Map, [Map], Map),
412 ("lexmax", UMap, [UMap], UMap),
413
414 # Undocumented
415 ("lex_lt_union_set", USet, [USet, USet], UMap),
416 ("lex_le_union_set", USet, [USet, USet], UMap),
417 ("lex_gt_union_set", USet, [USet, USet], UMap),
418 ("lex_ge_union_set", USet, [USet, USet], UMap),
419
420 ]
421keep_functions = [
422 # Unary properties
423 ("get_dim", BSet, [BSet], Dim),
424 ("get_dim", Set, [Set], Dim),
425 ("get_dim", USet, [USet], Dim),
426 ("get_dim", BMap, [BMap], Dim),
427 ("get_dim", Map, [Map], Dim),
428 ("get_dim", UMap, [UMap], Dim)
429 ]
430
431def addIslFunction(object, name):
432 functionName = "isl_" + object.isl_name() + "_" + name
433 islFunction = getattr(isl, functionName)
434 if len(islFunction.argtypes) == 1:
435 f = lambda a: islFunctionOneOp(islFunction, a)
436 elif len(islFunction.argtypes) == 2:
437 f = lambda a, b: islFunctionTwoOp(islFunction, a, b)
438 object.__dict__[name] = f
439
440
441def islFunctionOneOp(islFunction, ops):
442 ops = getattr(isl, "isl_" + ops.isl_name() + "_copy")(ops)
443 return islFunction(ops)
444
445def islFunctionTwoOp(islFunction, opOne, opTwo):
446 opOne = getattr(isl, "isl_" + opOne.isl_name() + "_copy")(opOne)
447 opTwo = getattr(isl, "isl_" + opTwo.isl_name() + "_copy")(opTwo)
448 return islFunction(opOne, opTwo)
449
450for (operation, base, operands, ret) in functions:
451 functionName = "isl_" + base.isl_name() + "_" + operation
452 islFunction = getattr(isl, functionName)
453 if len(operands) == 1:
454 islFunction.argtypes = [c_int]
455 elif len(operands) == 2:
456 islFunction.argtypes = [c_int, c_int]
457
458 if ret == c_int:
459 islFunction.restype = ret
460 else:
461 islFunction.restype = ret.from_ptr
462
463 addIslFunction(base, operation)
464
465def addIslFunctionKeep(object, name):
466 functionName = "isl_" + object.isl_name() + "_" + name
467 islFunction = getattr(isl, functionName)
468 if len(islFunction.argtypes) == 1:
469 f = lambda a: islFunctionOneOpKeep(islFunction, a)
470 elif len(islFunction.argtypes) == 2:
471 f = lambda a, b: islFunctionTwoOpKeep(islFunction, a, b)
472 object.__dict__[name] = f
473
474def islFunctionOneOpKeep(islFunction, ops):
475 return islFunction(ops)
476
477def islFunctionTwoOpKeep(islFunction, opOne, opTwo):
478 return islFunction(opOne, opTwo)
479
480for (operation, base, operands, ret) in keep_functions:
481 functionName = "isl_" + base.isl_name() + "_" + operation
482 islFunction = getattr(isl, functionName)
483 if len(operands) == 1:
484 islFunction.argtypes = [c_int]
485 elif len(operands) == 2:
486 islFunction.argtypes = [c_int, c_int]
487
488 if ret == c_int:
489 islFunction.restype = ret
490 else:
491 islFunction.restype = ret.from_ptr
492
493 addIslFunctionKeep(base, operation)
494
495isl.isl_ctx_free.argtypes = [Context]
496isl.isl_basic_set_read_from_str.argtypes = [Context, c_char_p, c_int]
497isl.isl_set_read_from_str.argtypes = [Context, c_char_p, c_int]
498isl.isl_basic_set_copy.argtypes = [BSet]
499isl.isl_basic_set_copy.restype = c_int
500isl.isl_set_copy.argtypes = [Set]
501isl.isl_set_copy.restype = c_int
502isl.isl_set_copy.argtypes = [Set]
503isl.isl_set_copy.restype = c_int
504isl.isl_set_free.argtypes = [Set]
505isl.isl_basic_set_get_ctx.argtypes = [BSet]
506isl.isl_basic_set_get_ctx.restype = Context.from_ptr
507isl.isl_set_get_ctx.argtypes = [Set]
508isl.isl_set_get_ctx.restype = Context.from_ptr
509isl.isl_basic_set_get_dim.argtypes = [BSet]
510isl.isl_basic_set_get_dim.restype = Dim.from_ptr
511isl.isl_set_get_dim.argtypes = [Set]
512isl.isl_set_get_dim.restype = Dim.from_ptr
513isl.isl_union_set_get_dim.argtypes = [USet]
514isl.isl_union_set_get_dim.restype = Dim.from_ptr
515
516isl.isl_basic_map_read_from_str.argtypes = [Context, c_char_p, c_int]
517isl.isl_map_read_from_str.argtypes = [Context, c_char_p, c_int]
518isl.isl_basic_map_free.argtypes = [BMap]
519isl.isl_map_free.argtypes = [Map]
520isl.isl_basic_map_copy.argtypes = [BMap]
521isl.isl_basic_map_copy.restype = c_int
522isl.isl_map_copy.argtypes = [Map]
523isl.isl_map_copy.restype = c_int
524isl.isl_map_get_ctx.argtypes = [Map]
525isl.isl_basic_map_get_ctx.argtypes = [BMap]
526isl.isl_basic_map_get_ctx.restype = Context.from_ptr
527isl.isl_map_get_ctx.argtypes = [Map]
528isl.isl_map_get_ctx.restype = Context.from_ptr
529isl.isl_basic_map_get_dim.argtypes = [BMap]
530isl.isl_basic_map_get_dim.restype = Dim.from_ptr
531isl.isl_map_get_dim.argtypes = [Map]
532isl.isl_map_get_dim.restype = Dim.from_ptr
533isl.isl_union_map_get_dim.argtypes = [UMap]
534isl.isl_union_map_get_dim.restype = Dim.from_ptr
535isl.isl_printer_free.argtypes = [Printer]
536isl.isl_printer_to_str.argtypes = [Context]
537isl.isl_printer_print_basic_set.argtypes = [Printer, BSet]
538isl.isl_printer_print_set.argtypes = [Printer, Set]
539isl.isl_printer_print_basic_map.argtypes = [Printer, BMap]
540isl.isl_printer_print_map.argtypes = [Printer, Map]
541isl.isl_printer_get_str.argtypes = [Printer]
542isl.isl_printer_get_str.restype = c_char_p
543isl.isl_printer_set_output_format.argtypes = [Printer, c_int]
544isl.isl_printer_set_output_format.restype = c_int
545isl.isl_dim_size.argtypes = [Dim, c_int]
546isl.isl_dim_size.restype = c_int
547
548isl.isl_map_lex_lt.argtypes = [c_int]
549isl.isl_map_lex_lt.restype = Map.from_ptr
550isl.isl_map_lex_le.argtypes = [c_int]
551isl.isl_map_lex_le.restype = Map.from_ptr
552isl.isl_map_lex_gt.argtypes = [c_int]
553isl.isl_map_lex_gt.restype = Map.from_ptr
554isl.isl_map_lex_ge.argtypes = [c_int]
555isl.isl_map_lex_ge.restype = Map.from_ptr
556
557isl.isl_union_map_compute_flow.argtypes = [c_int, c_int, c_int, c_int, c_void_p,
558 c_void_p, c_void_p, c_void_p]
559
560def dependences(sink, must_source, may_source, schedule):
561 sink = getattr(isl, "isl_" + sink.isl_name() + "_copy")(sink)
562 must_source = getattr(isl, "isl_" + must_source.isl_name() + "_copy")(must_source)
563 may_source = getattr(isl, "isl_" + may_source.isl_name() + "_copy")(may_source)
564 schedule = getattr(isl, "isl_" + schedule.isl_name() + "_copy")(schedule)
565 must_dep = c_int()
566 may_dep = c_int()
567 must_no_source = c_int()
568 may_no_source = c_int()
569 isl.isl_union_map_compute_flow(sink, must_source, may_source, schedule, \
570 byref(must_dep), byref(may_dep),
571 byref(must_no_source),
572 byref(may_no_source))
573
574 return (UMap.from_ptr(must_dep), UMap.from_ptr(may_dep), \
575 USet.from_ptr(must_no_source), USet.from_ptr(may_no_source))
576
577
578__all__ = ['Set', 'Map', 'Printer', 'Context']