ability to prevent force casts in numpy arguments
diff --git a/docs/advanced.rst b/docs/advanced.rst
index 99af039..c986921 100644
--- a/docs/advanced.rst
+++ b/docs/advanced.rst
@@ -1100,10 +1100,12 @@
.. code-block:: cpp
- void f(py::array_t<double, py::array::c_style> array);
+ void f(py::array_t<double, py::array::c_style | py::array::forcecast> array);
-As before, the implementation will attempt to convert non-conforming arguments
-into an array satisfying the specified requirements.
+The ``py::array::forcecast`` argument is the default value of the second
+template paramenter, and it ensures that non-conforming arguments are converted
+into an array satisfying the specified requirements instead of trying the next
+function overload.
Vectorizing functions
=====================
diff --git a/example/example10.cpp b/example/example10.cpp
index 09769fe..cbe737e 100644
--- a/example/example10.cpp
+++ b/example/example10.cpp
@@ -33,4 +33,9 @@
// Vectorize a complex-valued function
m.def("vectorized_func3", py::vectorize(my_func3));
+
+ /// Numpy function which only accepts specific data types
+ m.def("selective_func", [](py::array_t<int, py::array::c_style>) { std::cout << "Int branch taken. "<< std::endl; });
+ m.def("selective_func", [](py::array_t<float, py::array::c_style>) { std::cout << "Float branch taken. "<< std::endl; });
+ m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { std::cout << "Complex float branch taken. "<< std::endl; });
}
diff --git a/example/example10.py b/example/example10.py
index 0d49fca..b18e729 100755
--- a/example/example10.py
+++ b/example/example10.py
@@ -27,3 +27,8 @@
print(f(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2))
print(np.array([[1, 2, 3], [4, 5, 6]])* np.array([[2], [3]])* 2)
+from example import selective_func
+
+selective_func(np.array([1], dtype=np.int32))
+selective_func(np.array([1.0], dtype=np.float32))
+selective_func(np.array([1.0j], dtype=np.complex64))
diff --git a/example/example10.ref b/example/example10.ref
index 9d48d7c..4885fc1 100644
--- a/example/example10.ref
+++ b/example/example10.ref
@@ -73,3 +73,6 @@
[ 24. 30. 36.]]
[[ 4 8 12]
[24 30 36]]
+Int branch taken.
+Float branch taken.
+Complex float branch taken.
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index 0b0e0ee..f97c790 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -78,7 +78,8 @@
enum {
c_style = API::NPY_C_CONTIGUOUS_,
- f_style = API::NPY_F_CONTIGUOUS_
+ f_style = API::NPY_F_CONTIGUOUS_,
+ forcecast = API::NPY_ARRAY_FORCECAST_
};
template <typename Type> array(size_t size, const Type *ptr) {
@@ -124,7 +125,7 @@
}
};
-template <typename T, int ExtraFlags = 0> class array_t : public array {
+template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
public:
PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr));
array_t() : array() { }
@@ -135,10 +136,9 @@
return nullptr;
API &api = lookup_api();
PyObject *descr = api.PyArray_DescrFromType_(detail::npy_format_descriptor<T>::value);
- PyObject *result = api.PyArray_FromAny_(
- ptr, descr, 0, 0,
- API::NPY_ENSURE_ARRAY_ | API::NPY_ARRAY_FORCECAST_ | ExtraFlags,
- nullptr);
+ PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
+ if (!result)
+ PyErr_Clear();
Py_DECREF(ptr);
return result;
}
@@ -318,11 +318,11 @@
template <typename T>
vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
- object operator()(array_t<Args>... args) {
+ object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
}
- template <size_t ... Index> object run(array_t<Args>&... args, index_sequence<Index...> index) {
+ template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
/* Request buffers from all parameters */
const size_t N = sizeof...(Args);
@@ -332,7 +332,7 @@
int ndim = 0;
std::vector<size_t> shape(0);
bool trivial_broadcast = broadcast(buffers, ndim, shape);
-
+
size_t size = 1;
std::vector<size_t> strides(ndim);
if (ndim > 0) {
@@ -384,7 +384,7 @@
}
};
-template <typename T> struct handle_type_name<array_t<T>> {
+template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
static PYBIND11_DESCR name() { return _("numpy.ndarray[dtype=") + type_caster<T>::name() + _("]"); }
};