keyword argument support, removed last traces of std::function<> usage
diff --git a/include/pybind/numpy.h b/include/pybind/numpy.h
index 0336794..3c06854 100644
--- a/include/pybind/numpy.h
+++ b/include/pybind/numpy.h
@@ -10,6 +10,8 @@
#pragma once
#include <pybind/pybind.h>
+#include <functional>
+
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
@@ -132,6 +134,15 @@
}
};
+#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
+DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT);
+DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT);
+DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT);
+DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex<float>, NPY_CFLOAT);
+DECL_FMT(std::complex<double>, NPY_CDOUBLE);
+#undef DECL_FMT
+
+
NAMESPACE_BEGIN(detail)
PYBIND_TYPE_CASTER_PYTYPE(array)
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int8_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint8_t>)
@@ -142,24 +153,20 @@
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<float>>)
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<double>>)
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<bool>)
-NAMESPACE_END(detail)
-#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
-DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT);
-DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT);
-DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT);
-DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex<float>, NPY_CFLOAT);
-DECL_FMT(std::complex<double>, NPY_CDOUBLE);
-#undef DECL_FMT
+template <typename Func, typename Return, typename... Args>
+struct vectorize_helper {
+ typename std::remove_reference<Func>::type f;
-template <typename func_type, typename return_type, typename... args_type, size_t... Index>
- std::function<object(array_dtype<args_type>...)>
- vectorize(func_type &&f, return_type (*) (args_type ...),
- detail::index_sequence<Index...>) {
+ vectorize_helper(const Func &f) : f(f) { }
- return [f](array_dtype<args_type>... args) -> array {
+ object operator()(array_dtype<Args>... args) {
+ return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
+ }
+
+ template <size_t ... Index> object run(array_dtype<Args>&... args, index_sequence<Index...>) {
/* Request buffers from all parameters */
- const size_t N = sizeof...(args_type);
+ const size_t N = sizeof...(Args);
std::array<buffer_info, N> buffers {{ args.request()... }};
/* Determine dimensions parameters of output array */
@@ -174,7 +181,7 @@
}
std::vector<size_t> strides(ndim);
if (ndim > 0) {
- strides[ndim-1] = sizeof(return_type);
+ strides[ndim-1] = sizeof(Return);
for (int i=ndim-1; i>0; --i)
strides[i-1] = strides[i] * shape[i];
}
@@ -186,31 +193,32 @@
}
/* Call the function */
- std::vector<return_type> result(count);
+ std::vector<Return> result(count);
for (size_t i=0; i<count; ++i)
result[i] = f((buffers[Index].count == 1
- ? *((args_type *) buffers[Index].ptr)
- : ((args_type *) buffers[Index].ptr)[i])...);
+ ? *((Args *) buffers[Index].ptr)
+ : ((Args *) buffers[Index].ptr)[i])...);
if (count == 1)
return cast(result[0]);
/* Return the result */
- return array(buffer_info(result.data(), sizeof(return_type),
- format_descriptor<return_type>::value(),
+ return array(buffer_info(result.data(), sizeof(Return),
+ format_descriptor<Return>::value(),
ndim, shape, strides));
- };
+ }
+};
+
+NAMESPACE_END(detail)
+
+template <typename Func, typename Return, typename... Args>
+detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) {
+ return detail::vectorize_helper<Func, Return, Args...>(f);
}
-template <typename func_type, typename return_type, typename... args_type>
- std::function<object(array_dtype<args_type>...)>
- vectorize(func_type &&f, return_type (*f_) (args_type ...) = nullptr) {
- return vectorize(f, f_, typename detail::make_index_sequence<sizeof...(args_type)>::type());
-}
-
-template <typename return_type, typename... args_type>
-std::function<object(array_dtype<args_type>...)> vectorize(return_type (*f) (args_type ...)) {
- return vectorize(f, f);
+template <typename Return, typename... Args>
+detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) {
+ return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
}
template <typename func> auto vectorize(func &&f) -> decltype(