| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- #include "test.h"
- #include "pybind11/operators.h"
- namespace {
- struct Int {
- int x;
- Int(int x) : x(x) {}
- #define OPERATOR_IMPL(op) \
- template <typename LHS, typename RHS> \
- friend int operator op (const LHS& lhs, const RHS& rhs) { \
- int l, r; \
- if constexpr(std::is_same_v<LHS, Int>) \
- l = lhs.x; \
- else \
- l = lhs; \
- if constexpr(std::is_same_v<RHS, Int>) \
- r = rhs.x; \
- else \
- r = rhs; \
- return l op r; \
- }
- OPERATOR_IMPL(+)
- OPERATOR_IMPL(-)
- OPERATOR_IMPL(*)
- OPERATOR_IMPL(/)
- OPERATOR_IMPL(%)
- OPERATOR_IMPL(&)
- OPERATOR_IMPL(|)
- OPERATOR_IMPL(^)
- OPERATOR_IMPL(<<)
- OPERATOR_IMPL(>>)
- #undef OPERATOR_IMPL
- bool operator== (const Int& other) const { return x == other.x; }
- bool operator!= (const Int& other) const { return x != other.x; }
- bool operator< (const Int& other) const { return x < other.x; }
- bool operator<= (const Int& other) const { return x <= other.x; }
- bool operator> (const Int& other) const { return x > other.x; }
- bool operator>= (const Int& other) const { return x >= other.x; }
- bool operator!() const { return !x; }
- };
- } // namespace
- TEST_F(PYBIND11_TEST, arithmetic_operators) {
- py::module m = py::module::import("__main__");
- py::class_<Int>(m, "Int")
- .def(py::init<int>())
- .def(py::self + py::self)
- .def(py::self + int())
- .def(int() + py::self)
- .def(py::self - py::self)
- .def(py::self - int())
- .def(int() - py::self)
- .def(py::self * py::self)
- .def(py::self * int())
- .def(int() * py::self)
- .def(py::self / py::self)
- .def(py::self / int())
- .def(int() / py::self)
- .def(py::self % py::self)
- .def(py::self % int())
- .def(int() % py::self)
- .def(py::self & py::self)
- .def(py::self & int())
- .def(int() & py::self)
- .def(py::self | py::self)
- .def(py::self | int())
- .def(int() | py::self)
- .def(py::self ^ py::self)
- .def(py::self ^ int())
- .def(int() ^ py::self)
- .def(py::self << py::self)
- .def(py::self << int())
- .def(int() << py::self)
- .def(py::self >> py::self)
- .def(py::self >> int())
- .def(int() >> py::self);
- auto a = py::cast(Int(1));
- auto ai = py::cast(1);
- auto b = py::cast(Int(2));
- auto bi = py::cast(2);
- EXPECT_CAST_EQ(a + b, 3);
- EXPECT_CAST_EQ(a + bi, 3);
- EXPECT_CAST_EQ(ai + b, 3);
- EXPECT_CAST_EQ(a - b, -1);
- EXPECT_CAST_EQ(a - bi, -1);
- EXPECT_CAST_EQ(ai - b, -1);
- EXPECT_CAST_EQ(a * b, 2);
- EXPECT_CAST_EQ(a * bi, 2);
- EXPECT_CAST_EQ(ai * b, 2);
- EXPECT_CAST_EQ(a / b, 0);
- EXPECT_CAST_EQ(a / bi, 0);
- // EXPECT_CAST_EQ(ai / b, 0);
- EXPECT_CAST_EQ(a % b, 1);
- EXPECT_CAST_EQ(a % bi, 1);
- // EXPECT_CAST_EQ(ai % b, 1);
- EXPECT_CAST_EQ(a & b, 0);
- EXPECT_CAST_EQ(a & bi, 0);
- // EXPECT_CAST_EQ(ai & b, 0);
- EXPECT_CAST_EQ(a | b, 3);
- EXPECT_CAST_EQ(a | bi, 3);
- // EXPECT_CAST_EQ(ai | b, 3);
- EXPECT_CAST_EQ(a ^ b, 3);
- EXPECT_CAST_EQ(a ^ bi, 3);
- // EXPECT_CAST_EQ(ai ^ b, 3);
- EXPECT_CAST_EQ(a << b, 4);
- EXPECT_CAST_EQ(a << bi, 4);
- // EXPECT_CAST_EQ(ai << b, 4);
- EXPECT_CAST_EQ(a >> b, 0);
- EXPECT_CAST_EQ(a >> bi, 0);
- // EXPECT_CAST_EQ(ai >> b, 0);
- }
- TEST_F(PYBIND11_TEST, logic_operators) {
- py::module m = py::module::import("__main__");
- py::class_<Int>(m, "Int")
- .def(py::init<int>())
- .def_readwrite("x", &Int::x)
- .def(py::self == py::self)
- .def(py::self != py::self)
- .def(py::self < py::self)
- .def(py::self <= py::self)
- .def(py::self > py::self)
- .def(py::self >= py::self);
- auto a = py::cast(Int(1));
- auto b = py::cast(Int(2));
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
- EXPECT_TRUE(a < b);
- EXPECT_TRUE(a <= b);
- EXPECT_FALSE(a > b);
- EXPECT_FALSE(a >= b);
- }
- TEST_F(PYBIND11_TEST, item_operators) {
- py::module m = py::module::import("__main__");
- py::class_<std::vector<int>>(m, "vector")
- .def(py::init<>())
- .def("__getitem__",
- [](std::vector<int>& v, int i) {
- return v[i];
- })
- .def("__setitem__",
- [](std::vector<int>& v, int i, int x) {
- v[i] = x;
- })
- .def("push_back",
- [](std::vector<int>& v, int x) {
- v.push_back(x);
- })
- .def("__str__", [](const std::vector<int>& v) {
- std::ostringstream os;
- os << "[";
- for(size_t i = 0; i < v.size(); i++) {
- if(i > 0) os << ", ";
- os << v[i];
- }
- os << "]";
- return os.str();
- });
- py::exec(R"(
- v = vector()
- v.push_back(1)
- v.push_back(2)
- v.push_back(3)
- print(v)
- assert v[0] == 1
- assert v[1] == 2
- assert v[2] == 3
- v[1] = 4
- assert v[1] == 4
- )");
- }
|