Răsfoiți Sursa

raise error on mismatched eq/ne

blueloveTH 1 an în urmă
părinte
comite
5ca7abed5c

+ 8 - 1
src/interpreter/ceval.c

@@ -954,7 +954,6 @@ FrameResult VM__run_top_frame(VM* self) {
             }
             case OP_END_CLASS: {
                 // [cls or decorated]
-                // TODO: if __eq__ is defined, check __ne__ and provide a default implementation
                 py_Name name = byte.arg;
                 // set into f_globals
                 py_setdict(frame->module, name, TOP());
@@ -966,7 +965,15 @@ FrameResult VM__run_top_frame(VM* self) {
                         py_TypeInfo* base_ti = ti->base_ti;
                         if(base_ti->on_end_subclass) base_ti->on_end_subclass(ti);
                     }
+                    if(!py_isnil(&ti->magic[__eq__])) {
+                        if(py_isnil(&ti->magic[__ne__])) {
+                            TypeError("'%n' implements '__eq__' but not '__ne__'", ti->name);
+                            goto __ERROR;
+                        }
+                    }
                 }
+                // class with decorator is unsafe currently
+                // it skips the above check
                 POP();
                 self->__curr_class = NULL;
                 DISPATCH();

+ 5 - 6
src/public/py_dict.c

@@ -129,15 +129,15 @@ __RETRY:
     Dict__ctor(self, new_capacity, old_dict.entries.capacity);
     // move entries from old dict to new dict
     for(int i = 0; i < old_dict.entries.length; i++) {
-        DictEntry* entry = c11__at(DictEntry, &old_dict.entries, i);
-        if(py_isnil(&entry->key)) continue;
-        int idx = entry->hash % new_capacity;
+        DictEntry* old_entry = c11__at(DictEntry, &old_dict.entries, i);
+        if(py_isnil(&old_entry->key)) continue;
+        int idx = old_entry->hash % new_capacity;
         bool success = false;
         for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) {
             int idx2 = self->indices[idx]._[i];
             if(idx2 == -1) {
                 // insert new entry (empty slot)
-                c11_vector__push(DictEntry, &self->entries, *entry);
+                c11_vector__push(DictEntry, &self->entries, *old_entry);
                 self->indices[idx]._[i] = self->entries.length - 1;
                 self->length++;
                 success = true;
@@ -210,8 +210,7 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
     }
     // no empty slot found
     if(self->capacity >= (uint32_t)self->entries.length * 10) {
-        // raise error if we reach the minimum load factor (10%)
-        return RuntimeError("dict has too much collision: %d/%d/%d",
+        return RuntimeError("dict: %d/%d/%d: minimum load factor reached",
                             self->entries.length,
                             self->entries.capacity,
                             self->capacity);

+ 2 - 0
tests/72_collections.py

@@ -209,6 +209,8 @@ assertEqual((n+1) not in d, True)
 class BadCmp:
     def __eq__(self, other):
         raise RuntimeError
+    def __ne__(self, other):
+        raise RuntimeError
 
 
 # # Test detection of comparison exceptions

+ 1 - 0
tests/77_builtin_func.py

@@ -382,6 +382,7 @@ a = hash(object())  # object is hashable
 a = hash(A())       # A is hashable
 class B:
     def __eq__(self, o): return True
+    def __ne__(self, o): return False
 
 try:
     hash(B())

+ 10 - 0
tests/99_extras.py

@@ -103,3 +103,13 @@ class Context:
 for _ in range(5):
     with Context() as x:
         assert x == 1
+
+# bad dict hash
+class A:
+    def __eq__(self, o): return False
+    def __ne__(self, o): return True
+    def __hash__(self): return 1
+
+bad_dict = {A(): 1, A(): 2, A(): 3, A(): 4}
+assert len(bad_dict) == 4
+bad_dict[A()] = 5   # error