浏览代码

more fix...

blueloveTH 2 年之前
父节点
当前提交
da62a1906b
共有 4 个文件被更改,包括 39 次插入20 次删除
  1. 2 1
      include/pocketpy/dict.h
  2. 22 14
      include/pocketpy/namedict.h
  3. 4 4
      src/dict.cpp
  4. 11 1
      src/vm.cpp

+ 2 - 1
include/pocketpy/dict.h

@@ -41,7 +41,8 @@ struct Dict{
 
     int size() const { return _size; }
 
-    void _probe(PyObject* key, bool& ok, int& i) const;
+    void _probe_0(PyObject* key, bool& ok, int& i) const;
+    void _probe_1(PyObject* key, bool& ok, int& i) const;
 
     void set(PyObject* key, PyObject* val);
     void _rehash();

+ 22 - 14
include/pocketpy/namedict.h

@@ -29,11 +29,11 @@ struct NameDictImpl {
     uint16_t _mask;
     Item* _items;
 
-#define HASH_PROBE(key, ok, i)          \
-ok = false;                             \
-i = _hash(key, _mask, _hash_seed);      \
-for(int _j=0; _j<_capacity; _j++) {       \
-    if(!_items[i].first.empty()){       \
+#define HASH_PROBE_0(key, ok, i)            \
+ok = false;                                 \
+i = _hash(key, _mask, _hash_seed);          \
+for(int _j=0; _j<_capacity; _j++) {         \
+    if(!_items[i].first.empty()){           \
         if(_items[i].first == (key)) { ok = true; break; }  \
     }else{                                                  \
         if(_items[i].second == 0) break;                    \
@@ -41,6 +41,14 @@ for(int _j=0; _j<_capacity; _j++) {       \
     i = (i + 1) & _mask;                                    \
 }
 
+#define HASH_PROBE_1(key, ok, i)            \
+ok = false;                                 \
+i = _hash(key, _mask, _hash_seed);          \
+while(!_items[i].first.empty()) {           \
+    if(_items[i].first == (key)) { ok = true; break; }  \
+    i = (i + 1) & _mask;                                \
+}
+
 #define NAMEDICT_ALLOC()                \
     _items = (Item*)pool128_alloc(_capacity * sizeof(Item));    \
     memset(_items, 0, _capacity * sizeof(Item));                \
@@ -73,19 +81,19 @@ for(int _j=0; _j<_capacity; _j++) {       \
 
     T operator[](StrName key) const {
         bool ok; uint16_t i;
-        HASH_PROBE(key, ok, i);
+        HASH_PROBE_0(key, ok, i);
         if(!ok) throw std::out_of_range(fmt("NameDict key not found: ", key));
         return _items[i].second;
     }
 
     void set(StrName key, T val){
         bool ok; uint16_t i;
-        HASH_PROBE(key, ok, i);
+        HASH_PROBE_1(key, ok, i);
         if(!ok) {
             _size++;
             if(_size > _capacity*_load_factor){
                 _rehash(true);
-                HASH_PROBE(key, ok, i);
+                HASH_PROBE_1(key, ok, i);
             }
             _items[i].first = key;
         }
@@ -103,7 +111,7 @@ for(int _j=0; _j<_capacity; _j++) {       \
         for(uint16_t i=0; i<old_capacity; i++){
             if(old_items[i].first.empty()) continue;
             bool ok; uint16_t j;
-            HASH_PROBE(old_items[i].first, ok, j);
+            HASH_PROBE_1(old_items[i].first, ok, j);
             if(ok) FATAL_ERROR();
             _items[j] = old_items[i];
         }
@@ -117,7 +125,7 @@ for(int _j=0; _j<_capacity; _j++) {       \
 
     T try_get(StrName key) const{
         bool ok; uint16_t i;
-        HASH_PROBE(key, ok, i);
+        HASH_PROBE_0(key, ok, i);
         if(!ok){
             if constexpr(std::is_pointer_v<T>) return nullptr;
             else if constexpr(std::is_same_v<int, T>) return -1;
@@ -128,14 +136,14 @@ for(int _j=0; _j<_capacity; _j++) {       \
 
     T* try_get_2(StrName key) {
         bool ok; uint16_t i;
-        HASH_PROBE(key, ok, i);
+        HASH_PROBE_0(key, ok, i);
         if(!ok) return nullptr;
         return &_items[i].second;
     }
 
     bool try_set(StrName key, T val){
         bool ok; uint16_t i;
-        HASH_PROBE(key, ok, i);
+        HASH_PROBE_1(key, ok, i);
         if(!ok) return false;
         _items[i].second = val;
         return true;
@@ -143,7 +151,7 @@ for(int _j=0; _j<_capacity; _j++) {       \
 
     bool contains(StrName key) const {
         bool ok; uint16_t i;
-        HASH_PROBE(key, ok, i);
+        HASH_PROBE_0(key, ok, i);
         return ok;
     }
 
@@ -156,7 +164,7 @@ for(int _j=0; _j<_capacity; _j++) {       \
 
     void erase(StrName key){
         bool ok; uint16_t i;
-        HASH_PROBE(key, ok, i);
+        HASH_PROBE_0(key, ok, i);
         if(!ok) throw std::out_of_range(fmt("NameDict key not found: ", key));
         _items[i].first = StrName();
         // _items[i].second = PY_DELETED_SLOT;      // do not change .second if it is not zero, it means the slot is occupied by a deleted item

+ 4 - 4
src/dict.cpp

@@ -43,7 +43,7 @@ namespace pkpy{
         // do possible rehash
         if(_size+1 > _critical_size) _rehash();
         bool ok; int i;
-        _probe(key, ok, i);
+        _probe_1(key, ok, i);
         if(!ok) {
             _size++;
             _items[i].first = key;
@@ -91,20 +91,20 @@ namespace pkpy{
 
     PyObject* Dict::try_get(PyObject* key) const{
         bool ok; int i;
-        _probe(key, ok, i);
+        _probe_0(key, ok, i);
         if(!ok) return nullptr;
         return _items[i].second;
     }
 
     bool Dict::contains(PyObject* key) const{
         bool ok; int i;
-        _probe(key, ok, i);
+        _probe_0(key, ok, i);
         return ok;
     }
 
     bool Dict::erase(PyObject* key){
         bool ok; int i;
-        _probe(key, ok, i);
+        _probe_0(key, ok, i);
         if(!ok) return false;
         _items[i].first = nullptr;
         // _items[i].second = PY_DELETED_SLOT;  // do not change .second if it is not NULL, it means the slot is occupied by a deleted item

+ 11 - 1
src/vm.cpp

@@ -1031,7 +1031,7 @@ void VM::bind__len__(Type type, i64 (*f)(VM*, PyObject*)){
     PK_OBJ_GET(NativeFunc, nf).set_userdata(f);
 }
 
-void Dict::_probe(PyObject *key, bool &ok, int &i) const{
+void Dict::_probe_0(PyObject *key, bool &ok, int &i) const{
     ok = false;
     i64 hash = vm->py_hash(key);
     i = hash & _mask;
@@ -1048,6 +1048,16 @@ void Dict::_probe(PyObject *key, bool &ok, int &i) const{
     }
 }
 
+void Dict::_probe_1(PyObject *key, bool &ok, int &i) const{
+    ok = false;
+    i = vm->py_hash(key) & _mask;
+    while(_items[i].first != nullptr) {
+        if(vm->py_equals(_items[i].first, key)) { ok = true; break; }
+        // https://github.com/python/cpython/blob/3.8/Objects/dictobject.c#L166
+        i = ((5*i) + 1) & _mask;
+    }
+}
+
 void CodeObjectSerializer::write_object(VM *vm, PyObject *obj){
     if(is_int(obj)) write_int(_CAST(i64, obj));
     else if(is_float(obj)) write_float(_CAST(f64, obj));