operators should return NotImplemented given unsupported input (fixes #393)
diff --git a/docs/advanced.rst b/docs/advanced.rst
index 748f91e..ff85a4f 100644
--- a/docs/advanced.rst
+++ b/docs/advanced.rst
@@ -90,10 +90,13 @@
.def("__mul__", [](const Vector2 &a, float b) {
return a * b;
- })
+ }, py::is_operator())
This can be useful for exposing additional operators that don't exist on the
-C++ side, or to perform other types of customization.
+C++ side, or to perform other types of customization. The ``py::is_operator``
+flag marker is needed to inform pybind11 that this is an operator, which
+returns ``NotImplemented`` when invoked with incompatible arguments rather than
+throwing a type error.
.. note::
diff --git a/include/pybind11/attr.h b/include/pybind11/attr.h
index 9acb3e3..e3434b1 100644
--- a/include/pybind11/attr.h
+++ b/include/pybind11/attr.h
@@ -17,6 +17,9 @@
/// Annotation for methods
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
+/// Annotation for operators
+struct is_operator { };
+
/// Annotation for parent scope
struct scope { handle value; scope(const handle &s) : value(s) { } };
@@ -57,6 +60,10 @@
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
struct function_record {
+ function_record()
+ : is_constructor(false), is_stateless(false), is_operator(false),
+ has_args(false), has_kwargs(false) { }
+
/// Function name
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
@@ -87,6 +94,9 @@
/// True if this is a stateless function pointer
bool is_stateless : 1;
+ /// True if this is an operator (__add__), etc.
+ bool is_operator : 1;
+
/// True if the function has a '*args' argument
bool has_args : 1;
@@ -198,6 +208,10 @@
static void init(const scope &s, function_record *r) { r->scope = s.value; }
};
+/// Process an attribute which indicates that this function is an operator
+template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
+ static void init(const is_operator &, function_record *r) { r->is_operator = true; }
+};
/// Process a keyword argument attribute (*without* a default value)
template <> struct process_attribute<arg> : process_attribute_default<arg> {
diff --git a/include/pybind11/operators.h b/include/pybind11/operators.h
index eda51a1..22d1859 100644
--- a/include/pybind11/operators.h
+++ b/include/pybind11/operators.h
@@ -54,14 +54,14 @@
typedef typename std::conditional<std::is_same<L, self_t>::value, Base, L>::type L_type;
typedef typename std::conditional<std::is_same<R, self_t>::value, Base, R>::type R_type;
typedef op_impl<id, ot, Base, L_type, R_type> op;
- cl.def(op::name(), &op::execute, extra...);
+ cl.def(op::name(), &op::execute, is_operator(), extra...);
}
template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
typedef typename Class::type Base;
typedef typename std::conditional<std::is_same<L, self_t>::value, Base, L>::type L_type;
typedef typename std::conditional<std::is_same<R, self_t>::value, Base, R>::type R_type;
typedef op_impl<id, ot, Base, L_type, R_type> op;
- cl.def(op::name(), &op::execute_cast, extra...);
+ cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
}
};
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index 8f88d36..ea7acb4 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -71,6 +71,11 @@
object name() const { return attr("__name__"); }
protected:
+ /// Space optimization: don't inline this frequently instantiated fragment
+ PYBIND11_NOINLINE detail::function_record *make_function_record() {
+ return new detail::function_record();
+ }
+
/// Special internal constructor for functors, lambda functions, etc.
template <typename Func, typename Return, typename... Args, typename... Extra>
void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) {
@@ -80,7 +85,7 @@
struct capture { typename std::remove_reference<Func>::type f; };
/* Store the function including any extra state it might have (e.g. a lambda capture object) */
- auto rec = new detail::function_record();
+ auto rec = make_function_record();
/* Store the capture object directly in the function record if there is enough space */
if (sizeof(capture) <= sizeof(rec->data)) {
@@ -241,9 +246,6 @@
rec->signature = strdup(signature.c_str());
rec->args.shrink_to_fit();
rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__");
- rec->is_stateless = false;
- rec->has_args = false;
- rec->has_kwargs = false;
rec->nargs = (uint16_t) args;
#if PY_MAJOR_VERSION < 3
@@ -454,6 +456,9 @@
}
if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) {
+ if (overloads->is_operator)
+ return handle(Py_NotImplemented).inc_ref().ptr();
+
std::string msg = "Incompatible " + std::string(overloads->is_constructor ? "constructor" : "function") +
" arguments. The following argument types are supported:\n";
int ctr = 0;
diff --git a/tests/test_issues.cpp b/tests/test_issues.cpp
index c5314bc..843978e 100644
--- a/tests/test_issues.cpp
+++ b/tests/test_issues.cpp
@@ -20,6 +20,23 @@
struct NestB { NestA a; int value = 4; NestB& operator-=(int i) { value -= i; return *this; } TRACKERS(NestB) };
struct NestC { NestB b; int value = 5; NestC& operator*=(int i) { value *= i; return *this; } TRACKERS(NestC) };
+/// #393
+class OpTest1 {};
+class OpTest2 {};
+
+OpTest1 operator+(const OpTest1 &, const OpTest1 &) {
+ py::print("Add OpTest1 with OpTest1");
+ return OpTest1();
+}
+OpTest2 operator+(const OpTest2 &, const OpTest2 &) {
+ py::print("Add OpTest2 with OpTest2");
+ return OpTest2();
+}
+OpTest2 operator+(const OpTest2 &, const OpTest1 &) {
+ py::print("Add OpTest2 with OpTest1");
+ return OpTest2();
+}
+
void init_issues(py::module &m) {
py::module m2 = m.def_submodule("issues");
@@ -230,6 +247,16 @@
.def("A_value", &OverrideTest::A_value)
.def("A_ref", &OverrideTest::A_ref);
+ /// Issue 393: need to return NotSupported to ensure correct arithmetic operator behavior
+ py::class_<OpTest1>(m2, "OpTest1")
+ .def(py::init<>())
+ .def(py::self + py::self);
+
+ py::class_<OpTest2>(m2, "OpTest2")
+ .def(py::init<>())
+ .def(py::self + py::self)
+ .def("__add__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; })
+ .def("__radd__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; });
}
// MSVC workaround: trying to use a lambda here crashes MSCV
diff --git a/tests/test_issues.py b/tests/test_issues.py
index 2af6f1c..a28e509 100644
--- a/tests/test_issues.py
+++ b/tests/test_issues.py
@@ -181,3 +181,16 @@
assert a.value == "hi"
a.value = "bye"
assert a.value == "bye"
+
+def test_operators_notimplemented(capture):
+ from pybind11_tests.issues import OpTest1, OpTest2
+ with capture:
+ C1, C2 = OpTest1(), OpTest2()
+ C1 + C1
+ C2 + C2
+ C2 + C1
+ C1 + C2
+ assert capture == """Add OpTest1 with OpTest1
+Add OpTest2 with OpTest2
+Add OpTest2 with OpTest1
+Add OpTest2 with OpTest1"""