Add py::pickle() adaptor for safer __getstate__/__setstate__ bindings
This is analogous to `py::init()` vs `__init__` + placement-new.
`py::pickle()` reuses most of the implementation details of `py::init()`.
diff --git a/docs/advanced/classes.rst b/docs/advanced/classes.rst
index 7bcd038..be4bc2e 100644
--- a/docs/advanced/classes.rst
+++ b/docs/advanced/classes.rst
@@ -687,13 +687,15 @@
complete example that demonstrates how to work with overloaded operators in
more detail.
+.. _pickling:
+
Pickling support
================
Python's ``pickle`` module provides a powerful facility to serialize and
de-serialize a Python object graph into a binary data stream. To pickle and
-unpickle C++ classes using pybind11, two additional functions must be provided.
-Suppose the class in question has the following signature:
+unpickle C++ classes using pybind11, a ``py::pickle()`` definition must be
+provided. Suppose the class in question has the following signature:
.. code-block:: cpp
@@ -709,8 +711,9 @@
int m_extra = 0;
};
-The binding code including the requisite ``__setstate__`` and ``__getstate__`` methods [#f3]_
-looks as follows:
+Pickling support in Python is enable by defining the ``__setstate__`` and
+``__getstate__`` methods [#f3]_. For pybind11 classes, use ``py::pickle()``
+to bind these two functions:
.. code-block:: cpp
@@ -719,21 +722,28 @@
.def("value", &Pickleable::value)
.def("extra", &Pickleable::extra)
.def("setExtra", &Pickleable::setExtra)
- .def("__getstate__", [](const Pickleable &p) {
- /* Return a tuple that fully encodes the state of the object */
- return py::make_tuple(p.value(), p.extra());
- })
- .def("__setstate__", [](Pickleable &p, py::tuple t) {
- if (t.size() != 2)
- throw std::runtime_error("Invalid state!");
+ .def(py::pickle(
+ [](const Pickleable &p) { // __getstate__
+ /* Return a tuple that fully encodes the state of the object */
+ return py::make_tuple(p.value(), p.extra());
+ },
+ [](py::tuple t) { // __setstate__
+ if (t.size() != 2)
+ throw std::runtime_error("Invalid state!");
- /* Invoke the in-place constructor. Note that this is needed even
- when the object just has a trivial default constructor */
- new (&p) Pickleable(t[0].cast<std::string>());
+ /* Create a new C++ instance */
+ Pickleable p(t[0].cast<std::string>());
- /* Assign any additional state */
- p.setExtra(t[1].cast<int>());
- });
+ /* Assign any additional state */
+ p.setExtra(t[1].cast<int>());
+
+ return p;
+ }
+ ));
+
+The ``__setstate__`` part of the ``py::picke()`` definition follows the same
+rules as the single-argument version of ``py::init()``. The return type can be
+a value, pointer or holder type. See :ref:`custom_constructors` for details.
An instance can now be pickled as follows:
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 23d5563..478b7d7 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -91,6 +91,11 @@
return std::make_unique<Example>(std::to_string(n));
}));
+* Similarly to custom constructors, pickling support functions are now bound
+ using the ``py::pickle()`` adaptor which improves type safety. See the
+ :doc:`upgrade` and :ref:`pickling` for details.
+ `#1038 <https://github.com/pybind/pybind11/pull/1038>`_.
+
* Builtin support for converting C++17 standard library types and general
conversion improvements:
diff --git a/docs/upgrade.rst b/docs/upgrade.rst
index 2fe8470..bcbc6b1 100644
--- a/docs/upgrade.rst
+++ b/docs/upgrade.rst
@@ -162,6 +162,39 @@
}));
+New syntax for pickling support
+-------------------------------
+
+Mirroring the custom constructor changes, ``py::pickle()`` is now the preferred
+way to get and set object state. See :ref:`pickling` for details.
+
+.. code-block:: cpp
+
+ // old -- deprecated
+ py::class<Foo>(m, "Foo")
+ ...
+ .def("__getstate__", [](const Foo &self) {
+ return py::make_tuple(self.value1(), self.value2(), ...);
+ })
+ .def("__setstate__", [](Foo &self, py::tuple t) {
+ new (&self) Foo(t[0].cast<std::string>(), ...);
+ });
+
+ // new
+ py::class<Foo>(m, "Foo")
+ ...
+ .def(py::pickle(
+ [](const Foo &self) { // __getstate__
+ return py::make_tuple(f.value1(), f.value2(), ...); // unchanged
+ },
+ [](py::tuple t) { // __setstate__, note: no `self` argument
+ return new Foo(t[0].cast<std::string>(), ...);
+ // or: return std::make_unique<Foo>(...); // return by holder
+ // or: return Foo(...); // return by value (move constructor)
+ }
+ ));
+
+
Deprecation of some ``py::object`` APIs
---------------------------------------
diff --git a/include/pybind11/detail/init.h b/include/pybind11/detail/init.h
index ee2de8c..deace19 100644
--- a/include/pybind11/detail/init.h
+++ b/include/pybind11/detail/init.h
@@ -271,6 +271,55 @@
}
};
+/// Set just the C++ state. Same as `__init__`.
+template <typename Class, typename T>
+void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
+ construct<Class>(v_h, std::forward<T>(result), need_alias);
+}
+
+/// Set both the C++ and Python states
+template <typename Class, typename T, typename O,
+ enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
+void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
+ construct<Class>(v_h, std::move(result.first), need_alias);
+ setattr((PyObject *) v_h.inst, "__dict__", result.second);
+}
+
+/// Implementation for py::pickle(GetState, SetState)
+template <typename Get, typename Set,
+ typename = function_signature_t<Get>, typename = function_signature_t<Set>>
+struct pickle_factory;
+
+template <typename Get, typename Set,
+ typename RetState, typename Self, typename NewInstance, typename ArgState>
+struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
+ static_assert(std::is_same<RetState, ArgState>::value,
+ "The type returned by `__getstate__` must be the same "
+ "as the argument accepted by `__setstate__`");
+
+ remove_reference_t<Get> get;
+ remove_reference_t<Set> set;
+
+ pickle_factory(Get get, Set set)
+ : get(std::forward<Get>(get)), set(std::forward<Set>(set)) { }
+
+ template <typename Class, typename... Extra>
+ void execute(Class &cl, const Extra &...extra) && {
+ cl.def("__getstate__", std::move(get));
+
+#if defined(PYBIND11_CPP14)
+ cl.def("__setstate__", [func = std::move(set)]
+#else
+ auto &func = set;
+ cl.def("__setstate__", [func]
+#endif
+ (value_and_holder &v_h, ArgState state) {
+ setstate<Class>(v_h, func(std::forward<ArgState>(state)),
+ Py_TYPE(v_h.inst) != v_h.type->type);
+ }, is_new_style_constructor(), extra...);
+ }
+};
+
NAMESPACE_END(initimpl)
NAMESPACE_END(detail)
NAMESPACE_END(pybind11)
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index 16e3fdc..0e67c40 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -1095,6 +1095,12 @@
return *this;
}
+ template <typename... Args, typename... Extra>
+ class_ &def(detail::initimpl::pickle_factory<Args...> &&pf, const Extra &...extra) {
+ std::move(pf).execute(*this, extra...);
+ return *this;
+ }
+
template <typename Func> class_& def_buffer(Func &&func) {
struct capture { Func func; };
capture *ptr = new capture { std::forward<Func>(func) };
@@ -1399,6 +1405,13 @@
return {std::forward<CFunc>(c), std::forward<AFunc>(a)};
}
+/// Binds pickling functions `__getstate__` and `__setstate__` and ensures that the type
+/// returned by `__getstate__` is the same as the argument accepted by `__setstate__`.
+template <typename GetState, typename SetState>
+detail::initimpl::pickle_factory<GetState, SetState> pickle(GetState &&g, SetState &&s) {
+ return {std::forward<GetState>(g), std::forward<SetState>(s)};
+};
+
NAMESPACE_BEGIN(detail)
diff --git a/tests/test_pickling.cpp b/tests/test_pickling.cpp
index 1e5f4ce..821462a 100644
--- a/tests/test_pickling.cpp
+++ b/tests/test_pickling.cpp
@@ -25,6 +25,12 @@
int m_extra1 = 0;
int m_extra2 = 0;
};
+
+ class PickleableNew : public Pickleable {
+ public:
+ using Pickleable::Pickleable;
+ };
+
py::class_<Pickleable>(m, "Pickleable")
.def(py::init<std::string>())
.def("value", &Pickleable::value)
@@ -49,6 +55,23 @@
p.setExtra2(t[2].cast<int>());
});
+ py::class_<PickleableNew, Pickleable>(m, "PickleableNew")
+ .def(py::init<std::string>())
+ .def(py::pickle(
+ [](const PickleableNew &p) {
+ return py::make_tuple(p.value(), p.extra1(), p.extra2());
+ },
+ [](py::tuple t) {
+ if (t.size() != 3)
+ throw std::runtime_error("Invalid state!");
+ auto p = PickleableNew(t[0].cast<std::string>());
+
+ p.setExtra1(t[1].cast<int>());
+ p.setExtra2(t[2].cast<int>());
+ return p;
+ }
+ ));
+
#if !defined(PYPY_VERSION)
// test_roundtrip_with_dict
class PickleableWithDict {
@@ -58,6 +81,12 @@
std::string value;
int extra;
};
+
+ class PickleableWithDictNew : public PickleableWithDict {
+ public:
+ using PickleableWithDict::PickleableWithDict;
+ };
+
py::class_<PickleableWithDict>(m, "PickleableWithDict", py::dynamic_attr())
.def(py::init<std::string>())
.def_readwrite("value", &PickleableWithDict::value)
@@ -79,5 +108,23 @@
/* Assign Python state */
self.attr("__dict__") = t[2];
});
+
+ py::class_<PickleableWithDictNew, PickleableWithDict>(m, "PickleableWithDictNew")
+ .def(py::init<std::string>())
+ .def(py::pickle(
+ [](py::object self) {
+ return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__"));
+ },
+ [](py::tuple t) {
+ if (t.size() != 3)
+ throw std::runtime_error("Invalid state!");
+
+ auto cpp_state = PickleableWithDictNew(t[0].cast<std::string>());
+ cpp_state.extra = t[1].cast<int>();
+
+ auto py_state = t[2].cast<py::dict>();
+ return std::make_pair(cpp_state, py_state);
+ }
+ ));
#endif
}
diff --git a/tests/test_pickling.py b/tests/test_pickling.py
index 6cbcdf5..707d347 100644
--- a/tests/test_pickling.py
+++ b/tests/test_pickling.py
@@ -7,8 +7,10 @@
import pickle
-def test_roundtrip():
- p = m.Pickleable("test_value")
+@pytest.mark.parametrize("cls_name", ["Pickleable", "PickleableNew"])
+def test_roundtrip(cls_name):
+ cls = getattr(m, cls_name)
+ p = cls("test_value")
p.setExtra1(15)
p.setExtra2(48)
@@ -20,8 +22,10 @@
@pytest.unsupported_on_pypy
-def test_roundtrip_with_dict():
- p = m.PickleableWithDict("test_value")
+@pytest.mark.parametrize("cls_name", ["PickleableWithDict", "PickleableWithDictNew"])
+def test_roundtrip_with_dict(cls_name):
+ cls = getattr(m, cls_name)
+ p = cls("test_value")
p.extra = 15
p.dynamic = "Attribute"