| #-*- coding: ISO-8859-1 -*- | 
 | # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks | 
 | # | 
 | # Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de> | 
 | # | 
 | # This file is part of pysqlite. | 
 | # | 
 | # This software is provided 'as-is', without any express or implied | 
 | # warranty.  In no event will the authors be held liable for any damages | 
 | # arising from the use of this software. | 
 | # | 
 | # Permission is granted to anyone to use this software for any purpose, | 
 | # including commercial applications, and to alter it and redistribute it | 
 | # freely, subject to the following restrictions: | 
 | # | 
 | # 1. The origin of this software must not be misrepresented; you must not | 
 | #    claim that you wrote the original software. If you use this software | 
 | #    in a product, an acknowledgment in the product documentation would be | 
 | #    appreciated but is not required. | 
 | # 2. Altered source versions must be plainly marked as such, and must not be | 
 | #    misrepresented as being the original software. | 
 | # 3. This notice may not be removed or altered from any source distribution. | 
 |  | 
 | import unittest | 
 | import sqlite3 as sqlite | 
 |  | 
 | class CollationTests(unittest.TestCase): | 
 |     def setUp(self): | 
 |         pass | 
 |  | 
 |     def tearDown(self): | 
 |         pass | 
 |  | 
 |     def CheckCreateCollationNotCallable(self): | 
 |         con = sqlite.connect(":memory:") | 
 |         try: | 
 |             con.create_collation("X", 42) | 
 |             self.fail("should have raised a TypeError") | 
 |         except TypeError as e: | 
 |             self.assertEqual(e.args[0], "parameter must be callable") | 
 |  | 
 |     def CheckCreateCollationNotAscii(self): | 
 |         con = sqlite.connect(":memory:") | 
 |         try: | 
 |             con.create_collation("collä", lambda x, y: (x > y) - (x < y)) | 
 |             self.fail("should have raised a ProgrammingError") | 
 |         except sqlite.ProgrammingError as e: | 
 |             pass | 
 |  | 
 |     def CheckCollationIsUsed(self): | 
 |         if sqlite.version_info < (3, 2, 1):  # old SQLite versions crash on this test | 
 |             return | 
 |         def mycoll(x, y): | 
 |             # reverse order | 
 |             return -((x > y) - (x < y)) | 
 |  | 
 |         con = sqlite.connect(":memory:") | 
 |         con.create_collation("mycoll", mycoll) | 
 |         sql = """ | 
 |             select x from ( | 
 |             select 'a' as x | 
 |             union | 
 |             select 'b' as x | 
 |             union | 
 |             select 'c' as x | 
 |             ) order by x collate mycoll | 
 |             """ | 
 |         result = con.execute(sql).fetchall() | 
 |         if result[0][0] != "c" or result[1][0] != "b" or result[2][0] != "a": | 
 |             self.fail("the expected order was not returned") | 
 |  | 
 |         con.create_collation("mycoll", None) | 
 |         try: | 
 |             result = con.execute(sql).fetchall() | 
 |             self.fail("should have raised an OperationalError") | 
 |         except sqlite.OperationalError as e: | 
 |             self.assertEqual(e.args[0].lower(), "no such collation sequence: mycoll") | 
 |  | 
 |     def CheckCollationRegisterTwice(self): | 
 |         """ | 
 |         Register two different collation functions under the same name. | 
 |         Verify that the last one is actually used. | 
 |         """ | 
 |         con = sqlite.connect(":memory:") | 
 |         con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) | 
 |         con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) | 
 |         result = con.execute(""" | 
 |             select x from (select 'a' as x union select 'b' as x) order by x collate mycoll | 
 |             """).fetchall() | 
 |         if result[0][0] != 'b' or result[1][0] != 'a': | 
 |             self.fail("wrong collation function is used") | 
 |  | 
 |     def CheckDeregisterCollation(self): | 
 |         """ | 
 |         Register a collation, then deregister it. Make sure an error is raised if we try | 
 |         to use it. | 
 |         """ | 
 |         con = sqlite.connect(":memory:") | 
 |         con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) | 
 |         con.create_collation("mycoll", None) | 
 |         try: | 
 |             con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") | 
 |             self.fail("should have raised an OperationalError") | 
 |         except sqlite.OperationalError as e: | 
 |             if not e.args[0].startswith("no such collation sequence"): | 
 |                 self.fail("wrong OperationalError raised") | 
 |  | 
 | class ProgressTests(unittest.TestCase): | 
 |     def CheckProgressHandlerUsed(self): | 
 |         """ | 
 |         Test that the progress handler is invoked once it is set. | 
 |         """ | 
 |         con = sqlite.connect(":memory:") | 
 |         progress_calls = [] | 
 |         def progress(): | 
 |             progress_calls.append(None) | 
 |             return 0 | 
 |         con.set_progress_handler(progress, 1) | 
 |         con.execute(""" | 
 |             create table foo(a, b) | 
 |             """) | 
 |         self.assertTrue(progress_calls) | 
 |  | 
 |  | 
 |     def CheckOpcodeCount(self): | 
 |         """ | 
 |         Test that the opcode argument is respected. | 
 |         """ | 
 |         con = sqlite.connect(":memory:") | 
 |         progress_calls = [] | 
 |         def progress(): | 
 |             progress_calls.append(None) | 
 |             return 0 | 
 |         con.set_progress_handler(progress, 1) | 
 |         curs = con.cursor() | 
 |         curs.execute(""" | 
 |             create table foo (a, b) | 
 |             """) | 
 |         first_count = len(progress_calls) | 
 |         progress_calls = [] | 
 |         con.set_progress_handler(progress, 2) | 
 |         curs.execute(""" | 
 |             create table bar (a, b) | 
 |             """) | 
 |         second_count = len(progress_calls) | 
 |         self.assertTrue(first_count > second_count) | 
 |  | 
 |     def CheckCancelOperation(self): | 
 |         """ | 
 |         Test that returning a non-zero value stops the operation in progress. | 
 |         """ | 
 |         con = sqlite.connect(":memory:") | 
 |         progress_calls = [] | 
 |         def progress(): | 
 |             progress_calls.append(None) | 
 |             return 1 | 
 |         con.set_progress_handler(progress, 1) | 
 |         curs = con.cursor() | 
 |         self.assertRaises( | 
 |             sqlite.OperationalError, | 
 |             curs.execute, | 
 |             "create table bar (a, b)") | 
 |  | 
 |     def CheckClearHandler(self): | 
 |         """ | 
 |         Test that setting the progress handler to None clears the previously set handler. | 
 |         """ | 
 |         con = sqlite.connect(":memory:") | 
 |         action = 0 | 
 |         def progress(): | 
 |             action = 1 | 
 |             return 0 | 
 |         con.set_progress_handler(progress, 1) | 
 |         con.set_progress_handler(None, 1) | 
 |         con.execute("select 1 union select 2 union select 3").fetchall() | 
 |         self.assertEqual(action, 0, "progress handler was not cleared") | 
 |  | 
 | class TraceCallbackTests(unittest.TestCase): | 
 |     def CheckTraceCallbackUsed(self): | 
 |         """ | 
 |         Test that the trace callback is invoked once it is set. | 
 |         """ | 
 |         con = sqlite.connect(":memory:") | 
 |         traced_statements = [] | 
 |         def trace(statement): | 
 |             traced_statements.append(statement) | 
 |         con.set_trace_callback(trace) | 
 |         con.execute("create table foo(a, b)") | 
 |         self.assertTrue(traced_statements) | 
 |         self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) | 
 |  | 
 |     def CheckClearTraceCallback(self): | 
 |         """ | 
 |         Test that setting the trace callback to None clears the previously set callback. | 
 |         """ | 
 |         con = sqlite.connect(":memory:") | 
 |         traced_statements = [] | 
 |         def trace(statement): | 
 |             traced_statements.append(statement) | 
 |         con.set_trace_callback(trace) | 
 |         con.set_trace_callback(None) | 
 |         con.execute("create table foo(a, b)") | 
 |         self.assertFalse(traced_statements, "trace callback was not cleared") | 
 |  | 
 |     def CheckUnicodeContent(self): | 
 |         """ | 
 |         Test that the statement can contain unicode literals. | 
 |         """ | 
 |         unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' | 
 |         con = sqlite.connect(":memory:") | 
 |         traced_statements = [] | 
 |         def trace(statement): | 
 |             traced_statements.append(statement) | 
 |         con.set_trace_callback(trace) | 
 |         con.execute("create table foo(x)") | 
 |         # Can't execute bound parameters as their values don't appear | 
 |         # in traced statements before SQLite 3.6.21 | 
 |         # (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html) | 
 |         con.execute('insert into foo(x) values ("%s")' % unicode_value) | 
 |         con.commit() | 
 |         self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), | 
 |                         "Unicode data %s garbled in trace callback: %s" | 
 |                         % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) | 
 |  | 
 |  | 
 |  | 
 | def suite(): | 
 |     collation_suite = unittest.makeSuite(CollationTests, "Check") | 
 |     progress_suite = unittest.makeSuite(ProgressTests, "Check") | 
 |     trace_suite = unittest.makeSuite(TraceCallbackTests, "Check") | 
 |     return unittest.TestSuite((collation_suite, progress_suite, trace_suite)) | 
 |  | 
 | def test(): | 
 |     runner = unittest.TextTestRunner() | 
 |     runner.run(suite()) | 
 |  | 
 | if __name__ == "__main__": | 
 |     test() |