blob: 1d1915187965dea7d8ecb65630f96d6bed575798 [file] [log] [blame]
Anthony Baxterc51ee692006-04-01 00:57:31 +00001#-*- coding: ISO-8859-1 -*-
2# pysqlite2/test/userfunctions.py: tests for user-defined functions and
3# aggregates.
4#
Gerhard Häring2a11c052008-03-28 20:08:36 +00005# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
Anthony Baxterc51ee692006-04-01 00:57:31 +00006#
7# This file is part of pysqlite.
8#
9# This software is provided 'as-is', without any express or implied
10# warranty. In no event will the authors be held liable for any damages
11# arising from the use of this software.
12#
13# Permission is granted to anyone to use this software for any purpose,
14# including commercial applications, and to alter it and redistribute it
15# freely, subject to the following restrictions:
16#
17# 1. The origin of this software must not be misrepresented; you must not
18# claim that you wrote the original software. If you use this software
19# in a product, an acknowledgment in the product documentation would be
20# appreciated but is not required.
21# 2. Altered source versions must be plainly marked as such, and must not be
22# misrepresented as being the original software.
23# 3. This notice may not be removed or altered from any source distribution.
24
25import unittest
26import sqlite3 as sqlite
Victor Stinnera3acea32014-09-05 21:05:05 +020027from test import test_support
Anthony Baxterc51ee692006-04-01 00:57:31 +000028
29def func_returntext():
30 return "foo"
31def func_returnunicode():
32 return u"bar"
33def func_returnint():
34 return 42
35def func_returnfloat():
36 return 3.14
37def func_returnnull():
38 return None
39def func_returnblob():
Victor Stinnera3acea32014-09-05 21:05:05 +020040 with test_support.check_py3k_warnings():
41 return buffer("blob")
Petri Lehtinen4ab701b2012-02-21 13:58:40 +020042def func_returnlonglong():
43 return 1<<31
Anthony Baxterc51ee692006-04-01 00:57:31 +000044def func_raiseexception():
Ezio Melottidde5b942010-02-03 05:37:26 +000045 5 // 0
Anthony Baxterc51ee692006-04-01 00:57:31 +000046
47def func_isstring(v):
48 return type(v) is unicode
49def func_isint(v):
50 return type(v) is int
51def func_isfloat(v):
52 return type(v) is float
53def func_isnone(v):
54 return type(v) is type(None)
55def func_isblob(v):
56 return type(v) is buffer
Petri Lehtinen4ab701b2012-02-21 13:58:40 +020057def func_islonglong(v):
58 return isinstance(v, (int, long)) and v >= 1<<31
Anthony Baxterc51ee692006-04-01 00:57:31 +000059
60class AggrNoStep:
61 def __init__(self):
62 pass
63
Gerhard Häring1541ef02006-06-13 22:24:47 +000064 def finalize(self):
65 return 1
66
Anthony Baxterc51ee692006-04-01 00:57:31 +000067class AggrNoFinalize:
68 def __init__(self):
69 pass
70
71 def step(self, x):
72 pass
73
74class AggrExceptionInInit:
75 def __init__(self):
Ezio Melottidde5b942010-02-03 05:37:26 +000076 5 // 0
Anthony Baxterc51ee692006-04-01 00:57:31 +000077
78 def step(self, x):
79 pass
80
81 def finalize(self):
82 pass
83
84class AggrExceptionInStep:
85 def __init__(self):
86 pass
87
88 def step(self, x):
Ezio Melottidde5b942010-02-03 05:37:26 +000089 5 // 0
Anthony Baxterc51ee692006-04-01 00:57:31 +000090
91 def finalize(self):
92 return 42
93
94class AggrExceptionInFinalize:
95 def __init__(self):
96 pass
97
98 def step(self, x):
99 pass
100
101 def finalize(self):
Ezio Melottidde5b942010-02-03 05:37:26 +0000102 5 // 0
Anthony Baxterc51ee692006-04-01 00:57:31 +0000103
104class AggrCheckType:
105 def __init__(self):
106 self.val = None
107
108 def step(self, whichType, val):
109 theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
110 self.val = int(theType[whichType] is type(val))
111
112 def finalize(self):
113 return self.val
114
115class AggrSum:
116 def __init__(self):
117 self.val = 0.0
118
119 def step(self, val):
120 self.val += val
121
122 def finalize(self):
123 return self.val
124
125class FunctionTests(unittest.TestCase):
126 def setUp(self):
127 self.con = sqlite.connect(":memory:")
128
129 self.con.create_function("returntext", 0, func_returntext)
130 self.con.create_function("returnunicode", 0, func_returnunicode)
131 self.con.create_function("returnint", 0, func_returnint)
132 self.con.create_function("returnfloat", 0, func_returnfloat)
133 self.con.create_function("returnnull", 0, func_returnnull)
134 self.con.create_function("returnblob", 0, func_returnblob)
Petri Lehtinen4ab701b2012-02-21 13:58:40 +0200135 self.con.create_function("returnlonglong", 0, func_returnlonglong)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000136 self.con.create_function("raiseexception", 0, func_raiseexception)
137
138 self.con.create_function("isstring", 1, func_isstring)
139 self.con.create_function("isint", 1, func_isint)
140 self.con.create_function("isfloat", 1, func_isfloat)
141 self.con.create_function("isnone", 1, func_isnone)
142 self.con.create_function("isblob", 1, func_isblob)
Petri Lehtinen4ab701b2012-02-21 13:58:40 +0200143 self.con.create_function("islonglong", 1, func_islonglong)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000144
145 def tearDown(self):
146 self.con.close()
147
Gerhard Häring3e99c0a2006-04-23 15:24:26 +0000148 def CheckFuncErrorOnCreate(self):
149 try:
150 self.con.create_function("bla", -100, lambda x: 2*x)
151 self.fail("should have raised an OperationalError")
152 except sqlite.OperationalError:
153 pass
154
Anthony Baxterc51ee692006-04-01 00:57:31 +0000155 def CheckFuncRefCount(self):
156 def getfunc():
157 def f():
Gerhard Häring1541ef02006-06-13 22:24:47 +0000158 return 1
Anthony Baxterc51ee692006-04-01 00:57:31 +0000159 return f
Gerhard Häring1541ef02006-06-13 22:24:47 +0000160 f = getfunc()
161 globals()["foo"] = f
162 # self.con.create_function("reftest", 0, getfunc())
163 self.con.create_function("reftest", 0, f)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000164 cur = self.con.cursor()
165 cur.execute("select reftest()")
166
167 def CheckFuncReturnText(self):
168 cur = self.con.cursor()
169 cur.execute("select returntext()")
170 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000171 self.assertEqual(type(val), unicode)
172 self.assertEqual(val, "foo")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000173
174 def CheckFuncReturnUnicode(self):
175 cur = self.con.cursor()
176 cur.execute("select returnunicode()")
177 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000178 self.assertEqual(type(val), unicode)
179 self.assertEqual(val, u"bar")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000180
181 def CheckFuncReturnInt(self):
182 cur = self.con.cursor()
183 cur.execute("select returnint()")
184 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000185 self.assertEqual(type(val), int)
186 self.assertEqual(val, 42)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000187
188 def CheckFuncReturnFloat(self):
189 cur = self.con.cursor()
190 cur.execute("select returnfloat()")
191 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000192 self.assertEqual(type(val), float)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000193 if val < 3.139 or val > 3.141:
194 self.fail("wrong value")
195
196 def CheckFuncReturnNull(self):
197 cur = self.con.cursor()
198 cur.execute("select returnnull()")
199 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000200 self.assertEqual(type(val), type(None))
201 self.assertEqual(val, None)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000202
203 def CheckFuncReturnBlob(self):
204 cur = self.con.cursor()
205 cur.execute("select returnblob()")
206 val = cur.fetchone()[0]
Victor Stinnera3acea32014-09-05 21:05:05 +0200207 with test_support.check_py3k_warnings():
208 self.assertEqual(type(val), buffer)
209 self.assertEqual(val, buffer("blob"))
Anthony Baxterc51ee692006-04-01 00:57:31 +0000210
Petri Lehtinen4ab701b2012-02-21 13:58:40 +0200211 def CheckFuncReturnLongLong(self):
212 cur = self.con.cursor()
213 cur.execute("select returnlonglong()")
214 val = cur.fetchone()[0]
215 self.assertEqual(val, 1<<31)
216
Anthony Baxterc51ee692006-04-01 00:57:31 +0000217 def CheckFuncException(self):
218 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000219 try:
220 cur.execute("select raiseexception()")
221 cur.fetchone()
222 self.fail("should have raised OperationalError")
223 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000224 self.assertEqual(e.args[0], 'user-defined function raised exception')
Anthony Baxterc51ee692006-04-01 00:57:31 +0000225
226 def CheckParamString(self):
227 cur = self.con.cursor()
228 cur.execute("select isstring(?)", ("foo",))
229 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000230 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000231
232 def CheckParamInt(self):
233 cur = self.con.cursor()
234 cur.execute("select isint(?)", (42,))
235 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000236 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000237
238 def CheckParamFloat(self):
239 cur = self.con.cursor()
240 cur.execute("select isfloat(?)", (3.14,))
241 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000242 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000243
244 def CheckParamNone(self):
245 cur = self.con.cursor()
246 cur.execute("select isnone(?)", (None,))
247 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000248 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000249
250 def CheckParamBlob(self):
251 cur = self.con.cursor()
Victor Stinnera3acea32014-09-05 21:05:05 +0200252 with test_support.check_py3k_warnings():
253 cur.execute("select isblob(?)", (buffer("blob"),))
Anthony Baxterc51ee692006-04-01 00:57:31 +0000254 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000255 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000256
Petri Lehtinen4ab701b2012-02-21 13:58:40 +0200257 def CheckParamLongLong(self):
258 cur = self.con.cursor()
259 cur.execute("select islonglong(?)", (1<<42,))
260 val = cur.fetchone()[0]
261 self.assertEqual(val, 1)
262
Anthony Baxterc51ee692006-04-01 00:57:31 +0000263class AggregateTests(unittest.TestCase):
264 def setUp(self):
265 self.con = sqlite.connect(":memory:")
266 cur = self.con.cursor()
267 cur.execute("""
268 create table test(
269 t text,
270 i integer,
271 f float,
272 n,
273 b blob
274 )
275 """)
Victor Stinnera3acea32014-09-05 21:05:05 +0200276 with test_support.check_py3k_warnings():
277 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
278 ("foo", 5, 3.14, None, buffer("blob"),))
Anthony Baxterc51ee692006-04-01 00:57:31 +0000279
280 self.con.create_aggregate("nostep", 1, AggrNoStep)
281 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
282 self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
283 self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
284 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
285 self.con.create_aggregate("checkType", 2, AggrCheckType)
286 self.con.create_aggregate("mysum", 1, AggrSum)
287
288 def tearDown(self):
289 #self.cur.close()
290 #self.con.close()
291 pass
292
Gerhard Häring3e99c0a2006-04-23 15:24:26 +0000293 def CheckAggrErrorOnCreate(self):
294 try:
295 self.con.create_function("bla", -100, AggrSum)
296 self.fail("should have raised an OperationalError")
297 except sqlite.OperationalError:
298 pass
299
Anthony Baxterc51ee692006-04-01 00:57:31 +0000300 def CheckAggrNoStep(self):
301 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000302 try:
303 cur.execute("select nostep(t) from test")
304 self.fail("should have raised an AttributeError")
305 except AttributeError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000306 self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000307
308 def CheckAggrNoFinalize(self):
309 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000310 try:
311 cur.execute("select nofinalize(t) from test")
312 val = cur.fetchone()[0]
313 self.fail("should have raised an OperationalError")
314 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000315 self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000316
317 def CheckAggrExceptionInInit(self):
318 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000319 try:
320 cur.execute("select excInit(t) from test")
321 val = cur.fetchone()[0]
322 self.fail("should have raised an OperationalError")
323 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000324 self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000325
326 def CheckAggrExceptionInStep(self):
327 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000328 try:
329 cur.execute("select excStep(t) from test")
330 val = cur.fetchone()[0]
331 self.fail("should have raised an OperationalError")
332 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000333 self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000334
335 def CheckAggrExceptionInFinalize(self):
336 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000337 try:
338 cur.execute("select excFinalize(t) from test")
339 val = cur.fetchone()[0]
340 self.fail("should have raised an OperationalError")
341 except sqlite.OperationalError, e:
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000342 self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000343
344 def CheckAggrCheckParamStr(self):
345 cur = self.con.cursor()
346 cur.execute("select checkType('str', ?)", ("foo",))
347 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000348 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000349
350 def CheckAggrCheckParamInt(self):
351 cur = self.con.cursor()
352 cur.execute("select checkType('int', ?)", (42,))
353 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000354 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000355
356 def CheckAggrCheckParamFloat(self):
357 cur = self.con.cursor()
358 cur.execute("select checkType('float', ?)", (3.14,))
359 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000360 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000361
362 def CheckAggrCheckParamNone(self):
363 cur = self.con.cursor()
364 cur.execute("select checkType('None', ?)", (None,))
365 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000366 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000367
368 def CheckAggrCheckParamBlob(self):
369 cur = self.con.cursor()
Victor Stinnera3acea32014-09-05 21:05:05 +0200370 with test_support.check_py3k_warnings():
371 cur.execute("select checkType('blob', ?)", (buffer("blob"),))
Anthony Baxterc51ee692006-04-01 00:57:31 +0000372 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000373 self.assertEqual(val, 1)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000374
375 def CheckAggrCheckAggrSum(self):
376 cur = self.con.cursor()
377 cur.execute("delete from test")
378 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
379 cur.execute("select mysum(i) from test")
380 val = cur.fetchone()[0]
Gregory P. Smith1844b0d2009-07-04 08:42:10 +0000381 self.assertEqual(val, 60)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000382
Gerhard Häring1541ef02006-06-13 22:24:47 +0000383class AuthorizerTests(unittest.TestCase):
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200384 @staticmethod
385 def authorizer_cb(action, arg1, arg2, dbname, source):
386 if action != sqlite.SQLITE_SELECT:
387 return sqlite.SQLITE_DENY
388 if arg2 == 'c2' or arg1 == 't2':
389 return sqlite.SQLITE_DENY
390 return sqlite.SQLITE_OK
391
Gerhard Häring1541ef02006-06-13 22:24:47 +0000392 def setUp(self):
Gerhard Häring1541ef02006-06-13 22:24:47 +0000393 self.con = sqlite.connect(":memory:")
394 self.con.executescript("""
395 create table t1 (c1, c2);
396 create table t2 (c1, c2);
397 insert into t1 (c1, c2) values (1, 2);
398 insert into t2 (c1, c2) values (4, 5);
399 """)
400
401 # For our security test:
402 self.con.execute("select c2 from t2")
403
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200404 self.con.set_authorizer(self.authorizer_cb)
Gerhard Häring1541ef02006-06-13 22:24:47 +0000405
406 def tearDown(self):
407 pass
408
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200409 def test_table_access(self):
Gerhard Häring1541ef02006-06-13 22:24:47 +0000410 try:
411 self.con.execute("select * from t2")
412 except sqlite.DatabaseError, e:
413 if not e.args[0].endswith("prohibited"):
414 self.fail("wrong exception text: %s" % e.args[0])
415 return
416 self.fail("should have raised an exception due to missing privileges")
417
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200418 def test_column_access(self):
Gerhard Häring1541ef02006-06-13 22:24:47 +0000419 try:
420 self.con.execute("select c2 from t1")
421 except sqlite.DatabaseError, e:
422 if not e.args[0].endswith("prohibited"):
423 self.fail("wrong exception text: %s" % e.args[0])
424 return
425 self.fail("should have raised an exception due to missing privileges")
426
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200427class AuthorizerRaiseExceptionTests(AuthorizerTests):
428 @staticmethod
429 def authorizer_cb(action, arg1, arg2, dbname, source):
430 if action != sqlite.SQLITE_SELECT:
431 raise ValueError
432 if arg2 == 'c2' or arg1 == 't2':
433 raise ValueError
434 return sqlite.SQLITE_OK
435
436class AuthorizerIllegalTypeTests(AuthorizerTests):
437 @staticmethod
438 def authorizer_cb(action, arg1, arg2, dbname, source):
439 if action != sqlite.SQLITE_SELECT:
440 return 0.0
441 if arg2 == 'c2' or arg1 == 't2':
442 return 0.0
443 return sqlite.SQLITE_OK
444
445class AuthorizerLargeIntegerTests(AuthorizerTests):
446 @staticmethod
447 def authorizer_cb(action, arg1, arg2, dbname, source):
448 if action != sqlite.SQLITE_SELECT:
449 return 2**32
450 if arg2 == 'c2' or arg1 == 't2':
451 return 2**32
452 return sqlite.SQLITE_OK
453
454
Anthony Baxterc51ee692006-04-01 00:57:31 +0000455def suite():
456 function_suite = unittest.makeSuite(FunctionTests, "Check")
457 aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
Serhiy Storchaka35c52b62013-02-07 16:59:34 +0200458 authorizer_suite = unittest.makeSuite(AuthorizerTests)
459 return unittest.TestSuite((
460 function_suite,
461 aggregate_suite,
462 authorizer_suite,
463 unittest.makeSuite(AuthorizerRaiseExceptionTests),
464 unittest.makeSuite(AuthorizerIllegalTypeTests),
465 unittest.makeSuite(AuthorizerLargeIntegerTests),
466 ))
Anthony Baxterc51ee692006-04-01 00:57:31 +0000467
468def test():
469 runner = unittest.TextTestRunner()
470 runner.run(suite())
471
472if __name__ == "__main__":
473 test()