Improve py::array_t scalar type information (#724)
* Add value_type member alias to py::array_t (resolve #632)
* Use numpy scalar name in py::array_t function signatures (e.g. float32/64 instead of just float)
diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp
index 58a2052..8899644 100644
--- a/tests/test_numpy_array.cpp
+++ b/tests/test_numpy_array.cpp
@@ -17,6 +17,7 @@
using arr = py::array;
using arr_t = py::array_t<uint16_t, 0>;
+static_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
template<typename... Ix> arr data(const arr& a, Ix... index) {
return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py
index b58aa1b..7109ff3 100644
--- a/tests/test_numpy_array.py
+++ b/tests/test_numpy_array.py
@@ -279,6 +279,21 @@
# No exact match, should call first convertible version:
assert overloaded(np.array([1], dtype='uint8')) == 'double'
+ with pytest.raises(TypeError) as excinfo:
+ overloaded("not an array")
+ assert msg(excinfo.value) == """
+ overloaded(): incompatible function arguments. The following argument types are supported:
+ 1. (arg0: numpy.ndarray[float64]) -> str
+ 2. (arg0: numpy.ndarray[float32]) -> str
+ 3. (arg0: numpy.ndarray[int32]) -> str
+ 4. (arg0: numpy.ndarray[uint16]) -> str
+ 5. (arg0: numpy.ndarray[int64]) -> str
+ 6. (arg0: numpy.ndarray[complex128]) -> str
+ 7. (arg0: numpy.ndarray[complex64]) -> str
+
+ Invoked with: 'not an array'
+ """
+
assert overloaded2(np.array([1], dtype='float64')) == 'double'
assert overloaded2(np.array([1], dtype='float32')) == 'float'
assert overloaded2(np.array([1], dtype='complex64')) == 'float complex'
@@ -289,8 +304,8 @@
assert overloaded3(np.array([1], dtype='intc')) == 'int'
expected_exc = """
overloaded3(): incompatible function arguments. The following argument types are supported:
- 1. (arg0: numpy.ndarray[int]) -> str
- 2. (arg0: numpy.ndarray[float]) -> str
+ 1. (arg0: numpy.ndarray[int32]) -> str
+ 2. (arg0: numpy.ndarray[float64]) -> str
Invoked with:"""
diff --git a/tests/test_numpy_vectorize.py b/tests/test_numpy_vectorize.py
index e4cbf02..271241c 100644
--- a/tests/test_numpy_vectorize.py
+++ b/tests/test_numpy_vectorize.py
@@ -71,5 +71,5 @@
from pybind11_tests import vectorized_func
assert doc(vectorized_func) == """
- vectorized_func(arg0: numpy.ndarray[int], arg1: numpy.ndarray[float], arg2: numpy.ndarray[float]) -> object
+ vectorized_func(arg0: numpy.ndarray[int32], arg1: numpy.ndarray[float32], arg2: numpy.ndarray[float64]) -> object
""" # noqa: E501 line too long