Added function for reloading module (#1040)
diff --git a/docs/advanced/embedding.rst b/docs/advanced/embedding.rst
index bdfc75e..3930316 100644
--- a/docs/advanced/embedding.rst
+++ b/docs/advanced/embedding.rst
@@ -133,6 +133,11 @@
int n = result.cast<int>();
assert(n == 3);
+Modules can be reloaded using `module::reload()` if the source is modified e.g.
+by an external process. This can be useful in scenarios where the application
+imports a user defined data processing script which needs to be updated after
+changes by the user. Note that this function does not reload modules recursively.
+
.. _embedding_modules:
Adding embedded modules
@@ -185,7 +190,7 @@
namespace py = pybind11;
PYBIND11_EMBEDDED_MODULE(cpp_module, m) {
- m.attr("a") = 1
+ m.attr("a") = 1;
}
int main() {
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index 3fcb99f..80102c5 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -836,6 +836,14 @@
return reinterpret_steal<module>(obj);
}
+ /// Reload the module or throws `error_already_set`.
+ void reload() {
+ PyObject *obj = PyImport_ReloadModule(ptr());
+ if (!obj)
+ throw error_already_set();
+ *this = reinterpret_steal<module>(obj);
+ }
+
// Adds an object to the module using the given name. Throws if an object with the given name
// already exists.
//
diff --git a/tests/test_embed/test_interpreter.cpp b/tests/test_embed/test_interpreter.cpp
index acbad6b..6b5f051 100644
--- a/tests/test_embed/test_interpreter.cpp
+++ b/tests/test_embed/test_interpreter.cpp
@@ -2,6 +2,8 @@
#include <catch.hpp>
#include <thread>
+#include <fstream>
+#include <functional>
namespace py = pybind11;
using namespace py::literals;
@@ -216,3 +218,52 @@
REQUIRE(locals["count"].cast<int>() == num_threads);
}
+
+// Scope exit utility https://stackoverflow.com/a/36644501/7255855
+struct scope_exit {
+ std::function<void()> f_;
+ explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
+ ~scope_exit() { if (f_) f_(); }
+};
+
+TEST_CASE("Reload module from file") {
+ // Disable generation of cached bytecode (.pyc files) for this test, otherwise
+ // Python might pick up an old version from the cache instead of the new versions
+ // of the .py files generated below
+ auto sys = py::module::import("sys");
+ bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
+ sys.attr("dont_write_bytecode") = true;
+ // Reset the value at scope exit
+ scope_exit reset_dont_write_bytecode([&]() {
+ sys.attr("dont_write_bytecode") = dont_write_bytecode;
+ });
+
+ std::string module_name = "test_module_reload";
+ std::string module_file = module_name + ".py";
+
+ // Create the module .py file
+ std::ofstream test_module(module_file);
+ test_module << "def test():\n";
+ test_module << " return 1\n";
+ test_module.close();
+ // Delete the file at scope exit
+ scope_exit delete_module_file([&]() {
+ std::remove(module_file.c_str());
+ });
+
+ // Import the module from file
+ auto module = py::module::import(module_name.c_str());
+ int result = module.attr("test")().cast<int>();
+ REQUIRE(result == 1);
+
+ // Update the module .py file with a small change
+ test_module.open(module_file);
+ test_module << "def test():\n";
+ test_module << " return 2\n";
+ test_module.close();
+
+ // Reload the module
+ module.reload();
+ result = module.attr("test")().cast<int>();
+ REQUIRE(result == 2);
+}