blob: 634812d49d6c4459a583c33dc9df7e1d4722b1ed [file] [log] [blame]
Anthony Baxterc51ee692006-04-01 00:57:31 +00001#-*- coding: ISO-8859-1 -*-
2# pysqlite2/test/userfunctions.py: tests for user-defined functions and
3# aggregates.
4#
Gerhard Häring2a11c052008-03-28 20:08:36 +00005# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
Anthony Baxterc51ee692006-04-01 00:57:31 +00006#
7# This file is part of pysqlite.
8#
9# This software is provided 'as-is', without any express or implied
10# warranty. In no event will the authors be held liable for any damages
11# arising from the use of this software.
12#
13# Permission is granted to anyone to use this software for any purpose,
14# including commercial applications, and to alter it and redistribute it
15# freely, subject to the following restrictions:
16#
17# 1. The origin of this software must not be misrepresented; you must not
18# claim that you wrote the original software. If you use this software
19# in a product, an acknowledgment in the product documentation would be
20# appreciated but is not required.
21# 2. Altered source versions must be plainly marked as such, and must not be
22# misrepresented as being the original software.
23# 3. This notice may not be removed or altered from any source distribution.
24
25import unittest
26import sqlite3 as sqlite
27
28def func_returntext():
29 return "foo"
30def func_returnunicode():
31 return u"bar"
32def func_returnint():
33 return 42
34def func_returnfloat():
35 return 3.14
36def func_returnnull():
37 return None
38def func_returnblob():
39 return buffer("blob")
Petri Lehtinen4ab701b2012-02-21 13:58:40 +020040def func_returnlonglong():
41 return 1<<31
Anthony Baxterc51ee692006-04-01 00:57:31 +000042def func_raiseexception():
Ezio Melottidde5b942010-02-03 05:37:26 +000043 5 // 0
Anthony Baxterc51ee692006-04-01 00:57:31 +000044
45def func_isstring(v):
46 return type(v) is unicode
47def func_isint(v):
48 return type(v) is int
49def func_isfloat(v):
50 return type(v) is float
51def func_isnone(v):
52 return type(v) is type(None)
53def func_isblob(v):
54 return type(v) is buffer
Petri Lehtinen4ab701b2012-02-21 13:58:40 +020055def func_islonglong(v):
56 return isinstance(v, (int, long)) and v >= 1<<31
Anthony Baxterc51ee692006-04-01 00:57:31 +000057
58class AggrNoStep:
59 def __init__(self):
60 pass
61
Gerhard Häring1541ef02006-06-13 22:24:47 +000062 def finalize(self):
63 return 1
64
Anthony Baxterc51ee692006-04-01 00:57:31 +000065class AggrNoFinalize:
66 def __init__(self):
67 pass
68
69 def step(self, x):
70 pass
71
72class AggrExceptionInInit:
73 def __init__(self):
Ezio Melottidde5b942010-02-03 05:37:26 +000074 5 // 0
Anthony Baxterc51ee692006-04-01 00:57:31 +000075
76 def step(self, x):
77 pass
78
79 def finalize(self):
80 pass
81
82class AggrExceptionInStep:
83 def __init__(self):
84 pass
85
86 def step(self, x):
Ezio Melottidde5b942010-02-03 05:37:26 +000087 5 // 0
Anthony Baxterc51ee692006-04-01 00:57:31 +000088
89 def finalize(self):
90 return 42
91
92class AggrExceptionInFinalize:
93 def __init__(self):
94 pass
95
96 def step(self, x):
97 pass
98
99 def finalize(self):
Ezio Melottidde5b942010-02-03 05:37:26 +0000100 5 // 0
Anthony Baxterc51ee692006-04-01 00:57:31 +0000101
102class AggrCheckType:
103 def __init__(self):
104 self.val = None
105
106 def step(self, whichType, val):
107 theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
108 self.val = int(theType[whichType] is type(val))
109
110 def finalize(self):
111 return self.val
112
113class AggrSum:
114 def __init__(self):
115 self.val = 0.0
116
117 def step(self, val):
118 self.val += val
119
120 def finalize(self):
121 return self.val
122
123class FunctionTests(unittest.TestCase):
124 def setUp(self):
125 self.con = sqlite.connect(":memory:")
126
127 self.con.create_function("returntext", 0, func_returntext)
128 self.con.create_function("returnunicode", 0, func_returnunicode)
129 self.con.create_function("returnint", 0, func_returnint)
130 self.con.create_function("returnfloat", 0, func_returnfloat)
131 self.con.create_function("returnnull", 0, func_returnnull)
132 self.con.create_function("returnblob", 0, func_returnblob)
Petri Lehtinen4ab701b2012-02-21 13:58:40 +0200133 self.con.create_function("returnlonglong", 0, func_returnlonglong)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000134 self.con.create_function("raiseexception", 0, func_raiseexception)
135
136 self.con.create_function("isstring", 1, func_isstring)
137 self.con.create_function("isint", 1, func_isint)
138 self.con.create_function("isfloat", 1, func_isfloat)
139 self.con.create_function("isnone", 1, func_isnone)
140 self.con.create_function("isblob", 1, func_isblob)
Petri Lehtinen4ab701b2012-02-21 13:58:40 +0200141 self.con.create_function("islonglong", 1, func_islonglong)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000142
143 def tearDown(self):
144 self.con.close()
145
Gerhard Häring3e99c0a2006-04-23 15:24:26 +0000146 def CheckFuncErrorOnCreate(self):
147 try:
148 self.con.create_function("bla", -100, lambda x: 2*x)
149 self.fail("should have raised an OperationalError")
150 except sqlite.OperationalError:
151 pass
152
Anthony Baxterc51ee692006-04-01 00:57:31 +0000153 def CheckFuncRefCount(self):
154 def getfunc():
155 def f():
Gerhard Häring1541ef02006-06-13 22:24:47 +0000156 return 1
Anthony Baxterc51ee692006-04-01 00:57:31 +0000157 return f
Gerhard Häring1541ef02006-06-13 22:24:47 +0000158 f = getfunc()
159 globals()["foo"] = f
160 # self.con.create_function("reftest", 0, getfunc())
161 self.con.create_function("reftest", 0, f)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000162 cur = self.con.cursor()
163 cur.execute("select reftest()")
164
165 def CheckFuncReturnText(self):
166 cur = self.con.cursor()
167 cur.execute("select returntext()")
168 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000169 self.assertEqual(type(val), unicode)
170 self.assertEqual(val, "foo")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000171
172 def CheckFuncReturnUnicode(self):
173 cur = self.con.cursor()
174 cur.execute("select returnunicode()")
175 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000176 self.assertEqual(type(val), unicode)
177 self.assertEqual(val, u"bar")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000178
179 def CheckFuncReturnInt(self):
180 cur = self.con.cursor()
181 cur.execute("select returnint()")
182 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000183 self.assertEqual(type(val), int)
184 self.assertEqual(val, 42)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000185
186 def CheckFuncReturnFloat(self):
187 cur = self.con.cursor()
188 cur.execute("select returnfloat()")
189 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000190 self.assertEqual(type(val), float)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000191 if val < 3.139 or val > 3.141:
192 self.fail("wrong value")
193
194 def CheckFuncReturnNull(self):
195 cur = self.con.cursor()
196 cur.execute("select returnnull()")
197 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000198 self.assertEqual(type(val), type(None))
199 self.assertEqual(val, None)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000200
201 def CheckFuncReturnBlob(self):
202 cur = self.con.cursor()
203 cur.execute("select returnblob()")
204 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000205 self.assertEqual(type(val), buffer)
206 self.assertEqual(val, buffer("blob"))
Anthony Baxterc51ee692006-04-01 00:57:31 +0000207
Petri Lehtinen4ab701b2012-02-21 13:58:40 +0200208 def CheckFuncReturnLongLong(self):
209 cur = self.con.cursor()
210 cur.execute("select returnlonglong()")
211 val = cur.fetchone()[0]
212 self.assertEqual(val, 1<<31)
213
Anthony Baxterc51ee692006-04-01 00:57:31 +0000214 def CheckFuncException(self):
215 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000216 try:
217 cur.execute("select raiseexception()")
218 cur.fetchone()
219 self.fail("should have raised OperationalError")
220 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000221 self.assertEqual(e.args[0], 'user-defined function raised exception')
Anthony Baxterc51ee692006-04-01 00:57:31 +0000222
223 def CheckParamString(self):
224 cur = self.con.cursor()
225 cur.execute("select isstring(?)", ("foo",))
226 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000227 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000228
229 def CheckParamInt(self):
230 cur = self.con.cursor()
231 cur.execute("select isint(?)", (42,))
232 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000233 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000234
235 def CheckParamFloat(self):
236 cur = self.con.cursor()
237 cur.execute("select isfloat(?)", (3.14,))
238 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000239 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000240
241 def CheckParamNone(self):
242 cur = self.con.cursor()
243 cur.execute("select isnone(?)", (None,))
244 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000245 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000246
247 def CheckParamBlob(self):
248 cur = self.con.cursor()
249 cur.execute("select isblob(?)", (buffer("blob"),))
250 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000251 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000252
Petri Lehtinen4ab701b2012-02-21 13:58:40 +0200253 def CheckParamLongLong(self):
254 cur = self.con.cursor()
255 cur.execute("select islonglong(?)", (1<<42,))
256 val = cur.fetchone()[0]
257 self.assertEqual(val, 1)
258
Anthony Baxterc51ee692006-04-01 00:57:31 +0000259class AggregateTests(unittest.TestCase):
260 def setUp(self):
261 self.con = sqlite.connect(":memory:")
262 cur = self.con.cursor()
263 cur.execute("""
264 create table test(
265 t text,
266 i integer,
267 f float,
268 n,
269 b blob
270 )
271 """)
272 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
273 ("foo", 5, 3.14, None, buffer("blob"),))
274
275 self.con.create_aggregate("nostep", 1, AggrNoStep)
276 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
277 self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
278 self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
279 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
280 self.con.create_aggregate("checkType", 2, AggrCheckType)
281 self.con.create_aggregate("mysum", 1, AggrSum)
282
283 def tearDown(self):
284 #self.cur.close()
285 #self.con.close()
286 pass
287
Gerhard Häring3e99c0a2006-04-23 15:24:26 +0000288 def CheckAggrErrorOnCreate(self):
289 try:
290 self.con.create_function("bla", -100, AggrSum)
291 self.fail("should have raised an OperationalError")
292 except sqlite.OperationalError:
293 pass
294
Anthony Baxterc51ee692006-04-01 00:57:31 +0000295 def CheckAggrNoStep(self):
296 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000297 try:
298 cur.execute("select nostep(t) from test")
299 self.fail("should have raised an AttributeError")
300 except AttributeError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000301 self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000302
303 def CheckAggrNoFinalize(self):
304 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000305 try:
306 cur.execute("select nofinalize(t) from test")
307 val = cur.fetchone()[0]
308 self.fail("should have raised an OperationalError")
309 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000310 self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000311
312 def CheckAggrExceptionInInit(self):
313 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000314 try:
315 cur.execute("select excInit(t) from test")
316 val = cur.fetchone()[0]
317 self.fail("should have raised an OperationalError")
318 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000319 self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000320
321 def CheckAggrExceptionInStep(self):
322 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000323 try:
324 cur.execute("select excStep(t) from test")
325 val = cur.fetchone()[0]
326 self.fail("should have raised an OperationalError")
327 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000328 self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000329
330 def CheckAggrExceptionInFinalize(self):
331 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000332 try:
333 cur.execute("select excFinalize(t) from test")
334 val = cur.fetchone()[0]
335 self.fail("should have raised an OperationalError")
336 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000337 self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000338
339 def CheckAggrCheckParamStr(self):
340 cur = self.con.cursor()
341 cur.execute("select checkType('str', ?)", ("foo",))
342 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000343 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000344
345 def CheckAggrCheckParamInt(self):
346 cur = self.con.cursor()
347 cur.execute("select checkType('int', ?)", (42,))
348 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000349 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000350
351 def CheckAggrCheckParamFloat(self):
352 cur = self.con.cursor()
353 cur.execute("select checkType('float', ?)", (3.14,))
354 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000355 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000356
357 def CheckAggrCheckParamNone(self):
358 cur = self.con.cursor()
359 cur.execute("select checkType('None', ?)", (None,))
360 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000361 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000362
363 def CheckAggrCheckParamBlob(self):
364 cur = self.con.cursor()
365 cur.execute("select checkType('blob', ?)", (buffer("blob"),))
366 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000367 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000368
369 def CheckAggrCheckAggrSum(self):
370 cur = self.con.cursor()
371 cur.execute("delete from test")
372 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
373 cur.execute("select mysum(i) from test")
374 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000375 self.assertEqual(val, 60)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000376
Gerhard Häring1541ef02006-06-13 22:24:47 +0000377class AuthorizerTests(unittest.TestCase):
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200378 @staticmethod
379 def authorizer_cb(action, arg1, arg2, dbname, source):
380 if action != sqlite.SQLITE_SELECT:
381 return sqlite.SQLITE_DENY
382 if arg2 == 'c2' or arg1 == 't2':
383 return sqlite.SQLITE_DENY
384 return sqlite.SQLITE_OK
385
Gerhard Häring1541ef02006-06-13 22:24:47 +0000386 def setUp(self):
Gerhard Häring1541ef02006-06-13 22:24:47 +0000387 self.con = sqlite.connect(":memory:")
388 self.con.executescript("""
389 create table t1 (c1, c2);
390 create table t2 (c1, c2);
391 insert into t1 (c1, c2) values (1, 2);
392 insert into t2 (c1, c2) values (4, 5);
393 """)
394
395 # For our security test:
396 self.con.execute("select c2 from t2")
397
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200398 self.con.set_authorizer(self.authorizer_cb)
Gerhard Häring1541ef02006-06-13 22:24:47 +0000399
400 def tearDown(self):
401 pass
402
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200403 def test_table_access(self):
Gerhard Häring1541ef02006-06-13 22:24:47 +0000404 try:
405 self.con.execute("select * from t2")
406 except sqlite.DatabaseError, e:
407 if not e.args[0].endswith("prohibited"):
408 self.fail("wrong exception text: %s" % e.args[0])
409 return
410 self.fail("should have raised an exception due to missing privileges")
411
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200412 def test_column_access(self):
Gerhard Häring1541ef02006-06-13 22:24:47 +0000413 try:
414 self.con.execute("select c2 from t1")
415 except sqlite.DatabaseError, e:
416 if not e.args[0].endswith("prohibited"):
417 self.fail("wrong exception text: %s" % e.args[0])
418 return
419 self.fail("should have raised an exception due to missing privileges")
420
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200421class AuthorizerRaiseExceptionTests(AuthorizerTests):
422 @staticmethod
423 def authorizer_cb(action, arg1, arg2, dbname, source):
424 if action != sqlite.SQLITE_SELECT:
425 raise ValueError
426 if arg2 == 'c2' or arg1 == 't2':
427 raise ValueError
428 return sqlite.SQLITE_OK
429
430class AuthorizerIllegalTypeTests(AuthorizerTests):
431 @staticmethod
432 def authorizer_cb(action, arg1, arg2, dbname, source):
433 if action != sqlite.SQLITE_SELECT:
434 return 0.0
435 if arg2 == 'c2' or arg1 == 't2':
436 return 0.0
437 return sqlite.SQLITE_OK
438
439class AuthorizerLargeIntegerTests(AuthorizerTests):
440 @staticmethod
441 def authorizer_cb(action, arg1, arg2, dbname, source):
442 if action != sqlite.SQLITE_SELECT:
443 return 2**32
444 if arg2 == 'c2' or arg1 == 't2':
445 return 2**32
446 return sqlite.SQLITE_OK
447
448
Anthony Baxterc51ee692006-04-01 00:57:31 +0000449def suite():
450 function_suite = unittest.makeSuite(FunctionTests, "Check")
451 aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200452 authorizer_suite = unittest.makeSuite(AuthorizerTests)
453 return unittest.TestSuite((
454 function_suite,
455 aggregate_suite,
456 authorizer_suite,
457 unittest.makeSuite(AuthorizerRaiseExceptionTests),
458 unittest.makeSuite(AuthorizerIllegalTypeTests),
459 unittest.makeSuite(AuthorizerLargeIntegerTests),
460 ))
Anthony Baxterc51ee692006-04-01 00:57:31 +0000461
462def test():
463 runner = unittest.TextTestRunner()
464 runner.run(suite())
465
466if __name__ == "__main__":
467 test()