Add check for matching holder_type when inheriting (#588)
diff --git a/include/pybind11/attr.h b/include/pybind11/attr.h
index 0676d5d..740d3be 100644
--- a/include/pybind11/attr.h
+++ b/include/pybind11/attr.h
@@ -185,6 +185,9 @@
/// Does the class require its own metaclass?
bool metaclass : 1;
+ /// Is the default (unique_ptr) holder type used?
+ bool default_holder : 1;
+
PYBIND11_NOINLINE void add_base(const std::type_info *base, void *(*caster)(void *)) {
auto base_info = detail::get_type_info(*base, false);
if (!base_info) {
@@ -194,6 +197,15 @@
"\" referenced unknown base type \"" + tname + "\"");
}
+ if (default_holder != base_info->default_holder) {
+ std::string tname(base->name());
+ detail::clean_type_id(tname);
+ pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
+ (default_holder ? "does not have" : "has") +
+ " a non-default holder type while its base \"" + tname + "\" " +
+ (base_info->default_holder ? "does not" : "does"));
+ }
+
bases.append((PyObject *) base_info->type);
if (base_info->type->tp_dictoffset != 0)
diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h
index b953cc8..9077dbb 100644
--- a/include/pybind11/cast.h
+++ b/include/pybind11/cast.h
@@ -32,6 +32,8 @@
/** A simple type never occurs as a (direct or indirect) parent
* of a class that makes use of multiple inheritance */
bool simple_type = true;
+ /* for base vs derived holder_type checks */
+ bool default_holder = true;
};
PYBIND11_NOINLINE inline internals &get_internals() {
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index addcce7..99b1f72 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -741,6 +741,7 @@
tinfo->type_size = rec->type_size;
tinfo->init_holder = rec->init_holder;
tinfo->direct_conversions = &internals.direct_conversions[tindex];
+ tinfo->default_holder = rec->default_holder;
internals.registered_types_cpp[tindex] = tinfo;
internals.registered_types_py[type] = tinfo;
@@ -1006,6 +1007,7 @@
record.instance_size = sizeof(instance_type);
record.init_holder = init_holder;
record.dealloc = dealloc;
+ record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;
/* Register base classes specified via template arguments to class_, if any */
bool unused[] = { (add_base<options>(record), false)..., false };
diff --git a/tests/test_inheritance.cpp b/tests/test_inheritance.cpp
index 2ec0b4a..914b7a8 100644
--- a/tests/test_inheritance.cpp
+++ b/tests/test_inheritance.cpp
@@ -49,6 +49,12 @@
struct DerivedClass1 : BaseClass { };
struct DerivedClass2 : BaseClass { };
+struct MismatchBase1 { };
+struct MismatchDerived1 : MismatchBase1 { };
+
+struct MismatchBase2 { };
+struct MismatchDerived2 : MismatchBase2 { };
+
test_initializer inheritance([](py::module &m) {
py::class_<Pet> pet_class(m, "Pet");
pet_class
@@ -97,4 +103,15 @@
py::isinstance<Unregistered>(l[6])
);
});
+
+ m.def("test_mismatched_holder_type_1", []() {
+ auto m = py::module::import("__main__");
+ py::class_<MismatchBase1, std::shared_ptr<MismatchBase1>>(m, "MismatchBase1");
+ py::class_<MismatchDerived1, MismatchBase1>(m, "MismatchDerived1");
+ });
+ m.def("test_mismatched_holder_type_2", []() {
+ auto m = py::module::import("__main__");
+ py::class_<MismatchBase2>(m, "MismatchBase2");
+ py::class_<MismatchDerived2, std::shared_ptr<MismatchDerived2>, MismatchBase2>(m, "MismatchDerived2");
+ });
});
diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py
index 7bb52be..e4ab202 100644
--- a/tests/test_inheritance.py
+++ b/tests/test_inheritance.py
@@ -37,7 +37,8 @@
assert type(return_class_1()).__name__ == "DerivedClass1"
assert type(return_class_2()).__name__ == "DerivedClass2"
assert type(return_none()).__name__ == "NoneType"
- # Repeat these a few times in a random order to ensure no invalid caching is applied
+ # Repeat these a few times in a random order to ensure no invalid caching
+ # is applied
assert type(return_class_n(1)).__name__ == "DerivedClass1"
assert type(return_class_n(2)).__name__ == "DerivedClass2"
assert type(return_class_n(0)).__name__ == "BaseClass"
@@ -53,3 +54,21 @@
objects = [tuple(), dict(), Pet("Polly", "parrot")] + [Dog("Molly")] * 4
expected = (True, True, True, True, True, False, False)
assert test_isinstance(objects) == expected
+
+
+def test_holder():
+ from pybind11_tests import test_mismatched_holder_type_1, test_mismatched_holder_type_2
+
+ with pytest.raises(RuntimeError) as excinfo:
+ test_mismatched_holder_type_1()
+
+ assert str(excinfo.value) == ("generic_type: type \"MismatchDerived1\" does not have "
+ "a non-default holder type while its base "
+ "\"MismatchBase1\" does")
+
+ with pytest.raises(RuntimeError) as excinfo:
+ test_mismatched_holder_type_2()
+
+ assert str(excinfo.value) == ("generic_type: type \"MismatchDerived2\" has a "
+ "non-default holder type while its base "
+ "\"MismatchBase2\" does not")