blob: 78656e7e337d2ac05c6805a2e4aeeb8c600b3151 [file] [log] [blame]
Thomas Wouters49fd7fa2006-04-21 10:40:58 +00001#-*- coding: ISO-8859-1 -*-
2# pysqlite2/test/userfunctions.py: tests for user-defined functions and
3# aggregates.
4#
5# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de>
6#
7# This file is part of pysqlite.
8#
9# This software is provided 'as-is', without any express or implied
10# warranty. In no event will the authors be held liable for any damages
11# arising from the use of this software.
12#
13# Permission is granted to anyone to use this software for any purpose,
14# including commercial applications, and to alter it and redistribute it
15# freely, subject to the following restrictions:
16#
17# 1. The origin of this software must not be misrepresented; you must not
18# claim that you wrote the original software. If you use this software
19# in a product, an acknowledgment in the product documentation would be
20# appreciated but is not required.
21# 2. Altered source versions must be plainly marked as such, and must not be
22# misrepresented as being the original software.
23# 3. This notice may not be removed or altered from any source distribution.
24
25import unittest
26import sqlite3 as sqlite
27
28def func_returntext():
29 return "foo"
30def func_returnunicode():
31 return u"bar"
32def func_returnint():
33 return 42
34def func_returnfloat():
35 return 3.14
36def func_returnnull():
37 return None
38def func_returnblob():
39 return buffer("blob")
40def func_raiseexception():
41 5/0
42
43def func_isstring(v):
44 return type(v) is unicode
45def func_isint(v):
46 return type(v) is int
47def func_isfloat(v):
48 return type(v) is float
49def func_isnone(v):
50 return type(v) is type(None)
51def func_isblob(v):
52 return type(v) is buffer
53
54class AggrNoStep:
55 def __init__(self):
56 pass
57
58class AggrNoFinalize:
59 def __init__(self):
60 pass
61
62 def step(self, x):
63 pass
64
65class AggrExceptionInInit:
66 def __init__(self):
67 5/0
68
69 def step(self, x):
70 pass
71
72 def finalize(self):
73 pass
74
75class AggrExceptionInStep:
76 def __init__(self):
77 pass
78
79 def step(self, x):
80 5/0
81
82 def finalize(self):
83 return 42
84
85class AggrExceptionInFinalize:
86 def __init__(self):
87 pass
88
89 def step(self, x):
90 pass
91
92 def finalize(self):
93 5/0
94
95class AggrCheckType:
96 def __init__(self):
97 self.val = None
98
99 def step(self, whichType, val):
100 theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
101 self.val = int(theType[whichType] is type(val))
102
103 def finalize(self):
104 return self.val
105
106class AggrSum:
107 def __init__(self):
108 self.val = 0.0
109
110 def step(self, val):
111 self.val += val
112
113 def finalize(self):
114 return self.val
115
116class FunctionTests(unittest.TestCase):
117 def setUp(self):
118 self.con = sqlite.connect(":memory:")
119
120 self.con.create_function("returntext", 0, func_returntext)
121 self.con.create_function("returnunicode", 0, func_returnunicode)
122 self.con.create_function("returnint", 0, func_returnint)
123 self.con.create_function("returnfloat", 0, func_returnfloat)
124 self.con.create_function("returnnull", 0, func_returnnull)
125 self.con.create_function("returnblob", 0, func_returnblob)
126 self.con.create_function("raiseexception", 0, func_raiseexception)
127
128 self.con.create_function("isstring", 1, func_isstring)
129 self.con.create_function("isint", 1, func_isint)
130 self.con.create_function("isfloat", 1, func_isfloat)
131 self.con.create_function("isnone", 1, func_isnone)
132 self.con.create_function("isblob", 1, func_isblob)
133
134 def tearDown(self):
135 self.con.close()
136
Thomas Wouters477c8d52006-05-27 19:21:47 +0000137 def CheckFuncErrorOnCreate(self):
138 try:
139 self.con.create_function("bla", -100, lambda x: 2*x)
140 self.fail("should have raised an OperationalError")
141 except sqlite.OperationalError:
142 pass
143
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000144 def CheckFuncRefCount(self):
145 def getfunc():
146 def f():
147 return val
148 return f
149 self.con.create_function("reftest", 0, getfunc())
150 cur = self.con.cursor()
151 cur.execute("select reftest()")
152
153 def CheckFuncReturnText(self):
154 cur = self.con.cursor()
155 cur.execute("select returntext()")
156 val = cur.fetchone()[0]
157 self.failUnlessEqual(type(val), unicode)
158 self.failUnlessEqual(val, "foo")
159
160 def CheckFuncReturnUnicode(self):
161 cur = self.con.cursor()
162 cur.execute("select returnunicode()")
163 val = cur.fetchone()[0]
164 self.failUnlessEqual(type(val), unicode)
165 self.failUnlessEqual(val, u"bar")
166
167 def CheckFuncReturnInt(self):
168 cur = self.con.cursor()
169 cur.execute("select returnint()")
170 val = cur.fetchone()[0]
171 self.failUnlessEqual(type(val), int)
172 self.failUnlessEqual(val, 42)
173
174 def CheckFuncReturnFloat(self):
175 cur = self.con.cursor()
176 cur.execute("select returnfloat()")
177 val = cur.fetchone()[0]
178 self.failUnlessEqual(type(val), float)
179 if val < 3.139 or val > 3.141:
180 self.fail("wrong value")
181
182 def CheckFuncReturnNull(self):
183 cur = self.con.cursor()
184 cur.execute("select returnnull()")
185 val = cur.fetchone()[0]
186 self.failUnlessEqual(type(val), type(None))
187 self.failUnlessEqual(val, None)
188
189 def CheckFuncReturnBlob(self):
190 cur = self.con.cursor()
191 cur.execute("select returnblob()")
192 val = cur.fetchone()[0]
193 self.failUnlessEqual(type(val), buffer)
194 self.failUnlessEqual(val, buffer("blob"))
195
196 def CheckFuncException(self):
197 cur = self.con.cursor()
198 cur.execute("select raiseexception()")
199 val = cur.fetchone()[0]
200 self.failUnlessEqual(val, None)
201
202 def CheckParamString(self):
203 cur = self.con.cursor()
204 cur.execute("select isstring(?)", ("foo",))
205 val = cur.fetchone()[0]
206 self.failUnlessEqual(val, 1)
207
208 def CheckParamInt(self):
209 cur = self.con.cursor()
210 cur.execute("select isint(?)", (42,))
211 val = cur.fetchone()[0]
212 self.failUnlessEqual(val, 1)
213
214 def CheckParamFloat(self):
215 cur = self.con.cursor()
216 cur.execute("select isfloat(?)", (3.14,))
217 val = cur.fetchone()[0]
218 self.failUnlessEqual(val, 1)
219
220 def CheckParamNone(self):
221 cur = self.con.cursor()
222 cur.execute("select isnone(?)", (None,))
223 val = cur.fetchone()[0]
224 self.failUnlessEqual(val, 1)
225
226 def CheckParamBlob(self):
227 cur = self.con.cursor()
228 cur.execute("select isblob(?)", (buffer("blob"),))
229 val = cur.fetchone()[0]
230 self.failUnlessEqual(val, 1)
231
232class AggregateTests(unittest.TestCase):
233 def setUp(self):
234 self.con = sqlite.connect(":memory:")
235 cur = self.con.cursor()
236 cur.execute("""
237 create table test(
238 t text,
239 i integer,
240 f float,
241 n,
242 b blob
243 )
244 """)
245 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
246 ("foo", 5, 3.14, None, buffer("blob"),))
247
248 self.con.create_aggregate("nostep", 1, AggrNoStep)
249 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
250 self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
251 self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
252 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
253 self.con.create_aggregate("checkType", 2, AggrCheckType)
254 self.con.create_aggregate("mysum", 1, AggrSum)
255
256 def tearDown(self):
257 #self.cur.close()
258 #self.con.close()
259 pass
260
Thomas Wouters477c8d52006-05-27 19:21:47 +0000261 def CheckAggrErrorOnCreate(self):
262 try:
263 self.con.create_function("bla", -100, AggrSum)
264 self.fail("should have raised an OperationalError")
265 except sqlite.OperationalError:
266 pass
267
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000268 def CheckAggrNoStep(self):
269 cur = self.con.cursor()
270 cur.execute("select nostep(t) from test")
271
272 def CheckAggrNoFinalize(self):
273 cur = self.con.cursor()
274 cur.execute("select nofinalize(t) from test")
275 val = cur.fetchone()[0]
276 self.failUnlessEqual(val, None)
277
278 def CheckAggrExceptionInInit(self):
279 cur = self.con.cursor()
280 cur.execute("select excInit(t) from test")
281 val = cur.fetchone()[0]
282 self.failUnlessEqual(val, None)
283
284 def CheckAggrExceptionInStep(self):
285 cur = self.con.cursor()
286 cur.execute("select excStep(t) from test")
287 val = cur.fetchone()[0]
288 self.failUnlessEqual(val, 42)
289
290 def CheckAggrExceptionInFinalize(self):
291 cur = self.con.cursor()
292 cur.execute("select excFinalize(t) from test")
293 val = cur.fetchone()[0]
294 self.failUnlessEqual(val, None)
295
296 def CheckAggrCheckParamStr(self):
297 cur = self.con.cursor()
298 cur.execute("select checkType('str', ?)", ("foo",))
299 val = cur.fetchone()[0]
300 self.failUnlessEqual(val, 1)
301
302 def CheckAggrCheckParamInt(self):
303 cur = self.con.cursor()
304 cur.execute("select checkType('int', ?)", (42,))
305 val = cur.fetchone()[0]
306 self.failUnlessEqual(val, 1)
307
308 def CheckAggrCheckParamFloat(self):
309 cur = self.con.cursor()
310 cur.execute("select checkType('float', ?)", (3.14,))
311 val = cur.fetchone()[0]
312 self.failUnlessEqual(val, 1)
313
314 def CheckAggrCheckParamNone(self):
315 cur = self.con.cursor()
316 cur.execute("select checkType('None', ?)", (None,))
317 val = cur.fetchone()[0]
318 self.failUnlessEqual(val, 1)
319
320 def CheckAggrCheckParamBlob(self):
321 cur = self.con.cursor()
322 cur.execute("select checkType('blob', ?)", (buffer("blob"),))
323 val = cur.fetchone()[0]
324 self.failUnlessEqual(val, 1)
325
326 def CheckAggrCheckAggrSum(self):
327 cur = self.con.cursor()
328 cur.execute("delete from test")
329 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
330 cur.execute("select mysum(i) from test")
331 val = cur.fetchone()[0]
332 self.failUnlessEqual(val, 60)
333
334def suite():
335 function_suite = unittest.makeSuite(FunctionTests, "Check")
336 aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
337 return unittest.TestSuite((function_suite, aggregate_suite))
338
339def test():
340 runner = unittest.TextTestRunner()
341 runner.run(suite())
342
343if __name__ == "__main__":
344 test()