Alexandre Vassalotti | 5f3b63a | 2008-10-18 20:47:58 +0000 | [diff] [blame^] | 1 | # Simple example presenting how persistent ID can be used to pickle |
| 2 | # external objects by reference. |
| 3 | |
| 4 | import pickle |
| 5 | import sqlite3 |
| 6 | from collections import namedtuple |
| 7 | |
| 8 | # Simple class representing a record in our database. |
| 9 | MemoRecord = namedtuple("MemoRecord", "key, task") |
| 10 | |
| 11 | class DBPickler(pickle.Pickler): |
| 12 | |
| 13 | def persistent_id(self, obj): |
| 14 | # Instead of pickling MemoRecord as a regular class instance, we emit a |
| 15 | # persistent ID instead. |
| 16 | if isinstance(obj, MemoRecord): |
| 17 | # Here, our persistent ID is simply a tuple containing a tag and a |
| 18 | # key which refers to a specific record in the database. |
| 19 | return ("MemoRecord", obj.key) |
| 20 | else: |
| 21 | # If obj does not have a persistent ID, return None. This means obj |
| 22 | # needs to be pickled as usual. |
| 23 | return None |
| 24 | |
| 25 | |
| 26 | class DBUnpickler(pickle.Unpickler): |
| 27 | |
| 28 | def __init__(self, file, connection): |
| 29 | super().__init__(file) |
| 30 | self.connection = connection |
| 31 | |
| 32 | def persistent_load(self, pid): |
| 33 | # This method is invoked whenever a persistent ID is encountered. |
| 34 | # Here, pid is the tuple returned by DBPickler. |
| 35 | cursor = self.connection.cursor() |
| 36 | type_tag, key_id = pid |
| 37 | if type_tag == "MemoRecord": |
| 38 | # Fetch the referenced record from the database and return it. |
| 39 | cursor.execute("SELECT * FROM memos WHERE key=?", (str(key_id),)) |
| 40 | key, task = cursor.fetchone() |
| 41 | return MemoRecord(key, task) |
| 42 | else: |
| 43 | # Always raises an error if you cannot return the correct object. |
| 44 | # Otherwise, the unpickler will think None is the object referenced |
| 45 | # by the persistent ID. |
| 46 | raise pickle.UnpicklingError("unsupported persistent object") |
| 47 | |
| 48 | |
| 49 | def main(verbose=True): |
| 50 | import io, pprint |
| 51 | |
| 52 | # Initialize and populate our database. |
| 53 | conn = sqlite3.connect(":memory:") |
| 54 | cursor = conn.cursor() |
| 55 | cursor.execute("CREATE TABLE memos(key INTEGER PRIMARY KEY, task TEXT)") |
| 56 | tasks = ( |
| 57 | 'give food to fish', |
| 58 | 'prepare group meeting', |
| 59 | 'fight with a zebra', |
| 60 | ) |
| 61 | for task in tasks: |
| 62 | cursor.execute("INSERT INTO memos VALUES(NULL, ?)", (task,)) |
| 63 | |
| 64 | # Fetch the records to be pickled. |
| 65 | cursor.execute("SELECT * FROM memos") |
| 66 | memos = [MemoRecord(key, task) for key, task in cursor] |
| 67 | # Save the records using our custom DBPickler. |
| 68 | file = io.BytesIO() |
| 69 | DBPickler(file).dump(memos) |
| 70 | |
| 71 | if verbose: |
| 72 | print("Records to be pickled:") |
| 73 | pprint.pprint(memos) |
| 74 | |
| 75 | # Update a record, just for good measure. |
| 76 | cursor.execute("UPDATE memos SET task='learn italian' WHERE key=1") |
| 77 | |
| 78 | # Load the reports from the pickle data stream. |
| 79 | file.seek(0) |
| 80 | memos = DBUnpickler(file, conn).load() |
| 81 | |
| 82 | if verbose: |
| 83 | print("Unpickled records:") |
| 84 | pprint.pprint(memos) |
| 85 | |
| 86 | |
| 87 | if __name__ == '__main__': |
| 88 | main() |