support for overriding virtual functions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2e80bd0..96bb795 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -61,6 +61,7 @@
example/example9.cpp
example/example10.cpp
example/example11.cpp
+ example/example12.cpp
)
set_target_properties(example PROPERTIES PREFIX "")
diff --git a/README.md b/README.md
index ead5b37..f802d7d 100644
--- a/README.md
+++ b/README.md
@@ -38,6 +38,7 @@
- STL data structures
- Smart pointers with reference counting like `std::shared_ptr`
- Internal references with correct reference counting
+- C++ classes with virtual (and pure virtual) methods can be extended in Python
## Goodies
In addition to the core functionality, pybind11 provides some extra goodies:
diff --git a/example/example.cpp b/example/example.cpp
index 2cc7e50..9eea762 100644
--- a/example/example.cpp
+++ b/example/example.cpp
@@ -20,6 +20,7 @@
void init_ex9(py::module &);
void init_ex10(py::module &);
void init_ex11(py::module &);
+void init_ex12(py::module &);
PYTHON_PLUGIN(example) {
py::module m("example", "pybind example plugin");
@@ -35,6 +36,7 @@
init_ex9(m);
init_ex10(m);
init_ex11(m);
+ init_ex12(m);
return m.ptr();
}
diff --git a/example/example12.cpp b/example/example12.cpp
new file mode 100644
index 0000000..274edf8
--- /dev/null
+++ b/example/example12.cpp
@@ -0,0 +1,82 @@
+/*
+ example/example12.cpp -- overriding virtual functions from Python
+
+ Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#include "example.h"
+#include <pybind/functional.h>
+
+/* This is an example class that we'll want to be able to extend from Python */
+class Example12 {
+public:
+ Example12(int state) : state(state) {
+ cout << "Constructing Example12.." << endl;
+ }
+
+ ~Example12() {
+ cout << "Destructing Example12.." << endl;
+ }
+
+ virtual int run(int value) {
+ std::cout << "Original implementation of Example12::run(state=" << state
+ << ", value=" << value << ")" << std::endl;
+ return state + value;
+ }
+
+ virtual void pure_virtual() = 0;
+private:
+ int state;
+};
+
+/* This is a wrapper class that must be generated */
+class PyExample12 : public Example12 {
+public:
+ using Example12::Example12; /* Inherit constructors */
+
+ virtual int run(int value) {
+ /* Generate wrapping code that enables native function overloading */
+ PYBIND_OVERLOAD(
+ int, /* Return type */
+ Example12, /* Parent class */
+ run, /* Name of function */
+ value /* Argument(s) */
+ );
+ }
+
+ virtual void pure_virtual() {
+ PYBIND_OVERLOAD_PURE(
+ void, /* Return type */
+ Example12, /* Parent class */
+ pure_virtual /* Name of function */
+ /* This function has no arguments */
+ );
+ }
+};
+
+int runExample12(Example12 *ex, int value) {
+ return ex->run(value);
+}
+
+void runExample12Virtual(Example12 *ex) {
+ ex->pure_virtual();
+}
+
+void init_ex12(py::module &m) {
+ /* Important: use the wrapper type as a template
+ argument to class_<>, but use the original name
+ to denote the type */
+ py::class_<PyExample12>(m, "Example12")
+ /* Declare that 'PyExample12' is really an alias for the original type 'Example12' */
+ .alias<Example12>()
+ .def(py::init<int>())
+ /* Reference original class in function definitions */
+ .def("run", &Example12::run)
+ .def("pure_virtual", &Example12::pure_virtual);
+
+ m.def("runExample12", &runExample12);
+ m.def("runExample12Virtual", &runExample12Virtual);
+}
diff --git a/example/example12.py b/example/example12.py
new file mode 100644
index 0000000..4f78575
--- /dev/null
+++ b/example/example12.py
@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+from __future__ import print_function
+import sys
+sys.path.append('.')
+
+from example import Example12, runExample12, runExample12Virtual
+
+
+class ExtendedExample12(Example12):
+ def __init__(self, state):
+ super(ExtendedExample12, self).__init__(state + 1)
+ self.data = "Hello world"
+
+ def run(self, value):
+ print('ExtendedExample12::run(%i), calling parent..' % value)
+ return super(ExtendedExample12, self).run(value + 1)
+
+ def pure_virtual(self):
+ print('ExtendedExample12::pure_virtual(): %s' % self.data)
+
+
+ex12 = Example12(10)
+print(runExample12(ex12, 20))
+try:
+ runExample12Virtual(ex12)
+except Exception as e:
+ print("Caught expected exception: " + str(e))
+
+ex12p = ExtendedExample12(10)
+print(runExample12(ex12p, 20))
+runExample12Virtual(ex12p)
diff --git a/example/example5.cpp b/example/example5.cpp
index f6de5ba..91e0b4f 100644
--- a/example/example5.cpp
+++ b/example/example5.cpp
@@ -37,28 +37,6 @@
dog.bark();
}
-class Example5 {
-public:
- Example5(py::handle self, int state)
- : self(self), state(state) {
- cout << "Constructing Example5.." << endl;
- }
-
- ~Example5() {
- cout << "Destructing Example5.." << endl;
- }
-
- void callback(int value) {
- py::gil_scoped_acquire gil;
- cout << "In Example5::callback() " << endl;
- py::object method = self.attr("callback");
- method.call(state, value);
- }
-private:
- py::handle self;
- int state;
-};
-
bool test_callback1(py::object func) {
func.call();
return false;
@@ -69,16 +47,11 @@
return result.cast<int>();
}
-void test_callback3(Example5 *ex, int value) {
- py::gil_scoped_release gil;
- ex->callback(value);
-}
-
-void test_callback4(const std::function<int(int)> &func) {
+void test_callback3(const std::function<int(int)> &func) {
cout << "func(43) = " << func(43)<< std::endl;
}
-std::function<int(int)> test_callback5() {
+std::function<int(int)> test_callback4() {
return [](int i) { return i+1; };
}
@@ -99,8 +72,4 @@
m.def("test_callback2", &test_callback2);
m.def("test_callback3", &test_callback3);
m.def("test_callback4", &test_callback4);
- m.def("test_callback5", &test_callback5);
-
- py::class_<Example5>(m, "Example5")
- .def(py::init<py::object, int>());
}
diff --git a/example/example5.py b/example/example5.py
index 4e75e17..5aaaae7 100755
--- a/example/example5.py
+++ b/example/example5.py
@@ -24,29 +24,17 @@
from example import test_callback2
from example import test_callback3
from example import test_callback4
-from example import test_callback5
-from example import Example5
def func1():
print('Callback function 1 called!')
def func2(a, b, c, d):
print('Callback function 2 called : ' + str(a) + ", " + str(b) + ", " + str(c) + ", "+ str(d))
- return c
-
-class MyCallback(Example5):
- def __init__(self, value):
- Example5.__init__(self, self, value)
-
- def callback(self, value1, value2):
- print('got callback: %i %i' % (value1, value2))
+ return d
print(test_callback1(func1))
print(test_callback2(func2))
-callback = MyCallback(3)
-test_callback3(callback, 4)
-
-test_callback4(lambda i: i+1)
-f = test_callback5()
+test_callback3(lambda i: i + 1)
+f = test_callback4()
print("func(43) = %i" % f(43))
diff --git a/include/pybind/cast.h b/include/pybind/cast.h
index 2feaf84..b8eee92 100644
--- a/include/pybind/cast.h
+++ b/include/pybind/cast.h
@@ -601,6 +601,7 @@
}
template <typename T> inline T handle::cast() { return pybind::cast<T>(m_ptr); }
+template <> inline void handle::cast() { return; }
template <typename... Args> inline object handle::call(Args&&... args_) {
const size_t size = sizeof...(Args);
@@ -624,6 +625,8 @@
PyTuple_SetItem(tuple, counter++, result);
PyObject *result = PyObject_CallObject(m_ptr, tuple);
Py_DECREF(tuple);
+ if (result == nullptr && PyErr_Occurred())
+ throw error_already_set();
return object(result, false);
}
diff --git a/include/pybind/common.h b/include/pybind/common.h
index 3a95210..05750c6 100644
--- a/include/pybind/common.h
+++ b/include/pybind/common.h
@@ -27,6 +27,7 @@
#include <vector>
#include <string>
#include <stdexcept>
+#include <unordered_set>
#include <unordered_map>
#include <memory>
@@ -114,13 +115,6 @@
}
};
-// C++ bindings of core Python exceptions
-struct stop_iteration : public std::runtime_error { public: stop_iteration(const std::string &w="") : std::runtime_error(w) {} };
-struct index_error : public std::runtime_error { public: index_error(const std::string &w="") : std::runtime_error(w) {} };
-struct error_already_set : public std::exception { public: error_already_set() {} };
-/// Thrown when pybind::cast or handle::call fail due to a type casting error
-struct cast_error : public std::runtime_error { public: cast_error(const std::string &w = "") : std::runtime_error(w) {} };
-
NAMESPACE_BEGIN(detail)
inline std::string error_string();
@@ -145,10 +139,19 @@
void *get_buffer_data = nullptr;
};
+struct overload_hash {
+ inline std::size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
+ size_t value = std::hash<const void *>()(v.first);
+ value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
+ return value;
+ }
+};
+
/// Internal data struture used to track registered instances and types
struct internals {
std::unordered_map<const std::type_info *, type_info> registered_types;
- std::unordered_map<void *, PyObject *> registered_instances;
+ std::unordered_map<const void *, PyObject *> registered_instances;
+ std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
};
/// Return a reference to the current 'internals' information
@@ -176,5 +179,20 @@
/// Helper type to replace 'void' in some expressions
struct void_type { };
+/// to_string variant which also accepts strings
+template <typename T> inline typename std::enable_if<!std::is_enum<T>::value, std::string>::type
+to_string(const T &value) { return std::to_string(value); }
+template <> inline std::string to_string(const std::string &value) { return value; }
+template <typename T> inline typename std::enable_if<std::is_enum<T>::value, std::string>::type
+to_string(T value) { return std::to_string((int) value); }
+
NAMESPACE_END(detail)
+
+// C++ bindings of core Python exceptions
+struct stop_iteration : public std::runtime_error { public: stop_iteration(const std::string &w="") : std::runtime_error(w) {} };
+struct index_error : public std::runtime_error { public: index_error(const std::string &w="") : std::runtime_error(w) {} };
+struct error_already_set : public std::runtime_error { public: error_already_set() : std::runtime_error(detail::error_string()) {} };
+/// Thrown when pybind::cast or handle::call fail due to a type casting error
+struct cast_error : public std::runtime_error { public: cast_error(const std::string &w = "") : std::runtime_error(w) {} };
+
NAMESPACE_END(pybind)
diff --git a/include/pybind/functional.h b/include/pybind/functional.h
index f300d4d..7c3c9a0 100644
--- a/include/pybind/functional.h
+++ b/include/pybind/functional.h
@@ -25,8 +25,6 @@
object src(src_, true);
value = [src](Args... args) -> Return {
object retval(pybind::handle(src).call(std::move(args)...));
- if (retval.ptr() == nullptr && PyErr_Occurred())
- throw error_already_set();
/* Visual studio 2015 parser issue: need parentheses around this expression */
return (retval.template cast<Return>());
};
diff --git a/include/pybind/pybind.h b/include/pybind/pybind.h
index 156e570..699489c 100644
--- a/include/pybind/pybind.h
+++ b/include/pybind/pybind.h
@@ -24,7 +24,6 @@
#endif
#include <pybind/cast.h>
-#include <iostream>
NAMESPACE_BEGIN(pybind)
@@ -46,12 +45,8 @@
/// Annotation for methods
struct is_method {
-#if PY_MAJOR_VERSION < 3
PyObject *class_;
is_method(object *o) : class_(o->ptr()) { }
-#else
- is_method(object *) { }
-#endif
};
/// Annotation for documentation
@@ -76,9 +71,7 @@
short keywords = 0;
return_value_policy policy = return_value_policy::automatic;
std::string signature;
-#if PY_MAJOR_VERSION < 3
PyObject *class_ = nullptr;
-#endif
PyObject *sibling = nullptr;
const char *doc = nullptr;
function_entry *next = nullptr;
@@ -126,21 +119,18 @@
kw[entry->keywords++] = "self";
kw[entry->keywords++] = a.name;
}
+
template <typename T>
static void process_extra(const pybind::arg_t<T> &a, function_entry *entry, const char **kw, const char **def) {
if (entry->is_method && entry->keywords == 0)
kw[entry->keywords++] = "self";
kw[entry->keywords] = a.name;
- def[entry->keywords++] = strdup(std::to_string(a.value).c_str());
+ def[entry->keywords++] = strdup(detail::to_string(a.value).c_str());
}
static void process_extra(const pybind::is_method &m, function_entry *entry, const char **, const char **) {
entry->is_method = true;
-#if PY_MAJOR_VERSION < 3
entry->class_ = m.class_;
-#else
- (void) m;
-#endif
}
static void process_extra(const pybind::return_value_policy p, function_entry *entry, const char **, const char **) { entry->policy = p; }
static void process_extra(pybind::sibling s, function_entry *entry, const char **, const char **) { entry->sibling = s.value; }
@@ -366,35 +356,38 @@
m_entry->sibling = PyMethod_GET_FUNCTION(m_entry->sibling);
#endif
- function_entry *entry = m_entry;
- bool overloaded = false;
- if (!entry->sibling || !PyCFunction_Check(entry->sibling)) {
- entry->def = new PyMethodDef();
- memset(entry->def, 0, sizeof(PyMethodDef));
- entry->def->ml_name = entry->name;
- entry->def->ml_meth = reinterpret_cast<PyCFunction>(*dispatcher);
- entry->def->ml_flags = METH_VARARGS | METH_KEYWORDS;
- capsule entry_capsule(entry, [](PyObject *o) { destruct((function_entry *) PyCapsule_GetPointer(o, nullptr)); });
- m_ptr = PyCFunction_New(entry->def, entry_capsule.ptr());
+ function_entry *s_entry = nullptr, *entry = m_entry;
+ if (m_entry->sibling && PyCFunction_Check(m_entry->sibling)) {
+ capsule entry_capsule(PyCFunction_GetSelf(m_entry->sibling), true);
+ s_entry = (function_entry *) entry_capsule;
+ if (s_entry->class_ != m_entry->class_)
+ s_entry = nullptr; /* Method override */
+ }
+
+ if (!s_entry) {
+ m_entry->def = new PyMethodDef();
+ memset(m_entry->def, 0, sizeof(PyMethodDef));
+ m_entry->def->ml_name = m_entry->name;
+ m_entry->def->ml_meth = reinterpret_cast<PyCFunction>(*dispatcher);
+ m_entry->def->ml_flags = METH_VARARGS | METH_KEYWORDS;
+ capsule entry_capsule(m_entry, [](PyObject *o) { destruct((function_entry *) PyCapsule_GetPointer(o, nullptr)); });
+ m_ptr = PyCFunction_New(m_entry->def, entry_capsule.ptr());
if (!m_ptr)
throw std::runtime_error("cpp_function::cpp_function(): Could not allocate function object");
} else {
- m_ptr = entry->sibling;
+ m_ptr = m_entry->sibling;
inc_ref();
- capsule entry_capsule(PyCFunction_GetSelf(m_ptr), true);
- function_entry *parent = (function_entry *) entry_capsule, *backup = parent;
- while (parent->next)
- parent = parent->next;
- parent->next = entry;
- entry = backup;
- overloaded = true;
+ entry = s_entry;
+ while (s_entry->next)
+ s_entry = s_entry->next;
+ s_entry->next = m_entry;
}
std::string signatures;
int index = 0;
function_entry *it = entry;
while (it) { /* Create pydoc it */
- if (overloaded)
+ if (s_entry)
signatures += std::to_string(++index) + ". ";
signatures += "Signature : " + std::string(it->signature) + "\n";
if (it->doc && strlen(it->doc) > 0)
@@ -783,6 +776,12 @@
metaclass().attr(name) = property;
return *this;
}
+
+ template <typename target> class_ alias() {
+ auto &instances = pybind::detail::get_internals().registered_types;
+ instances[&typeid(target)] = instances[&typeid(type)];
+ return *this;
+ }
private:
static void init_holder(PyObject *inst_) {
instance_type *inst = (instance_type *) inst_;
@@ -882,6 +881,43 @@
inline ~gil_scoped_release() { PyEval_RestoreThread(state); }
};
+inline function get_overload(const void *this_ptr, const char *name) {
+ handle py_object = detail::get_object_handle(this_ptr);
+ handle type = py_object.get_type();
+ auto key = std::make_pair(type.ptr(), name);
+
+ /* Cache functions that aren't overloaded in python to avoid
+ many costly dictionary lookups in Python */
+ auto &cache = detail::get_internals().inactive_overload_cache;
+ if (cache.find(key) != cache.end())
+ return function();
+
+ function overload = (function) py_object.attr(name);
+ if (overload.is_cpp_function()) {
+ cache.insert(key);
+ return function();
+ }
+ PyFrameObject *frame = PyThreadState_Get()->frame;
+ pybind::str caller = pybind::handle(frame->f_code->co_name).str();
+ if (strcmp((const char *) caller, name) == 0)
+ return function();
+ return overload;
+}
+
+#define PYBIND_OVERLOAD_INT(ret_type, class_name, name, ...) { \
+ pybind::gil_scoped_acquire gil; \
+ pybind::function overload = pybind::get_overload(this, #name); \
+ if (overload) \
+ return overload.call(__VA_ARGS__).cast<ret_type>(); }
+
+#define PYBIND_OVERLOAD(ret_type, class_name, name, ...) \
+ PYBIND_OVERLOAD_INT(ret_type, class_name, name, __VA_ARGS__) \
+ return class_name::name(__VA_ARGS__)
+
+#define PYBIND_OVERLOAD_PURE(ret_type, class_name, name, ...) \
+ PYBIND_OVERLOAD_INT(ret_type, class_name, name, __VA_ARGS__) \
+ throw std::runtime_error("Tried to call pure virtual function \"" #name "\"");
+
NAMESPACE_END(pybind)
#if defined(_MSC_VER)
diff --git a/include/pybind/pytypes.h b/include/pybind/pytypes.h
index 4559ffb..8afc088 100644
--- a/include/pybind/pytypes.h
+++ b/include/pybind/pytypes.h
@@ -331,10 +331,12 @@
PyObject *ptr = m_ptr;
if (ptr == nullptr)
return false;
-#if PY_MAJOR_VERSION < 3
+#if PY_MAJOR_VERSION >= 3
+ if (PyInstanceMethod_Check(ptr))
+ ptr = PyInstanceMethod_GET_FUNCTION(ptr);
+#endif
if (PyMethod_Check(ptr))
ptr = PyMethod_GET_FUNCTION(ptr);
-#endif
return PyCFunction_Check(ptr);
}
};