Преглед изворни кода

allow `__eq__` returns non-bool

blueloveTH пре 1 година
родитељ
комит
29a989f09a
5 измењених фајлова са 35 додато и 37 уклоњено
  1. 9 6
      include/pocketpy/pocketpy.h
  2. 1 1
      src/public/py_array.c
  3. 4 4
      src/public/py_dict.c
  4. 6 6
      src/public/py_list.c
  5. 15 20
      src/public/py_ops.c

+ 9 - 6
include/pocketpy/pocketpy.h

@@ -288,12 +288,15 @@ bool KeyError(py_Ref key);
 /// Returns -1 if an error occurred.
 int py_bool(const py_Ref val);
 
-int py_eq(const py_Ref, const py_Ref);
-int py_ne(const py_Ref, const py_Ref);
-int py_le(const py_Ref, const py_Ref);
-int py_lt(const py_Ref, const py_Ref);
-int py_ge(const py_Ref, const py_Ref);
-int py_gt(const py_Ref, const py_Ref);
+#define py_eq(lhs, rhs) py_binaryop(lhs, rhs, __eq__, __eq__)
+#define py_ne(lhs, rhs) py_binaryop(lhs, rhs, __ne__, __ne__)
+#define py_lt(lhs, rhs) py_binaryop(lhs, rhs, __lt__, __gt__)
+#define py_le(lhs, rhs) py_binaryop(lhs, rhs, __le__, __ge__)
+#define py_gt(lhs, rhs) py_binaryop(lhs, rhs, __gt__, __lt__)
+#define py_ge(lhs, rhs) py_binaryop(lhs, rhs, __ge__, __le__)
+
+int py_equal(const py_Ref lhs, const py_Ref rhs);
+int py_less(const py_Ref lhs, const py_Ref rhs);
 
 bool py_hash(const py_Ref, py_i64* out);
 

+ 1 - 1
src/public/py_array.c

