Merge pull request #472 from aldanor/feature/shared-dtypes

Support for sharing dtypes across extensions + public shared data API
diff --git a/docs/advanced/misc.rst b/docs/advanced/misc.rst
index 2968f8a..b071906 100644
--- a/docs/advanced/misc.rst
+++ b/docs/advanced/misc.rst
@@ -149,6 +149,25 @@
         ...
     };
 
+Note also that it is possible (although would rarely be required) to share arbitrary
+C++ objects between extension modules at runtime. Internal library data is shared
+between modules using capsule machinery [#f6]_ which can be also utilized for
+storing, modifying and accessing user-defined data. Note that an extension module
+will "see" other extensions' data if and only if they were built with the same
+pybind11 version. Consider the following example:
+
+.. code-block:: cpp
+
+    auto data = (MyData *) py::get_shared_data("mydata");
+    if (!data)
+        data = (MyData *) py::set_shared_data("mydata", new MyData(42));
+
+If the above snippet was used in several separately compiled extension modules,
+the first one to be imported would create a ``MyData`` instance and associate
+a ``"mydata"`` key with a pointer to it. Extensions that are imported later
+would be then able to access the data behind the same pointer.
+
+.. [#f6] https://docs.python.org/3/extending/extending.html#using-capsules
 
 
 Generating documentation using Sphinx
diff --git a/include/pybind11/common.h b/include/pybind11/common.h
index b5434d0..62198c3 100644
--- a/include/pybind11/common.h
+++ b/include/pybind11/common.h
@@ -323,6 +323,7 @@
     std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
     std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
     std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
+    std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
 #if defined(WITH_THREAD)
     decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
     PyInterpreterState *istate = nullptr;
@@ -427,6 +428,35 @@
 
 NAMESPACE_END(detail)
 
+/// Returns a named pointer that is shared among all extension modules (using the same
+/// pybind11 version) running in the current interpreter. Names starting with underscores
+/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
+inline PYBIND11_NOINLINE void* get_shared_data(const std::string& name) {
+    auto& internals = detail::get_internals();
+    auto it = internals.shared_data.find(name);
+    return it != internals.shared_data.end() ? it->second : nullptr;
+}
+
+/// Set the shared data that can be later recovered by `get_shared_data()`.
+inline PYBIND11_NOINLINE void *set_shared_data(const std::string& name, void *data) {
+    detail::get_internals().shared_data[name] = data;
+    return data;
+}
+
+/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
+/// such entry exists. Otherwise, a new object of default-constructible type `T` is
+/// added to the shared data under the given name and a reference to it is returned.
+template<typename T> T& get_or_create_shared_data(const std::string& name) {
+    auto& internals = detail::get_internals();
+    auto it = internals.shared_data.find(name);
+    T* ptr = (T*) (it != internals.shared_data.end() ? it->second : nullptr);
+    if (!ptr) {
+        ptr = new T();
+        internals.shared_data[name] = ptr;
+    }
+    return *ptr;
+}
+
 /// Fetch and hold an error which was already set in Python
 class error_already_set : public std::runtime_error {
 public:
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index da04c62..af465a1 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -21,6 +21,7 @@
 #include <initializer_list>
 #include <functional>
 #include <utility>
+#include <typeindex>
 
 #if defined(_MSC_VER)
 #  pragma warning(push)
@@ -72,6 +73,39 @@
     PyObject *base;
 };
 
+struct numpy_type_info {
+    PyObject* dtype_ptr;
+    std::string format_str;
+};
+
+struct numpy_internals {
+    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
+
+    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
+        auto it = registered_dtypes.find(std::type_index(tinfo));
+        if (it != registered_dtypes.end())
+            return &(it->second);
+        if (throw_if_missing)
+            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
+        return nullptr;
+    }
+
+    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
+        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
+    }
+};
+
+inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
+    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
+}
+
+inline numpy_internals& get_numpy_internals() {
+    static numpy_internals* ptr = nullptr;
+    if (!ptr)
+        load_numpy_internals(ptr);
+    return *ptr;
+}
+
 struct npy_api {
     enum constants {
         NPY_C_CONTIGUOUS_ = 0x0001,
@@ -656,99 +690,100 @@
     dtype descr;
 };
 
+inline PYBIND11_NOINLINE void register_structured_dtype(
+    const std::initializer_list<field_descriptor>& fields,
+    const std::type_info& tinfo, size_t itemsize,
+    bool (*direct_converter)(PyObject *, void *&))
+{
+    auto& numpy_internals = get_numpy_internals();
+    if (numpy_internals.get_type_info(tinfo, false))
+        pybind11_fail("NumPy: dtype is already registered");
+
+    list names, formats, offsets;
+    for (auto field : fields) {
+        if (!field.descr)
+            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
+                            field.name + "` @ " + tinfo.name());
+        names.append(PYBIND11_STR_TYPE(field.name));
+        formats.append(field.descr);
+        offsets.append(pybind11::int_(field.offset));
+    }
+    auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
+
+    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
+    // not encoded explicitly into the format string. This will supposedly
+    // get fixed in v1.12; for further details, see these:
+    // - https://github.com/numpy/numpy/issues/7797
+    // - https://github.com/numpy/numpy/pull/7798
+    // Because of this, we won't use numpy's logic to generate buffer format
+    // strings and will just do it ourselves.
+    std::vector<field_descriptor> ordered_fields(fields);
+    std::sort(ordered_fields.begin(), ordered_fields.end(),
+        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
+    size_t offset = 0;
+    std::ostringstream oss;
+    oss << "T{";
+    for (auto& field : ordered_fields) {
+        if (field.offset > offset)
+            oss << (field.offset - offset) << 'x';
+        // note that '=' is required to cover the case of unaligned fields
+        oss << '=' << field.format << ':' << field.name << ':';
+        offset = field.offset + field.size;
+    }
+    if (itemsize > offset)
+        oss << (itemsize - offset) << 'x';
+    oss << '}';
+    auto format_str = oss.str();
+
+    // Sanity check: verify that NumPy properly parses our buffer format string
+    auto& api = npy_api::get();
+    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
+    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
+        pybind11_fail("NumPy: invalid buffer descriptor!");
+
+    auto tindex = std::type_index(tinfo);
+    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
+    get_internals().direct_conversions[tindex].push_back(direct_converter);
+}
+
 template <typename T>
 struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
     static PYBIND11_DESCR name() { return _("struct"); }
 
     static pybind11::dtype dtype() {
-        if (!dtype_ptr)
-            pybind11_fail("NumPy: unsupported buffer format!");
-        return object(dtype_ptr, true);
+        return object(dtype_ptr(), true);
     }
 
     static std::string format() {
-        if (!dtype_ptr)
-            pybind11_fail("NumPy: unsupported buffer format!");
+        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
         return format_str;
     }
 
-    static void register_dtype(std::initializer_list<field_descriptor> fields) {
-        if (dtype_ptr)
-            pybind11_fail("NumPy: dtype is already registered");
-
-        list names, formats, offsets;
-        for (auto field : fields) {
-            if (!field.descr)
-                pybind11_fail("NumPy: unsupported field dtype");
-            names.append(PYBIND11_STR_TYPE(field.name));
-            formats.append(field.descr);
-            offsets.append(pybind11::int_(field.offset));
-        }
-        dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
-
-        // There is an existing bug in NumPy (as of v1.11): trailing bytes are
-        // not encoded explicitly into the format string. This will supposedly
-        // get fixed in v1.12; for further details, see these:
-        // - https://github.com/numpy/numpy/issues/7797
-        // - https://github.com/numpy/numpy/pull/7798
-        // Because of this, we won't use numpy's logic to generate buffer format
-        // strings and will just do it ourselves.
-        std::vector<field_descriptor> ordered_fields(fields);
-        std::sort(ordered_fields.begin(), ordered_fields.end(),
-                  [](const field_descriptor &a, const field_descriptor &b) {
-                      return a.offset < b.offset;
-                  });
-        size_t offset = 0;
-        std::ostringstream oss;
-        oss << "T{";
-        for (auto& field : ordered_fields) {
-            if (field.offset > offset)
-                oss << (field.offset - offset) << 'x';
-            // note that '=' is required to cover the case of unaligned fields
-            oss << '=' << field.format << ':' << field.name << ':';
-            offset = field.offset + field.size;
-        }
-        if (sizeof(T) > offset)
-            oss << (sizeof(T) - offset) << 'x';
-        oss << '}';
-        format_str = oss.str();
-
-        // Sanity check: verify that NumPy properly parses our buffer format string
-        auto& api = npy_api::get();
-        auto arr =  array(buffer_info(nullptr, sizeof(T), format(), 1));
-        if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
-            pybind11_fail("NumPy: invalid buffer descriptor!");
-
-        register_direct_converter();
+    static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
+        register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
+                                  sizeof(T), &direct_converter);
     }
 
 private:
