operators should return NotImplemented given unsupported input (fixes #393)
diff --git a/tests/test_issues.cpp b/tests/test_issues.cpp
index c5314bc..843978e 100644
--- a/tests/test_issues.cpp
+++ b/tests/test_issues.cpp
@@ -20,6 +20,23 @@
struct NestB { NestA a; int value = 4; NestB& operator-=(int i) { value -= i; return *this; } TRACKERS(NestB) };
struct NestC { NestB b; int value = 5; NestC& operator*=(int i) { value *= i; return *this; } TRACKERS(NestC) };
+/// #393
+class OpTest1 {};
+class OpTest2 {};
+
+OpTest1 operator+(const OpTest1 &, const OpTest1 &) {
+ py::print("Add OpTest1 with OpTest1");
+ return OpTest1();
+}
+OpTest2 operator+(const OpTest2 &, const OpTest2 &) {
+ py::print("Add OpTest2 with OpTest2");
+ return OpTest2();
+}
+OpTest2 operator+(const OpTest2 &, const OpTest1 &) {
+ py::print("Add OpTest2 with OpTest1");
+ return OpTest2();
+}
+
void init_issues(py::module &m) {
py::module m2 = m.def_submodule("issues");
@@ -230,6 +247,16 @@
.def("A_value", &OverrideTest::A_value)
.def("A_ref", &OverrideTest::A_ref);
+ /// Issue 393: need to return NotSupported to ensure correct arithmetic operator behavior
+ py::class_<OpTest1>(m2, "OpTest1")
+ .def(py::init<>())
+ .def(py::self + py::self);
+
+ py::class_<OpTest2>(m2, "OpTest2")
+ .def(py::init<>())
+ .def(py::self + py::self)
+ .def("__add__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; })
+ .def("__radd__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; });
}
// MSVC workaround: trying to use a lambda here crashes MSCV
diff --git a/tests/test_issues.py b/tests/test_issues.py
index 2af6f1c..a28e509 100644
--- a/tests/test_issues.py
+++ b/tests/test_issues.py
@@ -181,3 +181,16 @@
assert a.value == "hi"
a.value = "bye"
assert a.value == "bye"
+
+def test_operators_notimplemented(capture):
+ from pybind11_tests.issues import OpTest1, OpTest2
+ with capture:
+ C1, C2 = OpTest1(), OpTest2()
+ C1 + C1
+ C2 + C2
+ C2 + C1
+ C1 + C2
+ assert capture == """Add OpTest1 with OpTest1
+Add OpTest2 with OpTest2
+Add OpTest2 with OpTest1
+Add OpTest2 with OpTest1"""