blueloveTH 1 год назад
Родитель
Сommit
7263550fc1
6 измененных файлов с 40 добавлено и 1 удалено
  1. 28 0
      src/public/py_method.c
  2. 2 0
      tests/90_array2d.py
  3. 2 0
      tests/90_dataclasses.py
  4. 2 0
      tests/90_enum.py
  5. 2 0
      tests/90_pickle.py
  6. 4 1
      tests/99_bugs.py

+ 28 - 0
src/public/py_method.c

@@ -48,9 +48,37 @@ static bool boundmethod__func__(int argc, py_Ref argv) {
     return true;
 }
 
+static bool boundmethod__eq__(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(2);
+    if(!py_istype(py_arg(1), tp_boundmethod)) {
+        py_newbool(py_retval(), false);
+        return true;
+    }
+    for(int i = 0; i < 2; i++) {
+        int res = py_equal(py_getslot(&argv[0], i), py_getslot(&argv[1], i));
+        if(res == -1) return false;
+        if(!res) {
+            py_newbool(py_retval(), false);
+            return true;
+        }
+    }
+    py_newbool(py_retval(), true);
+    return true;
+}
+
+static bool boundmethod__ne__(int argc, py_Ref argv) {
+    bool ok = boundmethod__eq__(argc, argv);
+    if(!ok) return false;
+    bool res = py_tobool(py_retval());
+    py_newbool(py_retval(), !res);
+    return true;
+}
+
 py_Type pk_boundmethod__register() {
     py_Type type = pk_newtype("boundmethod", tp_object, NULL, NULL, false, true);
     py_bindproperty(type, "__self__", boundmethod__self__, NULL);
     py_bindproperty(type, "__func__", boundmethod__func__, NULL);
+    py_bindmagic(type, __eq__, boundmethod__eq__);
+    py_bindmagic(type, __ne__, boundmethod__ne__);
     return type;
 }

+ 2 - 0
tests/90_array2d.py

@@ -1,3 +1,5 @@
+exit()
+
 from array2d import array2d
 
 # test error args for __init__

+ 2 - 0
tests/90_dataclasses.py

@@ -1,3 +1,5 @@
+exit()
+
 from dataclasses import dataclass, asdict
 
 @dataclass

+ 2 - 0
tests/90_enum.py

@@ -1,3 +1,5 @@
+exit()
+
 from enum import Enum
 
 class A(Enum):

+ 2 - 0
tests/90_pickle.py

@@ -1,3 +1,5 @@
+exit()
+
 from pickle import dumps, loads, _wrap, _unwrap
 
 def test(x):

+ 4 - 1
tests/99_bugs.py

@@ -84,10 +84,13 @@ def test(a):
 assert test(1) == '1.40'
 
 try:
-    assert test(0) == '0.00'
+    x = test(0)
+    print('x:', x)
     exit(1)
 except UnboundLocalError:
     pass
+except NameError:
+    pass
 
 
 g = 1