Explorar o código

fix a bug of dict

blueloveTH hai 1 ano
pai
achega
6eb785144e
Modificáronse 3 ficheiros con 136 adicións e 55 borrados
  1. 36 0
      scripts/gen_primes.py
  2. 19 10
      src/modules/linalg.c
  3. 81 45
      src/public/py_dict.c

+ 36 - 0
scripts/gen_primes.py

@@ -0,0 +1,36 @@
+import numba
+from typing import List
+
+@numba.jit(nopython=True)
+def sieve_of_eratosthenes(n: int) -> List[int]:
+    assert n >= 2
+    is_prime = [True] * (n + 1)
+    is_prime[0] = is_prime[1] = False  # 0 和 1 不是素数
+    
+    for start in range(2, int(n**0.5) + 1):
+        if is_prime[start]:
+            for multiple in range(start*start, n + 1, start):
+                is_prime[multiple] = False
+    
+    primes = [num for num, prime in enumerate(is_prime) if prime]
+    return primes
+
+all_primes = sieve_of_eratosthenes(2**31)
+print(len(all_primes), all_primes[:10], all_primes[-10:])
+
+index = 3
+caps = [all_primes[index]]
+
+while True:
+    for i in range(index+1, len(all_primes)):
+        last_cap = caps[-1]
+        min_cap = last_cap * 2
+        if all_primes[i] >= min_cap:
+            caps.append(all_primes[i])
+            index = i
+            break
+    else:
+        break
+
+print('-'*20)
+print(caps)

+ 19 - 10
src/modules/linalg.c

@@ -298,21 +298,30 @@ DEF_VECTOR_OPS(3)
             sum += a.data[i] * b.data[i];                                                          \
         py_newint(py_retval(), sum);                                                               \
         return true;                                                                               \
-    }                                                                                              \
-    static bool vec##D##i##__hash__(int argc, py_Ref argv) {                                       \
-        PY_CHECK_ARGC(1);                                                                          \
-        const uint32_t C = 2654435761;                                                             \
-        c11_vec##D##i v = py_tovec##D##i(argv);                                                    \
-        uint64_t hash = 0;                                                                         \
-        for(int i = 0; i < D; i++)                                                                 \
-            hash = hash * 31 + (uint32_t)v.data[i] * C;                                            \
-        py_newint(py_retval(), (py_i64)hash);                                                      \
-        return true;                                                                               \
     }
 
 DEF_VECTOR_INT_OPS(2)
 DEF_VECTOR_INT_OPS(3)
 
+static bool vec2i__hash__(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    c11_vec2i v = py_tovec2i(argv);
+    uint64_t hash = ((uint64_t)v.x << 32) | (uint64_t)v.y;
+    py_newint(py_retval(), (py_i64)hash);
+    return true;
+}
+
+static bool vec3i__hash__(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    c11_vec3i v = py_tovec3i(argv);
+    uint64_t x_part = (uint64_t)(v.x & 0xFFFFFF);
+    uint64_t y_part = (uint64_t)(v.y & 0xFFFFFF);
+    uint64_t z_part = (uint64_t)(v.z & 0xFFFF);
+    uint64_t hash = (x_part << 40) | (y_part << 16) | z_part;
+    py_newint(py_retval(), (py_i64)hash);
+    return true;
+}
+
 static bool vec2__repr__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     char buf[64];

+ 81 - 45
src/public/py_dict.c

@@ -5,10 +5,43 @@
 #include "pocketpy/objects/object.h"
 #include "pocketpy/interpreter/vm.h"
 