-    static std::string format_str;
-    static PyObject* dtype_ptr;
+    static PyObject* dtype_ptr() {
+        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
+        return ptr;
+    }
 
     static bool direct_converter(PyObject *obj, void*& value) {
         auto& api = npy_api::get();
         if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
             return false;
         if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
-            if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
+            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
                 value = ((PyVoidScalarObject_Proxy *) obj)->obval;
                 return true;
             }
         }
         return false;
     }
-
-    static void register_direct_converter() {
-        get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter);
-    }
 };
 
-template <typename T>
-std::string npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::format_str;
-template <typename T>
-PyObject* npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::dtype_ptr = nullptr;
-
 #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
     ::pybind11::detail::field_descriptor {                                                    \
         Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py
index b4e6d71..2ef6f4d 100644
--- a/tests/test_numpy_dtypes.py
+++ b/tests/test_numpy_dtypes.py
@@ -1,11 +1,20 @@
+import re
 import pytest
+
 with pytest.suppress(ImportError):
     import numpy as np
 
-    simple_dtype = np.dtype({'names': ['x', 'y', 'z'],
-                             'formats': ['?', 'u4', 'f4'],
-                             'offsets': [0, 4, 8]})
-    packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
+
+@pytest.fixture(scope='module')
+def simple_dtype():
+    return np.dtype({'names': ['x', 'y', 'z'],
+                     'formats': ['?', 'u4', 'f4'],
+                     'offsets': [0, 4, 8]})
+
+
+@pytest.fixture(scope='module')
+def packed_dtype():
+    return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
 
 
 def assert_equal(actual, expected_data, expected_dtype):
@@ -18,7 +27,7 @@
 
     with pytest.raises(RuntimeError) as excinfo:
         get_format_unbound()
-    assert 'unsupported buffer format' in str(excinfo.value)
+    assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))
 
     assert print_format_descriptors() == [
         "T{=?:x:3x=I:y:=f:z:}",
@@ -32,7 +41,7 @@
 
 
 @pytest.requires_numpy
-def test_dtype():
+def test_dtype(simple_dtype):
     from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods
 
     assert print_dtypes() == [
@@ -57,7 +66,7 @@
 
 
 @pytest.requires_numpy
-def test_recarray():
+def test_recarray(simple_dtype, packed_dtype):
     from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested,
                                 print_rec_simple, print_rec_packed, print_rec_nested,
                                 create_rec_partial, create_rec_partial_nested)