Fix /= operator under Python 3
The Python method for /= was set as `__idiv__`, which should be
`__itruediv__` under Python 3.
This wasn't totally broken in that without it defined, Python constructs
a new object by calling __truediv__. The operator tests, however,
didn't actually test the /= operator: when I added it, I saw an extra
construction, leading to the problem. This commit also includes tests
for the previously untested *= operator, and adds some element-wise
vector multiplication and division operators.
diff --git a/include/pybind11/operators.h b/include/pybind11/operators.h
index 2e78c01..7d40fb5 100644
--- a/include/pybind11/operators.h
+++ b/include/pybind11/operators.h
@@ -25,7 +25,7 @@
op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
- op_repr, op_truediv
+ op_repr, op_truediv, op_itruediv
};
enum op_type : int {
@@ -129,7 +129,11 @@
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
+#if PY_MAJOR_VERSION >= 3
+PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
+#else
PYBIND11_INPLACE_OPERATOR(idiv, operator/=, l /= r)
+#endif
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
diff --git a/tests/test_operator_overloading.cpp b/tests/test_operator_overloading.cpp
index 93aea80..4e868d9 100644
--- a/tests/test_operator_overloading.cpp
+++ b/tests/test_operator_overloading.cpp
@@ -39,10 +39,14 @@
Vector2 operator+(float value) const { return Vector2(x + value, y + value); }
Vector2 operator*(float value) const { return Vector2(x * value, y * value); }
Vector2 operator/(float value) const { return Vector2(x / value, y / value); }
+ Vector2 operator*(const Vector2 &v) const { return Vector2(x * v.x, y * v.y); }
+ Vector2 operator/(const Vector2 &v) const { return Vector2(x / v.x, y / v.y); }
Vector2& operator+=(const Vector2 &v) { x += v.x; y += v.y; return *this; }
Vector2& operator-=(const Vector2 &v) { x -= v.x; y -= v.y; return *this; }
Vector2& operator*=(float v) { x *= v; y *= v; return *this; }
Vector2& operator/=(float v) { x /= v; y /= v; return *this; }
+ Vector2& operator*=(const Vector2 &v) { x *= v.x; y *= v.y; return *this; }
+ Vector2& operator/=(const Vector2 &v) { x /= v.x; y /= v.y; return *this; }
friend Vector2 operator+(float f, const Vector2 &v) { return Vector2(f + v.x, f + v.y); }
friend Vector2 operator-(float f, const Vector2 &v) { return Vector2(f - v.x, f - v.y); }
@@ -61,10 +65,14 @@
.def(py::self - float())
.def(py::self * float())
.def(py::self / float())
+ .def(py::self * py::self)
+ .def(py::self / py::self)
.def(py::self += py::self)
.def(py::self -= py::self)
.def(py::self *= float())
.def(py::self /= float())
+ .def(py::self *= py::self)
+ .def(py::self /= py::self)
.def(float() + py::self)
.def(float() - py::self)
.def(float() * py::self)
diff --git a/tests/test_operator_overloading.py b/tests/test_operator_overloading.py
index 02ccb96..dd37c34 100644
--- a/tests/test_operator_overloading.py
+++ b/tests/test_operator_overloading.py
@@ -16,10 +16,21 @@
assert str(8 + v1) == "[9.000000, 10.000000]"
assert str(8 * v1) == "[8.000000, 16.000000]"
assert str(8 / v1) == "[8.000000, 4.000000]"
+ assert str(v1 * v2) == "[3.000000, -2.000000]"
+ assert str(v2 / v1) == "[3.000000, -0.500000]"
- v1 += v2
+ v1 += 2 * v2
+ assert str(v1) == "[7.000000, 0.000000]"
+ v1 -= v2
+ assert str(v1) == "[4.000000, 1.000000]"
v1 *= 2
assert str(v1) == "[8.000000, 2.000000]"
+ v1 /= 16
+ assert str(v1) == "[0.500000, 0.125000]"
+ v1 *= v2
+ assert str(v1) == "[1.500000, -0.125000]"
+ v2 /= v1
+ assert str(v2) == "[2.000000, 8.000000]"
cstats = ConstructorStats.get(Vector2)
assert cstats.alive() == 2
@@ -32,7 +43,9 @@
'[-7.000000, -6.000000]', '[9.000000, 10.000000]',
'[8.000000, 16.000000]', '[0.125000, 0.250000]',
'[7.000000, 6.000000]', '[9.000000, 10.000000]',
- '[8.000000, 16.000000]', '[8.000000, 4.000000]']
+ '[8.000000, 16.000000]', '[8.000000, 4.000000]',
+ '[3.000000, -2.000000]', '[3.000000, -0.500000]',
+ '[6.000000, -2.000000]']
assert cstats.default_constructions == 0
assert cstats.copy_constructions == 0
assert cstats.move_constructions >= 10