-#define PK_DICT_MAX_COLLISION 4
+#define PK_DICT_MAX_COLLISION 3
+
+static uint32_t Dict__next_cap(uint32_t cap) {
+    switch(cap) {
+        case 7: return 17;
+        case 17: return 37;
+        case 37: return 79;
+        case 79: return 163;
+        case 163: return 331;
+        case 331: return 673;
+        case 673: return 1361;
+        case 1361: return 2729;
+        case 2729: return 5471;
+        case 5471: return 10949;
+        case 10949: return 21911;
+        case 21911: return 43853;
+        case 43853: return 87719;
+        case 87719: return 175447;
+        case 175447: return 350899;
+        case 350899: return 701819;
+        case 701819: return 1403641;
+        case 1403641: return 2807303;
+        case 2807303: return 5614657;
+        case 5614657: return 11229331;
+        case 11229331: return 22458671;
+        case 22458671: return 44917381;
+        case 44917381: return 89834777;
+        case 89834777: return 179669557;
+        case 179669557: return 359339171;
+        case 359339171: return 718678369;
+        case 718678369: return 1437356741;
+        default: c11__unreachedable();
+    }
+}
 
 typedef struct {
-    py_i64 hash;
+    uint64_t hash;
     py_TValue key;
     py_TValue val;
 } DictEntry;
