Merge pull request #452 from aldanor/feature/numpy-enum

Auto-implement format/numpy descriptors for enum types
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index cee40c8..04001d6 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -552,6 +552,14 @@
     static std::string format() { return std::to_string(N) + "s"; }
 };
 
+template <typename T>
+struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
+    static std::string format() {
+        return format_descriptor<
+            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
+    }
+};
+
 NAMESPACE_BEGIN(detail)
 template <typename T> struct is_std_array : std::false_type { };
 template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
@@ -563,6 +571,7 @@
            !std::is_array<T>::value &&
            !is_std_array<T>::value &&
            !std::is_integral<T>::value &&
+           !std::is_enum<T>::value &&
            !std::is_same<typename std::remove_cv<T>::type, float>::value &&
            !std::is_same<typename std::remove_cv<T>::type, double>::value &&
            !std::is_same<typename std::remove_cv<T>::type, bool>::value &&
@@ -612,6 +621,14 @@
 template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
 #undef DECL_CHAR_FMT
 
+template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
+private:
+    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
+public:
+    static PYBIND11_DESCR name() { return base_descr::name(); }
+    static pybind11::dtype dtype() { return base_descr::dtype(); }
+};
+
 struct field_descriptor {
     const char *name;
     size_t offset;
diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp
index 3041e55..86e6e68 100644
--- a/tests/test_numpy_dtypes.cpp
+++ b/tests/test_numpy_dtypes.cpp
@@ -67,6 +67,14 @@
     std::array<char, 3> b;
 };
 
+enum class E1 : int64_t { A = -1, B = 1 };
+enum E2 : uint8_t { X = 1, Y = 2 };
+
+PYBIND11_PACKED(struct EnumStruct {
+    E1 e1;
+    E2 e2;
+});
+
 std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
     os << "a='";
     for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i];
@@ -75,6 +83,10 @@
     return os << "'";
 }
 
+std::ostream& operator<<(std::ostream& os, const EnumStruct& v) {
+    return os << "e1=" << (v.e1 == E1::A ? "A" : "B") << ",e2=" << (v.e2 == E2::X ? "X" : "Y");
+}
+
 template <typename T>
 py::array mkarray_via_buffer(size_t n) {
     return py::array(py::buffer_info(nullptr, sizeof(T),
@@ -137,6 +149,16 @@
     return arr;
 }
 
+py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
+    auto arr = mkarray_via_buffer<EnumStruct>(n);
+    auto ptr = (EnumStruct *) arr.mutable_data();
+    for (size_t i = 0; i < n; i++) {
+        ptr[i].e1 = static_cast<E1>(-1 + ((int) i % 2) * 2);
+        ptr[i].e2 = static_cast<E2>(1 + (i % 2));
+    }
+    return arr;
+}
+
 template <typename S>
 py::list print_recarray(py::array_t<S, 0> arr) {
     const auto req = arr.request();
@@ -157,7 +179,8 @@
         py::format_descriptor<NestedStruct>::format(),
         py::format_descriptor<PartialStruct>::format(),
         py::format_descriptor<PartialNestedStruct>::format(),
-        py::format_descriptor<StringStruct>::format()
+        py::format_descriptor<StringStruct>::format(),
+        py::format_descriptor<EnumStruct>::format()
     };
     auto l = py::list();
     for (const auto &fmt : fmts) {
@@ -173,7 +196,8 @@
         py::dtype::of<NestedStruct>().str(),
         py::dtype::of<PartialStruct>().str(),
         py::dtype::of<PartialNestedStruct>().str(),
-        py::dtype::of<StringStruct>().str()
+        py::dtype::of<StringStruct>().str(),
+        py::dtype::of<EnumStruct>().str()
     };
     auto l = py::list();
     for (const auto &s : dtypes) {
@@ -280,6 +304,7 @@
     PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
     PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
     PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
+    PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
 
     m.def("create_rec_simple", &create_recarray<SimpleStruct>);
     m.def("create_rec_packed", &create_recarray<PackedStruct>);
@@ -294,6 +319,8 @@
     m.def("get_format_unbound", &get_format_unbound);
     m.def("create_string_array", &create_string_array);
     m.def("print_string_array", &print_recarray<StringStruct>);
+    m.def("create_enum_array", &create_enum_array);
+    m.def("print_enum_array", &print_recarray<EnumStruct>);
     m.def("test_array_ctors", &test_array_ctors);
     m.def("test_dtype_ctors", &test_dtype_ctors);
     m.def("test_dtype_methods", &test_dtype_methods);
diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py
index 2f4cab0..22f5c66 100644
--- a/tests/test_numpy_dtypes.py
+++ b/tests/test_numpy_dtypes.py
@@ -26,7 +26,8 @@
         "T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}",
         "T{=?:x:3x=I:y:=f:z:12x}",
         "T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}",
-        "T{=3s:a:=3s:b:}"
+        "T{=3s:a:=3s:b:}",
+        'T{=q:e1:=B:e2:}'
     ]
 
 
@@ -40,7 +41,8 @@
         "[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]",
         "{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}",
         "{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}",
-        "[('a', 'S3'), ('b', 'S3')]"
+        "[('a', 'S3'), ('b', 'S3')]",
+        "[('e1', '<i8'), ('e2', 'u1')]"
     ]
 
     d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
@@ -151,6 +153,23 @@
 
 
 @pytest.requires_numpy
+def test_enum_array():
+    from pybind11_tests import create_enum_array, print_enum_array
+
+    arr = create_enum_array(3)
+    dtype = arr.dtype
+    assert dtype == np.dtype([('e1', '<i8'), ('e2', 'u1')])
+    assert print_enum_array(arr) == [
+        "e1=A,e2=X",
+        "e1=B,e2=Y",
+        "e1=A,e2=X"
+    ]
+    assert arr['e1'].tolist() == [-1, 1, -1]
+    assert arr['e2'].tolist() == [1, 2, 1]
+    assert create_enum_array(0).dtype == dtype
+
+
+@pytest.requires_numpy
 def test_signature(doc):
     from pybind11_tests import create_rec_nested