Merge pull request #453 from aldanor/feature/numpy-scalars

NumPy scalars to ctypes conversion support
diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h
index cf9fcf2..1b4f33b 100644
--- a/include/pybind11/cast.h
+++ b/include/pybind11/cast.h
@@ -26,6 +26,7 @@
     void (*init_holder)(PyObject *, const void *);
     std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
     std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
+    std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
     buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
     void *get_buffer_data = nullptr;
     /** A simple type never occurs as a (direct or indirect) parent
@@ -90,7 +91,8 @@
     } while (true);
 }
 
-PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_info &tp, bool throw_if_missing) {
+PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_info &tp,
+                                                          bool throw_if_missing = false) {
     auto &types = get_internals().registered_types_cpp;
 
     auto it = types.find(std::type_index(tp));
@@ -157,7 +159,7 @@
 class type_caster_generic {
 public:
     PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info)
-     : typeinfo(get_type_info(type_info, false)) { }
+     : typeinfo(get_type_info(type_info)) { }
 
     PYBIND11_NOINLINE bool load(handle src, bool convert) {
         if (!src)
@@ -215,6 +217,10 @@
                 if (load(temp, false))
                     return true;
             }
+            for (auto &converter : *typeinfo->direct_conversions) {
+                if (converter(src.ptr(), value))
+                    return true;
+            }
         }
         return false;
     }
diff --git a/include/pybind11/common.h b/include/pybind11/common.h
index 6f79f91..b5434d0 100644
--- a/include/pybind11/common.h
+++ b/include/pybind11/common.h
@@ -321,6 +321,7 @@
     std::unordered_map<const void *, void*> registered_types_py;       // PyTypeObject* -> type_info
     std::unordered_multimap<const void *, void*> registered_instances; // void * -> PyObject*
     std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
+    std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
     std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
 #if defined(WITH_THREAD)
     decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index 4111ccd..4ae3de8 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -63,6 +63,14 @@
     int flags;
 };
 
+struct PyVoidScalarObject_Proxy {
+    PyObject_VAR_HEAD
+    char *obval;
+    PyArrayDescr_Proxy *descr;
+    int flags;
+    PyObject *base;
+};
+
 struct npy_api {
     enum constants {
         NPY_C_CONTIGUOUS_ = 0x0001,
@@ -103,7 +111,9 @@
     PyObject *(*PyArray_DescrNewFromType_)(int);
     PyObject *(*PyArray_NewCopy_)(PyObject *, int);
     PyTypeObject *PyArray_Type_;
+    PyTypeObject *PyVoidArrType_Type_;
     PyTypeObject *PyArrayDescr_Type_;
+    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
     PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
     int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
     bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
@@ -114,7 +124,9 @@
     enum functions {
         API_PyArray_Type = 2,
         API_PyArrayDescr_Type = 3,
+        API_PyVoidArrType_Type = 39,
         API_PyArray_DescrFromType = 45,
+        API_PyArray_DescrFromScalar = 57,
         API_PyArray_FromAny = 69,
         API_PyArray_NewCopy = 85,
         API_PyArray_NewFromDescr = 94,
@@ -136,8 +148,10 @@
         npy_api api;
 #define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
         DECL_NPY_API(PyArray_Type);
+        DECL_NPY_API(PyVoidArrType_Type);
         DECL_NPY_API(PyArrayDescr_Type);
         DECL_NPY_API(PyArray_DescrFromType);
+        DECL_NPY_API(PyArray_DescrFromScalar);
         DECL_NPY_API(PyArray_FromAny);
         DECL_NPY_API(PyArray_NewCopy);
         DECL_NPY_API(PyArray_NewFromDescr);
@@ -658,6 +672,9 @@
     }
 
     static void register_dtype(std::initializer_list<field_descriptor> fields) {
+        if (dtype_ptr)
+            pybind11_fail("NumPy: dtype is already registered");
+
         list names, formats, offsets;
         for (auto field : fields) {
             if (!field.descr)
@@ -700,11 +717,30 @@
         auto arr =  array(buffer_info(nullptr, sizeof(T), format(), 1));
         if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
             pybind11_fail("NumPy: invalid buffer descriptor!");
+
+        register_direct_converter();
     }
 
 private:
     static std::string format_str;
     static PyObject* dtype_ptr;
+
+    static bool direct_converter(PyObject *obj, void*& value) {
+        auto& api = npy_api::get();
+        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
+            return false;
+        if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
+            if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
+                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
+                return true;
+            }
+        }
+        return false;
+    }
+
+    static void register_direct_converter() {
+        get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter);
+    }
 };
 
 template <typename T>
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index 83abe51..114ae97 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -180,8 +180,6 @@
                 a.descr = strdup(a.value.attr("__repr__")().cast<std::string>().c_str());
         }
 
-        auto const &registered_types = detail::get_internals().registered_types_cpp;
-
         /* Generate a proper function signature */
         std::string signature;
         size_t type_depth = 0, char_index = 0, type_index = 0, arg_index = 0;
@@ -216,9 +214,8 @@
                 const std::type_info *t = types[type_index++];
                 if (!t)
                     pybind11_fail("Internal error while parsing type signature (1)");
