blob: a2f2574e7c55e70343622928cd1cd032868b2a5a [file] [log] [blame]
Thomas Wouters49fd7fa2006-04-21 10:40:58 +00001#-*- coding: ISO-8859-1 -*-
2# pysqlite2/test/userfunctions.py: tests for user-defined functions and
3# aggregates.
4#
5# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de>
6#
7# This file is part of pysqlite.
8#
9# This software is provided 'as-is', without any express or implied
10# warranty. In no event will the authors be held liable for any damages
11# arising from the use of this software.
12#
13# Permission is granted to anyone to use this software for any purpose,
14# including commercial applications, and to alter it and redistribute it
15# freely, subject to the following restrictions:
16#
17# 1. The origin of this software must not be misrepresented; you must not
18# claim that you wrote the original software. If you use this software
19# in a product, an acknowledgment in the product documentation would be
20# appreciated but is not required.
21# 2. Altered source versions must be plainly marked as such, and must not be
22# misrepresented as being the original software.
23# 3. This notice may not be removed or altered from any source distribution.
24
25import unittest
26import sqlite3 as sqlite
27
28def func_returntext():
29 return "foo"
30def func_returnunicode():
Guido van Rossumef87d6e2007-05-02 19:09:54 +000031 return "bar"
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000032def func_returnint():
33 return 42
34def func_returnfloat():
35 return 3.14
36def func_returnnull():
37 return None
38def func_returnblob():
39 return buffer("blob")
40def func_raiseexception():
41 5/0
42
43def func_isstring(v):
Guido van Rossumef87d6e2007-05-02 19:09:54 +000044 return type(v) is str
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000045def func_isint(v):
46 return type(v) is int
47def func_isfloat(v):
48 return type(v) is float
49def func_isnone(v):
50 return type(v) is type(None)
51def func_isblob(v):
52 return type(v) is buffer
53
54class AggrNoStep:
55 def __init__(self):
56 pass
57
Thomas Wouters0e3f5912006-08-11 14:57:12 +000058 def finalize(self):
59 return 1
60
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000061class AggrNoFinalize:
62 def __init__(self):
63 pass
64
65 def step(self, x):
66 pass
67
68class AggrExceptionInInit:
69 def __init__(self):
70 5/0
71
72 def step(self, x):
73 pass
74
75 def finalize(self):
76 pass
77
78class AggrExceptionInStep:
79 def __init__(self):
80 pass
81
82 def step(self, x):
83 5/0
84
85 def finalize(self):
86 return 42
87
88class AggrExceptionInFinalize:
89 def __init__(self):
90 pass
91
92 def step(self, x):
93 pass
94
95 def finalize(self):
96 5/0
97
98class AggrCheckType:
99 def __init__(self):
100 self.val = None
101
102 def step(self, whichType, val):
Guido van Rossumef87d6e2007-05-02 19:09:54 +0000103 theType = {"str": str, "int": int, "float": float, "None": type(None), "blob": buffer}
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000104 self.val = int(theType[whichType] is type(val))
105
106 def finalize(self):
107 return self.val
108
109class AggrSum:
110 def __init__(self):
111 self.val = 0.0
112
113 def step(self, val):
114 self.val += val
115
116 def finalize(self):
117 return self.val
118
119class FunctionTests(unittest.TestCase):
120 def setUp(self):
121 self.con = sqlite.connect(":memory:")
122
123 self.con.create_function("returntext", 0, func_returntext)
124 self.con.create_function("returnunicode", 0, func_returnunicode)
125 self.con.create_function("returnint", 0, func_returnint)
126 self.con.create_function("returnfloat", 0, func_returnfloat)
127 self.con.create_function("returnnull", 0, func_returnnull)
128 self.con.create_function("returnblob", 0, func_returnblob)
129 self.con.create_function("raiseexception", 0, func_raiseexception)
130
131 self.con.create_function("isstring", 1, func_isstring)
132 self.con.create_function("isint", 1, func_isint)
133 self.con.create_function("isfloat", 1, func_isfloat)
134 self.con.create_function("isnone", 1, func_isnone)
135 self.con.create_function("isblob", 1, func_isblob)
136
137 def tearDown(self):
138 self.con.close()
139
Thomas Wouters477c8d52006-05-27 19:21:47 +0000140 def CheckFuncErrorOnCreate(self):
141 try:
142 self.con.create_function("bla", -100, lambda x: 2*x)
143 self.fail("should have raised an OperationalError")
144 except sqlite.OperationalError:
145 pass
146
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000147 def CheckFuncRefCount(self):
148 def getfunc():
149 def f():
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000150 return 1
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000151 return f
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000152 f = getfunc()
153 globals()["foo"] = f
154 # self.con.create_function("reftest", 0, getfunc())
155 self.con.create_function("reftest", 0, f)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000156 cur = self.con.cursor()
157 cur.execute("select reftest()")
158
159 def CheckFuncReturnText(self):
160 cur = self.con.cursor()
161 cur.execute("select returntext()")
162 val = cur.fetchone()[0]
Guido van Rossumef87d6e2007-05-02 19:09:54 +0000163 self.failUnlessEqual(type(val), str)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000164 self.failUnlessEqual(val, "foo")
165
166 def CheckFuncReturnUnicode(self):
167 cur = self.con.cursor()
168 cur.execute("select returnunicode()")
169 val = cur.fetchone()[0]
Guido van Rossumef87d6e2007-05-02 19:09:54 +0000170 self.failUnlessEqual(type(val), str)
171 self.failUnlessEqual(val, "bar")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000172
173 def CheckFuncReturnInt(self):
174 cur = self.con.cursor()
175 cur.execute("select returnint()")
176 val = cur.fetchone()[0]
177 self.failUnlessEqual(type(val), int)
178 self.failUnlessEqual(val, 42)
179
180 def CheckFuncReturnFloat(self):
181 cur = self.con.cursor()
182 cur.execute("select returnfloat()")
183 val = cur.fetchone()[0]
184 self.failUnlessEqual(type(val), float)
185 if val < 3.139 or val > 3.141:
186 self.fail("wrong value")
187
188 def CheckFuncReturnNull(self):
189 cur = self.con.cursor()
190 cur.execute("select returnnull()")
191 val = cur.fetchone()[0]
192 self.failUnlessEqual(type(val), type(None))
193 self.failUnlessEqual(val, None)
194
195 def CheckFuncReturnBlob(self):
196 cur = self.con.cursor()
197 cur.execute("select returnblob()")
198 val = cur.fetchone()[0]
199 self.failUnlessEqual(type(val), buffer)
200 self.failUnlessEqual(val, buffer("blob"))
201
202 def CheckFuncException(self):
203 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000204 try:
205 cur.execute("select raiseexception()")
206 cur.fetchone()
207 self.fail("should have raised OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000208 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000209 self.failUnlessEqual(e.args[0], 'user-defined function raised exception')
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000210
211 def CheckParamString(self):
212 cur = self.con.cursor()
213 cur.execute("select isstring(?)", ("foo",))
214 val = cur.fetchone()[0]
215 self.failUnlessEqual(val, 1)
216
217 def CheckParamInt(self):
218 cur = self.con.cursor()
219 cur.execute("select isint(?)", (42,))
220 val = cur.fetchone()[0]
221 self.failUnlessEqual(val, 1)
222
223 def CheckParamFloat(self):
224 cur = self.con.cursor()
225 cur.execute("select isfloat(?)", (3.14,))
226 val = cur.fetchone()[0]
227 self.failUnlessEqual(val, 1)
228
229 def CheckParamNone(self):
230 cur = self.con.cursor()
231 cur.execute("select isnone(?)", (None,))
232 val = cur.fetchone()[0]
233 self.failUnlessEqual(val, 1)
234
235 def CheckParamBlob(self):
236 cur = self.con.cursor()
237 cur.execute("select isblob(?)", (buffer("blob"),))
238 val = cur.fetchone()[0]
239 self.failUnlessEqual(val, 1)
240
241class AggregateTests(unittest.TestCase):
242 def setUp(self):
243 self.con = sqlite.connect(":memory:")
244 cur = self.con.cursor()
245 cur.execute("""
246 create table test(
247 t text,
248 i integer,
249 f float,
250 n,
251 b blob
252 )
253 """)
254 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
255 ("foo", 5, 3.14, None, buffer("blob"),))
256
257 self.con.create_aggregate("nostep", 1, AggrNoStep)
258 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
259 self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
260 self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
261 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
262 self.con.create_aggregate("checkType", 2, AggrCheckType)
263 self.con.create_aggregate("mysum", 1, AggrSum)
264
265 def tearDown(self):
266 #self.cur.close()
267 #self.con.close()
268 pass
269
Thomas Wouters477c8d52006-05-27 19:21:47 +0000270 def CheckAggrErrorOnCreate(self):
271 try:
272 self.con.create_function("bla", -100, AggrSum)
273 self.fail("should have raised an OperationalError")
274 except sqlite.OperationalError:
275 pass
276
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000277 def CheckAggrNoStep(self):
278 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000279 try:
280 cur.execute("select nostep(t) from test")
281 self.fail("should have raised an AttributeError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000282 except AttributeError as e:
Guido van Rossum3b843cc2006-08-17 22:37:44 +0000283 self.failUnlessEqual(e.args[0], "'AggrNoStep' object has no attribute 'step'")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000284
285 def CheckAggrNoFinalize(self):
286 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000287 try:
288 cur.execute("select nofinalize(t) from test")
289 val = cur.fetchone()[0]
290 self.fail("should have raised an OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000291 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000292 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000293
294 def CheckAggrExceptionInInit(self):
295 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000296 try:
297 cur.execute("select excInit(t) from test")
298 val = cur.fetchone()[0]
299 self.fail("should have raised an OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000300 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000301 self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000302
303 def CheckAggrExceptionInStep(self):
304 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000305 try:
306 cur.execute("select excStep(t) from test")
307 val = cur.fetchone()[0]
308 self.fail("should have raised an OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000309 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000310 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000311
312 def CheckAggrExceptionInFinalize(self):
313 cur = self.con.cursor()
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000314 try:
315 cur.execute("select excFinalize(t) from test")
316 val = cur.fetchone()[0]
317 self.fail("should have raised an OperationalError")
Guido van Rossumb940e112007-01-10 16:19:56 +0000318 except sqlite.OperationalError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000319 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000320
321 def CheckAggrCheckParamStr(self):
322 cur = self.con.cursor()
323 cur.execute("select checkType('str', ?)", ("foo",))
324 val = cur.fetchone()[0]
325 self.failUnlessEqual(val, 1)
326
327 def CheckAggrCheckParamInt(self):
328 cur = self.con.cursor()
329 cur.execute("select checkType('int', ?)", (42,))
330 val = cur.fetchone()[0]
331 self.failUnlessEqual(val, 1)
332
333 def CheckAggrCheckParamFloat(self):
334 cur = self.con.cursor()
335 cur.execute("select checkType('float', ?)", (3.14,))
336 val = cur.fetchone()[0]
337 self.failUnlessEqual(val, 1)
338
339 def CheckAggrCheckParamNone(self):
340 cur = self.con.cursor()
341 cur.execute("select checkType('None', ?)", (None,))
342 val = cur.fetchone()[0]
343 self.failUnlessEqual(val, 1)
344
345 def CheckAggrCheckParamBlob(self):
346 cur = self.con.cursor()
347 cur.execute("select checkType('blob', ?)", (buffer("blob"),))
348 val = cur.fetchone()[0]
349 self.failUnlessEqual(val, 1)
350
351 def CheckAggrCheckAggrSum(self):
352 cur = self.con.cursor()
353 cur.execute("delete from test")
354 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
355 cur.execute("select mysum(i) from test")
356 val = cur.fetchone()[0]
357 self.failUnlessEqual(val, 60)
358
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000359def authorizer_cb(action, arg1, arg2, dbname, source):
360 if action != sqlite.SQLITE_SELECT:
361 return sqlite.SQLITE_DENY
362 if arg2 == 'c2' or arg1 == 't2':
363 return sqlite.SQLITE_DENY
364 return sqlite.SQLITE_OK
365
366class AuthorizerTests(unittest.TestCase):
367 def setUp(self):
368 self.con = sqlite.connect(":memory:")
369 self.con.executescript("""
370 create table t1 (c1, c2);
371 create table t2 (c1, c2);
372 insert into t1 (c1, c2) values (1, 2);
373 insert into t2 (c1, c2) values (4, 5);
374 """)
375
376 # For our security test:
377 self.con.execute("select c2 from t2")
378
379 self.con.set_authorizer(authorizer_cb)
380
381 def tearDown(self):
382 pass
383
384 def CheckTableAccess(self):
385 try:
386 self.con.execute("select * from t2")
Guido van Rossumb940e112007-01-10 16:19:56 +0000387 except sqlite.DatabaseError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000388 if not e.args[0].endswith("prohibited"):
389 self.fail("wrong exception text: %s" % e.args[0])
390 return
391 self.fail("should have raised an exception due to missing privileges")
392
393 def CheckColumnAccess(self):
394 try:
395 self.con.execute("select c2 from t1")
Guido van Rossumb940e112007-01-10 16:19:56 +0000396 except sqlite.DatabaseError as e:
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000397 if not e.args[0].endswith("prohibited"):
398 self.fail("wrong exception text: %s" % e.args[0])
399 return
400 self.fail("should have raised an exception due to missing privileges")
401
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000402def suite():
403 function_suite = unittest.makeSuite(FunctionTests, "Check")
404 aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000405 authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check")
406 return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000407
408def test():
409 runner = unittest.TextTestRunner()
410 runner.run(suite())
411
412if __name__ == "__main__":
413 test()