1
0

operators.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #include "test.h"
  2. #include "pybind11/operators.h"
  3. namespace {
  4. struct Int {
  5. int x;
  6. Int(int x) : x(x) {}
  7. #define OPERATOR_IMPL(op) \
  8. template <typename LHS, typename RHS> \
  9. friend int operator op (const LHS& lhs, const RHS& rhs) { \
  10. int l, r; \
  11. if constexpr(std::is_same_v<LHS, Int>) \
  12. l = lhs.x; \
  13. else \
  14. l = lhs; \
  15. if constexpr(std::is_same_v<RHS, Int>) \
  16. r = rhs.x; \
  17. else \
  18. r = rhs; \
  19. return l op r; \
  20. }
  21. OPERATOR_IMPL(+)
  22. OPERATOR_IMPL(-)
  23. OPERATOR_IMPL(*)
  24. OPERATOR_IMPL(/)
  25. OPERATOR_IMPL(%)
  26. OPERATOR_IMPL(&)
  27. OPERATOR_IMPL(|)
  28. OPERATOR_IMPL(^)
  29. OPERATOR_IMPL(<<)
  30. OPERATOR_IMPL(>>)
  31. #undef OPERATOR_IMPL
  32. bool operator== (const Int& other) const { return x == other.x; }
  33. bool operator!= (const Int& other) const { return x != other.x; }
  34. bool operator< (const Int& other) const { return x < other.x; }
  35. bool operator<= (const Int& other) const { return x <= other.x; }
  36. bool operator> (const Int& other) const { return x > other.x; }
  37. bool operator>= (const Int& other) const { return x >= other.x; }
  38. bool operator!() const { return !x; }
  39. };
  40. } // namespace
  41. TEST_F(PYBIND11_TEST, arithmetic_operators) {
  42. py::module m = py::module::import("__main__");
  43. py::class_<Int>(m, "Int")
  44. .def(py::init<int>())
  45. .def(py::self + py::self)
  46. .def(py::self + int())
  47. .def(int() + py::self)
  48. .def(py::self - py::self)
  49. .def(py::self - int())
  50. .def(int() - py::self)
  51. .def(py::self * py::self)
  52. .def(py::self * int())
  53. .def(int() * py::self)
  54. .def(py::self / py::self)
  55. .def(py::self / int())
  56. .def(int() / py::self)
  57. .def(py::self % py::self)
  58. .def(py::self % int())
  59. .def(int() % py::self)
  60. .def(py::self & py::self)
  61. .def(py::self & int())
  62. .def(int() & py::self)
  63. .def(py::self | py::self)
  64. .def(py::self | int())
  65. .def(int() | py::self)
  66. .def(py::self ^ py::self)
  67. .def(py::self ^ int())
  68. .def(int() ^ py::self)
  69. .def(py::self << py::self)
  70. .def(py::self << int())
  71. .def(int() << py::self)
  72. .def(py::self >> py::self)
  73. .def(py::self >> int())
  74. .def(int() >> py::self);
  75. auto a = py::cast(Int(1));
  76. auto ai = py::cast(1);
  77. auto b = py::cast(Int(2));
  78. auto bi = py::cast(2);
  79. EXPECT_CAST_EQ(a + b, 3);
  80. EXPECT_CAST_EQ(a + bi, 3);
  81. EXPECT_CAST_EQ(ai + b, 3);
  82. EXPECT_CAST_EQ(a - b, -1);
  83. EXPECT_CAST_EQ(a - bi, -1);
  84. EXPECT_CAST_EQ(ai - b, -1);
  85. EXPECT_CAST_EQ(a * b, 2);
  86. EXPECT_CAST_EQ(a * bi, 2);
  87. EXPECT_CAST_EQ(ai * b, 2);
  88. EXPECT_CAST_EQ(a / b, 0);
  89. EXPECT_CAST_EQ(a / bi, 0);
  90. // EXPECT_CAST_EQ(ai / b, 0);
  91. EXPECT_CAST_EQ(a % b, 1);
  92. EXPECT_CAST_EQ(a % bi, 1);
  93. // EXPECT_CAST_EQ(ai % b, 1);
  94. EXPECT_CAST_EQ(a & b, 0);
  95. EXPECT_CAST_EQ(a & bi, 0);
  96. // EXPECT_CAST_EQ(ai & b, 0);
  97. EXPECT_CAST_EQ(a | b, 3);
  98. EXPECT_CAST_EQ(a | bi, 3);
  99. // EXPECT_CAST_EQ(ai | b, 3);
  100. EXPECT_CAST_EQ(a ^ b, 3);
  101. EXPECT_CAST_EQ(a ^ bi, 3);
  102. // EXPECT_CAST_EQ(ai ^ b, 3);
  103. EXPECT_CAST_EQ(a << b, 4);
  104. EXPECT_CAST_EQ(a << bi, 4);
  105. // EXPECT_CAST_EQ(ai << b, 4);
  106. EXPECT_CAST_EQ(a >> b, 0);
  107. EXPECT_CAST_EQ(a >> bi, 0);
  108. // EXPECT_CAST_EQ(ai >> b, 0);
  109. }
  110. TEST_F(PYBIND11_TEST, logic_operators) {
  111. py::module m = py::module::import("__main__");
  112. py::class_<Int>(m, "Int")
  113. .def(py::init<int>())
  114. .def_readwrite("x", &Int::x)
  115. .def(py::self == py::self)
  116. .def(py::self != py::self)
  117. .def(py::self < py::self)
  118. .def(py::self <= py::self)
  119. .def(py::self > py::self)
  120. .def(py::self >= py::self);
  121. auto a = py::cast(Int(1));
  122. auto b = py::cast(Int(2));
  123. EXPECT_FALSE(a == b);
  124. EXPECT_TRUE(a != b);
  125. EXPECT_TRUE(a < b);
  126. EXPECT_TRUE(a <= b);
  127. EXPECT_FALSE(a > b);
  128. EXPECT_FALSE(a >= b);
  129. }
  130. TEST_F(PYBIND11_TEST, item_operators) {
  131. py::module m = py::module::import("__main__");
  132. py::class_<std::vector<int>>(m, "vector")
  133. .def(py::init<>())
  134. .def("__getitem__",
  135. [](std::vector<int>& v, int i) {
  136. return v[i];
  137. })
  138. .def("__setitem__",
  139. [](std::vector<int>& v, int i, int x) {
  140. v[i] = x;
  141. })
  142. .def("push_back",
  143. [](std::vector<int>& v, int x) {
  144. v.push_back(x);
  145. })
  146. .def("__str__", [](const std::vector<int>& v) {
  147. std::ostringstream os;
  148. os << "[";
  149. for(size_t i = 0; i < v.size(); i++) {
  150. if(i > 0) os << ", ";
  151. os << v[i];
  152. }
  153. os << "]";
  154. return os.str();
  155. });
  156. py::exec(R"(
  157. v = vector()
  158. v.push_back(1)
  159. v.push_back(2)
  160. v.push_back(3)
  161. print(v)
  162. assert v[0] == 1
  163. assert v[1] == 2
  164. assert v[2] == 3
  165. v[1] = 4
  166. assert v[1] == 4
  167. )");
  168. }