blob: 587d39c876ba5d6c5bd90e56c5da3c57382bb75f [file] [log] [blame]
Anthony Baxterc51ee692006-04-01 00:57:31 +00001#-*- coding: ISO-8859-1 -*-
2# pysqlite2/test/userfunctions.py: tests for user-defined functions and
3# aggregates.
4#
5# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de>
6#
7# This file is part of pysqlite.
8#
9# This software is provided 'as-is', without any express or implied
10# warranty. In no event will the authors be held liable for any damages
11# arising from the use of this software.
12#
13# Permission is granted to anyone to use this software for any purpose,
14# including commercial applications, and to alter it and redistribute it
15# freely, subject to the following restrictions:
16#
17# 1. The origin of this software must not be misrepresented; you must not
18# claim that you wrote the original software. If you use this software
19# in a product, an acknowledgment in the product documentation would be
20# appreciated but is not required.
21# 2. Altered source versions must be plainly marked as such, and must not be
22# misrepresented as being the original software.
23# 3. This notice may not be removed or altered from any source distribution.
24
25import unittest
26import sqlite3 as sqlite
27
28def func_returntext():
29 return "foo"
30def func_returnunicode():
31 return u"bar"
32def func_returnint():
33 return 42
34def func_returnfloat():
35 return 3.14
36def func_returnnull():
37 return None
38def func_returnblob():
39 return buffer("blob")
40def func_raiseexception():
41 5/0
42
43def func_isstring(v):
44 return type(v) is unicode
45def func_isint(v):
46 return type(v) is int
47def func_isfloat(v):
48 return type(v) is float
49def func_isnone(v):
50 return type(v) is type(None)
51def func_isblob(v):
52 return type(v) is buffer
53
54class AggrNoStep:
55 def __init__(self):
56 pass
57
Gerhard Häring1541ef02006-06-13 22:24:47 +000058 def finalize(self):
59 return 1
60
Anthony Baxterc51ee692006-04-01 00:57:31 +000061class AggrNoFinalize:
62 def __init__(self):
63 pass
64
65 def step(self, x):
66 pass
67
68class AggrExceptionInInit:
69 def __init__(self):
70 5/0
71
72 def step(self, x):
73 pass
74
75 def finalize(self):
76 pass
77
78class AggrExceptionInStep:
79 def __init__(self):
80 pass
81
82 def step(self, x):
83 5/0
84
85 def finalize(self):
86 return 42
87
88class AggrExceptionInFinalize:
89 def __init__(self):
90 pass
91
92 def step(self, x):
93 pass
94
95 def finalize(self):
96 5/0
97
98class AggrCheckType:
99 def __init__(self):
100 self.val = None
101
102 def step(self, whichType, val):
103 theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
104 self.val = int(theType[whichType] is type(val))
105
106 def finalize(self):
107 return self.val
108
109class AggrSum:
110 def __init__(self):
111 self.val = 0.0
112
113 def step(self, val):
114 self.val += val
115
116 def finalize(self):
117 return self.val
118
119class FunctionTests(unittest.TestCase):
120 def setUp(self):
121 self.con = sqlite.connect(":memory:")
122
123 self.con.create_function("returntext", 0, func_returntext)
124 self.con.create_function("returnunicode", 0, func_returnunicode)
125 self.con.create_function("returnint", 0, func_returnint)
126 self.con.create_function("returnfloat", 0, func_returnfloat)
127 self.con.create_function("returnnull", 0, func_returnnull)
128 self.con.create_function("returnblob", 0, func_returnblob)
129 self.con.create_function("raiseexception", 0, func_raiseexception)
130
131 self.con.create_function("isstring", 1, func_isstring)
132 self.con.create_function("isint", 1, func_isint)
133 self.con.create_function("isfloat", 1, func_isfloat)
134 self.con.create_function("isnone", 1, func_isnone)
135 self.con.create_function("isblob", 1, func_isblob)
136
137 def tearDown(self):
138 self.con.close()
139
Gerhard Häring3e99c0a2006-04-23 15:24:26 +0000140 def CheckFuncErrorOnCreate(self):
141 try:
142 self.con.create_function("bla", -100, lambda x: 2*x)
143 self.fail("should have raised an OperationalError")
144 except sqlite.OperationalError:
145 pass
146
Anthony Baxterc51ee692006-04-01 00:57:31 +0000147 def CheckFuncRefCount(self):
148 def getfunc():
149 def f():
Gerhard Häring1541ef02006-06-13 22:24:47 +0000150 return 1
Anthony Baxterc51ee692006-04-01 00:57:31 +0000151 return f
Gerhard Häring1541ef02006-06-13 22:24:47 +0000152 f = getfunc()
153 globals()["foo"] = f
154 # self.con.create_function("reftest", 0, getfunc())
155 self.con.create_function("reftest", 0, f)
Anthony Baxterc51ee692006-04-01 00:57:31 +0000156 cur = self.con.cursor()
157 cur.execute("select reftest()")
158
159 def CheckFuncReturnText(self):
160 cur = self.con.cursor()
161 cur.execute("select returntext()")
162 val = cur.fetchone()[0]
163 self.failUnlessEqual(type(val), unicode)
164 self.failUnlessEqual(val, "foo")
165
166 def CheckFuncReturnUnicode(self):
167 cur = self.con.cursor()
168 cur.execute("select returnunicode()")
169 val = cur.fetchone()[0]
170 self.failUnlessEqual(type(val), unicode)
171 self.failUnlessEqual(val, u"bar")
172
173 def CheckFuncReturnInt(self):
174 cur = self.con.cursor()
175 cur.execute("select returnint()")
176 val = cur.fetchone()[0]
177 self.failUnlessEqual(type(val), int)
178 self.failUnlessEqual(val, 42)
179
180 def CheckFuncReturnFloat(self):
181 cur = self.con.cursor()
182 cur.execute("select returnfloat()")
183 val = cur.fetchone()[0]
184 self.failUnlessEqual(type(val), float)
185 if val < 3.139 or val > 3.141:
186 self.fail("wrong value")
187
188 def CheckFuncReturnNull(self):
189 cur = self.con.cursor()
190 cur.execute("select returnnull()")
191 val = cur.fetchone()[0]
192 self.failUnlessEqual(type(val), type(None))
193 self.failUnlessEqual(val, None)
194
195 def CheckFuncReturnBlob(self):
196 cur = self.con.cursor()
197 cur.execute("select returnblob()")
198 val = cur.fetchone()[0]
199 self.failUnlessEqual(type(val), buffer)
200 self.failUnlessEqual(val, buffer("blob"))
201
202 def CheckFuncException(self):
Gerhard Häringb2e88162006-06-14 22:28:37 +0000203 if sqlite.version_info < (3, 3, 3): # don't raise bug in earlier SQLite versions
204 return
Anthony Baxterc51ee692006-04-01 00:57:31 +0000205 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000206 try:
207 cur.execute("select raiseexception()")
208 cur.fetchone()
209 self.fail("should have raised OperationalError")
210 except sqlite.OperationalError, e:
211 self.failUnlessEqual(e.args[0], 'user-defined function raised exception')
Anthony Baxterc51ee692006-04-01 00:57:31 +0000212
213 def CheckParamString(self):
214 cur = self.con.cursor()
215 cur.execute("select isstring(?)", ("foo",))
216 val = cur.fetchone()[0]
217 self.failUnlessEqual(val, 1)
218
219 def CheckParamInt(self):
220 cur = self.con.cursor()
221 cur.execute("select isint(?)", (42,))
222 val = cur.fetchone()[0]
223 self.failUnlessEqual(val, 1)
224
225 def CheckParamFloat(self):
226 cur = self.con.cursor()
227 cur.execute("select isfloat(?)", (3.14,))
228 val = cur.fetchone()[0]
229 self.failUnlessEqual(val, 1)
230
231 def CheckParamNone(self):
232 cur = self.con.cursor()
233 cur.execute("select isnone(?)", (None,))
234 val = cur.fetchone()[0]
235 self.failUnlessEqual(val, 1)
236
237 def CheckParamBlob(self):
238 cur = self.con.cursor()
239 cur.execute("select isblob(?)", (buffer("blob"),))
240 val = cur.fetchone()[0]
241 self.failUnlessEqual(val, 1)
242
243class AggregateTests(unittest.TestCase):
244 def setUp(self):
245 self.con = sqlite.connect(":memory:")
246 cur = self.con.cursor()
247 cur.execute("""
248 create table test(
249 t text,
250 i integer,
251 f float,
252 n,
253 b blob
254 )
255 """)
256 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
257 ("foo", 5, 3.14, None, buffer("blob"),))
258
259 self.con.create_aggregate("nostep", 1, AggrNoStep)
260 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
261 self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
262 self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
263 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
264 self.con.create_aggregate("checkType", 2, AggrCheckType)
265 self.con.create_aggregate("mysum", 1, AggrSum)
266
267 def tearDown(self):
268 #self.cur.close()
269 #self.con.close()
270 pass
271
Gerhard Häring3e99c0a2006-04-23 15:24:26 +0000272 def CheckAggrErrorOnCreate(self):
273 try:
274 self.con.create_function("bla", -100, AggrSum)
275 self.fail("should have raised an OperationalError")
276 except sqlite.OperationalError:
277 pass
278
Anthony Baxterc51ee692006-04-01 00:57:31 +0000279 def CheckAggrNoStep(self):
280 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000281 try:
282 cur.execute("select nostep(t) from test")
283 self.fail("should have raised an AttributeError")
284 except AttributeError, e:
285 self.failUnlessEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000286
287 def CheckAggrNoFinalize(self):
Gerhard Häringb2e88162006-06-14 22:28:37 +0000288 if sqlite.version_info < (3, 3, 3): # don't raise bug in earlier SQLite versions
289 return
Anthony Baxterc51ee692006-04-01 00:57:31 +0000290 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000291 try:
292 cur.execute("select nofinalize(t) from test")
293 val = cur.fetchone()[0]
294 self.fail("should have raised an OperationalError")
295 except sqlite.OperationalError, e:
296 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000297
298 def CheckAggrExceptionInInit(self):
Gerhard Häringb2e88162006-06-14 22:28:37 +0000299 if sqlite.version_info < (3, 3, 3): # don't raise bug in earlier SQLite versions
300 return
Anthony Baxterc51ee692006-04-01 00:57:31 +0000301 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000302 try:
303 cur.execute("select excInit(t) from test")
304 val = cur.fetchone()[0]
305 self.fail("should have raised an OperationalError")
306 except sqlite.OperationalError, e:
307 self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000308
309 def CheckAggrExceptionInStep(self):
Gerhard Häringb2e88162006-06-14 22:28:37 +0000310 if sqlite.version_info < (3, 3, 3): # don't raise bug in earlier SQLite versions
311 return
Anthony Baxterc51ee692006-04-01 00:57:31 +0000312 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000313 try:
314 cur.execute("select excStep(t) from test")
315 val = cur.fetchone()[0]
316 self.fail("should have raised an OperationalError")
317 except sqlite.OperationalError, e:
318 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000319
320 def CheckAggrExceptionInFinalize(self):
Gerhard Häringb2e88162006-06-14 22:28:37 +0000321 if sqlite.version_info < (3, 3, 3): # don't raise bug in earlier SQLite versions
322 return
Anthony Baxterc51ee692006-04-01 00:57:31 +0000323 cur = self.con.cursor()
Gerhard Häring1541ef02006-06-13 22:24:47 +0000324 try:
325 cur.execute("select excFinalize(t) from test")
326 val = cur.fetchone()[0]
327 self.fail("should have raised an OperationalError")
328 except sqlite.OperationalError, e:
329 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
Anthony Baxterc51ee692006-04-01 00:57:31 +0000330
331 def CheckAggrCheckParamStr(self):
332 cur = self.con.cursor()
333 cur.execute("select checkType('str', ?)", ("foo",))
334 val = cur.fetchone()[0]
335 self.failUnlessEqual(val, 1)
336
337 def CheckAggrCheckParamInt(self):
338 cur = self.con.cursor()
339 cur.execute("select checkType('int', ?)", (42,))
340 val = cur.fetchone()[0]
341 self.failUnlessEqual(val, 1)
342
343 def CheckAggrCheckParamFloat(self):
344 cur = self.con.cursor()
345 cur.execute("select checkType('float', ?)", (3.14,))
346 val = cur.fetchone()[0]
347 self.failUnlessEqual(val, 1)
348
349 def CheckAggrCheckParamNone(self):
350 cur = self.con.cursor()
351 cur.execute("select checkType('None', ?)", (None,))
352 val = cur.fetchone()[0]
353 self.failUnlessEqual(val, 1)
354
355 def CheckAggrCheckParamBlob(self):
356 cur = self.con.cursor()
357 cur.execute("select checkType('blob', ?)", (buffer("blob"),))
358 val = cur.fetchone()[0]
359 self.failUnlessEqual(val, 1)
360
361 def CheckAggrCheckAggrSum(self):
362 cur = self.con.cursor()
363 cur.execute("delete from test")
364 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
365 cur.execute("select mysum(i) from test")
366 val = cur.fetchone()[0]
367 self.failUnlessEqual(val, 60)
368
Gerhard Häring1541ef02006-06-13 22:24:47 +0000369def authorizer_cb(action, arg1, arg2, dbname, source):
370 if action != sqlite.SQLITE_SELECT:
371 return sqlite.SQLITE_DENY
372 if arg2 == 'c2' or arg1 == 't2':
373 return sqlite.SQLITE_DENY
374 return sqlite.SQLITE_OK
375
376class AuthorizerTests(unittest.TestCase):
377 def setUp(self):
378 sqlite.enable_callback_tracebacks(1)
379 self.con = sqlite.connect(":memory:")
380 self.con.executescript("""
381 create table t1 (c1, c2);
382 create table t2 (c1, c2);
383 insert into t1 (c1, c2) values (1, 2);
384 insert into t2 (c1, c2) values (4, 5);
385 """)
386
387 # For our security test:
388 self.con.execute("select c2 from t2")
389
390 self.con.set_authorizer(authorizer_cb)
391
392 def tearDown(self):
393 pass
394
395 def CheckTableAccess(self):
396 try:
397 self.con.execute("select * from t2")
398 except sqlite.DatabaseError, e:
399 if not e.args[0].endswith("prohibited"):
400 self.fail("wrong exception text: %s" % e.args[0])
401 return
402 self.fail("should have raised an exception due to missing privileges")
403
404 def CheckColumnAccess(self):
405 try:
406 self.con.execute("select c2 from t1")
407 except sqlite.DatabaseError, e:
408 if not e.args[0].endswith("prohibited"):
409 self.fail("wrong exception text: %s" % e.args[0])
410 return
411 self.fail("should have raised an exception due to missing privileges")
412
Anthony Baxterc51ee692006-04-01 00:57:31 +0000413def suite():
414 function_suite = unittest.makeSuite(FunctionTests, "Check")
415 aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
Tim Peters16ec4bb2006-06-14 04:15:27 +0000416 authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check")
Gerhard Häring1541ef02006-06-13 22:24:47 +0000417 return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite))
Anthony Baxterc51ee692006-04-01 00:57:31 +0000418
419def test():
420 runner = unittest.TextTestRunner()
421 runner.run(suite())
422
423if __name__ == "__main__":
424 test()