@@ -19,7 +52,7 @@ typedef struct {
 
 typedef struct {
     int length;
-    int capacity;
+    uint32_t capacity;
     DictIndex* indices;
     c11_vector /*T=DictEntry*/ entries;
 } Dict;
@@ -29,13 +62,13 @@ typedef struct {
     DictEntry* end;
 } DictIterator;
 
-static void Dict__ctor(Dict* self, int capacity) {
+static void Dict__ctor(Dict* self, uint32_t capacity, int entries_capacity) {
     self->length = 0;
     self->capacity = capacity;
     self->indices = malloc(self->capacity * sizeof(DictIndex));
     memset(self->indices, -1, self->capacity * sizeof(DictIndex));
     c11_vector__ctor(&self->entries, sizeof(DictEntry));
-    c11_vector__reserve(&self->entries, capacity);
+    c11_vector__reserve(&self->entries, entries_capacity);
 }
 
 static void Dict__dtor(Dict* self) {
@@ -48,7 +81,7 @@ static void Dict__dtor(Dict* self) {
 static bool Dict__try_get(Dict* self, py_TValue* key, DictEntry** out) {
     py_i64 hash;
     if(!py_hash(key, &hash)) return false;
-    int idx = hash & (self->capacity - 1);
+    int idx = (uint64_t)hash % self->capacity;
     for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) {
         int idx2 = self->indices[idx]._[i];
         if(idx2 == -1) continue;
@@ -72,38 +105,37 @@ static void Dict__clear(Dict* self) {
 
 static void Dict__rehash_2x(Dict* self) {
     Dict old_dict = *self;
-
-    int new_capacity = self->capacity * 2;
-
-    do {
-        Dict__ctor(self, new_capacity);
-
-        for(int i = 0; i < old_dict.entries.length; i++) {
-            DictEntry* entry = c11__at(DictEntry, &old_dict.entries, i);
-            if(py_isnil(&entry->key)) continue;
-            int idx = entry->hash & (new_capacity - 1);
-            bool success = false;
-            for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) {
-                int idx2 = self->indices[idx]._[i];
-                if(idx2 == -1) {
-                    // insert new entry (empty slot)
-                    c11_vector__push(DictEntry, &self->entries, *entry);
-                    self->indices[idx]._[i] = self->entries.length - 1;
-                    self->length++;
-                    success = true;
-                    break;
-                }
-            }
-            if(!success) {
-                Dict__dtor(self);
-                new_capacity *= 2;
-                continue;
+    uint32_t new_capacity = self->capacity;
+
+__RETRY:
+    // use next capacity
+    new_capacity = Dict__next_cap(new_capacity);
+    // create a new dict with new capacity
+    Dict__ctor(self, new_capacity, old_dict.entries.capacity);
+    // move entries from old dict to new dict
+    for(int i = 0; i < old_dict.entries.length; i++) {
+        DictEntry* entry = c11__at(DictEntry, &old_dict.entries, i);
+        if(py_isnil(&entry->key)) continue;
+        int idx = entry->hash % new_capacity;
+        bool success = false;
+        for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) {
+            int idx2 = self->indices[idx]._[i];
+            if(idx2 == -1) {
+                // insert new entry (empty slot)
+                c11_vector__push(DictEntry, &self->entries, *entry);
+                self->indices[idx]._[i] = self->entries.length - 1;
+                self->length++;
+                success = true;
+                break;
             }
         }
-        // resize complete
-        Dict__dtor(&old_dict);
-        return;
-    } while(1);
+        if(!success) {
+            Dict__dtor(self);
+            goto __RETRY;
+        }
+    }
+    // done
+    Dict__dtor(&old_dict);
 }
 
 static void Dict__compact_entries(Dict* self) {
@@ -135,13 +167,13 @@ static void Dict__compact_entries(Dict* self) {
 static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
     py_i64 hash;
     if(!py_hash(key, &hash)) return false;
-    int idx = hash & (self->capacity - 1);
+    int idx = (uint64_t)hash % self->capacity;
     for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) {
         int idx2 = self->indices[idx]._[i];
         if(idx2 == -1) {
             // insert new entry
             DictEntry* new_entry = c11_vector__emplace(&self->entries);
-            new_entry->hash = hash;
+            new_entry->hash = (uint64_t)hash;
             new_entry->key = *key;
             new_entry->val = *val;
             self->indices[idx]._[i] = self->entries.length - 1;
@@ -159,7 +191,11 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
     }
     // no empty slot found
     if(self->capacity >= self->entries.length * 10) {
-        return RuntimeError("dict has too much collision: %d/%d", self->entries.length, self->capacity);
+        // raise error if we reach the minimum load factor (0.1)
+        return RuntimeError("dict has too much collision: %d/%d/%d",
+                            self->entries.length,
+                            self->entries.capacity,
+                            self->capacity);
     }
     Dict__rehash_2x(self);
     return Dict__set(self, key, val);
@@ -170,7 +206,7 @@ static bool Dict__set(Dict* self, py_TValue* key, py_TValue* val) {
 static int Dict__pop(Dict* self, py_Ref key) {
     py_i64 hash;
     if(!py_hash(key, &hash)) return -1;
-    int idx = hash & (self->capacity - 1);
+    int idx = (uint64_t)hash % self->capacity;
     for(int i = 0; i < PK_DICT_MAX_COLLISION; i++) {
         int idx2 = self->indices[idx]._[i];
         if(idx2 == -1) continue;
@@ -208,13 +244,13 @@ static bool dict__new__(int argc, py_Ref argv) {
     py_Type cls = py_totype(argv);
     int slots = cls == tp_dict ? 0 : -1;
     Dict* ud = py_newobject(py_retval(), cls, slots, sizeof(Dict));
-    Dict__ctor(ud, 8);
+    Dict__ctor(ud, 7, 8);
     return true;
 }
 
 void py_newdict(py_Ref out) {
     Dict* ud = py_newobject(out, tp_dict, 0, sizeof(Dict));
-    Dict__ctor(ud, 8);
+    Dict__ctor(ud, 7, 8);
 }
 
 static bool dict__init__(int argc, py_Ref argv) {
@@ -535,7 +571,7 @@ int py_dict_delitem(py_Ref self, py_Ref key) {
     return Dict__pop(ud, key);
 }
 
-int py_dict_getitem_by_str(py_Ref self, const char *key){
+int py_dict_getitem_by_str(py_Ref self, const char* key) {
     py_Ref tmp = py_pushtmp();
     py_newstr(tmp, key);
     int res = py_dict_getitem(self, tmp);
@@ -543,7 +579,7 @@ int py_dict_getitem_by_str(py_Ref self, const char *key){
     return res;
 }
 
-bool py_dict_setitem_by_str(py_Ref self, const char *key, py_Ref val){
+bool py_dict_setitem_by_str(py_Ref self, const char* key, py_Ref val) {
     py_Ref tmp = py_pushtmp();
     py_newstr(tmp, key);
     bool res = py_dict_setitem(self, tmp, val);
@@ -551,7 +587,7 @@ bool py_dict_setitem_by_str(py_Ref self, const char *key, py_Ref val){
     return res;
 }
 
-int py_dict_delitem_by_str(py_Ref self, const char *key){
+int py_dict_delitem_by_str(py_Ref self, const char* key) {
     py_Ref tmp = py_pushtmp();
     py_newstr(tmp, key);
     int res = py_dict_delitem(self, tmp);