blob: 9e92e5dd9aad4389f67416f297634c16bf1e4738 [file] [log] [blame]
Wenzel Jakob38bd7112015-07-05 20:05:44 +02001/*
Dean Moldovana0c1ccf2016-08-12 13:50:00 +02002 tests/test_buffers.cpp -- supporting Pythons' buffer protocol
Wenzel Jakob38bd7112015-07-05 20:05:44 +02003
Wenzel Jakob8cb6cb32016-04-17 20:21:41 +02004 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob38bd7112015-07-05 20:05:44 +02005
6 All rights reserved. Use of this source code is governed by a
7 BSD-style license that can be found in the LICENSE file.
8*/
9
Dean Moldovana0c1ccf2016-08-12 13:50:00 +020010#include "pybind11_tests.h"
11#include "constructor_stats.h"
Wenzel Jakob38bd7112015-07-05 20:05:44 +020012
13class Matrix {
14public:
Cris Luengo30d43c42017-04-14 14:33:44 -060015 Matrix(ssize_t rows, ssize_t cols) : m_rows(rows), m_cols(cols) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040016 print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Cris Luengo30d43c42017-04-14 14:33:44 -060017 m_data = new float[(size_t) (rows*cols)];
18 memset(m_data, 0, sizeof(float) * (size_t) (rows * cols));
Wenzel Jakob38bd7112015-07-05 20:05:44 +020019 }
20
21 Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040022 print_copy_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Cris Luengo30d43c42017-04-14 14:33:44 -060023 m_data = new float[(size_t) (m_rows * m_cols)];
24 memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols));
Wenzel Jakob38bd7112015-07-05 20:05:44 +020025 }
26
27 Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040028 print_move_created(this);
Wenzel Jakob38bd7112015-07-05 20:05:44 +020029 s.m_rows = 0;
30 s.m_cols = 0;
31 s.m_data = nullptr;
32 }
33
34 ~Matrix() {
Jason Rhinelander3f589372016-08-07 13:05:26 -040035 print_destroyed(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Wenzel Jakob38bd7112015-07-05 20:05:44 +020036 delete[] m_data;
37 }
38
39 Matrix &operator=(const Matrix &s) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040040 print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Wenzel Jakob38bd7112015-07-05 20:05:44 +020041 delete[] m_data;
42 m_rows = s.m_rows;
43 m_cols = s.m_cols;
Cris Luengo30d43c42017-04-14 14:33:44 -060044 m_data = new float[(size_t) (m_rows * m_cols)];
45 memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols));
Wenzel Jakob38bd7112015-07-05 20:05:44 +020046 return *this;
47 }
48
49 Matrix &operator=(Matrix &&s) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040050 print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Wenzel Jakob38bd7112015-07-05 20:05:44 +020051 if (&s != this) {
52 delete[] m_data;
53 m_rows = s.m_rows; m_cols = s.m_cols; m_data = s.m_data;
54 s.m_rows = 0; s.m_cols = 0; s.m_data = nullptr;
55 }
56 return *this;
57 }
58
Cris Luengo30d43c42017-04-14 14:33:44 -060059 float operator()(ssize_t i, ssize_t j) const {
60 return m_data[(size_t) (i*m_cols + j)];
Wenzel Jakob38bd7112015-07-05 20:05:44 +020061 }
62
Cris Luengo30d43c42017-04-14 14:33:44 -060063 float &operator()(ssize_t i, ssize_t j) {
64 return m_data[(size_t) (i*m_cols + j)];
Wenzel Jakob38bd7112015-07-05 20:05:44 +020065 }
66
67 float *data() { return m_data; }
68
Cris Luengo30d43c42017-04-14 14:33:44 -060069 ssize_t rows() const { return m_rows; }
70 ssize_t cols() const { return m_cols; }
Wenzel Jakob38bd7112015-07-05 20:05:44 +020071private:
Cris Luengo30d43c42017-04-14 14:33:44 -060072 ssize_t m_rows;
73 ssize_t m_cols;
Wenzel Jakob38bd7112015-07-05 20:05:44 +020074 float *m_data;
75};
76
Dean Moldovan427e4af2017-05-28 16:35:02 +020077class SquareMatrix : public Matrix {
78public:
79 SquareMatrix(ssize_t n) : Matrix(n, n) { }
80};
81
Bruce Merryfe0cf8b2017-05-17 10:52:33 +020082struct PTMFBuffer {
83 int32_t value = 0;
84
85 py::buffer_info get_buffer_info() {
86 return py::buffer_info(&value, sizeof(value),
87 py::format_descriptor<int32_t>::format(), 1);
88 }
89};
90
91class ConstPTMFBuffer {
92 std::unique_ptr<int32_t> value;
93
94public:
95 int32_t get_value() const { return *value; }
96 void set_value(int32_t v) { *value = v; }
97
98 py::buffer_info get_buffer_info() const {
99 return py::buffer_info(value.get(), sizeof(*value),
100 py::format_descriptor<int32_t>::format(), 1);
101 }
102
103 ConstPTMFBuffer() : value(new int32_t{0}) { };
104};
105
106struct DerivedPTMFBuffer : public PTMFBuffer { };
107
Jason Rhinelander52f4be82016-09-03 14:54:22 -0400108test_initializer buffers([](py::module &m) {
Wenzel Jakob1d1f81b2016-12-16 15:00:46 +0100109 py::class_<Matrix> mtx(m, "Matrix", py::buffer_protocol());
Wenzel Jakob38bd7112015-07-05 20:05:44 +0200110
Cris Luengo30d43c42017-04-14 14:33:44 -0600111 mtx.def(py::init<ssize_t, ssize_t>())
Wenzel Jakob38bd7112015-07-05 20:05:44 +0200112 /// Construct from a buffer
113 .def("__init__", [](Matrix &v, py::buffer b) {
114 py::buffer_info info = b.request();
Ivan Smirnov5e71e172016-06-26 12:42:34 +0100115 if (info.format != py::format_descriptor<float>::format() || info.ndim != 2)
Wenzel Jakob38bd7112015-07-05 20:05:44 +0200116 throw std::runtime_error("Incompatible buffer format!");
117 new (&v) Matrix(info.shape[0], info.shape[1]);
Cris Luengo30d43c42017-04-14 14:33:44 -0600118 memcpy(v.data(), info.ptr, sizeof(float) * (size_t) (v.rows() * v.cols()));
Wenzel Jakob38bd7112015-07-05 20:05:44 +0200119 })
120
121 .def("rows", &Matrix::rows)
122 .def("cols", &Matrix::cols)
123
124 /// Bare bones interface
Cris Luengo30d43c42017-04-14 14:33:44 -0600125 .def("__getitem__", [](const Matrix &m, std::pair<ssize_t, ssize_t> i) {
Wenzel Jakob38bd7112015-07-05 20:05:44 +0200126 if (i.first >= m.rows() || i.second >= m.cols())
127 throw py::index_error();
128 return m(i.first, i.second);
129 })
Cris Luengo30d43c42017-04-14 14:33:44 -0600130 .def("__setitem__", [](Matrix &m, std::pair<ssize_t, ssize_t> i, float v) {
Wenzel Jakob38bd7112015-07-05 20:05:44 +0200131 if (i.first >= m.rows() || i.second >= m.cols())
132 throw py::index_error();
133 m(i.first, i.second) = v;
134 })
135 /// Provide buffer access
136 .def_buffer([](Matrix &m) -> py::buffer_info {
137 return py::buffer_info(
Ivan Smirnov5e71e172016-06-26 12:42:34 +0100138 m.data(), /* Pointer to buffer */
Ivan Smirnov5e71e172016-06-26 12:42:34 +0100139 { m.rows(), m.cols() }, /* Buffer dimensions */
Cris Luengo30d43c42017-04-14 14:33:44 -0600140 { sizeof(float) * size_t(m.rows()), /* Strides (in bytes) for each index */
141 sizeof(float) }
Wenzel Jakob38bd7112015-07-05 20:05:44 +0200142 );
Jason Rhinelander3f589372016-08-07 13:05:26 -0400143 })
144 ;
Bruce Merryfe0cf8b2017-05-17 10:52:33 +0200145
Dean Moldovan427e4af2017-05-28 16:35:02 +0200146 // Derived classes inherit the buffer protocol and the buffer access function
147 py::class_<SquareMatrix, Matrix>(m, "SquareMatrix")
148 .def(py::init<ssize_t>());
149
Bruce Merryfe0cf8b2017-05-17 10:52:33 +0200150 py::class_<PTMFBuffer>(m, "PTMFBuffer", py::buffer_protocol())
151 .def(py::init<>())
152 .def_readwrite("value", &PTMFBuffer::value)
153 .def_buffer(&PTMFBuffer::get_buffer_info);
154
155 py::class_<ConstPTMFBuffer>(m, "ConstPTMFBuffer", py::buffer_protocol())
156 .def(py::init<>())
157 .def_property("value", &ConstPTMFBuffer::get_value, &ConstPTMFBuffer::set_value)
158 .def_buffer(&ConstPTMFBuffer::get_buffer_info);
159
160 // Tests that passing a pointer to member to the base class works in
161 // the derived class.
162 py::class_<DerivedPTMFBuffer>(m, "DerivedPTMFBuffer", py::buffer_protocol())
163 .def(py::init<>())
164 .def_readwrite("value", (int32_t DerivedPTMFBuffer::*) &DerivedPTMFBuffer::value)
165 .def_buffer(&DerivedPTMFBuffer::get_buffer_info);
Jason Rhinelander52f4be82016-09-03 14:54:22 -0400166});