Allow module-local classes to be loaded externally
The main point of `py::module_local` is to make the C++ -> Python cast
unique so that returning/casting a C++ instance is well-defined.
Unfortunately it also makes loading unique, but this isn't particularly
desirable: when an instance contains `Type` instance there's no reason
it shouldn't be possible to pass that instance to a bound function
taking a `Type` parameter, even if that function is in another module.
This commit solves the issue by allowing foreign module (and global)
type loaders have a chance to load the value if the local module loader
fails. The implementation here does this by storing a module-local
loading function in a capsule in the python type, which we can then call
if the local (and possibly global, if the local type is masking a global
type) version doesn't work.
diff --git a/docs/advanced/classes.rst b/docs/advanced/classes.rst
index cdac77d..569102e 100644
--- a/docs/advanced/classes.rst
+++ b/docs/advanced/classes.rst
@@ -775,27 +775,29 @@
When creating a binding for a class, pybind by default makes that binding
"global" across modules. What this means is that a type defined in one module
-can be passed to functions of other modules that expect the same C++ type. For
+can be returned from any module resulting in the same Python type. For
example, this allows the following:
.. code-block:: cpp
// In the module1.cpp binding code for module1:
py::class_<Pet>(m, "Pet")
- .def(py::init<std::string>());
+ .def(py::init<std::string>())
+ .def_readonly("name", &Pet::name);
.. code-block:: cpp
// In the module2.cpp binding code for module2:
- m.def("pet_name", [](Pet &p) { return p.name(); });
+ m.def("create_pet", [](std::string name) { return new Pet(name); });
.. code-block:: pycon
>>> from module1 import Pet
- >>> from module2 import pet_name
- >>> mypet = Pet("Kitty")
- >>> pet_name(mypet)
- 'Kitty'
+ >>> from module2 import create_pet
+ >>> pet1 = Pet("Kitty")
+ >>> pet2 = create_pet("Doggy")
+ >>> pet2.name()
+ 'Doggy'
When writing binding code for a library, this is usually desirable: this
allows, for example, splitting up a complex library into multiple Python
@@ -855,39 +857,45 @@
py::class<pets::Pet>(m, "Pet", py::module_local())
.def("get_name", &pets::Pet::name);
-This makes the Python-side ``dogs.Pet`` and ``cats.Pet`` into distinct classes
-that can only be accepted as ``Pet`` arguments within those classes. This
-avoids the conflict and allows both modules to be loaded.
+This makes the Python-side ``dogs.Pet`` and ``cats.Pet`` into distinct classes,
+avoiding the conflict and allowing both modules to be loaded. C++ code in the
+``dogs`` module that casts or returns a ``Pet`` instance will result in a
+``dogs.Pet`` Python instance, while C++ code in the ``cats`` module will result
+in a ``cats.Pet`` Python instance.
-One limitation of this approach is that because ``py::module_local`` types are
-distinct on the Python side, it is not possible to pass such a module-local
-type as a C++ ``Pet``-taking function outside that module. For example, if the
-above ``cats`` and ``dogs`` module are each extended with a function:
+This does come with two caveats, however: First, external modules cannot return
+or cast a ``Pet`` instance to Python (unless they also provide their own local
+bindings). Second, from the Python point of view they are two distinct classes.
+
+Note that the locality only applies in the C++ -> Python direction. When
+passing such a ``py::module_local`` type into a C++ function, the module-local
+classes are still considered. This means that if the following function is
+added to any module (including but not limited to the ``cats`` and ``dogs``
+modules above) it will be callable with either a ``dogs.Pet`` or ``cats.Pet``
+argument:
.. code-block:: cpp
- m.def("petname", [](pets::Pet &p) { return p.name(); });
+ m.def("pet_name", [](const pets::Pet &pet) { return pet.name(); });
-you will only be able to call the function with the local module's class:
+For example, suppose the above function is added to each of ``cats.cpp``,
+``dogs.cpp`` and ``frogs.cpp`` (where ``frogs.cpp`` is some other module that
+does *not* bind ``Pets`` at all).
.. code-block:: pycon
- >>> import cats, dogs # No error because of the added py::module_local()
+ >>> import cats, dogs, frogs # No error because of the added py::module_local()
>>> mycat, mydog = cats.Cat("Fluffy"), dogs.Dog("Rover")
- >>> (cats.petname(mycat), dogs.petname(mydog))
+ >>> (cats.pet_name(mycat), dogs.pet_name(mydog))
('Fluffy', 'Rover')
- >>> cats.petname(mydog)
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- TypeError: petname(): incompatible function arguments. The following argument types are supported:
- 1. (arg0: cats.Pet) -> str
+ >>> (cats.pet_name(mydog), dogs.pet_name(mycat), frogs.pet_name(mycat))
+ ('Rover', 'Fluffy', 'Fluffy')
- Invoked with: <dogs.Dog object at 0x123>
-
-It is possible to use ``py::module_local()`` registrations in one module even if another module
-registers the same type globally: within the module with the module-local definition, all C++
-instances will be cast to the associated bound Python type. Outside the module, any such values
-are converted to the global Python type created elsewhere.
+It is possible to use ``py::module_local()`` registrations in one module even
+if another module registers the same type globally: within the module with the
+module-local definition, all C++ instances will be cast to the associated bound
+Python type. In other modules any such values are converted to the global
+Python type created elsewhere.
.. note::
diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h
index a0bc8d3..805fea6 100644
--- a/include/pybind11/cast.h
+++ b/include/pybind11/cast.h
@@ -51,6 +51,7 @@
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
void *get_buffer_data = nullptr;
+ void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
/* A simple type never occurs as a (direct or indirect) parent
* of a class that makes use of multiple inheritance */
bool simple_type : 1;
@@ -265,23 +266,30 @@
return bases.front();
}
-/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr.
-/// `check_global_types` can be specified as `false` to only check types registered locally to the
-/// current module.
-PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp,
- bool throw_if_missing = false,
- bool check_global_types = true) {
- std::type_index type_idx(tp);
+inline detail::type_info *get_local_type_info(const std::type_index &tp) {
auto &locals = registered_local_types_cpp();
- auto it = locals.find(type_idx);
+ auto it = locals.find(tp);
if (it != locals.end())
return (detail::type_info *) it->second;
- if (check_global_types) {
- auto &types = get_internals().registered_types_cpp;
- it = types.find(type_idx);
- if (it != types.end())
- return (detail::type_info *) it->second;
- }
+ return nullptr;
+}
+
+inline detail::type_info *get_global_type_info(const std::type_index &tp) {
+ auto &types = get_internals().registered_types_cpp;
+ auto it = types.find(tp);
+ if (it != types.end())
+ return (detail::type_info *) it->second;
+ return nullptr;
+}
+
+/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr.
+PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp,
+ bool throw_if_missing = false) {
+ if (auto ltype = get_local_type_info(tp))
+ return ltype;
+ if (auto gtype = get_global_type_info(tp))
+ return gtype;
+
if (throw_if_missing) {
std::string tname = tp.name();
detail::clean_type_id(tname);
@@ -578,6 +586,8 @@
PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info)
: typeinfo(get_type_info(type_info)) { }
+ type_caster_generic(const type_info *typeinfo) : typeinfo(typeinfo) { }
+
bool load(handle src, bool convert) {
return load_impl<type_caster_generic>(src, convert);
}
@@ -597,7 +607,7 @@
auto it_instances = get_internals().registered_instances.equal_range(src);
for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) {
for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) {
- if (instance_type && instance_type == tinfo)
+ if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype))
return handle((PyObject *) it_i->second).inc_ref();
}
}
@@ -655,8 +665,6 @@
return inst.release();
}
-protected:
-
// Base methods for generic caster; there are overridden in copyable_holder_caster
void load_value(value_and_holder &&v_h) {
auto *&vptr = v_h.value_ptr();
@@ -686,13 +694,41 @@
}
void check_holder_compat() {}
+ PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) {
+ auto caster = type_caster_generic(ti);
+ if (caster.load(src, false))
+ return caster.value;
+ return nullptr;
+ }
+
+ /// Try to load with foreign typeinfo, if available. Used when there is no
+ /// native typeinfo, or when the native one wasn't able to produce a value.
+ PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) {
+ constexpr auto *local_key = "_pybind11_module_local_typeinfo";
+ const auto pytype = src.get_type();
+ if (!hasattr(pytype, local_key))
+ return false;
+
+ type_info *foreign_typeinfo = reinterpret_borrow<capsule>(getattr(pytype, local_key));
+ // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type
+ if (foreign_typeinfo->module_local_load == &local_load
+ || !same_type(*typeinfo->cpptype, *foreign_typeinfo->cpptype))
+ return false;
+
+ if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) {
+ value = result;
+ return true;
+ }
+ return false;
+ }
+
// Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant
// bits of code between here and copyable_holder_caster where the two classes need different
// logic (without having to resort to virtual inheritance).
template <typename ThisT>
PYBIND11_NOINLINE bool load_impl(handle src, bool convert) {
- if (!src || !typeinfo)
- return false;
+ if (!src) return false;
+ if (!typeinfo) return try_load_foreign_module_local(src);
if (src.is_none()) {
// Defer accepting None to other overloads (if we aren't in convert mode):
if (!convert) return false;
@@ -757,7 +793,17 @@
if (this_.try_direct_conversions(src))
return true;
}
- return false;
+
+ // Failed to match local typeinfo. Try again with global.
+ if (typeinfo->module_local) {
+ if (auto gtype = get_global_type_info(*typeinfo->cpptype)) {
+ typeinfo = gtype;
+ return load(src, false);
+ }
+ }
+
+ // Global typeinfo has precedence over foreign module_local
+ return try_load_foreign_module_local(src);
}
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index afa91ac..5c78abe 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -829,7 +829,7 @@
pybind11_fail("generic_type: cannot initialize type \"" + std::string(rec.name) +
"\": an object with that name is already defined");
- if (get_type_info(*rec.type, false /* don't throw */, !rec.module_local))
+ if (rec.module_local ? get_local_type_info(*rec.type) : get_global_type_info(*rec.type))
pybind11_fail("generic_type: type \"" + std::string(rec.name) +
"\" is already registered!");
@@ -866,6 +866,12 @@
auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr());
tinfo->simple_ancestors = parent_tinfo->simple_ancestors;
}
+
+ if (rec.module_local) {
+ // Stash the local typeinfo and loader so that external modules can access it.
+ tinfo->module_local_load = &type_caster_generic::local_load;
+ setattr(m_ptr, "_pybind11_module_local_typeinfo", capsule(tinfo));
+ }
}
/// Helper function which tags all parents of a type using mult. inheritance
diff --git a/tests/local_bindings.h b/tests/local_bindings.h
index 06a56fc..8f48ed9 100644
--- a/tests/local_bindings.h
+++ b/tests/local_bindings.h
@@ -18,7 +18,7 @@
using LocalExternal = LocalBase<3>;
/// Mixed: registered local first, then global
using MixedLocalGlobal = LocalBase<4>;
-/// Mixed: global first, then local (which fails)
+/// Mixed: global first, then local
using MixedGlobalLocal = LocalBase<5>;
using LocalVec = std::vector<LocalType>;
@@ -29,6 +29,15 @@
using NonLocalMap = std::unordered_map<std::string, NonLocalType>;
using NonLocalMap2 = std::unordered_map<std::string, uint8_t>;
+PYBIND11_MAKE_OPAQUE(LocalVec);
+PYBIND11_MAKE_OPAQUE(LocalVec2);
+PYBIND11_MAKE_OPAQUE(LocalMap);
+PYBIND11_MAKE_OPAQUE(NonLocalVec);
+//PYBIND11_MAKE_OPAQUE(NonLocalVec2); // same type as LocalVec2
+PYBIND11_MAKE_OPAQUE(NonLocalMap);
+PYBIND11_MAKE_OPAQUE(NonLocalMap2);
+
+
// Simple bindings (used with the above):
template <typename T, int Adjust, typename... Args>
py::class_<T> bind_local(Args && ...args) {
@@ -36,3 +45,16 @@
.def(py::init<int>())
.def("get", [](T &i) { return i.i + Adjust; });
};
+
+// Simulate a foreign library base class (to match the example in the docs):
+namespace pets {
+class Pet {
+public:
+ Pet(std::string name) : name_(name) {}
+ std::string name_;
+ const std::string &name() { return name_; }
+};
+}
+
+struct MixGL { int i; MixGL(int i) : i{i} {} };
+struct MixGL2 { int i; MixGL2(int i) : i{i} {} };
diff --git a/tests/pybind11_cross_module_tests.cpp b/tests/pybind11_cross_module_tests.cpp
index 252f893..2091624 100644
--- a/tests/pybind11_cross_module_tests.cpp
+++ b/tests/pybind11_cross_module_tests.cpp
@@ -87,4 +87,21 @@
m.def("load_vector_via_binding", [](std::vector<int> &v) {
return std::accumulate(v.begin(), v.end(), 0);
});
+
+ // test_cross_module_calls
+ m.def("return_self", [](LocalVec *v) { return v; });
+ m.def("return_copy", [](const LocalVec &v) { return LocalVec(v); });
+
+ class Dog : public pets::Pet { public: Dog(std::string name) : Pet(name) {}; };
+ py::class_<pets::Pet>(m, "Pet", py::module_local())
+ .def("name", &pets::Pet::name);
+ // Binding for local extending class:
+ py::class_<Dog, pets::Pet>(m, "Dog")
+ .def(py::init<std::string>());
+ m.def("pet_name", [](pets::Pet &p) { return p.name(); });
+
+ py::class_<MixGL>(m, "MixGL", py::module_local()).def(py::init<int>());
+ m.def("get_gl_value", [](MixGL &o) { return o.i + 100; });
+
+ py::class_<MixGL2>(m, "MixGL2", py::module_local()).def(py::init<int>());
}
diff --git a/tests/test_local_bindings.cpp b/tests/test_local_bindings.cpp
index fdb67a1..359d6c6 100644
--- a/tests/test_local_bindings.cpp
+++ b/tests/test_local_bindings.cpp
@@ -14,13 +14,6 @@
#include <pybind11/stl_bind.h>
#include <numeric>
-PYBIND11_MAKE_OPAQUE(LocalVec);
-PYBIND11_MAKE_OPAQUE(LocalVec2);
-PYBIND11_MAKE_OPAQUE(LocalMap);
-PYBIND11_MAKE_OPAQUE(NonLocalVec);
-PYBIND11_MAKE_OPAQUE(NonLocalMap);
-PYBIND11_MAKE_OPAQUE(NonLocalMap2);
-
TEST_SUBMODULE(local_bindings, m) {
// test_local_bindings
// Register a class with py::module_local:
@@ -84,4 +77,21 @@
m.def("load_vector_via_caster", [](std::vector<int> v) {
return std::accumulate(v.begin(), v.end(), 0);
});
+
+ // test_cross_module_calls
+ m.def("return_self", [](LocalVec *v) { return v; });
+ m.def("return_copy", [](const LocalVec &v) { return LocalVec(v); });
+
+ class Cat : public pets::Pet { public: Cat(std::string name) : Pet(name) {}; };
+ py::class_<pets::Pet>(m, "Pet", py::module_local())
+ .def("get_name", &pets::Pet::name);
+ // Binding for local extending class:
+ py::class_<Cat, pets::Pet>(m, "Cat")
+ .def(py::init<std::string>());
+ m.def("pet_name", [](pets::Pet &p) { return p.name(); });
+
+ py::class_<MixGL>(m, "MixGL").def(py::init<int>());
+ m.def("get_gl_value", [](MixGL &o) { return o.i + 10; });
+
+ py::class_<MixGL2>(m, "MixGL2").def(py::init<int>());
}
diff --git a/tests/test_local_bindings.py b/tests/test_local_bindings.py
index 3a6ad8b..e9a63ca 100644
--- a/tests/test_local_bindings.py
+++ b/tests/test_local_bindings.py
@@ -20,16 +20,14 @@
assert not hasattr(i1, 'get2')
assert not hasattr(i2, 'get3')
+ # Loading within the local module
assert m.local_value(i1) == 5
assert cm.local_value(i2) == 10
- with pytest.raises(TypeError) as excinfo:
- m.local_value(i2)
- assert "incompatible function arguments" in str(excinfo.value)
-
- with pytest.raises(TypeError) as excinfo:
- cm.local_value(i1)
- assert "incompatible function arguments" in str(excinfo.value)
+ # Cross-module loading works as well (on failure, the type loader looks for
+ # external module-local converters):
+ assert m.local_value(i2) == 10
+ assert cm.local_value(i1) == 5
def test_nonlocal_failure():
@@ -60,13 +58,12 @@
v2.append(cm.LocalType(1))
v2.append(cm.LocalType(2))
- with pytest.raises(TypeError):
- v1.append(cm.LocalType(3))
- with pytest.raises(TypeError):
- v2.append(m.LocalType(3))
+ # Cross module value loading:
+ v1.append(cm.LocalType(3))
+ v2.append(m.LocalType(3))
- assert [i.get() for i in v1] == [0, 1]
- assert [i.get() for i in v2] == [2, 3]
+ assert [i.get() for i in v1] == [0, 1, 2]
+ assert [i.get() for i in v2] == [2, 3, 4]
v3, v4 = m.NonLocalVec(), cm.NonLocalVec2()
v3.append(m.NonLocalType(1))
@@ -158,3 +155,56 @@
Invoked with: [1, 2, 3]
""" # noqa: E501 line too long
+
+
+def test_cross_module_calls():
+ import pybind11_cross_module_tests as cm
+
+ v1 = m.LocalVec()
+ v1.append(m.LocalType(1))
+ v2 = cm.LocalVec()
+ v2.append(cm.LocalType(2))
+
+ # Returning the self pointer should get picked up as returning an existing
+ # instance (even when that instance is of a foreign, non-local type).
+ assert m.return_self(v1) is v1
+ assert cm.return_self(v2) is v2
+ assert m.return_self(v2) is v2
+ assert cm.return_self(v1) is v1
+
+ assert m.LocalVec is not cm.LocalVec
+ # Returning a copy, on the other hand, always goes to the local type,
+ # regardless of where the source type came from.
+ assert type(m.return_copy(v1)) is m.LocalVec
+ assert type(m.return_copy(v2)) is m.LocalVec
+ assert type(cm.return_copy(v1)) is cm.LocalVec
+ assert type(cm.return_copy(v2)) is cm.LocalVec
+
+ # Test the example given in the documentation (which also tests inheritance casting):
+ mycat = m.Cat("Fluffy")
+ mydog = cm.Dog("Rover")
+ assert mycat.get_name() == "Fluffy"
+ assert mydog.name() == "Rover"
+ assert m.Cat.__base__.__name__ == "Pet"
+ assert cm.Dog.__base__.__name__ == "Pet"
+ assert m.Cat.__base__ is not cm.Dog.__base__
+ assert m.pet_name(mycat) == "Fluffy"
+ assert m.pet_name(mydog) == "Rover"
+ assert cm.pet_name(mycat) == "Fluffy"
+ assert cm.pet_name(mydog) == "Rover"
+
+ assert m.MixGL is not cm.MixGL
+ a = m.MixGL(1)
+ b = cm.MixGL(2)
+ assert m.get_gl_value(a) == 11
+ assert m.get_gl_value(b) == 12
+ assert cm.get_gl_value(a) == 101
+ assert cm.get_gl_value(b) == 102
+
+ c, d = m.MixGL2(3), cm.MixGL2(4)
+ with pytest.raises(TypeError) as excinfo:
+ m.get_gl_value(c)
+ assert "incompatible function arguments" in str(excinfo)
+ with pytest.raises(TypeError) as excinfo:
+ m.get_gl_value(d)
+ assert "incompatible function arguments" in str(excinfo)