improved error handling at module import time
diff --git a/include/pybind11/common.h b/include/pybind11/common.h
index d26c84f..bebcd07 100644
--- a/include/pybind11/common.h
+++ b/include/pybind11/common.h
@@ -67,12 +67,24 @@
#endif
#if PY_MAJOR_VERSION >= 3
-#define PYBIND11_PLUGIN(name) \
+#define PYBIND11_PLUGIN_IMPL(name) \
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
#else
-#define PYBIND11_PLUGIN(name) \
+#define PYBIND11_PLUGIN_IMPL(name) \
extern "C" PYBIND11_EXPORT PyObject *init##name()
#endif
+#define PYBIND11_PLUGIN(name) \
+ static PyObject *pybind11_init(); \
+ PYBIND11_PLUGIN_IMPL(name) { \
+ try { \
+ return pybind11_init(); \
+ } catch (const std::exception &e) { \
+ PyErr_SetString(PyExc_ImportError, e.what()); \
+ return nullptr; \
+ } \
+ } \
+ PyObject *pybind11_init()
+
NAMESPACE_BEGIN(pybind11)
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index eb0bf20..c159b11 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -471,7 +471,10 @@
}
static module import(const char *name) {
- return module(PyImport_ImportModule(name), false);
+ PyObject *obj = PyImport_ImportModule(name);
+ if (!obj)
+ throw std::runtime_error("Module \"" + std::string(name) + "\" not found!");
+ return module(obj, false);
}
};