Use `PyGILState_GetThisThreadState` when using gil_scoped_acquire. (#1211)
This avoids GIL deadlocking when pybind11 tries to acquire the GIL in a thread that already acquired it using standard Python API (e.g. when running from a Python thread).
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index c50ba89..7fa0f0e 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -1872,6 +1872,15 @@
tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate);
if (!tstate) {
+ /* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if
+ calling from a Python thread). Since we use a different key, this ensures
+ we don't create a new thread state and deadlock in PyEval_AcquireThread
+ below. Note we don't save this state with internals.tstate, since we don't
+ create it we would fail to clear it (its reference count should be > 0). */
+ tstate = PyGILState_GetThisThreadState();
+ }
+
+ if (!tstate) {
tstate = PyThreadState_New(internals.istate);
#if !defined(NDEBUG)
if (!tstate)
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index b5e0b52..a31d5b8 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -40,6 +40,7 @@
test_eval.cpp
test_exceptions.cpp
test_factory_constructors.cpp
+ test_gil_scoped.cpp
test_iostream.cpp
test_kwargs_and_defaults.cpp
test_local_bindings.cpp
diff --git a/tests/test_gil_scoped.cpp b/tests/test_gil_scoped.cpp
new file mode 100644
index 0000000..a94b7a2
--- /dev/null
+++ b/tests/test_gil_scoped.cpp
@@ -0,0 +1,43 @@
+/*
+ tests/test_gil_scoped.cpp -- acquire and release gil
+
+ Copyright (c) 2017 Borja Zarco (Google LLC) <bzarco@google.com>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#include "pybind11_tests.h"
+#include <pybind11/functional.h>
+
+
+class VirtClass {
+public:
+ virtual void virtual_func() {}
+ virtual void pure_virtual_func() = 0;
+};
+
+class PyVirtClass : public VirtClass {
+ void virtual_func() override {
+ PYBIND11_OVERLOAD(void, VirtClass, virtual_func,);
+ }
+ void pure_virtual_func() override {
+ PYBIND11_OVERLOAD_PURE(void, VirtClass, pure_virtual_func,);
+ }
+};
+
+TEST_SUBMODULE(gil_scoped, m) {
+ py::class_<VirtClass, PyVirtClass>(m, "VirtClass")
+ .def(py::init<>())
+ .def("virtual_func", &VirtClass::virtual_func)
+ .def("pure_virtual_func", &VirtClass::pure_virtual_func);
+
+ m.def("test_callback_py_obj",
+ [](py::object func) { func(); });
+ m.def("test_callback_std_func",
+ [](const std::function<void()> &func) { func(); });
+ m.def("test_callback_virtual_func",
+ [](VirtClass &virt) { virt.virtual_func(); });
+ m.def("test_callback_pure_virtual_func",
+ [](VirtClass &virt) { virt.pure_virtual_func(); });
+}
diff --git a/tests/test_gil_scoped.py b/tests/test_gil_scoped.py
new file mode 100644
index 0000000..5e70243
--- /dev/null
+++ b/tests/test_gil_scoped.py
@@ -0,0 +1,80 @@
+import multiprocessing
+import threading
+from pybind11_tests import gil_scoped as m
+
+
+def _run_in_process(target, *args, **kwargs):
+ """Runs target in process and returns its exitcode after 1s (None if still alive)."""
+ process = multiprocessing.Process(target=target, args=args, kwargs=kwargs)
+ process.daemon = True
+ try:
+ process.start()
+ # Do not need to wait much, 1s should be more than enough.
+ process.join(timeout=1)
+ return process.exitcode
+ finally:
+ if process.is_alive():
+ process.terminate()
+
+
+def _python_to_cpp_to_python():
+ """Calls different C++ functions that come back to Python."""
+ class ExtendedVirtClass(m.VirtClass):
+ def virtual_func(self):
+ pass
+
+ def pure_virtual_func(self):
+ pass
+
+ extended = ExtendedVirtClass()
+ m.test_callback_py_obj(lambda: None)
+ m.test_callback_std_func(lambda: None)
+ m.test_callback_virtual_func(extended)
+ m.test_callback_pure_virtual_func(extended)
+
+
+def _python_to_cpp_to_python_from_threads(num_threads, parallel=False):
+ """Calls different C++ functions that come back to Python, from Python threads."""
+ threads = []
+ for _ in range(num_threads):
+ thread = threading.Thread(target=_python_to_cpp_to_python)
+ thread.daemon = True
+ thread.start()
+ if parallel:
+ threads.append(thread)
+ else:
+ thread.join()
+ for thread in threads:
+ thread.join()
+
+
+def test_python_to_cpp_to_python_from_thread():
+ """Makes sure there is no GIL deadlock when running in a thread.
+
+ It runs in a separate process to be able to stop and assert if it deadlocks.
+ """
+ assert _run_in_process(_python_to_cpp_to_python_from_threads, 1) == 0
+
+
+def test_python_to_cpp_to_python_from_thread_multiple_parallel():
+ """Makes sure there is no GIL deadlock when running in a thread multiple times in parallel.
+
+ It runs in a separate process to be able to stop and assert if it deadlocks.
+ """
+ assert _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=True) == 0
+
+
+def test_python_to_cpp_to_python_from_thread_multiple_sequential():
+ """Makes sure there is no GIL deadlock when running in a thread multiple times sequentially.
+
+ It runs in a separate process to be able to stop and assert if it deadlocks.
+ """
+ assert _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=False) == 0
+
+
+def test_python_to_cpp_to_python_from_process():
+ """Makes sure there is no GIL deadlock when using processes.
+
+ This test is for completion, but it was never an issue.
+ """
+ assert _run_in_process(_python_to_cpp_to_python) == 0