-                auto it = registered_types.find(std::type_index(*t));
-                if (it != registered_types.end()) {
-                    signature += ((const detail::type_info *) it->second)->type->tp_name;
+                if (auto tinfo = detail::get_type_info(*t)) {
+                    signature += tinfo->type->tp_name;
                 } else {
                     std::string tname(t->name());
                     detail::clean_type_id(tname);
@@ -610,8 +607,7 @@
         auto &internals = get_internals();
         auto tindex = std::type_index(*(rec->type));
 
-        if (internals.registered_types_cpp.find(tindex) !=
-            internals.registered_types_cpp.end())
+        if (get_type_info(*(rec->type)))
             pybind11_fail("generic_type: type \"" + std::string(rec->name) +
                           "\" is already registered!");
 
@@ -672,6 +668,7 @@
         tinfo->type = (PyTypeObject *) type;
         tinfo->type_size = rec->type_size;
         tinfo->init_holder = rec->init_holder;
+        tinfo->direct_conversions = &internals.direct_conversions[tindex];
         internals.registered_types_cpp[tindex] = tinfo;
         internals.registered_types_py[type] = tinfo;
 
@@ -1333,11 +1330,11 @@
             PyErr_Clear();
         return result;
     };
-    auto &registered_types = detail::get_internals().registered_types_cpp;
-    auto it = registered_types.find(std::type_index(typeid(OutputType)));
-    if (it == registered_types.end())
+
+    if (auto tinfo = detail::get_type_info(typeid(OutputType)))
+        tinfo->implicit_conversions.push_back(implicit_caster);
+    else
         pybind11_fail("implicitly_convertible: Unable to find type " + type_id<OutputType>());
-    ((detail::type_info *) it->second)->implicit_conversions.push_back(implicit_caster);
 }
 
 template <typename ExceptionTranslator>
@@ -1589,11 +1586,8 @@
 }
 
 template <class T> function get_overload(const T *this_ptr, const char *name) {
-    auto &cpp_types = detail::get_internals().registered_types_cpp;
-    auto it = cpp_types.find(typeid(T));
-    if (it == cpp_types.end())
-        return function();
-    return get_type_overload(this_ptr, (const detail::type_info *) it->second, name);
+    auto tinfo = detail::get_type_info(typeid(T));
+    return tinfo ? get_type_overload(this_ptr, tinfo, name) : function();
 }
 
 #define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \
diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp
index 86e6e68..40aca0c 100644
--- a/tests/test_numpy_dtypes.cpp
+++ b/tests/test_numpy_dtypes.cpp
@@ -298,6 +298,9 @@
         return;
     }
 
+    // typeinfo may be registered before the dtype descriptor for scalar casts to work...
+    py::class_<SimpleStruct>(m, "SimpleStruct");
+
     PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z);
     PYBIND11_NUMPY_DTYPE(PackedStruct, x, y, z);
     PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
@@ -306,6 +309,9 @@
     PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
     PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
 
+    // ... or after
+    py::class_<PackedStruct>(m, "PackedStruct");
+
     m.def("create_rec_simple", &create_recarray<SimpleStruct>);
     m.def("create_rec_packed", &create_recarray<PackedStruct>);
     m.def("create_rec_nested", &create_nested);
@@ -324,6 +330,10 @@
     m.def("test_array_ctors", &test_array_ctors);
     m.def("test_dtype_ctors", &test_dtype_ctors);
     m.def("test_dtype_methods", &test_dtype_methods);
+    m.def("f_simple", [](SimpleStruct s) { return s.y * 10; });
+    m.def("f_packed", [](PackedStruct s) { return s.y * 10; });
+    m.def("f_nested", [](NestedStruct s) { return s.a.y * 10; });
+    m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z); });
 });
 
 #undef PYBIND11_PACKED
diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py
index 22f5c66..47d7c3b 100644
--- a/tests/test_numpy_dtypes.py
+++ b/tests/test_numpy_dtypes.py
@@ -174,3 +174,34 @@
     from pybind11_tests import create_rec_nested
 
     assert doc(create_rec_nested) == "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
+
+
+@pytest.requires_numpy
+def test_scalar_conversion():
+    from pybind11_tests import (create_rec_simple, f_simple,
+                                create_rec_packed, f_packed,
+                                create_rec_nested, f_nested,
+                                create_enum_array)
+
+    n = 3
+    arrays = [create_rec_simple(n), create_rec_packed(n),
+              create_rec_nested(n), create_enum_array(n)]
+    funcs = [f_simple, f_packed, f_nested]
+
+    for i, func in enumerate(funcs):
+        for j, arr in enumerate(arrays):
+            if i == j and i < 2:
+                assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)]
+            else:
+                with pytest.raises(TypeError) as excinfo:
+                    func(arr[0])
+                assert 'incompatible function arguments' in str(excinfo.value)
+
+
+@pytest.requires_numpy
+def test_register_dtype():
+    from pybind11_tests import register_dtype
+
+    with pytest.raises(RuntimeError) as excinfo:
+        register_dtype()
+    assert 'dtype is already registered' in str(excinfo.value)