blob: b96790c3902e50c3087976fcd2168b8342f6718c [file] [log] [blame]
Ivan Smirnov91b3d682016-08-29 02:41:05 +01001import pytest
2
3with pytest.suppress(ImportError):
4 import numpy as np
5
6
Ivan Smirnovaca6bca2016-09-08 23:03:35 +01007@pytest.fixture(scope='function')
8def arr():
9 return np.array([[1, 2, 3], [4, 5, 6]], '<u2')
10
11
Ivan Smirnov91b3d682016-08-29 02:41:05 +010012@pytest.requires_numpy
13def test_array_attributes():
Ivan Smirnovaca6bca2016-09-08 23:03:35 +010014 from pybind11_tests.array import (
15 ndim, shape, strides, writeable, size, itemsize, nbytes, owndata
16 )
Ivan Smirnov91b3d682016-08-29 02:41:05 +010017
18 a = np.array(0, 'f8')
Ivan Smirnovaca6bca2016-09-08 23:03:35 +010019 assert ndim(a) == 0
20 assert all(shape(a) == [])
21 assert all(strides(a) == [])
22 with pytest.raises(IndexError) as excinfo:
23 shape(a, 0)
24 assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
25 with pytest.raises(IndexError) as excinfo:
26 strides(a, 0)
27 assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
28 assert writeable(a)
29 assert size(a) == 1
30 assert itemsize(a) == 8
31 assert nbytes(a) == 8
32 assert owndata(a)
Ivan Smirnov91b3d682016-08-29 02:41:05 +010033
34 a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
35 a.flags.writeable = False
Ivan Smirnovaca6bca2016-09-08 23:03:35 +010036 assert ndim(a) == 2
37 assert all(shape(a) == [2, 3])
38 assert shape(a, 0) == 2
39 assert shape(a, 1) == 3
40 assert all(strides(a) == [6, 2])
41 assert strides(a, 0) == 6
42 assert strides(a, 1) == 2
43 with pytest.raises(IndexError) as excinfo:
44 shape(a, 2)
45 assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
46 with pytest.raises(IndexError) as excinfo:
47 strides(a, 2)
48 assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
49 assert not writeable(a)
50 assert size(a) == 6
51 assert itemsize(a) == 2
52 assert nbytes(a) == 12
53 assert not owndata(a)
54
55
56@pytest.requires_numpy
57@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)])
58def test_index_offset(arr, args, ret):
59 from pybind11_tests.array import index_at, index_at_t, offset_at, offset_at_t
60 assert index_at(arr, *args) == ret
61 assert index_at_t(arr, *args) == ret
62 assert offset_at(arr, *args) == ret * arr.dtype.itemsize
63 assert offset_at_t(arr, *args) == ret * arr.dtype.itemsize
64
65
66@pytest.requires_numpy
67def test_dim_check_fail(arr):
68 from pybind11_tests.array import (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
69 mutate_data, mutate_data_t)
70 for func in (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
71 mutate_data, mutate_data_t):
72 with pytest.raises(IndexError) as excinfo:
73 func(arr, 1, 2, 3)
74 assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)'
75
76
77@pytest.requires_numpy
78@pytest.mark.parametrize('args, ret',
79 [([], [1, 2, 3, 4, 5, 6]),
80 ([1], [4, 5, 6]),
81 ([0, 1], [2, 3, 4, 5, 6]),
82 ([1, 2], [6])])
83def test_data(arr, args, ret):
84 from pybind11_tests.array import data, data_t
85 assert all(data_t(arr, *args) == ret)
86 assert all(data(arr, *args)[::2] == ret)
87 assert all(data(arr, *args)[1::2] == 0)
88
89
90@pytest.requires_numpy
91def test_mutate_readonly(arr):
92 from pybind11_tests.array import mutate_data, mutate_data_t, mutate_at_t
93 arr.flags.writeable = False
94 for func, args in (mutate_data, ()), (mutate_data_t, ()), (mutate_at_t, (0, 0)):
95 with pytest.raises(RuntimeError) as excinfo:
96 func(arr, *args)
97 assert str(excinfo.value) == 'array is not writeable'
98
99
100@pytest.requires_numpy
101@pytest.mark.parametrize('dim', [0, 1, 3])
102def test_at_fail(arr, dim):
103 from pybind11_tests.array import at_t, mutate_at_t
104 for func in at_t, mutate_at_t:
105 with pytest.raises(IndexError) as excinfo:
106 func(arr, *([0] * dim))
107 assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim)
108
109
110@pytest.requires_numpy
111def test_at(arr):
112 from pybind11_tests.array import at_t, mutate_at_t
113
114 assert at_t(arr, 0, 2) == 3
115 assert at_t(arr, 1, 0) == 4
116
117 assert all(mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
118 assert all(mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
119
120
121@pytest.requires_numpy
122def test_mutate_data(arr):
123 from pybind11_tests.array import mutate_data, mutate_data_t
124
125 assert all(mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12])
126 assert all(mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24])
127 assert all(mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48])
128 assert all(mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96])
129 assert all(mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192])
130
131 assert all(mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193])
132 assert all(mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194])
133 assert all(mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195])
134 assert all(mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196])
135 assert all(mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
136
137
138@pytest.requires_numpy
139def test_bounds_check(arr):
140 from pybind11_tests.array import (index_at, index_at_t, data, data_t,
141 mutate_data, mutate_data_t, at_t, mutate_at_t)
142 funcs = (index_at, index_at_t, data, data_t,
143 mutate_data, mutate_data_t, at_t, mutate_at_t)
144 for func in funcs:
145 with pytest.raises(IndexError) as excinfo:
Dean Moldovanbad17402016-11-20 21:21:54 +0100146 func(arr, 2, 0)
Ivan Smirnovaca6bca2016-09-08 23:03:35 +0100147 assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2'
148 with pytest.raises(IndexError) as excinfo:
Dean Moldovanbad17402016-11-20 21:21:54 +0100149 func(arr, 0, 4)
Ivan Smirnovaca6bca2016-09-08 23:03:35 +0100150 assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
Wenzel Jakob43f6aa62016-10-12 23:34:06 +0200151
Wenzel Jakob369e9b32016-10-13 00:57:42 +0200152
Wenzel Jakob43f6aa62016-10-12 23:34:06 +0200153@pytest.requires_numpy
154def test_make_c_f_array():
155 from pybind11_tests.array import (
156 make_c_array, make_f_array
157 )
158 assert make_c_array().flags.c_contiguous
159 assert not make_c_array().flags.f_contiguous
160 assert make_f_array().flags.f_contiguous
161 assert not make_f_array().flags.c_contiguous
Wenzel Jakob369e9b32016-10-13 00:57:42 +0200162
163
164@pytest.requires_numpy
165def test_wrap():
166 from pybind11_tests.array import wrap
167
Dean Moldovanbad17402016-11-20 21:21:54 +0100168 def assert_references(a, b):
169 assert a is not b
170 assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0]
171 assert a.shape == b.shape
172 assert a.strides == b.strides
173 assert a.flags.c_contiguous == b.flags.c_contiguous
174 assert a.flags.f_contiguous == b.flags.f_contiguous
175 assert a.flags.writeable == b.flags.writeable
176 assert a.flags.aligned == b.flags.aligned
177 assert a.flags.updateifcopy == b.flags.updateifcopy
178 assert np.all(a == b)
179 assert not b.flags.owndata
180 assert b.base is a
181 if a.flags.writeable and a.ndim == 2:
182 a[0, 0] = 1234
183 assert b[0, 0] == 1234
Wenzel Jakob369e9b32016-10-13 00:57:42 +0200184
Dean Moldovanbad17402016-11-20 21:21:54 +0100185 a1 = np.array([1, 2], dtype=np.int16)
186 assert a1.flags.owndata and a1.base is None
187 a2 = wrap(a1)
188 assert_references(a1, a2)
Wenzel Jakob369e9b32016-10-13 00:57:42 +0200189
Dean Moldovanbad17402016-11-20 21:21:54 +0100190 a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F')
191 assert a1.flags.owndata and a1.base is None
192 a2 = wrap(a1)
193 assert_references(a1, a2)
Wenzel Jakob369e9b32016-10-13 00:57:42 +0200194
Dean Moldovanbad17402016-11-20 21:21:54 +0100195 a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C')
196 a1.flags.writeable = False
197 a2 = wrap(a1)
198 assert_references(a1, a2)
Wenzel Jakob369e9b32016-10-13 00:57:42 +0200199
Dean Moldovanbad17402016-11-20 21:21:54 +0100200 a1 = np.random.random((4, 4, 4))
201 a2 = wrap(a1)
202 assert_references(a1, a2)
Wenzel Jakob369e9b32016-10-13 00:57:42 +0200203
Dean Moldovanbad17402016-11-20 21:21:54 +0100204 a1 = a1.transpose()
205 a2 = wrap(a1)
206 assert_references(a1, a2)
Wenzel Jakob369e9b32016-10-13 00:57:42 +0200207
Dean Moldovanbad17402016-11-20 21:21:54 +0100208 a1 = a1.diagonal()
209 a2 = wrap(a1)
210 assert_references(a1, a2)
Wenzel Jakobfac7c092016-10-13 10:37:52 +0200211
212
213@pytest.requires_numpy
214def test_numpy_view(capture):
215 from pybind11_tests.array import ArrayClass
216 with capture:
217 ac = ArrayClass()
218 ac_view_1 = ac.numpy_view()
219 ac_view_2 = ac.numpy_view()
220 assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
221 del ac
Wenzel Jakob1d1f81b2016-12-16 15:00:46 +0100222 pytest.gc_collect()
Wenzel Jakobfac7c092016-10-13 10:37:52 +0200223 assert capture == """
224 ArrayClass()
225 ArrayClass::numpy_view()
226 ArrayClass::numpy_view()
227 """
228 ac_view_1[0] = 4
229 ac_view_1[1] = 3
230 assert ac_view_2[0] == 4
231 assert ac_view_2[1] == 3
232 with capture:
233 del ac_view_1
234 del ac_view_2
Wenzel Jakob1d1f81b2016-12-16 15:00:46 +0100235 pytest.gc_collect()
236 pytest.gc_collect()
Wenzel Jakobfac7c092016-10-13 10:37:52 +0200237 assert capture == """
238 ~ArrayClass()
239 """
Wenzel Jakob496feac2016-10-28 00:37:07 +0200240
241
Wenzel Jakob1d1f81b2016-12-16 15:00:46 +0100242@pytest.unsupported_on_pypy
Wenzel Jakob496feac2016-10-28 00:37:07 +0200243@pytest.requires_numpy
244def test_cast_numpy_int64_to_uint64():
245 from pybind11_tests.array import function_taking_uint64
246 function_taking_uint64(123)
247 function_taking_uint64(np.uint64(123))
Dean Moldovan4de27102016-11-16 01:35:22 +0100248
249
250@pytest.requires_numpy
251def test_isinstance():
252 from pybind11_tests.array import isinstance_untyped, isinstance_typed
253
254 assert isinstance_untyped(np.array([1, 2, 3]), "not an array")
255 assert isinstance_typed(np.array([1.0, 2.0, 3.0]))
256
257
258@pytest.requires_numpy
259def test_constructors():
260 from pybind11_tests.array import default_constructors, converting_constructors
261
262 defaults = default_constructors()
263 for a in defaults.values():
264 assert a.size == 0
265 assert defaults["array"].dtype == np.array([]).dtype
266 assert defaults["array_t<int32>"].dtype == np.int32
267 assert defaults["array_t<double>"].dtype == np.float64
268
269 results = converting_constructors([1, 2, 3])
270 for a in results.values():
271 np.testing.assert_array_equal(a, [1, 2, 3])
272 assert results["array"].dtype == np.int_
273 assert results["array_t<int32>"].dtype == np.int32
274 assert results["array_t<double>"].dtype == np.float64