blob: 994057e945ea349a6ac5e52453541a3ec3948ab5 [file] [log] [blame]
Thomas Wouters49fd7fa2006-04-21 10:40:58 +00001#-*- coding: ISO-8859-1 -*-
2# pysqlite2/test/userfunctions.py: tests for user-defined functions and
3# aggregates.
4#
5# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de>
6#
7# This file is part of pysqlite.
8#
9# This software is provided 'as-is', without any express or implied
10# warranty. In no event will the authors be held liable for any damages
11# arising from the use of this software.
12#
13# Permission is granted to anyone to use this software for any purpose,
14# including commercial applications, and to alter it and redistribute it
15# freely, subject to the following restrictions:
16#
17# 1. The origin of this software must not be misrepresented; you must not
18# claim that you wrote the original software. If you use this software
19# in a product, an acknowledgment in the product documentation would be
20# appreciated but is not required.
21# 2. Altered source versions must be plainly marked as such, and must not be
22# misrepresented as being the original software.
23# 3. This notice may not be removed or altered from any source distribution.
24
25import unittest
26import sqlite3 as sqlite
27
28def func_returntext():
29 return "foo"
30def func_returnunicode():
Guido van Rossumef87d6e2007-05-02 19:09:54 +000031 return "bar"
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000032def func_returnint():
33 return 42
34def func_returnfloat():
35 return 3.14
36def func_returnnull():
37 return None
38def func_returnblob():
Guido van Rossumbae07c92007-10-08 02:46:15 +000039 return b"blob"
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000040def func_raiseexception():
41 5/0
42
43def func_isstring(v):
Guido van Rossumef87d6e2007-05-02 19:09:54 +000044 return type(v) is str
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000045def func_isint(v):
46 return type(v) is int
47def func_isfloat(v):
48 return type(v) is float
49def func_isnone(v):
50 return type(v) is type(None)
51def func_isblob(v):
Guido van Rossumbae07c92007-10-08 02:46:15 +000052 return isinstance(v, (bytes, memoryview))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000053
54class AggrNoStep:
55 def __init__(self):
56 pass
57
Thomas Wouters0e3f5912006-08-11 14:57:12 +000058 def finalize(self):
59 return 1
60
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000061class AggrNoFinalize:
62 def __init__(self):
63 pass
64
65 def step(self, x):
66 pass
67
68class AggrExceptionInInit:
69 def __init__(self):
70 5/0
71
72 def step(self, x):
73 pass
74
75 def finalize(self):
76 pass
77
78class AggrExceptionInStep:
79 def __init__(self):
80 pass
81
82 def step(self, x):
83 5/0
84
85 def finalize(self):
86 return 42
87
88class AggrExceptionInFinalize:
89 def __init__(self):
90 pass
91
92 def step(self, x):
93 pass
94
95 def finalize(self):
96 5/0
97
98class AggrCheckType:
99 def __init__(self):
100 self.val = None
101
102 def step(self, whichType, val):
Guido van Rossumbae07c92007-10-08 02:46:15 +0000103 theType = {"str": str, "int": int, "float": float, "None": type(None),
104 "blob": bytes}
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000105 self.val = int(theType[whichType] is type(val))
106
107 def finalize(self):
108 return self.val
109
110class AggrSum:
111 def __init__(self):
112 self.val = 0.0
113
114 def step(self, val):
115 self.val += val
116
117 def finalize(self):
118 return self.val
119
120class FunctionTests(unittest.TestCase):
121 def setUp(self):
122 self.con = sqlite.connect(":memory:")
123
124 self.con.create_function("returntext", 0, func_returntext)
125 self.con.create_function("returnunicode", 0, func_returnunicode)
126 self.con.create_function("returnint", 0, func_returnint)
127 self.con.create_function("returnfloat", 0, func_returnfloat)
128 self.con.create_function("returnnull", 0, func_returnnull)
129 self.con.create_function("returnblob", 0, func_returnblob)
130 self.con.create_function("raiseexception", 0, func_raiseexception)
131
132 self.con.create_function("isstring", 1, func_isstring)
133 self.con.create_function("isint", 1, func_isint)
134 self.con.create_function("isfloat", 1, func_isfloat)
135 self.con.create_function("isnone", 1, func_isnone)
136 self.con.create_function("isblob", 1, func_isblob)
137
138 def tearDown(self):
139 self.con.close()
140
Thomas Wouters477c8d52006-05-27 19:21:47 +0000141 def CheckFuncErrorOnCreate(self):
142 try:
143 self.con.create_function("bla", -100, lambda x: 2*x)
144 self.fail("should have raised an OperationalError")
145 except sqlite.OperationalError:
146 pass
147
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000148 def CheckFuncRefCount(self):
149 def getfunc():
150 def f():
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000151 return 1
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000152 return f
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000153 f = getfunc()
154 globals()["foo"] = f
155 # self.con.create_function("reftest", 0, getfunc())
156 self.con.create_function("reftest", 0, f)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000157 cur = self.con.cursor()
158 cur.execute("select reftest()")
159
160 def CheckFuncReturnText(self):
161 cur = self.con.cursor()
162 cur.execute("select returntext()")
163 val = cur.fetchone()[0]
Guido van Rossumef87d6e2007-05-02 19:09:54 +0000164 self.failUnlessEqual(type(val), str)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000165 self.failUnlessEqual(val, "foo")
166
167 def CheckFuncReturnUnicode(self):
168 cur = self.con.cursor()
169 cur.execute("select returnunicode()")
170 val = cur.fetchone()[0]
Guido van Rossumef87d6e2007-05-02 19:09:54 +0000171 self.failUnlessEqual(type(val), str)
172 self.failUnlessEqual(val, "bar")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000173
174 def CheckFuncReturnInt(self):
175 cur = self.con.cursor()
176 cur.execute("select returnint()")
177 val = cur.fetchone()[0]
178 self.failUnlessEqual(type(val), int)
179 self.failUnlessEqual(val, 42)
180
181 def CheckFuncReturnFloat(self):
182 cur = self.con.cursor()
183 cur.execute("select returnfloat()")
184 val = cur.fetchone()[0]
185 self.failUnlessEqual(type(val), float)
186 if val < 3.139 or val > 3.141:
187 self.fail("wrong value")
188
189 def CheckFuncReturnNull(self):
190 cur = self.con.cursor()
191 cur.execute("select returnnull()")
192 val = cur.fetchone()[0]
193 self.failUnlessEqual(type(val), type(None))
194 self.failUnlessEqual(val, None)
195
196 def CheckFuncReturnBlob(self):
197 cur = self.con.cursor()
198 cur.execute("select returnblob()")
199 val = cur.fetchone()[0]
Guido van Rossumbae07c92007-10-08 02:46:15 +0000200 self.failUnlessEqual(type(val), bytes)
201 self.failUnlessEqual(val, memoryview(b"blob"))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000202
203 def CheckFuncException(self):
204 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000205 try:
206 cur.execute("select raiseexception()")
207 cur.fetchone()
208 self.fail("should have raised OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000209 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000210 self.failUnlessEqual(e.args[0], 'user-defined function raised exception')
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000211
212 def CheckParamString(self):
213 cur = self.con.cursor()
214 cur.execute("select isstring(?)", ("foo",))
215 val = cur.fetchone()[0]
216 self.failUnlessEqual(val, 1)
217
218 def CheckParamInt(self):
219 cur = self.con.cursor()
220 cur.execute("select isint(?)", (42,))
221 val = cur.fetchone()[0]
222 self.failUnlessEqual(val, 1)
223
224 def CheckParamFloat(self):
225 cur = self.con.cursor()
226 cur.execute("select isfloat(?)", (3.14,))
227 val = cur.fetchone()[0]
228 self.failUnlessEqual(val, 1)
229
230 def CheckParamNone(self):
231 cur = self.con.cursor()
232 cur.execute("select isnone(?)", (None,))
233 val = cur.fetchone()[0]
234 self.failUnlessEqual(val, 1)
235
236 def CheckParamBlob(self):
237 cur = self.con.cursor()
Guido van Rossumbae07c92007-10-08 02:46:15 +0000238 cur.execute("select isblob(?)", (memoryview(b"blob"),))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000239 val = cur.fetchone()[0]
240 self.failUnlessEqual(val, 1)
241
242class AggregateTests(unittest.TestCase):
243 def setUp(self):
244 self.con = sqlite.connect(":memory:")
245 cur = self.con.cursor()
246 cur.execute("""
247 create table test(
248 t text,
249 i integer,
250 f float,
251 n,
252 b blob
253 )
254 """)
255 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
Guido van Rossumbae07c92007-10-08 02:46:15 +0000256 ("foo", 5, 3.14, None, memoryview(b"blob"),))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000257
258 self.con.create_aggregate("nostep", 1, AggrNoStep)
259 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
260 self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
261 self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
262 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
263 self.con.create_aggregate("checkType", 2, AggrCheckType)
264 self.con.create_aggregate("mysum", 1, AggrSum)
265
266 def tearDown(self):
267 #self.cur.close()
268 #self.con.close()
269 pass
270
Thomas Wouters477c8d52006-05-27 19:21:47 +0000271 def CheckAggrErrorOnCreate(self):
272 try:
273 self.con.create_function("bla", -100, AggrSum)
274 self.fail("should have raised an OperationalError")
275 except sqlite.OperationalError:
276 pass
277
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000278 def CheckAggrNoStep(self):
279 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000280 try:
281 cur.execute("select nostep(t) from test")
282 self.fail("should have raised an AttributeError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000283 except AttributeError as e:
Guido van Rossum3b843cc2006-08-17 22:37:44 +0000284 self.failUnlessEqual(e.args[0], "'AggrNoStep' object has no attribute 'step'")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000285
286 def CheckAggrNoFinalize(self):
287 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000288 try:
289 cur.execute("select nofinalize(t) from test")
290 val = cur.fetchone()[0]
291 self.fail("should have raised an OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000292 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000293 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000294
295 def CheckAggrExceptionInInit(self):
296 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000297 try:
298 cur.execute("select excInit(t) from test")
299 val = cur.fetchone()[0]
300 self.fail("should have raised an OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000301 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000302 self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000303
304 def CheckAggrExceptionInStep(self):
305 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000306 try:
307 cur.execute("select excStep(t) from test")
308 val = cur.fetchone()[0]
309 self.fail("should have raised an OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000310 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000311 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000312
313 def CheckAggrExceptionInFinalize(self):
314 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000315 try:
316 cur.execute("select excFinalize(t) from test")
317 val = cur.fetchone()[0]
318 self.fail("should have raised an OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000319 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000320 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000321
322 def CheckAggrCheckParamStr(self):
323 cur = self.con.cursor()
324 cur.execute("select checkType('str', ?)", ("foo",))
325 val = cur.fetchone()[0]
326 self.failUnlessEqual(val, 1)
327
328 def CheckAggrCheckParamInt(self):
329 cur = self.con.cursor()
330 cur.execute("select checkType('int', ?)", (42,))
331 val = cur.fetchone()[0]
332 self.failUnlessEqual(val, 1)
333
334 def CheckAggrCheckParamFloat(self):
335 cur = self.con.cursor()
336 cur.execute("select checkType('float', ?)", (3.14,))
337 val = cur.fetchone()[0]
338 self.failUnlessEqual(val, 1)
339
340 def CheckAggrCheckParamNone(self):
341 cur = self.con.cursor()
342 cur.execute("select checkType('None', ?)", (None,))
343 val = cur.fetchone()[0]
344 self.failUnlessEqual(val, 1)
345
346 def CheckAggrCheckParamBlob(self):
347 cur = self.con.cursor()
Guido van Rossumbae07c92007-10-08 02:46:15 +0000348 cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000349 val = cur.fetchone()[0]
350 self.failUnlessEqual(val, 1)
351
352 def CheckAggrCheckAggrSum(self):
353 cur = self.con.cursor()
354 cur.execute("delete from test")
355 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
356 cur.execute("select mysum(i) from test")
357 val = cur.fetchone()[0]
358 self.failUnlessEqual(val, 60)
359
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000360def authorizer_cb(action, arg1, arg2, dbname, source):
361 if action != sqlite.SQLITE_SELECT:
362 return sqlite.SQLITE_DENY
363 if arg2 == 'c2' or arg1 == 't2':
364 return sqlite.SQLITE_DENY
365 return sqlite.SQLITE_OK
366
367class AuthorizerTests(unittest.TestCase):
368 def setUp(self):
369 self.con = sqlite.connect(":memory:")
370 self.con.executescript("""
371 create table t1 (c1, c2);
372 create table t2 (c1, c2);
373 insert into t1 (c1, c2) values (1, 2);
374 insert into t2 (c1, c2) values (4, 5);
375 """)
376
377 # For our security test:
378 self.con.execute("select c2 from t2")
379
380 self.con.set_authorizer(authorizer_cb)
381
382 def tearDown(self):
383 pass
384
385 def CheckTableAccess(self):
386 try:
387 self.con.execute("select * from t2")
Guido van Rossumb940e112007-01-10 16:19:56 +0000388 except sqlite.DatabaseError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000389 if not e.args[0].endswith("prohibited"):
390 self.fail("wrong exception text: %s" % e.args[0])
391 return
392 self.fail("should have raised an exception due to missing privileges")
393
394 def CheckColumnAccess(self):
395 try:
396 self.con.execute("select c2 from t1")
Guido van Rossumb940e112007-01-10 16:19:56 +0000397 except sqlite.DatabaseError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000398 if not e.args[0].endswith("prohibited"):
399 self.fail("wrong exception text: %s" % e.args[0])
400 return
401 self.fail("should have raised an exception due to missing privileges")
402
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000403def suite():
404 function_suite = unittest.makeSuite(FunctionTests, "Check")
405 aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000406 authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check")
407 return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000408
409def test():
410 runner = unittest.TextTestRunner()
411 runner.run(suite())
412
413if __name__ == "__main__":
414 test()