Browse Source

`pickle` supports cyclic references

blueloveTH 2 years ago
parent
commit
6874141b62
2 changed files with 91 additions and 59 deletions
  1. 79 54
      python/pickle.py
  2. 12 5
      tests/81_pickle.py

+ 79 - 54
python/pickle.py

@@ -21,7 +21,8 @@ def _find__new__(cls):
     assert False
 
 class _Pickler:
-    def __init__(self) -> None:
+    def __init__(self, obj) -> None:
+        self.obj = obj
         self.raw_memo = {}  # id -> int
         self.memo = []      # int -> object
 
@@ -31,81 +32,106 @@ class _Pickler:
         
         index = self.raw_memo.get(id(o), None)
         if index is not None:
-            return ["$", index]
+            return [index]
         
         ret = []
         index = len(self.memo)
         self.memo.append(ret)
         self.raw_memo[id(o)] = index
 
-        if type(o) is list:
-            ret.append("list")
-            ret.append([self.wrap(i) for i in o])
-            return ["$", index]
-
         if type(o) is tuple:
             ret.append("tuple")
             ret.append([self.wrap(i) for i in o])
-            return ["$", index]
-        
-        if type(o) is dict:
-            ret.append("dict")
-            ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
-            return ["$", index]
-        
+            return [index]
         if type(o) is bytes:
             ret.append("bytes")
             ret.append([o[j] for j in range(len(o))])
-            return ["$", index]
+            return [index]
+        if type(o) is list:
+            ret.append("list")
+            ret.append([self.wrap(i) for i in o])
+            return [index]
+        if type(o) is dict:
+            ret.append("dict")
+            ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
+            return [index]
         
         _0 = o.__class__.__name__
+
         if hasattr(o, "__getnewargs__"):
             _1 = o.__getnewargs__()     # an iterable
             _1 = [self.wrap(i) for i in _1]
         else:
             _1 = None
-        if hasattr(o, "__getstate__"):
-            _2 = o.__getstate__()
+
+        if o.__dict__ is None:
+            _2 = None
         else:
-            if o.__dict__ is None:
-                _2 = None
-            else:
-                _2 = {}
-                for k,v in o.__dict__.items():
-                    _2[k] = self.wrap(v)
+            _2 = {}
+            for k,v in o.__dict__.items():
+                _2[k] = self.wrap(v)
+
         ret.append(_0)
         ret.append(_1)
         ret.append(_2)
-        return ["$", index]
+        return [index]
+    
+    def run_pipe(self):
+        o = self.wrap(self.obj)
+        return [o, self.memo]
+
+
 
 class _Unpickler:
-    def __init__(self, memo: list) -> None:
+    def __init__(self, obj, memo: list) -> None:
+        self.obj = obj
         self.memo = memo
-        self.unwrapped = [None] * len(memo)
+        self._unwrapped = [None] * len(memo)
 
-    def unwrap_ref(self, i: int):
-        if self.unwrapped[i] is None:
-            o = self.memo[i]
-            assert type(o) is list
-            assert o[0] != '$'
-            self.unwrapped[i] = self.unwrap(o)
-        return self.unwrapped[i]
+    def tag(self, index, o):
+        assert self._unwrapped[index] is None
+        self._unwrapped[index] = o
 
-    def unwrap(self, o):
+    def unwrap(self, o, index=None):
         if type(o) in _BASIC_TYPES:
             return o
         assert type(o) is list
-        if o[0] == '$':
-            index = o[1]
-            return self.unwrap_ref(index)
-        if o[0] == "list":
-            return [self.unwrap(i) for i in o[1]]
+
+        # reference
+        if type(o[0]) is int:
+            assert index is None    # index should be None
+            index = o[0]
+            if self._unwrapped[index] is None:
+                o = self.memo[index]
+                assert type(o) is list
+                assert type(o[0]) is str
+                self.unwrap(o, index)
+                assert self._unwrapped[index] is not None
+            return self._unwrapped[index]
+        
+        # concrete reference type
         if o[0] == "tuple":
-            return tuple([self.unwrap(i) for i in o[1]])
-        if o[0] == "dict":
-            return {self.unwrap(k): self.unwrap(v) for k,v in o[1]}
+            ret = tuple([self.unwrap(i) for i in o[1]])
+            self.tag(index, ret)
+            return ret
         if o[0] == "bytes":
-            return bytes(o[1])
+            ret = bytes(o[1])
+            self.tag(index, ret)
+            return ret
+
+        if o[0] == "list":
+            ret = []
+            self.tag(index, ret)
+            for i in o[1]:
+                ret.append(self.unwrap(i))
+            return ret
+        if o[0] == "dict":
+            ret = {}
+            self.tag(index, ret)
+            for k,v in o[1]:
+                ret[self.unwrap(k)] = self.unwrap(v)
+            return ret
+        
         # generic object
         cls, newargs, state = o
         cls = _find_class(o[0])
@@ -116,23 +142,22 @@ class _Unpickler:
             inst = new_f(cls, *newargs)
         else:
             inst = new_f(cls)
+        self.tag(index, inst)
         # restore state
-        if hasattr(inst, "__setstate__"):
-            inst.__setstate__(state)
-        else:
-            if state is not None:
-                for k,v in state.items():
-                    setattr(inst, k, self.unwrap(v))
+        if state is not None:
+            for k,v in state.items():
+                setattr(inst, k, self.unwrap(v))
         return inst
 
+    def run_pipe(self):
+        return self.unwrap(self.obj)
+
+
 def _wrap(o):
-    p = _Pickler()
-    o = p.wrap(o)
-    return [o, p.memo]
+    return _Pickler(o).run_pipe()
 
 def _unwrap(packed: list):
-    o, memo = packed
-    return _Unpickler(memo).unwrap(o)
+    return _Unpickler(*packed).run_pipe()
 
 def dumps(o) -> bytes:
     o = _wrap(o)

+ 12 - 5
tests/81_pickle.py

@@ -4,9 +4,12 @@ def test(x):
     ok = x == loads(dumps(x))
     if not ok:
         _0 = _wrap(x)
-        _1 = _unwrap(0)
+        _1 = _unwrap(_0)
+        print('='*50)
         print(_0)
+        print('-'*50)
         print(_1)
+        print('='*50)
         assert False
 
 test(1)
@@ -34,9 +37,7 @@ class Foo:
         return f"Foo({self.x}, {self.y})"
     
 test(Foo(1, 2))
-
-a = [1,2]
-test(Foo([1, 2], a))
+test(Foo([1, True], 'c'))
 
 from linalg import vec2
 
@@ -53,4 +54,10 @@ d = {'k': a, 'j': a}
 c = loads(dumps(d))
 
 assert c['k'] is c['j']
-assert c == d
+assert c == d
+
+# test circular references
+from collections import deque
+
+a = deque([1, 2, 3])
+test(a)