blob: 55a0871095cddec20d254c5d0481f409c2ad34de [file] [log] [blame]
Jason Rhinelanderb3f3d792016-07-18 16:43:18 -04001/*
2 example/example-virtual-functions.cpp -- overriding virtual functions from Python
3
4 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
5
6 All rights reserved. Use of this source code is governed by a
7 BSD-style license that can be found in the LICENSE file.
8*/
9
10#include "example.h"
11#include <pybind11/functional.h>
12
13/* This is an example class that we'll want to be able to extend from Python */
14class ExampleVirt {
15public:
16 ExampleVirt(int state) : state(state) {
17 cout << "Constructing ExampleVirt.." << endl;
18 }
19
20 ~ExampleVirt() {
21 cout << "Destructing ExampleVirt.." << endl;
22 }
23
24 virtual int run(int value) {
25 std::cout << "Original implementation of ExampleVirt::run(state=" << state
26 << ", value=" << value << ")" << std::endl;
27 return state + value;
28 }
29
30 virtual bool run_bool() = 0;
31 virtual void pure_virtual() = 0;
32private:
33 int state;
34};
35
36/* This is a wrapper class that must be generated */
37class PyExampleVirt : public ExampleVirt {
38public:
39 using ExampleVirt::ExampleVirt; /* Inherit constructors */
40
41 virtual int run(int value) {
42 /* Generate wrapping code that enables native function overloading */
43 PYBIND11_OVERLOAD(
44 int, /* Return type */
45 ExampleVirt, /* Parent class */
46 run, /* Name of function */
47 value /* Argument(s) */
48 );
49 }
50
51 virtual bool run_bool() {
52 PYBIND11_OVERLOAD_PURE(
53 bool, /* Return type */
54 ExampleVirt, /* Parent class */
55 run_bool, /* Name of function */
56 /* This function has no arguments. The trailing comma
57 in the previous line is needed for some compilers */
58 );
59 }
60
61 virtual void pure_virtual() {
62 PYBIND11_OVERLOAD_PURE(
63 void, /* Return type */
64 ExampleVirt, /* Parent class */
65 pure_virtual, /* Name of function */
66 /* This function has no arguments. The trailing comma
67 in the previous line is needed for some compilers */
68 );
69 }
70};
71
Jason Rhinelandered148792016-07-21 21:31:05 -040072class NonCopyable {
73public:
74 NonCopyable(int a, int b) : value{new int(a*b)} {}
75 NonCopyable(NonCopyable &&) = default;
76 NonCopyable(const NonCopyable &) = delete;
77 NonCopyable() = delete;
78 void operator=(const NonCopyable &) = delete;
79 void operator=(NonCopyable &&) = delete;
80 std::string get_value() const {
81 if (value) return std::to_string(*value); else return "(null)";
82 }
83 ~NonCopyable() { std::cout << "NonCopyable destructor @ " << this << "; value = " << get_value() << std::endl; }
84
85private:
86 std::unique_ptr<int> value;
87};
88
89// This is like the above, but is both copy and movable. In effect this means it should get moved
90// when it is not referenced elsewhere, but copied if it is still referenced.
91class Movable {
92public:
93 Movable(int a, int b) : value{a+b} {}
94 Movable(const Movable &m) { value = m.value; std::cout << "Movable @ " << this << " copy constructor" << std::endl; }
95 Movable(Movable &&m) { value = std::move(m.value); std::cout << "Movable @ " << this << " move constructor" << std::endl; }
96 int get_value() const { return value; }
97 ~Movable() { std::cout << "Movable destructor @ " << this << "; value = " << get_value() << std::endl; }
98private:
99 int value;
100};
101
102class NCVirt {
103public:
104 virtual NonCopyable get_noncopyable(int a, int b) { return NonCopyable(a, b); }
105 virtual Movable get_movable(int a, int b) = 0;
106
107 void print_nc(int a, int b) { std::cout << get_noncopyable(a, b).get_value() << std::endl; }
108 void print_movable(int a, int b) { std::cout << get_movable(a, b).get_value() << std::endl; }
109};
110class NCVirtTrampoline : public NCVirt {
111 virtual NonCopyable get_noncopyable(int a, int b) {
112 PYBIND11_OVERLOAD(NonCopyable, NCVirt, get_noncopyable, a, b);
113 }
114 virtual Movable get_movable(int a, int b) {
115 PYBIND11_OVERLOAD_PURE(Movable, NCVirt, get_movable, a, b);
116 }
117};
118
Jason Rhinelanderb3f3d792016-07-18 16:43:18 -0400119int runExampleVirt(ExampleVirt *ex, int value) {
120 return ex->run(value);
121}
122
123bool runExampleVirtBool(ExampleVirt* ex) {
124 return ex->run_bool();
125}
126
127void runExampleVirtVirtual(ExampleVirt *ex) {
128 ex->pure_virtual();
129}
130
Jason Rhinelander0ca96e22016-08-05 17:02:33 -0400131
Jason Rhinelanderd6c365b2016-08-05 17:44:28 -0400132// Inheriting virtual methods. We do two versions here: the repeat-everything version and the
133// templated trampoline versions mentioned in docs/advanced.rst.
Jason Rhinelander0ca96e22016-08-05 17:02:33 -0400134//
Jason Rhinelanderd6c365b2016-08-05 17:44:28 -0400135// These base classes are exactly the same, but we technically need distinct
136// classes for this example code because we need to be able to bind them
137// properly (pybind11, sensibly, doesn't allow us to bind the same C++ class to
Jason Rhinelander0ca96e22016-08-05 17:02:33 -0400138// multiple python classes).
139class A_Repeat {
140#define A_METHODS \
141public: \
142 virtual int unlucky_number() = 0; \
143 virtual void say_something(unsigned times) { \
144 for (unsigned i = 0; i < times; i++) std::cout << "hi"; \
145 std::cout << std::endl; \
146 }
147A_METHODS
148};
149class B_Repeat : public A_Repeat {
150#define B_METHODS \
151public: \
152 int unlucky_number() override { return 13; } \
153 void say_something(unsigned times) override { \
154 std::cout << "B says hi " << times << " times" << std::endl; \
155 } \
156 virtual double lucky_number() { return 7.0; }
157B_METHODS
158};
159class C_Repeat : public B_Repeat {
160#define C_METHODS \
161public: \
162 int unlucky_number() override { return 4444; } \
163 double lucky_number() override { return 888; }
164C_METHODS
165};
166class D_Repeat : public C_Repeat {
167#define D_METHODS // Nothing overridden.
168D_METHODS
169};
170
Jason Rhinelander0ca96e22016-08-05 17:02:33 -0400171// Base classes for templated inheritance trampolines. Identical to the repeat-everything version:
172class A_Tpl { A_METHODS };
173class B_Tpl : public A_Tpl { B_METHODS };
174class C_Tpl : public B_Tpl { C_METHODS };
175class D_Tpl : public C_Tpl { D_METHODS };
176
177
178// Inheritance approach 1: each trampoline gets every virtual method (11 in total)
179class PyA_Repeat : public A_Repeat {
180public:
181 using A_Repeat::A_Repeat;
182 int unlucky_number() override { PYBIND11_OVERLOAD_PURE(int, A_Repeat, unlucky_number, ); }
183 void say_something(unsigned times) override { PYBIND11_OVERLOAD(void, A_Repeat, say_something, times); }
184};
185class PyB_Repeat : public B_Repeat {
186public:
187 using B_Repeat::B_Repeat;
188 int unlucky_number() override { PYBIND11_OVERLOAD(int, B_Repeat, unlucky_number, ); }
189 void say_something(unsigned times) override { PYBIND11_OVERLOAD(void, B_Repeat, say_something, times); }
190 double lucky_number() override { PYBIND11_OVERLOAD(double, B_Repeat, lucky_number, ); }
191};
192class PyC_Repeat : public C_Repeat {
193public:
194 using C_Repeat::C_Repeat;
195 int unlucky_number() override { PYBIND11_OVERLOAD(int, C_Repeat, unlucky_number, ); }
196 void say_something(unsigned times) override { PYBIND11_OVERLOAD(void, C_Repeat, say_something, times); }
197 double lucky_number() override { PYBIND11_OVERLOAD(double, C_Repeat, lucky_number, ); }
198};
199class PyD_Repeat : public D_Repeat {
200public:
201 using D_Repeat::D_Repeat;
202 int unlucky_number() override { PYBIND11_OVERLOAD(int, D_Repeat, unlucky_number, ); }
203 void say_something(unsigned times) override { PYBIND11_OVERLOAD(void, D_Repeat, say_something, times); }
204 double lucky_number() override { PYBIND11_OVERLOAD(double, D_Repeat, lucky_number, ); }
205};
206
207// Inheritance approach 2: templated trampoline classes.
208//
209// Advantages:
210// - we have only 2 (template) class and 4 method declarations (one per virtual method, plus one for
211// any override of a pure virtual method), versus 4 classes and 6 methods (MI) or 4 classes and 11
212// methods (repeat).
213// - Compared to MI, we also don't have to change the non-trampoline inheritance to virtual, and can
214// properly inherit constructors.
215//
216// Disadvantage:
217// - the compiler must still generate and compile 14 different methods (more, even, than the 11
218// required for the repeat approach) instead of the 6 required for MI. (If there was no pure
219// method (or no pure method override), the number would drop down to the same 11 as the repeat
220// approach).
221template <class Base = A_Tpl>
222class PyA_Tpl : public Base {
223public:
224 using Base::Base; // Inherit constructors
225 int unlucky_number() override { PYBIND11_OVERLOAD_PURE(int, Base, unlucky_number, ); }
226 void say_something(unsigned times) override { PYBIND11_OVERLOAD(void, Base, say_something, times); }
227};
228template <class Base = B_Tpl>
229class PyB_Tpl : public PyA_Tpl<Base> {
230public:
231 using PyA_Tpl<Base>::PyA_Tpl; // Inherit constructors (via PyA_Tpl's inherited constructors)
232 int unlucky_number() override { PYBIND11_OVERLOAD(int, Base, unlucky_number, ); }
233 double lucky_number() { PYBIND11_OVERLOAD(double, Base, lucky_number, ); }
234};
235// Since C_Tpl and D_Tpl don't declare any new virtual methods, we don't actually need these (we can
236// use PyB_Tpl<C_Tpl> and PyB_Tpl<D_Tpl> for the trampoline classes instead):
237/*
238template <class Base = C_Tpl> class PyC_Tpl : public PyB_Tpl<Base> {
239public:
240 using PyB_Tpl<Base>::PyB_Tpl;
241};
242template <class Base = D_Tpl> class PyD_Tpl : public PyC_Tpl<Base> {
243public:
244 using PyC_Tpl<Base>::PyC_Tpl;
245};
246*/
247
Jason Rhinelander0ca96e22016-08-05 17:02:33 -0400248
249void initialize_inherited_virtuals(py::module &m) {
250 // Method 1: repeat
251 py::class_<A_Repeat, std::unique_ptr<A_Repeat>, PyA_Repeat>(m, "A_Repeat")
252 .def(py::init<>())
253 .def("unlucky_number", &A_Repeat::unlucky_number)
254 .def("say_something", &A_Repeat::say_something);
255 py::class_<B_Repeat, std::unique_ptr<B_Repeat>, PyB_Repeat>(m, "B_Repeat", py::base<A_Repeat>())
256 .def(py::init<>())
257 .def("lucky_number", &B_Repeat::lucky_number);
258 py::class_<C_Repeat, std::unique_ptr<C_Repeat>, PyC_Repeat>(m, "C_Repeat", py::base<B_Repeat>())
259 .def(py::init<>());
260 py::class_<D_Repeat, std::unique_ptr<D_Repeat>, PyD_Repeat>(m, "D_Repeat", py::base<C_Repeat>())
261 .def(py::init<>());
262
263 // Method 2: Templated trampolines
264 py::class_<A_Tpl, std::unique_ptr<A_Tpl>, PyA_Tpl<>>(m, "A_Tpl")
265 .def(py::init<>())
266 .def("unlucky_number", &A_Tpl::unlucky_number)
267 .def("say_something", &A_Tpl::say_something);
268 py::class_<B_Tpl, std::unique_ptr<B_Tpl>, PyB_Tpl<>>(m, "B_Tpl", py::base<A_Tpl>())
269 .def(py::init<>())
270 .def("lucky_number", &B_Tpl::lucky_number);
271 py::class_<C_Tpl, std::unique_ptr<C_Tpl>, PyB_Tpl<C_Tpl>>(m, "C_Tpl", py::base<B_Tpl>())
272 .def(py::init<>());
273 py::class_<D_Tpl, std::unique_ptr<D_Tpl>, PyB_Tpl<D_Tpl>>(m, "D_Tpl", py::base<C_Tpl>())
274 .def(py::init<>());
275
Jason Rhinelander0ca96e22016-08-05 17:02:33 -0400276};
277
278
Jason Rhinelanderb3f3d792016-07-18 16:43:18 -0400279void init_ex_virtual_functions(py::module &m) {
280 /* Important: indicate the trampoline class PyExampleVirt using the third
281 argument to py::class_. The second argument with the unique pointer
282 is simply the default holder type used by pybind11. */
283 py::class_<ExampleVirt, std::unique_ptr<ExampleVirt>, PyExampleVirt>(m, "ExampleVirt")
284 .def(py::init<int>())
285 /* Reference original class in function definitions */
286 .def("run", &ExampleVirt::run)
287 .def("run_bool", &ExampleVirt::run_bool)
288 .def("pure_virtual", &ExampleVirt::pure_virtual);
289
Jason Rhinelandered148792016-07-21 21:31:05 -0400290 py::class_<NonCopyable>(m, "NonCopyable")
291 .def(py::init<int, int>())
292 ;
293 py::class_<Movable>(m, "Movable")
294 .def(py::init<int, int>())
295 ;
296 py::class_<NCVirt, std::unique_ptr<NCVirt>, NCVirtTrampoline>(m, "NCVirt")
297 .def(py::init<>())
298 .def("get_noncopyable", &NCVirt::get_noncopyable)
299 .def("get_movable", &NCVirt::get_movable)
300 .def("print_nc", &NCVirt::print_nc)
301 .def("print_movable", &NCVirt::print_movable)
302 ;
303
Jason Rhinelanderb3f3d792016-07-18 16:43:18 -0400304 m.def("runExampleVirt", &runExampleVirt);
305 m.def("runExampleVirtBool", &runExampleVirtBool);
306 m.def("runExampleVirtVirtual", &runExampleVirtVirtual);
Jason Rhinelander0ca96e22016-08-05 17:02:33 -0400307
308 initialize_inherited_virtuals(m);
Jason Rhinelanderb3f3d792016-07-18 16:43:18 -0400309}