blob: b88ee87d872ab4580500c7d132a372f5dbb61183 [file] [log] [blame]
Alexandre Vassalotti5f3b63a2008-10-18 20:47:58 +00001# Simple example presenting how persistent ID can be used to pickle
2# external objects by reference.
3
4import pickle
5import sqlite3
6from collections import namedtuple
7
8# Simple class representing a record in our database.
9MemoRecord = namedtuple("MemoRecord", "key, task")
10
11class DBPickler(pickle.Pickler):
12
13 def persistent_id(self, obj):
14 # Instead of pickling MemoRecord as a regular class instance, we emit a
Alexandre Vassalotti45a583b2008-10-25 17:10:07 +000015 # persistent ID.
Alexandre Vassalotti5f3b63a2008-10-18 20:47:58 +000016 if isinstance(obj, MemoRecord):
Alexandre Vassalotti45a583b2008-10-25 17:10:07 +000017 # Here, our persistent ID is simply a tuple, containing a tag and a
18 # key, which refers to a specific record in the database.
Alexandre Vassalotti5f3b63a2008-10-18 20:47:58 +000019 return ("MemoRecord", obj.key)
20 else:
21 # If obj does not have a persistent ID, return None. This means obj
22 # needs to be pickled as usual.
23 return None
24
25
26class DBUnpickler(pickle.Unpickler):
27
28 def __init__(self, file, connection):
29 super().__init__(file)
30 self.connection = connection
31
32 def persistent_load(self, pid):
33 # This method is invoked whenever a persistent ID is encountered.
34 # Here, pid is the tuple returned by DBPickler.
35 cursor = self.connection.cursor()
36 type_tag, key_id = pid
37 if type_tag == "MemoRecord":
38 # Fetch the referenced record from the database and return it.
39 cursor.execute("SELECT * FROM memos WHERE key=?", (str(key_id),))
40 key, task = cursor.fetchone()
41 return MemoRecord(key, task)
42 else:
43 # Always raises an error if you cannot return the correct object.
44 # Otherwise, the unpickler will think None is the object referenced
45 # by the persistent ID.
46 raise pickle.UnpicklingError("unsupported persistent object")
47
48
Alexandre Vassalottid0392862008-10-24 01:32:40 +000049def main():
Raymond Hettinger40fc59d2011-04-26 13:55:55 -070050 import io
51 import pprint
Alexandre Vassalotti5f3b63a2008-10-18 20:47:58 +000052
53 # Initialize and populate our database.
54 conn = sqlite3.connect(":memory:")
55 cursor = conn.cursor()
56 cursor.execute("CREATE TABLE memos(key INTEGER PRIMARY KEY, task TEXT)")
57 tasks = (
58 'give food to fish',
59 'prepare group meeting',
60 'fight with a zebra',
61 )
62 for task in tasks:
63 cursor.execute("INSERT INTO memos VALUES(NULL, ?)", (task,))
64
65 # Fetch the records to be pickled.
66 cursor.execute("SELECT * FROM memos")
67 memos = [MemoRecord(key, task) for key, task in cursor]
68 # Save the records using our custom DBPickler.
69 file = io.BytesIO()
70 DBPickler(file).dump(memos)
71
Alexandre Vassalottid0392862008-10-24 01:32:40 +000072 print("Pickled records:")
73 pprint.pprint(memos)
Alexandre Vassalotti5f3b63a2008-10-18 20:47:58 +000074
75 # Update a record, just for good measure.
76 cursor.execute("UPDATE memos SET task='learn italian' WHERE key=1")
77
Alexandre Vassalottid0392862008-10-24 01:32:40 +000078 # Load the records from the pickle data stream.
Alexandre Vassalotti5f3b63a2008-10-18 20:47:58 +000079 file.seek(0)
80 memos = DBUnpickler(file, conn).load()
81
Alexandre Vassalottid0392862008-10-24 01:32:40 +000082 print("Unpickled records:")
83 pprint.pprint(memos)
Alexandre Vassalotti5f3b63a2008-10-18 20:47:58 +000084
85
86if __name__ == '__main__':
87 main()