Allow references to objects held by smart pointers (#533)
diff --git a/docs/advanced/smart_ptrs.rst b/docs/advanced/smart_ptrs.rst
index 6e8c9de..23072b6 100644
--- a/docs/advanced/smart_ptrs.rst
+++ b/docs/advanced/smart_ptrs.rst
@@ -53,8 +53,6 @@
.. code-block:: cpp
- PYBIND11_DECLARE_HOLDER_TYPE(T, std::shared_ptr<T>);
-
class Child { };
class Parent {
diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h
index 8a3ccb4..6cdc482 100644
--- a/include/pybind11/cast.h
+++ b/include/pybind11/cast.h
@@ -860,20 +860,12 @@
if (typeinfo->simple_type) { /* Case 1: no multiple inheritance etc. involved */
/* Check if we can safely perform a reinterpret-style cast */
- if (PyType_IsSubtype(tobj, typeinfo->type)) {
- auto inst = (instance<type, holder_type> *) src.ptr();
- value = (void *) inst->value;
- holder = inst->holder;
- return true;
- }
+ if (PyType_IsSubtype(tobj, typeinfo->type))
+ return load_value_and_holder(src);
} else { /* Case 2: multiple inheritance */
/* Check if we can safely perform a reinterpret-style cast */
- if (tobj == typeinfo->type) {
- auto inst = (instance<type, holder_type> *) src.ptr();
- value = (void *) inst->value;
- holder = inst->holder;
- return true;
- }
+ if (tobj == typeinfo->type)
+ return load_value_and_holder(src);
/* If this is a python class, also check the parents recursively */
auto const &type_dict = get_internals().registered_types_py;
@@ -902,6 +894,22 @@
return false;
}
+ bool load_value_and_holder(handle src) {
+ auto inst = (instance<type, holder_type> *) src.ptr();
+ value = (void *) inst->value;
+ if (inst->holder_constructed) {
+ holder = inst->holder;
+ return true;
+ } else {
+ throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
+#if defined(NDEBUG)
+ "(compile in debug mode for type information)");
+#else
+ "of type '" + type_id<holder_type>() + "''");
+#endif
+ }
+ }
+
template <typename T = holder_type, detail::enable_if_t<!std::is_constructible<T, const T &, type*>::value, int> = 0>
bool try_implicit_casts(handle, bool) { return false; }
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index b737dc7..1db9efb 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -1138,21 +1138,26 @@
static void init_holder_helper(instance_type *inst, const holder_type * /* unused */, const std::enable_shared_from_this<T> * /* dummy */) {
try {
new (&inst->holder) holder_type(std::static_pointer_cast<typename holder_type::element_type>(inst->value->shared_from_this()));
+ inst->holder_constructed = true;
} catch (const std::bad_weak_ptr &) {
- new (&inst->holder) holder_type(inst->value);
+ if (inst->owned) {
+ new (&inst->holder) holder_type(inst->value);
+ inst->holder_constructed = true;
+ }
}
- inst->holder_constructed = true;
}
/// Initialize holder object, variant 2: try to construct from existing holder object, if possible
template <typename T = holder_type,
detail::enable_if_t<std::is_copy_constructible<T>::value, int> = 0>
static void init_holder_helper(instance_type *inst, const holder_type *holder_ptr, const void * /* dummy */) {
- if (holder_ptr)
+ if (holder_ptr) {
new (&inst->holder) holder_type(*holder_ptr);
- else
+ inst->holder_constructed = true;
+ } else if (inst->owned) {
new (&inst->holder) holder_type(inst->value);
- inst->holder_constructed = true;
+ inst->holder_constructed = true;
+ }
}
/// Initialize holder object, variant 3: holder is not copy constructible (e.g. unique_ptr), always initialize from raw pointer
diff --git a/tests/test_smart_ptr.cpp b/tests/test_smart_ptr.cpp
index 7d50f0d..07c3cb0 100644
--- a/tests/test_smart_ptr.cpp
+++ b/tests/test_smart_ptr.cpp
@@ -165,3 +165,60 @@
// Expose constructor stats for the ref type
m.def("cstats_ref", &ConstructorStats::get<ref_tag>);
});
+
+struct SharedPtrRef {
+ struct A {
+ A() { print_created(this); }
+ A(const A &) { print_copy_created(this); }
+ A(A &&) { print_move_created(this); }
+ ~A() { print_destroyed(this); }
+ };
+
+ A value = {};
+ std::shared_ptr<A> shared = std::make_shared<A>();
+};
+
+struct SharedFromThisRef {
+ struct B : std::enable_shared_from_this<B> {
+ B() { print_created(this); }
+ B(const B &) : std::enable_shared_from_this<B>() { print_copy_created(this); }
+ B(B &&) : std::enable_shared_from_this<B>() { print_move_created(this); }
+ ~B() { print_destroyed(this); }
+ };
+
+ B value = {};
+ std::shared_ptr<B> shared = std::make_shared<B>();
+};
+
+test_initializer smart_ptr_and_references([](py::module &pm) {
+ auto m = pm.def_submodule("smart_ptr");
+
+ using A = SharedPtrRef::A;
+ py::class_<A, std::shared_ptr<A>>(m, "A");
+
+ py::class_<SharedPtrRef>(m, "SharedPtrRef")
+ .def(py::init<>())
+ .def_readonly("ref", &SharedPtrRef::value)
+ .def_property_readonly("copy", [](const SharedPtrRef &s) { return s.value; },
+ py::return_value_policy::copy)
+ .def_readonly("holder_ref", &SharedPtrRef::shared)
+ .def_property_readonly("holder_copy", [](const SharedPtrRef &s) { return s.shared; },
+ py::return_value_policy::copy)
+ .def("set_ref", [](SharedPtrRef &, const A &) { return true; })
+ .def("set_holder", [](SharedPtrRef &, std::shared_ptr<A>) { return true; });
+
+ using B = SharedFromThisRef::B;
+ py::class_<B, std::shared_ptr<B>>(m, "B");
+
+ py::class_<SharedFromThisRef>(m, "SharedFromThisRef")
+ .def(py::init<>())
+ .def_readonly("bad_wp", &SharedFromThisRef::value)
+ .def_property_readonly("ref", [](const SharedFromThisRef &s) -> const B & { return *s.shared; })
+ .def_property_readonly("copy", [](const SharedFromThisRef &s) { return s.value; },
+ py::return_value_policy::copy)
+ .def_readonly("holder_ref", &SharedFromThisRef::shared)
+ .def_property_readonly("holder_copy", [](const SharedFromThisRef &s) { return s.shared; },
+ py::return_value_policy::copy)
+ .def("set_ref", [](SharedFromThisRef &, const B &) { return true; })
+ .def("set_holder", [](SharedFromThisRef &, std::shared_ptr<B>) { return true; });
+});
diff --git a/tests/test_smart_ptr.py b/tests/test_smart_ptr.py
index 8af2ae8..3a33e17 100644
--- a/tests/test_smart_ptr.py
+++ b/tests/test_smart_ptr.py
@@ -1,3 +1,4 @@
+import pytest
from pybind11_tests import ConstructorStats
@@ -124,3 +125,74 @@
del o
cstats = ConstructorStats.get(MyObject4)
assert cstats.alive() == 1 # Leak, but that's intentional
+
+
+def test_shared_ptr_and_references():
+ from pybind11_tests.smart_ptr import SharedPtrRef, A
+
+ s = SharedPtrRef()
+ stats = ConstructorStats.get(A)
+ assert stats.alive() == 2
+
+ ref = s.ref # init_holder_helper(holder_ptr=false, owned=false)
+ assert stats.alive() == 2
+ assert s.set_ref(ref)
+ with pytest.raises(RuntimeError) as excinfo:
+ assert s.set_holder(ref)
+ assert "Unable to cast from non-held to held instance" in str(excinfo.value)
+
+ copy = s.copy # init_holder_helper(holder_ptr=false, owned=true)
+ assert stats.alive() == 3
+ assert s.set_ref(copy)
+ assert s.set_holder(copy)
+
+ holder_ref = s.holder_ref # init_holder_helper(holder_ptr=true, owned=false)
+ assert stats.alive() == 3
+ assert s.set_ref(holder_ref)
+ assert s.set_holder(holder_ref)
+
+ holder_copy = s.holder_copy # init_holder_helper(holder_ptr=true, owned=true)
+ assert stats.alive() == 3
+ assert s.set_ref(holder_copy)
+ assert s.set_holder(holder_copy)
+
+ del ref, copy, holder_ref, holder_copy, s
+ assert stats.alive() == 0
+
+
+def test_shared_ptr_from_this_and_references():
+ from pybind11_tests.smart_ptr import SharedFromThisRef, B
+
+ s = SharedFromThisRef()
+ stats = ConstructorStats.get(B)
+ assert stats.alive() == 2
+
+ ref = s.ref # init_holder_helper(holder_ptr=false, owned=false, bad_wp=false)
+ assert stats.alive() == 2
+ assert s.set_ref(ref)
+ assert s.set_holder(ref) # std::enable_shared_from_this can create a holder from a reference
+
+ bad_wp = s.bad_wp # init_holder_helper(holder_ptr=false, owned=false, bad_wp=true)
+ assert stats.alive() == 2
+ assert s.set_ref(bad_wp)
+ with pytest.raises(RuntimeError) as excinfo:
+ assert s.set_holder(bad_wp)
+ assert "Unable to cast from non-held to held instance" in str(excinfo.value)
+
+ copy = s.copy # init_holder_helper(holder_ptr=false, owned=true, bad_wp=false)
+ assert stats.alive() == 3
+ assert s.set_ref(copy)
+ assert s.set_holder(copy)
+
+ holder_ref = s.holder_ref # init_holder_helper(holder_ptr=true, owned=false, bad_wp=false)
+ assert stats.alive() == 3
+ assert s.set_ref(holder_ref)
+ assert s.set_holder(holder_ref)
+
+ holder_copy = s.holder_copy # init_holder_helper(holder_ptr=true, owned=true, bad_wp=false)
+ assert stats.alive() == 3
+ assert s.set_ref(holder_copy)
+ assert s.set_holder(holder_copy)
+
+ del ref, bad_wp, copy, holder_ref, holder_copy, s
+ assert stats.alive() == 0