@@ -26,7 +26,7 @@ py_TValue* pk_arrayview(py_Ref self, int* length) {
 int pk_arrayeq(py_TValue* lhs, int lhs_length, py_TValue* rhs, int rhs_length) {
     if(lhs_length != rhs_length) return false;
     for(int i = 0; i < lhs_length; i++) {
-        int res = py_eq(lhs + i, rhs + i);
+        int res = py_equal(lhs + i, rhs + i);
         if(res == -1) return -1;
         if(!res) return false;
     }

+ 4 - 4
src/public/py_dict.c

@@ -52,7 +52,7 @@ static bool Dict__try_get(Dict* self, py_TValue* key, DictEntry** out) {
         int idx2 = self->indices[idx]._[i];
         if(idx2 == -1) continue;
         DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
-        int res = py_eq(&entry->key, key);
+        int res = py_equal(&entry->key, key);
         if(res == 1) {
             *out = entry;
             return true;
@@ -150,7 +150,7 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
         }
         // update existing entry
         DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
-        int res = py_eq(&entry->key, key);
+        int res = py_equal(&entry->key, key);
         if(res == 1) {
             entry->val = *val;
             return true;
@@ -174,7 +174,7 @@ static bool Dict__pop(Dict* self, py_Ref key) {
         int idx2 = self->indices[idx]._[i];
         if(idx2 == -1) continue;
         DictEntry* entry = c11__at(DictEntry, &self->entries, idx2);
-        int res = py_eq(&entry->key, key);
+        int res = py_equal(&entry->key, key);
         if(res == 1) {
             *py_retval() = entry->val;
             py_newnil(&entry->key);
@@ -318,7 +318,7 @@ static bool _py_dict__eq__(int argc, py_Ref argv) {
             py_newbool(py_retval(), false);
             return true;
         }
-        int res = py_eq(&entry->val, &other_entry->val);
+        int res = py_equal(&entry->val, &other_entry->val);
         if(res == -1) return false;
         if(!res) {
             py_newbool(py_retval(), false);

+ 6 - 6
src/public/py_list.c

@@ -258,7 +258,7 @@ static bool _py_list__count(int argc, py_Ref argv) {
     PY_CHECK_ARGC(2);
     int count = 0;
     for(int i = 0; i < py_list__len(py_arg(0)); i++) {
-        int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1));
+        int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
         if(res == -1) return false;
         if(res) count++;
     }
@@ -290,7 +290,7 @@ static bool _py_list__index(int argc, py_Ref argv) {
         start = py_toint(py_arg(2));
     }
     for(int i = start; i < py_list__len(py_arg(0)); i++) {
-        int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1));
+        int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
         if(res == -1) return false;
         if(res) {
             py_newint(py_retval(), i);
@@ -311,7 +311,7 @@ static bool _py_list__reverse(int argc, py_Ref argv) {
 static bool _py_list__remove(int argc, py_Ref argv) {
     PY_CHECK_ARGC(2);
     for(int i = 0; i < py_list__len(py_arg(0)); i++) {
-        int res = py_eq(py_list__getitem(py_arg(0), i), py_arg(1));
+        int res = py_equal(py_list__getitem(py_arg(0), i), py_arg(1));
         if(res == -1) return false;
         if(res) {
             py_list__delitem(py_arg(0), i);
@@ -354,7 +354,7 @@ static bool _py_list__insert(int argc, py_Ref argv) {
 }
 
 static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) {
-    if(!key) return py_lt(a, b);
+    if(!key) return py_less(a, b);
     pk_VM* vm = pk_current_vm;
     // project a
     py_push(key);
@@ -372,7 +372,7 @@ static int _py_lt_with_key(py_TValue* a, py_TValue* b, py_TValue* key) {
     bool ok = pk_stack_binaryop(vm, __lt__, __gt__);
     if(!ok) return -1;
     py_shrink(2);
-    return py_tobool(py_retval());
+    return py_bool(py_retval());
 }
 
 // sort(self, key=None, reverse=False)
@@ -436,7 +436,7 @@ py_Type pk_list__register() {
     return type;
 }
 
-void pk_list__mark(void* ud, void (*marker)(py_TValue*)){
+void pk_list__mark(void* ud, void (*marker)(py_TValue*)) {
     List* self = ud;
     for(int i = 0; i < self->count; i++) {
         marker(c11__at(py_TValue, self, i));

+ 15 - 20
src/public/py_ops.c

@@ -26,14 +26,14 @@ int py_bool(const py_Ref val) {
         default: {
             py_Ref tmp = py_tpfindmagic(val->type, __bool__);
             if(tmp) {
-                bool ok = py_call(tmp, 1, val);
-                if(!ok) return -1;
+                if(!py_call(tmp, 1, val)) return -1;
+                if(!py_checkbool(py_retval())) return -1;
                 return py_tobool(py_retval());
             } else {
                 tmp = py_tpfindmagic(val->type, __len__);
                 if(tmp) {
-                    bool ok = py_call(tmp, 1, val);
-                    if(!ok) return -1;
+                    if(!py_call(tmp, 1, val)) return -1;
+                    if(!py_checkint(py_retval())) return -1;
                     return py_toint(py_retval());
                 } else {
                     return 1;  // True
@@ -51,8 +51,8 @@ bool py_hash(const py_Ref val, int64_t* out) {
         if(py_isnone(_hash)) break;
         py_Ref _eq = &types[t].magic[__eq__];
         if(!py_isnil(_hash) && !py_isnil(_eq)) {
-            bool ok = py_call(_hash, 1, val);
-            if(!ok) return false;
+            if(!py_call(_hash, 1, val)) return false;
+            if(!py_checkint(py_retval())) return false;
             *out = py_toint(py_retval());
             return true;
         }
@@ -72,8 +72,7 @@ int py_next(const py_Ref val) {
     vm->is_stopiteration = false;
     py_Ref tmp = py_tpfindmagic(val->type, __next__);
     if(!tmp) return TypeError("'%t' object is not an iterator", val->type);
-    bool ok = py_call(tmp, 1, val);
-    if(ok) return true;
+    if(py_call(tmp, 1, val)) return true;
     return vm->is_stopiteration ? 0 : -1;
 }
 
@@ -201,16 +200,12 @@ bool py_delitem(py_Ref self, const py_Ref key) {
     return ok;
 }
 
-#define COMPARE_OP_IMPL(name, op, rop)                                                             \
-    int py_##name(const py_Ref lhs, const py_Ref rhs) {                                            \
-        bool ok = py_binaryop(lhs, rhs, op, rop);                                                  \
-        if(!ok) return -1;                                                                         \
-        return py_tobool(py_retval());                                                             \
-    }
+int py_equal(const py_Ref lhs, const py_Ref rhs){
+    if(!py_eq(lhs, rhs)) return -1;
+    return py_bool(py_retval());
+}
 
-COMPARE_OP_IMPL(eq, __eq__, __eq__)
-COMPARE_OP_IMPL(ne, __ne__, __ne__)
-COMPARE_OP_IMPL(lt, __lt__, __gt__)
-COMPARE_OP_IMPL(le, __le__, __ge__)
-COMPARE_OP_IMPL(gt, __gt__, __lt__)
-COMPARE_OP_IMPL(ge, __ge__, __le__)
+int py_less(const py_Ref lhs, const py_Ref rhs){
+    if(!py_lt(lhs, rhs)) return -1;
+    return py_bool(py_retval());
+}