blob: c3a7a9e020e6bd482523b626c8f9b8e50ef0eda7 [file] [log] [blame]
Wenzel Jakob38bd7112015-07-05 20:05:44 +02001/*
Dean Moldovana0c1ccf2016-08-12 13:50:00 +02002 tests/test_buffers.cpp -- supporting Pythons' buffer protocol
Wenzel Jakob38bd7112015-07-05 20:05:44 +02003
Wenzel Jakob8cb6cb32016-04-17 20:21:41 +02004 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob38bd7112015-07-05 20:05:44 +02005
6 All rights reserved. Use of this source code is governed by a
7 BSD-style license that can be found in the LICENSE file.
8*/
9
Dean Moldovana0c1ccf2016-08-12 13:50:00 +020010#include "pybind11_tests.h"
11#include "constructor_stats.h"
Wenzel Jakob38bd7112015-07-05 20:05:44 +020012
13class Matrix {
14public:
15 Matrix(size_t rows, size_t cols) : m_rows(rows), m_cols(cols) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040016 print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Wenzel Jakob38bd7112015-07-05 20:05:44 +020017 m_data = new float[rows*cols];
18 memset(m_data, 0, sizeof(float) * rows * cols);
19 }
20
21 Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040022 print_copy_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Wenzel Jakob38bd7112015-07-05 20:05:44 +020023 m_data = new float[m_rows * m_cols];
24 memcpy(m_data, s.m_data, sizeof(float) * m_rows * m_cols);
25 }
26
27 Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040028 print_move_created(this);
Wenzel Jakob38bd7112015-07-05 20:05:44 +020029 s.m_rows = 0;
30 s.m_cols = 0;
31 s.m_data = nullptr;
32 }
33
34 ~Matrix() {
Jason Rhinelander3f589372016-08-07 13:05:26 -040035 print_destroyed(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Wenzel Jakob38bd7112015-07-05 20:05:44 +020036 delete[] m_data;
37 }
38
39 Matrix &operator=(const Matrix &s) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040040 print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Wenzel Jakob38bd7112015-07-05 20:05:44 +020041 delete[] m_data;
42 m_rows = s.m_rows;
43 m_cols = s.m_cols;
44 m_data = new float[m_rows * m_cols];
45 memcpy(m_data, s.m_data, sizeof(float) * m_rows * m_cols);
46 return *this;
47 }
48
49 Matrix &operator=(Matrix &&s) {
Jason Rhinelander3f589372016-08-07 13:05:26 -040050 print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
Wenzel Jakob38bd7112015-07-05 20:05:44 +020051 if (&s != this) {
52 delete[] m_data;
53 m_rows = s.m_rows; m_cols = s.m_cols; m_data = s.m_data;
54 s.m_rows = 0; s.m_cols = 0; s.m_data = nullptr;
55 }
56 return *this;
57 }
58
59 float operator()(size_t i, size_t j) const {
60 return m_data[i*m_cols + j];
61 }
62
63 float &operator()(size_t i, size_t j) {
64 return m_data[i*m_cols + j];
65 }
66
67 float *data() { return m_data; }
68
69 size_t rows() const { return m_rows; }
70 size_t cols() const { return m_cols; }
71private:
72 size_t m_rows;
73 size_t m_cols;
74 float *m_data;
75};
76
Jason Rhinelander52f4be82016-09-03 14:54:22 -040077test_initializer buffers([](py::module &m) {
Wenzel Jakob38bd7112015-07-05 20:05:44 +020078 py::class_<Matrix> mtx(m, "Matrix");
79
80 mtx.def(py::init<size_t, size_t>())
81 /// Construct from a buffer
82 .def("__init__", [](Matrix &v, py::buffer b) {
83 py::buffer_info info = b.request();
Ivan Smirnov5e71e172016-06-26 12:42:34 +010084 if (info.format != py::format_descriptor<float>::format() || info.ndim != 2)
Wenzel Jakob38bd7112015-07-05 20:05:44 +020085 throw std::runtime_error("Incompatible buffer format!");
86 new (&v) Matrix(info.shape[0], info.shape[1]);
87 memcpy(v.data(), info.ptr, sizeof(float) * v.rows() * v.cols());
88 })
89
90 .def("rows", &Matrix::rows)
91 .def("cols", &Matrix::cols)
92
93 /// Bare bones interface
94 .def("__getitem__", [](const Matrix &m, std::pair<size_t, size_t> i) {
95 if (i.first >= m.rows() || i.second >= m.cols())
96 throw py::index_error();
97 return m(i.first, i.second);
98 })
99 .def("__setitem__", [](Matrix &m, std::pair<size_t, size_t> i, float v) {
100 if (i.first >= m.rows() || i.second >= m.cols())
101 throw py::index_error();
102 m(i.first, i.second) = v;
103 })
104 /// Provide buffer access
105 .def_buffer([](Matrix &m) -> py::buffer_info {
106 return py::buffer_info(
Ivan Smirnov5e71e172016-06-26 12:42:34 +0100107 m.data(), /* Pointer to buffer */
108 sizeof(float), /* Size of one scalar */
109 py::format_descriptor<float>::format(), /* Python struct-style format descriptor */
110 2, /* Number of dimensions */
111 { m.rows(), m.cols() }, /* Buffer dimensions */
112 { sizeof(float) * m.rows(), /* Strides (in bytes) for each index */
Wenzel Jakob38bd7112015-07-05 20:05:44 +0200113 sizeof(float) }
114 );
Jason Rhinelander3f589372016-08-07 13:05:26 -0400115 })
116 ;
Jason Rhinelander52f4be82016-09-03 14:54:22 -0400117});