blueloveTH há 1 ano atrás
pai
commit
86dc516791
3 ficheiros alterados com 40 adições e 8 exclusões
  1. 23 1
      src/modules/pickle.c
  2. 0 2
      src/public/cast.c
  3. 17 5
      tests/90_pickle.py

+ 23 - 1
src/modules/pickle.c

@@ -372,9 +372,22 @@ static bool pkl__write_object(PickleObject* buf, py_TValue* obj) {
                 return true;
             }
             if(ti->is_python) {
+                NameDict* dict = PyObject__dict(obj->_obj);
+                for(int i = dict->length - 1; i >= 0; i--) {
+                    NameDict_KV* kv = c11__at(NameDict_KV, dict, i);
+                    if(!pkl__write_object(buf, &kv->value)) return false;
+                }
                 pkl__emit_op(buf, PKL_OBJECT);
                 pkl__emit_int(buf, obj->type);
                 buf->used_types[obj->type] = true;
+                pkl__emit_int(buf, dict->length);
+                for(int i = 0; i < dict->length; i++) {
+                    NameDict_KV* kv = c11__at(NameDict_KV, dict, i);
+                    c11_sv field = py_name2sv(kv->key);
+                    // include '\0'
+                    PickleObject__write_bytes(buf, field.data, field.size + 1);
+                }
+
                 // store memo
                 pkl__store_memo(buf, obj->_obj);
                 return true;
@@ -662,7 +675,16 @@ bool py_pickle_loads_body(const unsigned char* p, int memo_length, c11_smallmap_
             case PKL_OBJECT: {
                 py_Type type = (py_Type)pkl__read_int(&p);
                 type = pkl__fix_type(type, type_mapping);
-                if(!py_tpcall(type, 0, NULL)) return false;
+                py_newobject(py_retval(), type, -1, 0);
+                NameDict* dict = PyObject__dict(py_retval()->_obj);
+                int dict_length = pkl__read_int(&p);
+                for(int i = 0; i < dict_length; i++) {
+                    py_StackRef value = py_peek(-1);
+                    c11_sv field = {(const char*)p, strlen((const char*)p)}; 
+                    NameDict__set(dict, py_namev(field), *value);
+                    py_pop();
+                    p += field.size + 1;
+                }
                 py_push(py_retval());
                 break;
             }

+ 0 - 2
src/public/cast.c

@@ -1,8 +1,6 @@
-#include "pocketpy/common/str.h"
 #include "pocketpy/objects/base.h"
 #include "pocketpy/pocketpy.h"
 
-#include "pocketpy/common/utils.h"
 #include "pocketpy/objects/object.h"
 #include "pocketpy/interpreter/vm.h"
 

+ 17 - 5
tests/90_pickle.py

@@ -122,12 +122,24 @@ class A:
 test([A(1)]*10)
 
 class Simple:
-    def __init__(self): pass
-    def __eq__(self, other): return True
-    def __ne__(self, other): return False
+    def __init__(self, x):
+        self.field1 = x
+        self.field2 = [...]
+    def __eq__(self, other): return self.field1 == other.field1
+    def __ne__(self, other): return self.field1 != other.field1
 
-test(Simple())
-test([Simple()]*10)
+test(Simple(1))
+test([Simple(2)]*10)
+
+from dataclasses import dataclass
+
+@dataclass
+class Data:
+    a: int
+    b: str = '2'
+    c: float = 3.0
+
+test(Data(1))
 
 exit()